Artificial Intelligence 34 min read

Build a Handwritten Digit Recognizer in Java with TensorFlow

This article walks through the complete process of creating, training, evaluating, saving, and loading a MNIST handwritten digit recognition model using TensorFlow in Java, comparing it with the equivalent Python implementation and covering required knowledge, environment setup, and code details.

JD Cloud Developers
JD Cloud Developers
JD Cloud Developers
Build a Handwritten Digit Recognizer in Java with TensorFlow

Introduction: Teaching people to fish rather than giving them a fish, this guide shows how to implement a simple handwritten digit recognition model using the MNIST dataset in Java with TensorFlow, providing a bridge for Java‑oriented backend developers who find most tutorials in Python.

Goal

Train a model on the MNIST dataset to recognize hand‑written numbers.

Required knowledge

Machine‑learning basics (supervised, unsupervised, reinforcement learning)

Data processing and analysis (cleaning, feature engineering, visualization)

Programming language (Python is common, but this guide focuses on Java)

Mathematics (linear algebra, probability, calculus)

ML algorithms (linear regression, decision trees, neural networks, SVM)

Deep‑learning frameworks (TensorFlow, PyTorch)

Model evaluation and optimization (cross‑validation, hyper‑parameter tuning, metrics)

Practical experience through projects and competitions

The Hello‑World example uses the TensorFlow framework.

Main requirements

Understand the MNIST data shape (60000,28,28) and label format

Know the role of activation functions

Understand forward and backward propagation

Train and save the model

Load and use the saved model

Java vs Python code comparison

Python code for loading data (omitted for brevity) and Java code for loading data:

<code>def load_data(dpata_folder):
    files = ["train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"]
    paths = []
    for fname in files:
        paths.append(os.path.join(data_folder, fname))
    # ... (gzip loading and reshaping) ...
    return (train_x, train_y), (test_x, test_y)
(train_x, train_y), (test_x, test_y) = load_data("mnistDataSet/")
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s' % (train_x.shape, train_y.shape, test_x.shape, test_y.shape))
</code>

Java constants for archive paths:

<code>private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz";
private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz";
private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz";
private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz";
</code>

Model construction

Python (Keras) model:

<code>model = tensorflow.keras.Sequential()
model.add(tensorflow.keras.layers.Flatten(input_shape=(28, 28)))
model.add(tensorflow.keras.layers.Dense(128, activation='relu'))
model.add(tensorflow.keras.layers.Dense(10, activation='softmax'))
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
</code>

Java (TensorFlow Java API) model building:

<code>Ops tf = Ops.create(graph);
Placeholder<TFloat32> images = tf.placeholder(TFloat32.class);
Placeholder<TFloat32> labels = tf.placeholder(TFloat32.class);
Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES);
Variable<TFloat32> weights = tf.variable(tf.zeros(tf.constant(weightShape), TFloat32.class));
Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES);
Variable<TFloat32> biases = tf.variable(tf.zeros(tf.constant(biasShape), TFloat32.class));
MatMul<TFloat32> matMul = tf.linalg.matMul(images, weights);
Add<TFloat32> add = tf.math.add(matMul, biases);
Softmax<TFloat32> softmax = tf.nn.softmax(add);
Mean<TFloat32> crossEntropy = tf.math.mean(
    tf.math.neg(tf.reduceSum(tf.math.mul(labels, tf.math.log(softmax)), tf.array(1))),
    tf.array(0));
Optimizer optimizer = new GradientDescent(graph, LEARNING_RATE);
Op minimize = optimizer.minimize(crossEntropy);
Operand<TInt64> predicted = tf.math.argMax(softmax, tf.constant(1));
Operand<TInt64> expected = tf.math.argMax(labels, tf.constant(1));
Operand<TFloat32> accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.class), tf.array(0));
</code>

Training the model

Python:

<code>history = model.fit(train_x, train_y, batch_size=64, epochs=5, validation_split=0.2)
</code>

Java:

<code>for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
    try (TFloat32 batchImages = preprocessImages(trainingBatch.images());
         TFloat32 batchLabels = preprocessLabels(trainingBatch.labels())) {
        session.runner()
            .addTarget(minimize)
            .feed(images.asOutput(), batchImages)
            .feed(labels.asOutput(), batchLabels)
            .run();
    }
}
</code>

