How Java Developers Can Build Neural Networks with PyTorch: A Step‑by‑Step Guide
This tutorial walks Java developers through the complete workflow of building, training, and evaluating a neural network in PyTorch, covering network definition, data iteration, forward and backward passes, loss calculation, and parameter updates with detailed code examples and Java‑centric analogies.
Learning Goals
Define a network, iterate over data, forward‑propagate inputs, compute loss, back‑propagate gradients, and update parameters using PyTorch.
Define a Neural Network
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 6 * 6, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features
net = Net()
print(net)Key points:
Class definition : subclass nn.Module to create a model.
Constructor ( __init__) defines layers such as nn.Conv2d and nn.Linear.
Forward method specifies data flow through convolutions, ReLU, pooling, and fully‑connected layers.
Parameter access via net.parameters() returns an iterator of trainable tensors.
Testing the Network
# Random input (batch=1, channel=1, 32x32)
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)
# Inspect parameters
params = list(net.parameters())
print(len(params)) # number of parameter groups
print(params[0].size()) # shape of first weight tensor
# Zero gradients and back‑propagate a dummy loss
net.zero_grad()
out.backward(torch.randn(1, 10))Note: nn.Conv2d expects a 4‑D tensor (N, C, H, W). Use input.unsqueeze(0) to add a batch dimension if needed.
Loss Function
# Forward pass
output = net(input)
# Random target matching output shape
target = torch.randn(10).view(1, -1)
criterion = nn.MSELoss()
loss = criterion(output, target)
print(loss)
print(loss.grad_fn) # computation graph nodeThe MSE loss computes the mean of (output - target)^2. Calling loss.backward() triggers autograd.
Backpropagation
# Zero gradients before a new iteration
net.zero_grad()
print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)
# Backward pass
loss.backward()
print('conv1.bias.grad after backward')
print(net.conv1.bias.grad) loss.backward()traverses the computation graph and populates each parameter's .grad attribute.
Parameter Update
Manual SGD
learning_rate = 0.01
for f in net.parameters():
f.data.sub_(f.grad.data * learning_rate)Using a PyTorch Optimizer (recommended)
import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr=0.01)
optimizer.zero_grad()
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()The optimizer encapsulates the update rule weight = weight - lr * gradient, handling gradient clearing and stepping.
Typical Training Loop
for epoch in range(num_epochs):
for data, target in dataloader:
optimizer.zero_grad()
output = net(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()Key Technical Points
Loss function : nn.MSELoss produces a scalar; loss.backward() triggers gradient computation.
Gradient handling : Call net.zero_grad() before each iteration to avoid accumulation.
Parameter update : Use an optimizer such as optim.SGD for efficient weight updates.
Reference repository (technical only): http://www.javaedge.cn/
JavaEdge
First‑line development experience at multiple leading tech firms; now a software architect at a Shanghai state‑owned enterprise and founder of Programming Yanxuan. Nearly 300k followers online; expertise in distributed system design, AIGC application development, and quantitative finance investing.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
