A Practical Guide to PyTorch Visualization Tools for Deep Learning
This article walks through the core PyTorch visualization utilities—making image grids, drawing bounding boxes, segmentation masks, and keypoints—explaining why they are needed, how to set up the pipeline, and providing complete code examples for each computer‑vision task.
Training neural‑network models for computer‑vision tasks requires clear visual feedback; PyTorch provides a set of visualization functions that integrate directly into the training pipeline without external libraries.
Why Use PyTorch Visualization
Using OpenCV for annotation is common, but PyTorch visualization offers three advantages: it eliminates dependence on other libraries, keeps the entire workflow inside a single PyTorch pipeline, and leverages PIL for image handling.
Preparing the Workspace
A simple directory structure holds three input images and a Jupyter notebook ( visualization_utilities.ipynb) containing the code. The notebook imports the necessary modules:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
import os
from torchvision.utils import make_grid
from torchvision.io import read_image
from PIL import Image make_gridbuilds an image grid from a list of tensors, while read_image loads images as [C, H, W] tensors.
Visualizing an Image Grid
Three image paths are defined, read with read_image, resized to 450×450 using F.resize, and combined:
image_1_path = os.path.join('input', 'image_1.jpg')
image_2_path = os.path.join('input', 'image_2.jpg')
image_3_path = os.path.join('input', 'image_3.jpg')
image_1 = read_image(image_1_path)
image_2 = read_image(image_2_path)
image_3 = read_image(image_3_path)
image_1 = F.resize(image_1, (450, 450))
image_2 = F.resize(image_2, (450, 450))
image_3 = F.resize(image_3, (450, 450))
grid = make_grid([image_1, image_2, image_3])
def show(image):
plt.figure(figsize=(12, 9))
plt.imshow(np.transpose(image, [1, 2, 0]))
plt.axis('off')
show(grid)The resulting grid is displayed without any explicit loops.
Drawing Bounding Boxes
After importing draw_bounding_boxes, a tensor of box coordinates and a color list are defined:
boxes = torch.tensor([[135, 50, 210, 365], [210, 59, 280, 370], [300, 240, 375, 380]])
colors = ['red', 'red', 'green']
result = draw_bounding_boxes(image=image_1, boxes=boxes, colors=colors, width=3)
show(result)Colors can also be supplied as RGB tuples, e.g., [(255,0,0), (255,0,0), (0,255,0)], producing the same visual output.
Bounding Boxes from a Detection Model
The Faster RCNN ResNet50 FPN model is loaded, an image is transformed, and forward‑propagation yields predictions. Scores below a threshold of 0.8 are filtered, and the remaining boxes are visualized with random colors and class labels:
int_input, tensor_input = read_transform_return(image_1_path)
model = fasterrcnn_resnet50_fpn(pretrained=True, min_size=800)
model.eval()
outputs = model(tensor_input)
pred_scores = outputs[0]['scores'].detach().cpu().numpy()
pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in outputs[0]['labels'].cpu().numpy()]
pred_bboxes = outputs[0]['boxes'].detach().cpu().numpy()
boxes = pred_bboxes[pred_scores >= detection_threshold].astype(np.int32)
pred_classes = pred_classes[:len(boxes)]
colors = np.random.randint(0, 255, size=(len(boxes), 3))
colors = [tuple(c) for c in colors]
result_with_boxes = draw_bounding_boxes(image=int_input, boxes=torch.tensor(boxes), colors=colors, labels=pred_classes, width=4)
show(result_with_boxes)Setting fill=True produces filled boxes with the same colors.
Visualizing Segmentation Masks
For semantic segmentation, the pretrained FCN ResNet50 model is used. The draw_segmentation_masks function overlays either a boolean mask or an RGB‑colored mask based on the PASCAL VOC class map.
from torchvision.models.segmentation import fcn_resnet50
from torchvision.utils import draw_segmentation_masks
model = fcn_resnet50(pretrained=True)
model.eval()
int_input, tensor_input = read_transform_return(image_1_path)
outputs = model(tensor_input)
# Boolean mask example
labels = torch.argmax(outputs['out'][0].squeeze(), dim=0).detach().cpu().numpy()
boolean_mask = torch.tensor(labels, dtype=torch.bool)
seg_result = draw_segmentation_masks(image=int_input, masks=boolean_mask, alpha=0.5)
show(seg_result)
# RGB mask example
num_classes = outputs['out'].shape[1]
masks = outputs['out'][0]
all_masks = masks.argmax(0) == torch.arange(num_classes)[:, None, None]
seg_result = draw_segmentation_masks(int_input, all_masks, colors=label_color_map, alpha=0.5)
show(seg_result)Instance Segmentation
Mask RCNN ResNet50 FPN provides both boxes and masks. After filtering by the same confidence threshold, boxes are drawn, then the masks (thresholded at 0.5) are overlaid using the previously generated colors.
int_input, tensor_input = read_transform_return(image_3_path)
model = maskrcnn_resnet50_fpn(pretrained=True)
model.eval()
outputs = model(tensor_input)
output = outputs[0]
boxes = output['boxes'].detach().cpu().numpy()
scores = output['scores'].detach().cpu().numpy()
mask = output['masks']
boxes = boxes[scores >= detection_threshold].astype(np.int32)
colors = [tuple(c) for c in np.random.randint(0,255,(len(COCO_INSTANCE_CATEGORY_NAMES),3))]
result_with_boxes = draw_bounding_boxes(int_input, boxes=torch.tensor(boxes), colors=colors, labels=pred_classes, width=4)
final_masks = (mask > 0.5).squeeze(1)
seg_result = draw_segmentation_masks(result_with_boxes, final_masks, colors=colors, alpha=0.8)
show(seg_result)Keypoint Visualization
The Keypoint RCNN ResNet50 FPN model outputs keypoint coordinates and scores. After applying the detection threshold, draw_keypoints visualizes points, and with the connectivity argument it draws skeleton lines using the predefined CONNECT_POINTS list.
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.utils import draw_keypoints
model = keypointrcnn_resnet50_fpn(pretrained=True)
model.eval()
int_input, tensor_input = read_transform_return(image_1_path)
outputs = model(tensor_input)
keypoints = outputs[0]['keypoints']
scores = outputs[0]['scores']
idx = torch.where(scores > detection_threshold)
keypoints = keypoints[idx]
# Points only
points_res = draw_keypoints(image=int_input, keypoints=keypoints, colors=(255,0,0), radius=2)
show(points_res)
# Skeleton
skeleton_res = draw_keypoints(image=int_input, keypoints=keypoints, connectivity=CONNECT_POINTS, colors=(255,0,0), radius=4, width=3)
show(skeleton_res)Summary
The article demonstrates how PyTorch’s built‑in visualization utilities— make_grid, draw_bounding_boxes, draw_segmentation_masks, and draw_keypoints —can be used to annotate and inspect outputs for object detection, semantic/instance segmentation, and keypoint detection, all within a single, streamlined PyTorch workflow.
Signed-in readers can open the original source through BestHub's protected redirect.
This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactand we will review it promptly.
Code DAO
We deliver AI algorithm tutorials and the latest news, curated by a team of researchers from Peking University, Shanghai Jiao Tong University, Central South University, and leading AI companies such as Huawei, Kuaishou, and SenseTime. Join us in the AI alchemy—making life better!
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