Model evaluation

Python:

<code>test_loss, test_acc = model.evaluate(test_x, test_y)
print('Test loss: %.3f' % test_loss)
print('Test accuracy: %.3f' % test_acc)
</code>

Java:

<code>ImageBatch testBatch = dataset.testBatch();
try (TFloat32 testImages = preprocessImages(testBatch.images());
     TFloat32 testLabels = preprocessLabels(testBatch.labels());
     TFloat32 accuracyValue = (TFloat32) session.runner()
         .fetch(accuracy)
         .fetch(predicted)
         .fetch(expected)
         .feed(images.asOutput(), testImages)
         .feed(labels.asOutput(), testLabels)
         .run()
         .get(0)) {
    System.out.println("Accuracy: " + accuracyValue.getFloat());
}
</code>

Saving the model

Python:

<code>save_model(model, 'D:\pythonProject\mnistDemo\number_model', save_format='pb')
</code>

Java:

<code>SavedModelBundle.Exporter exporter = SavedModelBundle.exporter("D:\\ai\\ai-demo").withSession(session);
Signature.Builder builder = Signature.builder();
builder.input("images", images);
builder.input("labels", labels);
builder.output("accuracy", accuracy);
builder.output("expected", expected);
builder.output("predicted", predicted);
Signature signature = builder.build();
SessionFunction sessionFunction = SessionFunction.create(signature, session);
exporter.withFunction(sessionFunction);
exporter.export();
</code>

Loading the model

Python:

<code>load_model = load_model('D:\pythonProject\mnistDemo\number_model')
load_model.summary()
predictValue = load_model.predict(input_data)
</code>

Java:

<code>SavedModelBundle model = SavedModelBundle.load(exportDir, "serve");
printSignature(model);
Result run = model.session().runner()
    .feed("Placeholder:0", testImages)
    .feed("Placeholder_1:0", testLabels)
    .fetch("ArgMax:0")
    .fetch("ArgMax_1:0")
    .fetch("Mean_1:0")
    .run();
// process outputs …
</code>

Full Python code (mnistTrainDemo.py)

<code>import gzip, os.path, tensorflow as tf, matplotlib.pyplot as plt, numpy as np
# load_data, model definition, compile, fit, evaluate, save (see source for details)
</code>

Full Java code

Dependencies (Maven):

<code>&lt;dependency&gt;
  &lt;groupId&gt;org.tensorflow&lt;/groupId&gt;
  &lt;artifactId&gt;tensorflow-core-platform&lt;/artifactId&gt;
  &lt;version&gt;0.6.0-SNAPSHOT&lt;/version&gt;
&lt;/dependency&gt;
&lt;dependency&gt;
  &lt;groupId&gt;org.tensorflow&lt;/groupId&gt;
  &lt;artifactId&gt;tensorflow-framework&lt;/artifactId&gt;
  &lt;version&gt;0.6.0-SNAPSHOT&lt;/version&gt;
&lt;/dependency&gt;
</code>

Key classes:

MnistDataset – reads MNIST gzip archives, provides training/validation/test tensors.

SimpleMnist – builds the graph, trains, evaluates, saves, and loads the model.

<code>package org.example.tensorDemo.datasets.mnist;
public class MnistDataset { /* code as in source (readArchive, getOneValidationImage, etc.) */ }
</code>
<code>package org.example.tensorDemo.dense;
public class SimpleMnist implements Runnable { /* full implementation from source */ }
</code>

Running results

Pending improvements

Add a web service to accept image input and perform binary preprocessing.

Replace the simple linear model with convolutional neural networks for higher accuracy.

Explore deeper network architectures and hyper‑parameter tuning.

javamachine learningdeep learningTensorFlowMNISThandwritten digit recognition
JD Cloud Developers
Written by

JD Cloud Developers

JD Cloud Developers (Developer of JD Technology) is a JD Technology Group platform offering technical sharing and communication for AI, cloud computing, IoT and related developers. It publishes JD product technical information, industry content, and tech event news. Embrace technology and partner with developers to envision the future.

0 followers
Reader feedback

How this landed with the community

login Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.