Deep Learning Semantic Segmentation: FCN Source Code Analysis
This tutorial walks through the complete FCN pipeline for semantic segmentation, covering VOC dataset loading, data augmentation, collate functions, model construction, training loops, loss computation with cross‑entropy (including ignore‑index handling), and inference, while providing full PyTorch code snippets for each step.
The article presents a step‑by‑step walkthrough of implementing Fully Convolutional Networks (FCN) for semantic segmentation using PyTorch, starting from dataset preparation to model inference, and explains key concepts such as data transforms, loss calculation, and handling of ignored pixels.
Dataset Reading – my_dataset.py
A VOCSegmentation class is defined to load image‑mask pairs from the VOC dataset, verify paths, and build lists of file names.
class VOCSegmentation(data.Dataset):
def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
super(VOCSegmentation, self).__init__()
assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
assert os.path.exists(root), "path '{}' does not exist.".format(root)
image_dir = os.path.join(root, 'JPEGImages')
mask_dir = os.path.join(root, 'SegmentationClass')
txt_path = os.path.join(root, "ImageSets", "Segmentation", txt_name)
assert os.path.exists(txt_path), "file '{}' does not exist.".format(txt_path)
with open(os.path.join(txt_path), "r") as f:
file_names = [x.strip() for x in f.readlines() if len(x.strip()) > 0]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]Example usage sets voc_root="D:\数据集\VOC\VOCtrainval_11-May-2012" so that self.images and self.masks contain the full paths.
Data Transforms
Two transform classes are provided: one for training (random resize, horizontal flip, random crop, tensor conversion, and normalization) and one for evaluation (deterministic resize, tensor conversion, normalization).
# Training transforms
class SegmentationPresetTrain:
def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
min_size = int(0.5 * base_size)
max_size = int(2.0 * base_size)
trans = [T.RandomResize(min_size, max_size)]
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob))
trans.extend([
T.RandomCrop(crop_size),
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])
self.transforms = T.Compose(trans)
def __call__(self, img, target):
return self.transforms(img, target)
# Evaluation transforms
class SegmentationPresetEval:
def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = T.Compose([
T.RandomResize(base_size, base_size),
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])
def __call__(self, img, target):
return self.transforms(img, target)Training uses crop_size=480 , while validation skips random cropping.
Collate Function
The custom collate_fn packs a batch of images and masks into tensors of uniform size by padding.
def collate_fn(batch):
images, targets = list(zip(*batch))
batched_imgs = cat_list(images, fill_value=0)
batched_targets = cat_list(targets, fill_value=255)
return batched_imgs, batched_targets
def cat_list(images, fill_value=0):
# Compute max channel, height, width across the batch
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
batch_shape = (len(images),) + max_size
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
for img, pad_img in zip(images, batched_imgs):
pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
return batched_imgsModel Training – train.py
Data loaders are created for training and validation using the dataset class and the custom collate function.
train_dataset = VOCSegmentation(args.data_path, year="2012", transforms=get_transform(train=True), txt_name="train.txt")
val_dataset = VOCSegmentation(args.data_path, year="2012", transforms=get_transform(train=False), txt_name="val.txt")
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
pin_memory=True,
collate_fn=train_dataset.collate_fn)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=1,
num_workers=num_workers,
pin_memory=True,
collate_fn=val_dataset.collate_fn)The network is instantiated with a VGG backbone (or ResNet in newer versions) via create_model(aux=args.aux, num_classes=num_classes) .
model = create_model(aux=args.aux, num_classes=num_classes)Optimizer and learning‑rate scheduler are set up using SGD.
# Optimizer
optimizer = torch.optim.SGD(
params_to_optimize,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)The training loop iterates over epochs, calling train_one_epoch which performs forward, loss computation, backward (with optional AMP scaler), optimizer step, and LR scheduling.
def train_one_epoch(model, optimizer, data_loader, device, epoch, lr_scheduler, print_freq=10, scaler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = f'Epoch: [{epoch}]'
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image, target = image.to(device), target.to(device)
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
if scaler is not None:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
lr_scheduler.step()
lr = optimizer.param_groups[0]["lr"]
metric_logger.update(loss=loss.item(), lr=lr)
return metric_logger.meters["loss"].global_avg, lrLoss Function
The loss is computed with nn.functional.cross_entropy , ignoring the label value 255 (used for unlabeled or border pixels).
def criterion(inputs, target):
losses = {}
for name, x in inputs.items():
# ignore_index=255 skips background/edge pixels
losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)
if len(losses) == 1:
return losses['out']
return losses['out'] + 0.5 * losses['aux']Detailed examples illustrate how ignore_index works, how the loss is averaged over pixels, and how to compute it manually for verification.
Model Inference – predict.py
During inference the model output tensor of shape [1, C, H, W] is reduced to class indices with argmax , converted to a PIL image, colored with a palette, and saved.
output = model(img.to(device))
prediction = output['out'].argmax(1).squeeze(0)
prediction = prediction.to("cpu").numpy().astype(np.uint8)
mask = Image.fromarray(prediction)
mask.putpalette(pallette)
mask.save("test_result.png")Cross‑Entropy Details (Appendix)
The appendix walks through the mathematical definition of cross‑entropy, shows 1‑D and 2‑D examples, demonstrates how to ignore a specific class index, and verifies manual calculations against PyTorch's implementation.
Overall, the article provides a comprehensive, reproducible guide for building, training, and evaluating an FCN model on the VOC segmentation benchmark, making it a valuable resource for practitioners in computer vision and deep learning.
Rare Earth Juejin Tech Community
Juejin, a tech community that helps developers grow.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.