Training MNIST with Burn on wgpu: From PyTorch to Rust Backend
This tutorial demonstrates how to train a MNIST digit‑recognition model using the Rust‑based Burn framework on top of the cross‑platform wgpu API, covering model export from PyTorch to ONNX, code generation, data loading, training loops, and performance comparison across CPU, GPU, and other backends.
When first encountering PyTorch on a modest laptop, the author struggled with GPU requirements for training CNNs, prompting a search for a cross‑platform, driver‑agnostic API framework.
1. wgpu
wgpu is a pure‑Rust, cross‑platform graphics API that implements the WebGPU standard and runs on Vulkan, Metal, D3D12, OpenGL, and WebGL2/WebGPU in browsers. Its broad driver support makes it a solid foundation for GPU‑accelerated workloads, as used by Firefox and Deno.
2. burn
Burn is a Rust deep‑learning framework emphasizing flexibility, computational efficiency, and portability. It serves as the mediation layer for training arbitrary models on the wgpu backend, offering strong device compatibility and easy model import.
3. Code Walkthrough
The example starts with a standard PyTorch MNIST model:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 8, 3)
self.conv2 = nn.Conv2d(8, 16, 3)
self.conv3 = nn.Conv2d(16, 24, 3)
self.norm1 = nn.BatchNorm2d(24)
self.dropout1 = nn.Dropout(0.3)
self.fc1 = nn.Linear(24 * 22 * 22, 32)
self.fc2 = nn.Linear(32, 10)
self.norm2 = nn.BatchNorm1d(10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = self.norm1(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.dropout1(x)
x = self.norm2(self.fc2(x))
return F.log_softmax(x, dim=1)Using torch.save(model.state_dict(), "mnist.pt") , the parameters are saved, then Burn re‑creates an equivalent model in Rust:
#[derive(Module, Debug)]
pub struct Model
{
conv1: Conv2d
,
conv2: Conv2d
,
conv3: Conv2d
,
norm1: BatchNorm
,
fc1: Linear
,
fc2: Linear
,
norm2: BatchNorm
,
phantom: core::marker::PhantomData
,
}The model is exported to ONNX with torch.onnx.export(..., "mnist.onnx") , then converted to Rust source via burn-import :
ModelGen::new()
.input("./model/mnist.onnx")
.out_dir("./model/")
.run_from_script();After generation, the model can be embedded directly into the binary or loaded from the generated mnist.rs , mnist.bin , and mnist.mpk files.
4. Training Pipeline
Data loading mirrors the PyTorch workflow using Burn’s MNISTDataset and a custom ClassificationBatcher . The training loop employs CrossEntropyLoss and the AdaGrad optimizer, with TrainStep and ValidStep traits implemented for the model.
pub fn train
(config: TrainingConfig, device: B::Device) {
// create dataloaders, learner, and run training
let model_trained = learner.fit(dataloader_train, dataloader_test);
model_trained.save_file("{ARTIFACT_DIR}/model", &CompactRecorder::new())
.expect("Trained model should be saved successfully");
}Configuration examples show how to run on CPU, wgpu (GPU), or other backends such as LibTorch with CUDA or Apple Silicon.
5. Results and Observations
Running on the laptop’s integrated GPU dramatically reduces training time (≈37 minutes) and improves accuracy compared to CPU, while monitoring tools confirm GPU utilization and modest CPU load.
Conclusion
Burn, combined with wgpu, provides a viable path to train deep‑learning models in pure Rust across diverse hardware, though some advanced loss functions and optimizers remain unsupported.
Rare Earth Juejin Tech Community
Juejin, a tech community that helps developers grow.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.