Build a CIFAR‑10 Image Classifier with PyTorch – A Java Developer’s Guide

This tutorial walks Java developers through building, training, evaluating, and deploying a CIFAR‑10 image classifier using PyTorch, covering data loading, preprocessing, network definition, loss and optimizer setup, GPU acceleration, model saving, and per‑class accuracy analysis.

JavaEdge
JavaEdge
JavaEdge
Build a CIFAR‑10 Image Classifier with PyTorch – A Java Developer’s Guide

Goal

Understand the classifier task and data format.

Learn how to implement an image classifier with PyTorch.

Classifier Task and Data Introduction

Task Description

Construct a neural‑network classifier that assigns each input image to one of the ten classes in the CIFAR‑10 dataset.

Data Introduction: CIFAR‑10

The dataset contains 10 categories of small colour images (3 × 32 × 32): plane, car, bird, cat, deer, dog, frog, horse, ship, truck.

CIFAR‑10 sample
CIFAR‑10 sample

Steps to Train the Classifier

The following sections detail a complete PyTorch workflow.

Download CIFAR‑10 with torchvision

import torch
import torchvision
import torchvision.transforms as transforms

# Define preprocessing: convert PIL image to tensor and normalize to [-1, 1]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load training and test sets
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Why transforms.ToTensor() ? It converts a PIL image (H, W, C) to a tensor (C, H, W) and scales pixel values from [0, 255] to [0, 1], which is the format required by PyTorch models.

Why transforms.Normalize() ? Normalization shifts each channel to have mean 0 and standard deviation 1 (here using (0.5, 0.5, 0.5) for both mean and std). This speeds up gradient‑descent optimisation and stabilises training.

Visualise a Batch of Images

import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img = img / 2 + 0.5  # un‑normalize to [0, 1]
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# Get one batch
dataiter = iter(trainloader)
images, labels = next(dataiter)
imshow(torchvision.utils.make_grid(images))
print(' '.join(f'{classes[labels[j]]:5}' for j in range(4)))

Define the Convolutional Neural Network

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)   # 3→6 channels, 5×5 kernel
        self.pool = nn.MaxPool2d(2, 2)    # 2×2 pooling
        self.conv2 = nn.Conv2d(6, 16, 5)  # 6→16 channels
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)      # 10 output classes

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

Define Loss Function and Optimiser

import torch.optim as optim

criterion = nn.CrossEntropyLoss()               # suitable for multi‑class classification
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

Train the Model

for epoch in range(2):  # iterate over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()               # zero the parameter gradients
        outputs = net(inputs)                # forward pass
        loss = criterion(outputs, labels)   # compute loss
        loss.backward()                      # backward pass
        optimizer.step()                    # optimise
        running_loss += loss.item()
        if (i + 1) % 2000 == 0:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0
print('Finished Training')

Save and Load the Model

# Save the trained parameters
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

# Load the parameters later
net.load_state_dict(torch.load(PATH))

Test the Model on the Test Set

# Visualise a few test images
dataiter = iter(testloader)
images, labels = next(dataiter)
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5}' for j in range(4)))

# Predict with the saved model
net = Net()
net.load_state_dict(torch.load(PATH))
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5}' for j in range(4)))

Compute Overall Accuracy

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

Per‑Class Accuracy

class_correct = [0. for _ in range(10)]
class_total = [0. for _ in range(10)]
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1
for i in range(10):
    print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

Training on GPU

Accelerate training by moving the model and tensors to a CUDA device when available.

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)  # prints "cuda:0" or "cpu"

# Transfer model
net.to(device)
# Inside the training loop, move inputs and labels
inputs, labels = data[0].to(device), data[1].to(device)

Summary

Task : Classify CIFAR‑10 images into 10 categories using a CNN.

Data : 10 classes, each image is 3 × 32 × 32 colour.

Workflow :

Download and preprocess the dataset with torchvision.

Define a simple CNN (two conv layers, pooling, three fully‑connected layers).

Set up cross‑entropy loss and SGD optimiser.

Train for several epochs, printing loss periodically.

Save the model state, then load it for inference.

Evaluate overall accuracy and per‑class accuracy on the test set.

Optionally move the model to GPU for faster training.

image classificationdeep learningGPUPyTorchCIFAR-10
JavaEdge
Written by

JavaEdge

First‑line development experience at multiple leading tech firms; now a software architect at a Shanghai state‑owned enterprise and founder of Programming Yanxuan. Nearly 300k followers online; expertise in distributed system design, AIGC application development, and quantitative finance investing.

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.