Build a CIFAR‑10 Image Classifier with PyTorch – A Java Developer’s Guide
This tutorial walks Java developers through building, training, evaluating, and deploying a CIFAR‑10 image classifier using PyTorch, covering data loading, preprocessing, network definition, loss and optimizer setup, GPU acceleration, model saving, and per‑class accuracy analysis.
Goal
Understand the classifier task and data format.
Learn how to implement an image classifier with PyTorch.
Classifier Task and Data Introduction
Task Description
Construct a neural‑network classifier that assigns each input image to one of the ten classes in the CIFAR‑10 dataset.
Data Introduction: CIFAR‑10
The dataset contains 10 categories of small colour images (3 × 32 × 32): plane, car, bird, cat, deer, dog, frog, horse, ship, truck.
Steps to Train the Classifier
The following sections detail a complete PyTorch workflow.
Download CIFAR‑10 with torchvision
import torch
import torchvision
import torchvision.transforms as transforms
# Define preprocessing: convert PIL image to tensor and normalize to [-1, 1]
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load training and test sets
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')Why transforms.ToTensor() ? It converts a PIL image (H, W, C) to a tensor (C, H, W) and scales pixel values from [0, 255] to [0, 1], which is the format required by PyTorch models.
Why transforms.Normalize() ? Normalization shifts each channel to have mean 0 and standard deviation 1 (here using (0.5, 0.5, 0.5) for both mean and std). This speeds up gradient‑descent optimisation and stabilises training.
Visualise a Batch of Images
import matplotlib.pyplot as plt
import numpy as np
def imshow(img):
img = img / 2 + 0.5 # un‑normalize to [0, 1]
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# Get one batch
dataiter = iter(trainloader)
images, labels = next(dataiter)
imshow(torchvision.utils.make_grid(images))
print(' '.join(f'{classes[labels[j]]:5}' for j in range(4)))Define the Convolutional Neural Network
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5) # 3→6 channels, 5×5 kernel
self.pool = nn.MaxPool2d(2, 2) # 2×2 pooling
self.conv2 = nn.Conv2d(6, 16, 5) # 6→16 channels
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) # 10 output classes
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()Define Loss Function and Optimiser
import torch.optim as optim
criterion = nn.CrossEntropyLoss() # suitable for multi‑class classification
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)Train the Model
for epoch in range(2): # iterate over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad() # zero the parameter gradients
outputs = net(inputs) # forward pass
loss = criterion(outputs, labels) # compute loss
loss.backward() # backward pass
optimizer.step() # optimise
running_loss += loss.item()
if (i + 1) % 2000 == 0:
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Finished Training')Save and Load the Model
# Save the trained parameters
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
# Load the parameters later
net.load_state_dict(torch.load(PATH))Test the Model on the Test Set
# Visualise a few test images
dataiter = iter(testloader)
images, labels = next(dataiter)
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5}' for j in range(4)))
# Predict with the saved model
net = Net()
net.load_state_dict(torch.load(PATH))
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5}' for j in range(4)))Compute Overall Accuracy
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))Per‑Class Accuracy
class_correct = [0. for _ in range(10)]
class_total = [0. for _ in range(10)]
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))Training on GPU
Accelerate training by moving the model and tensors to a CUDA device when available.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device) # prints "cuda:0" or "cpu"
# Transfer model
net.to(device)
# Inside the training loop, move inputs and labels
inputs, labels = data[0].to(device), data[1].to(device)Summary
Task : Classify CIFAR‑10 images into 10 categories using a CNN.
Data : 10 classes, each image is 3 × 32 × 32 colour.
Workflow :
Download and preprocess the dataset with torchvision.
Define a simple CNN (two conv layers, pooling, three fully‑connected layers).
Set up cross‑entropy loss and SGD optimiser.
Train for several epochs, printing loss periodically.
Save the model state, then load it for inference.
Evaluate overall accuracy and per‑class accuracy on the test set.
Optionally move the model to GPU for faster training.
JavaEdge
First‑line development experience at multiple leading tech firms; now a software architect at a Shanghai state‑owned enterprise and founder of Programming Yanxuan. Nearly 300k followers online; expertise in distributed system design, AIGC application development, and quantitative finance investing.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
