Comprehensive PyTorch Code Snippets: Configuration, Tensor Operations, Model Definition, Training, and Best Practices
This article provides a thorough collection of commonly used PyTorch code snippets covering environment setup, reproducibility, GPU configuration, tensor manipulation, model building, data preprocessing, training and evaluation loops, custom loss functions, regularization techniques, learning‑rate scheduling, checkpointing, and practical tips for efficient deep‑learning development.
This article is a curated collection of frequently used PyTorch code snippets, offering a complete reference for basic configuration, tensor handling, model definition, data processing, training, testing, and various practical tips.
1. Basic Configuration
Import packages and query versions
import torch
import torch.nn as nn
import torchvision
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(torch.cuda.get_device_name(0))Reproducibility
Full reproducibility across different hardware is not guaranteed, but on the same device you can fix random seeds for torch and numpy.
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = FalseGPU Settings
If you need only one GPU:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')To specify multiple GPUs (e.g., GPU 0 and 1):
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'You can also set the visible devices from the command line: CUDA_VISIBLE_DEVICES=0,1 python train.py Clear GPU memory: torch.cuda.empty_cache() Reset GPU via command line:
nvidia-smi --gpu-reset -i [gpu_id]2. Tensor (Tensor) Processing
Tensor data types
PyTorch provides 9 CPU tensor types and 9 GPU tensor types.
Basic tensor information
tensor = torch.randn(3,4,5)
print(tensor.type()) # data type
print(tensor.size()) # shape (tuple)
print(tensor.dim()) # number of dimensionsNamed tensors
Using named dimensions improves readability and reduces errors.
# Before PyTorch 1.3, use comments
# Tensor[N, C, H, W]
images = torch.randn(32, 3, 56, 56)
images.sum(dim=1)
images.select(dim=1, index=0)
# After PyTorch 1.3
NCHW = ['N', 'C', 'H', 'W']
images = torch.randn(32, 3, 56, 56, names=NCHW)
images.sum('C')
images.select('C', index=0)
tensor = torch.rand(3,4,1,2, names=('C','N','H','W'))
tensor = tensor.align_to('N','C','H','W')Data type conversion
# Set default type (FloatTensor is faster than DoubleTensor)
torch.set_default_tensor_type(torch.FloatTensor)
# Type conversion examples
tensor = tensor.cuda()
tensor = tensor.cpu()
tensor = tensor.float()
tensor = tensor.long()torch.Tensor ↔ np.ndarray conversion
ndarray = tensor.cpu().numpy()
tensor = torch.from_numpy(ndarray).float()
# If ndarray has negative stride
tensor = torch.from_numpy(ndarray.copy()).float()torch.Tensor ↔ PIL.Image conversion
# Tensor → PIL.Image (tensor shape [C,H,W] in [0,1])
image = PIL.Image.fromarray(torch.clamp(tensor*255, min=0, max=255).byte().permute(1,2,0).cpu().numpy())
image = torchvision.transforms.functional.to_pil_image(tensor)
# PIL.Image → Tensor
path = './figure.jpg'
tensor = torch.from_numpy(np.asarray(PIL.Image.open(path))).permute(2,0,1).float() / 255
tensor = torchvision.transforms.functional.to_tensor(PIL.Image.open(path))Extract value from a single‑element tensor
value = torch.rand(1).item()Tensor reshaping
# Reshape for fully‑connected layers after convolutions
tensor = torch.rand(2,3,4)
shape = (6,4)
tensor = torch.reshape(tensor, shape)Shuffle order
tensor = tensor[torch.randperm(tensor.size(0))] # shuffle first dimensionHorizontal flip
# PyTorch does not support negative step slicing; use indexing
tensor = tensor[:,:,:,torch.arange(tensor.size(3)-1, -1, -1).long()]Tensor copy
# Operation | New/Shared memory | Still in computation graph |
tensor.clone() | New | Yes |
tensor.detach() | Shared| No |
tensor.detach().clone() | New | No |Tensor concatenation
# torch.cat concatenates along a given dimension; torch.stack adds a new dimension
tensor = torch.cat(list_of_tensors, dim=0)
tensor = torch.stack(list_of_tensors, dim=0)One‑hot encoding of integer labels
tensor = torch.tensor([0, 2, 1, 3])
N = tensor.size(0)
num_classes = 4
one_hot = torch.zeros(N, num_classes).long()
one_hot.scatter_(dim=1, index=torch.unsqueeze(tensor, dim=1), src=torch.ones(N, num_classes).long())Non‑zero element indices
torch.nonzero(tensor) # indices of non‑zero elements
torch.nonzero(tensor==0) # indices of zero elements
torch.nonzero(tensor).size(0) # count of non‑zero elements
torch.nonzero(tensor==0).size(0) # count of zero elementsTensor equality
torch.allclose(tensor1, tensor2) # for float tensors
torch.equal(tensor1, tensor2) # for int tensorsTensor expansion
# Expand shape (64,512) to (64,512,7,7)
tensor = torch.rand(64,512)
tensor = torch.reshape(tensor, (64,512,1,1)).expand(64,512,7,7)Matrix multiplication
# (m×n) * (n×p) → (m×p)
result = torch.mm(tensor1, tensor2)
# Batch matrix multiplication (b×m×n) * (b×n×p) → (b×m×p)
result = torch.bmm(tensor1, tensor2)
# Element‑wise multiplication
result = tensor1 * tensor2Pairwise Euclidean distance between two sets
dist = torch.sqrt(torch.sum((X1[:,None,:] - X2) ** 2, dim=2))3. Model Definition and Operations
Simple two‑layer convolutional network
# convolutional neural network (2 convolutional layers)
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
super(ConvNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = nn.Linear(7*7*32, num_classes)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out
model = ConvNet(num_classes).to(device)Bilinear pooling
X = torch.reshape(N, D, H * W) # assume X shape N×D×H×W
X = torch.bmm(X, torch.transpose(X, 1, 2)) / (H * W) # bilinear pooling
assert X.size() == (N, D, D)
X = torch.reshape(X, (N, D * D))
X = torch.sign(X) * torch.sqrt(torch.abs(X) + 1e-5) # signed‑sqrt normalization
X = torch.nn.functional.normalize(X) # L2 normalizationMulti‑GPU synchronized BatchNorm
sync_bn = torch.nn.SyncBatchNorm(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)Replace all BN layers with SyncBN
def convertBNtoSyncBN(module, process_group=None):
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
sync_bn = torch.nn.SyncBatchNorm(module.num_features, module.eps, module.momentum,
module.affine, module.track_running_stats, process_group)
sync_bn.running_mean = module.running_mean
sync_bn.running_var = module.running_var
if module.affine:
sync_bn.weight = module.weight.clone().detach()
sync_bn.bias = module.bias.clone().detach()
return sync_bn
else:
for name, child_module in module.named_children():
setattr(module, name, convertBNtoSyncBN(child_module, process_group=process_group))
return moduleModel parameter counting
num_parameters = sum(torch.numel(p) for p in model.parameters())Inspect model parameters
params = list(model.named_parameters())
(name, param) = params[28]
print(name)
print(param.grad)Model weight initialization
# Common practice for initialization
for layer in model.modules():
if isinstance(layer, torch.nn.Conv2d):
torch.nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
if layer.bias is not None:
torch.nn.init.constant_(layer.bias, 0.0)
elif isinstance(layer, torch.nn.BatchNorm2d):
torch.nn.init.constant_(layer.weight, 1.0)
torch.nn.init.constant_(layer.bias, 0.0)
elif isinstance(layer, torch.nn.Linear):
torch.nn.init.xavier_normal_(layer.weight)
if layer.bias is not None:
torch.nn.init.constant_(layer.bias, 0.0)
# Initialize with a given tensor
layer.weight = torch.nn.Parameter(tensor)Extract a specific layer
# Get the first two layers
new_model = nn.Sequential(*list(model.children())[:2])
# Extract all Conv2d layers
conv_model = nn.Module()
for name, layer in model.named_modules():
if isinstance(layer, nn.Conv2d):
conv_model.add_module(name, layer)Load part of a pretrained model
# Load weights from another model (partial match)
model_new_dict = model_new.state_dict()
model_common_dict = {k: v for k, v in model_saved.items() if k in model_new_dict}
model_new_dict.update(model_common_dict)
model_new.load_state_dict(model_new_dict)4. Data Processing
Compute dataset mean and std
def compute_mean_and_std(dataset):
mean_r = mean_g = mean_b = 0
for img, _ in dataset:
img = np.asarray(img)
mean_b += np.mean(img[:,:,0])
mean_g += np.mean(img[:,:,1])
mean_r += np.mean(img[:,:,2])
mean_b /= len(dataset); mean_g /= len(dataset); mean_r /= len(dataset)
diff_r = diff_g = diff_b = 0
N = 0
for img, _ in dataset:
img = np.asarray(img)
diff_b += np.sum((img[:,:,0] - mean_b) ** 2)
diff_g += np.sum((img[:,:,1] - mean_g) ** 2)
diff_r += np.sum((img[:,:,2] - mean_r) ** 2)
N += np.prod(img[:,:,0].shape)
std_b = np.sqrt(diff_b / N)
std_g = np.sqrt(diff_g / N)
std_r = np.sqrt(diff_r / N)
mean = (mean_b/255.0, mean_g/255.0, mean_r/255.0)
std = (std_b/255.0, std_g/255.0, std_r/255.0)
return mean, stdVideo basic info
video = cv2.VideoCapture(mp4_path)
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(video.get(cv2.CAP_PROP_FPS))
video.release()TSN segment sampling
K = self._num_segments
if is_train:
if num_frames > K:
frame_indices = torch.randint(high=num_frames // K, size=(K,), dtype=torch.long)
frame_indices += num_frames // K * torch.arange(K)
else:
frame_indices = torch.randint(high=num_frames, size=(K - num_frames,), dtype=torch.long)
frame_indices = torch.sort(torch.cat((torch.arange(num_frames), frame_indices)))[0]
else:
if num_frames > K:
frame_indices = num_frames / K // 2
frame_indices += num_frames // K * torch.arange(K)
else:
frame_indices = torch.sort(torch.cat((torch.arange(num_frames), torch.arange(K - num_frames))))[0]
assert frame_indices.size() == (K,)
return [frame_indices[i] for i in range(K)]Common training and validation transforms
train_transform = torchvision.transforms.Compose([
torchvision.transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))
])
val_transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))
])5. Model Training and Testing
Classification model training code
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print('Epoch: [{}/{}], Step: [{}/{}], Loss: {}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))Classification model testing code
model.eval()
with torch.no_grad():
correct = total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test accuracy: {} %'.format(100 * correct / total))Custom loss example
class MyLoss(torch.nn.Module):
def __init__(self):
super(MyLoss, self).__init__()
def forward(self, x, y):
loss = torch.mean((x - y) ** 2)
return lossLabel smoothing (LSR)
class LSR(nn.Module):
def __init__(self, e=0.1, reduction='mean'):
super().__init__()
self.log_softmax = nn.LogSoftmax(dim=1)
self.e = e
self.reduction = reduction
def _one_hot(self, labels, classes, value=1):
one_hot = torch.zeros(labels.size(0), classes, device=labels.device)
labels = labels.view(-1,1)
one_hot.scatter_(1, labels, value)
return one_hot
def _smooth_label(self, target, length, smooth_factor):
one_hot = self._one_hot(target, length, value=1 - smooth_factor)
one_hot += smooth_factor / (length - 1)
return one_hot
def forward(self, x, target):
if x.size(0) != target.size(0):
raise ValueError('Batch size mismatch')
smoothed_target = self._smooth_label(target, x.size(1), self.e)
x = self.log_softmax(x)
loss = torch.sum(-x * smoothed_target, dim=1)
if self.reduction == 'none':
return loss
elif self.reduction == 'sum':
return loss.sum()
elif self.reduction == 'mean':
return loss.mean()
else:
raise ValueError('Invalid reduction')Mixup training
beta_distribution = torch.distributions.beta.Beta(alpha, alpha)
for images, labels in train_loader:
images, labels = images.cuda(), labels.cuda()
lam = beta_distribution.sample([]).item()
index = torch.randperm(images.size(0)).cuda()
mixed_images = lam * images + (1 - lam) * images[index]
label_a, label_b = labels, labels[index]
scores = model(mixed_images)
loss = lam * loss_function(scores, label_a) + (1 - lam) * loss_function(scores, label_b)
optimizer.zero_grad()
loss.backward()
optimizer.step()L1 regularization
l1_regularization = torch.nn.L1Loss(reduction='sum')
loss = ...
for param in model.parameters():
loss += torch.sum(torch.abs(param))
loss.backward()Weight decay without bias decay
bias_list = [p for name, p in model.named_parameters() if name.endswith('bias')]
others_list = [p for name, p in model.named_parameters() if not name.endswith('bias')]
parameters = [
{'params': bias_list, 'weight_decay': 0},
{'params': others_list}
]
optimizer = torch.optim.SGD(parameters, lr=1e-2, momentum=0.9, weight_decay=1e-4)Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=20)Get current learning rate
# Single global LR
lr = next(iter(optimizer.param_groups))['lr']
# Multiple LRs
all_lr = [pg['lr'] for pg in optimizer.param_groups]Learning‑rate scheduling
# Reduce LR on plateau
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=5, verbose=True)
for t in range(80):
train(...)
val(...)
scheduler.step(val_acc)
# Cosine annealing
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=80)
# Multi‑step LR decay
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,70], gamma=0.1)
for t in range(80):
scheduler.step()
train(...)
val(...)
# Warmup LR
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda t: t/10)
for t in range(10):
scheduler.step()
train(...)
val(...)Chained schedulers (PyTorch ≥1.4)
model = [torch.nn.Parameter(torch.randn(2,2,requires_grad=True))]
optimizer = torch.optim.SGD(model, 0.1)
scheduler1 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
for epoch in range(4):
print(epoch, scheduler2.get_last_lr()[0])
optimizer.step()
scheduler1.step()
scheduler2.step()Training visualization with TensorBoard
# Install and launch TensorBoard
pip install tensorboard
tensorboard --logdir=runs
from torch.utils.tensorboard import SummaryWriter
import numpy as np
writer = SummaryWriter()
for n_iter in range(100):
writer.add_scalar('Loss/train', np.random.random(), n_iter)
writer.add_scalar('Loss/test', np.random.random(), n_iter)
writer.add_scalar('Accuracy/train', np.random.random(), n_iter)
writer.add_scalar('Accuracy/test', np.random.random(), n_iter)Checkpoint saving and loading
start_epoch = 0
if resume:
checkpoint = torch.load('model/best_checkpoint.pth.tar')
best_acc = checkpoint['best_acc']
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print('Load checkpoint at epoch {}.'.format(start_epoch))
print('Best accuracy so far {}.'.format(best_acc))
for epoch in range(start_epoch, num_epochs):
... # training loop
... # validation loop
is_best = current_acc > best_acc
best_acc = max(current_acc, best_acc)
checkpoint = {
'best_acc': best_acc,
'epoch': epoch + 1,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(checkpoint, 'model/checkpoint.pth.tar')
if is_best:
shutil.copy('model/checkpoint.pth.tar', 'model/best_checkpoint.pth.tar')Extracting ImageNet pretrained features
# VGG‑16 relu5‑3 feature
model = torchvision.models.vgg16(pretrained=True).features[:-1]
# VGG‑16 pool5 feature
model = torchvision.models.vgg16(pretrained=True).features
# VGG‑16 fc7 feature
model = torchvision.models.vgg16(pretrained=True)
model.classifier = torch.nn.Sequential(*list(model.classifier.children())[:-3])
# ResNet GAP feature
model = torchvision.models.resnet18(pretrained=True)
model = torch.nn.Sequential(collections.OrderedDict(list(model.named_children())[:-1]))
with torch.no_grad():
model.eval()
conv_representation = model(image)Feature extractor for multiple layers
class FeatureExtractor(torch.nn.Module):
"""Helper class to extract several convolution features from a pretrained model.
Args:
pretrained_model (torch.nn.Module): the model
layers_to_extract (list or set of str): layer names to extract
"""
def __init__(self, pretrained_model, layers_to_extract):
super().__init__()
self._model = pretrained_model
self._model.eval()
self._layers_to_extract = set(layers_to_extract)
def forward(self, x):
with torch.no_grad():
features = []
for name, layer in self._model.named_children():
x = layer(x)
if name in self._layers_to_extract:
features.append(x)
return featuresFine‑tune fully connected layer
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
model.fc = nn.Linear(512, 100) # replace last FC layer
optimizer = torch.optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-4)Fine‑tune with different LR for FC and conv layers
model = torchvision.models.resnet18(pretrained=True)
finetuned_params = list(map(id, model.fc.parameters()))
conv_params = [p for p in model.parameters() if id(p) not in finetuned_params]
parameters = [
{'params': conv_params, 'lr': 1e-3},
{'params': model.fc.parameters()}
]
optimizer = torch.optim.SGD(parameters, lr=1e-2, momentum=0.9, weight_decay=1e-4)6. Additional Practical Tips
Avoid using excessively large linear layers because they consume a lot of memory and may exceed GPU capacity.
Do not apply RNNs to very long sequences; back‑propagation through time (BPTT) memory usage grows linearly with sequence length.
Switch between training and evaluation modes with model.train() and model.eval() before calling model(x).
Wrap inference code that does not require gradients in with torch.no_grad(): to save memory and speed up computation. model.eval() changes layer behavior (e.g., BatchNorm, Dropout), while torch.no_grad() only disables gradient tracking. model.zero_grad() clears gradients for all parameters, whereas optimizer.zero_grad() clears gradients only for parameters managed by that optimizer. torch.nn.CrossEntropyLoss expects raw logits; it internally applies log_softmax and NLLLoss.
Always call optimizer.zero_grad() before loss.backward() to prevent gradient accumulation.
Set pin_memory=True in DataLoader for large datasets; for tiny datasets like MNIST, pin_memory=False may be faster. Tune num_workers experimentally.
Use del to delete unused intermediate tensors and free GPU memory.
In‑place operations (e.g., torch.nn.functional.relu(x, inplace=True)) reduce memory consumption.
Avoid frequent CPU‑GPU transfers; accumulate metrics on GPU and transfer them back only once per epoch.
Half‑precision ( .half()) can speed up training on supported GPUs, but watch out for numerical stability.
Use assert tensor.size() == (N, D, H, W) to debug tensor shapes.
Prefer 2‑D tensors of shape (N,1) over 1‑D tensors to avoid unexpected broadcasting issues.
Profile code sections with torch.autograd.profiler.profile or python -m torch.utils.bottleneck for performance analysis.
Debug with TorchSnooper to automatically print tensor shapes, dtypes, devices, and gradient requirements.
Python Programming Learning Circle
A global community of Chinese Python developers offering technical articles, columns, original video tutorials, and problem sets. Topics include web full‑stack development, web scraping, data analysis, natural language processing, image processing, machine learning, automated testing, DevOps automation, and big data.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
