Artificial Intelligence 18 min read

Visualizing PyTorch Neural Network Architecture and Training Process with HiddenLayer, torchviz, TensorBoardX, and Visdom

This tutorial explains how to visualize a PyTorch convolutional neural network's architecture and training dynamics using tools such as HiddenLayer, torchviz, TensorBoardX, and Visdom, providing step‑by‑step code examples and screenshots for each method.

Python Programming Learning Circle
Python Programming Learning Circle
Python Programming Learning Circle
Visualizing PyTorch Neural Network Architecture and Training Process with HiddenLayer, torchviz, TensorBoardX, and Visdom

This article demonstrates how to visualize both the structure and training process of a simple convolutional neural network built with PyTorch.

1. Network structure visualization

We first define a basic ConvNet class consisting of two convolutional blocks, a fully‑connected block, and an output layer, then print the model to show its architecture.

<code>import torch
import torch.nn as nn

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, 3, 1, 1),
            nn.ReLU(),
            nn.AvgPool2d(2, 2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(32 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self.out = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        output = self.out(x)
        return output</code>

Printing the model yields a detailed hierarchical description of each layer.

<code>MyConvNet = ConvNet()
print(MyConvNet)</code>

1.1 Visualizing with HiddenLayer

Install the library and generate a graph object for the model.

<code>pip install hiddenlayer</code>
<code>import hiddenlayer as h
vis_graph = h.build_graph(MyConvNet, torch.zeros([1, 1, 28, 28]))
vis_graph.theme = h.graph.THEMES["blue"].copy()
vis_graph.save("./demo1.png")</code>

The resulting PNG shows the network topology.

1.2 Visualizing with torchviz

Install torchviz and use make_dot to create a Graphviz representation.

<code>pip install torchviz</code>
<code>from torchviz import make_dot
x = torch.randn(1, 1, 28, 28).requires_grad_(True)
y = MyConvNet(x)
MyConvNetVis = make_dot(y, params=dict(list(MyConvNet.named_parameters()) + [('x', x)]))
MyConvNetVis.format = "png"
MyConvNetVis.directory = "data"
MyConvNetVis.view()</code>

This generates a .gv script and a rendered .png image.

2. Training process visualization

Monitoring loss and accuracy during training helps assess model performance. The tutorial shows how to log these metrics with tensorboardX and HiddenLayer .

2.1 Using tensorboardX

Install the required packages and add the log directory to the system PATH.

<code>pip install tensorboardX
pip install tensorboard</code>

Typical training loop with logging:

<code>from tensorboardX import SummaryWriter
logger = SummaryWriter(log_dir="data/log")
optimizer = torch.optim.Adam(MyConvNet.parameters(), lr=3e-4)
loss_func = nn.CrossEntropyLoss()
log_step_interval = 100
for epoch in range(5):
    for step, (x, y) in enumerate(train_loader):
        predict = MyConvNet(x)
        loss = loss_func(predict, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_iter_num = epoch * len(train_loader) + step + 1
        if global_iter_num % log_step_interval == 0:
            print(f"global_step:{global_iter_num}, loss:{loss.item():.2}")
            logger.add_scalar("train loss", loss.item(), global_step=global_iter_num)
            test_predict = MyConvNet(test_data_x)
            _, predict_idx = torch.max(test_predict, 1)
            acc = accuracy_score(test_data_y, predict_idx)
            logger.add_scalar("test accuracy", acc.item(), global_step=global_iter_num)
            img = vutils.make_grid(x, nrow=12)
            logger.add_image("train image sample", img, global_step=global_iter_num)
            for name, param in MyConvNet.named_parameters():
                logger.add_histogram(name, param.data.numpy(), global_step=global_iter_num)</code>

Run tensorboard --logdir="./data/log" to launch the visual interface.

2.2 Visualizing training with HiddenLayer

HiddenLayer can dynamically plot loss, accuracy, and weight matrices during training.

<code>import hiddenlayer as hl
import time
history = hl.History()
canvas = hl.Canvas()
optimizer = torch.optim.Adam(MyConvNet.parameters(), lr=3e-4)
loss_func = nn.CrossEntropyLoss()
log_step_interval = 100
for epoch in range(5):
    for step, (x, y) in enumerate(train_loader):
        predict = MyConvNet(x)
        loss = loss_func(predict, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_iter_num = epoch * len(train_loader) + step + 1
        if global_iter_num % log_step_interval == 0:
            test_predict = MyConvNet(test_data_x)
            _, predict_idx = torch.max(test_predict, 1)
            acc = accuracy_score(test_data_y, predict_idx)
            history.log((epoch, step), train_loss=loss, test_acc=acc, hidden_weight=MyConvNet.fc[2].weight)
            with canvas:
                canvas.draw_plot(history["train_loss"])
                canvas.draw_plot(history["test_acc"])
                canvas.draw_image(history["hidden_weight"])</code>

3. Using Visdom for visualization

Visdom, a Facebook tool for PyTorch, provides a flexible web‑based interface similar to Matplotlib.

<code>pip install visdom</code>

Basic usage example:

<code>from visdom import Visdom
from sklearn.datasets import load_iris
import torch, numpy as np
from PIL import Image
vis = Visdom()
# line plot
x = torch.linspace(-6, 6, 100).view([-1, 1])
sigmoid_y = torch.nn.Sigmoid()(x)
tanh_y = torch.nn.Tanh()(x)
relu_y = torch.nn.ReLU()(x)
plot_x = torch.cat([x, x, x], dim=1)
plot_y = torch.cat([sigmoid_y, tanh_y, relu_y], dim=1)
vis.line(X=plot_x, Y=plot_y, win="line plot", env="main", opts={"legend": ["Sigmoid", "Tanh", "ReLU"]})
# scatter plot
iris_x, iris_y = load_iris(return_X_y=True)
vis.scatter(iris_x[:, :2], Y=iris_y+1, win="scatter2d", env="main")
# stem plot
x = torch.linspace(-6, 6, 100).view([-1, 1])
y1 = torch.sin(x)
y2 = torch.cos(x)
plot_x = torch.cat([x, x], dim=1)
plot_y = torch.cat([y1, y2], dim=1)
vis.stem(X=plot_x, Y=plot_y, win="stem plot", env="main", opts={"legend": ["sin", "cos"], "title": "Stem Plot"})
# heatmap
iris_corr = torch.from_numpy(np.corrcoef(iris_x, rowvar=False))
vis.heatmap(iris_corr, win="heatmap", env="main", opts={"title": "Correlation Heatmap"})
# image
img = Image.open("./example.jpg").convert("L")
img_tensor = torch.from_numpy(np.array(img, dtype=np.float32))
vis.image(img_tensor, win="one image", env="MyPlotEnv", opts={"title": "Sample Image"})
# text
vis.text("hello world", win="text plot", env="MyPlotEnv", opts={"title": "Text Visualization"})</code>

Start the server with python -m visdom.server and open the provided URL in a browser to explore the visualizations.

Additional notes cover saving and reloading Visdom environments, handling environment names, and retrieving window data via the Visdom API.

deep learningPyTorchnetwork visualizationHiddenLayertensorboardXVisdom
Python Programming Learning Circle
Written by

Python Programming Learning Circle

A global community of Chinese Python developers offering technical articles, columns, original video tutorials, and problem sets. Topics include web full‑stack development, web scraping, data analysis, natural language processing, image processing, machine learning, automated testing, DevOps automation, and big data.

0 followers
Reader feedback

How this landed with the community

login Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.