Applying UNETR Transformer for 3D Medical Image Segmentation

This article walks through using the UNETR transformer architecture to segment 3D brain MRI scans from the BRATS dataset, detailing environment setup, data preprocessing with MONAI, model construction, training with DiceCE loss, validation metrics, and visualizing the best‑performing model outputs.

Code DAO
Code DAO
Code DAO
Applying UNETR Transformer for 3D Medical Image Segmentation

The UNETR model, the first Transformer‑based architecture for 3D medical image segmentation, is applied to the multi‑modal BRATS brain tumor MRI dataset. The goal is to match the performance of a conventional UNet while demonstrating the transformer workflow.

Environment setup

!pip install monai tqdm
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
!pip install self-attention-cv==1.2.3

Library imports

import os, shutil, tempfile
import matplotlib.pyplot as plt
import numpy as np
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader
from monai.losses import DiceLoss, DiceCELoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.transforms import (
    Activations, AsChannelFirstd, AsDiscrete, CenterSpatialCropd,
    Compose, LoadImaged, MapTransform, NormalizeIntensityd,
    Orientationd, RandFlipd, RandScaleIntensityd, RandShiftIntensityd,
    RandSpatialCropd, Spacingd, ToTensord,
)
from monai.utils import set_determinism
import torch
print_config()

Custom transform for BRATS label conversion

class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """Convert BRATS labels to three channels: TC, WT, ET.
    label 1 → edema, label 2 → enhancing tumor, label 3 → necrotic core.
    """
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # TC = label 2 ∪ label 3
            result.append(np.logical_or(d[key] == 2, d[key] == 3))
            # WT = label 1 ∪ label 2 ∪ label 3
            result.append(np.logical_or(np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1))
            # ET = label 2 only
            result.append(d[key] == 2)
            d[key] = np.stack(result, axis=0).astype(np.float32)
        return d

Data transforms for training and validation

roi_size = [128, 128, 64]
pixdim = (1.5, 1.5, 2.0)

train_transform = Compose([
    LoadImaged(keys=["image", "label"]),
    AsChannelFirstd(keys="image"),
    ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
    Spacingd(keys=["image", "label"], pixdim=pixdim, mode=("bilinear", "nearest")),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    RandSpatialCropd(keys=["image", "label"], roi_size=roi_size, random_size=False),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
    RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
    ToTensord(keys=["image", "label"]),
])

val_transform = Compose([
    LoadImaged(keys=["image", "label"]),
    AsChannelFirstd(keys="image"),
    ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
    Spacingd(keys=["image", "label"], pixdim=pixdim, mode=("bilinear", "nearest")),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    CenterSpatialCropd(keys=["image", "label"], roi_size=roi_size),
    NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ToTensord(keys=["image", "label"]),
])

Dataset loading with MONAI's DecathlonDataset

root_dir = "./"
cache_num = 8
train_ds = DecathlonDataset(root_dir=root_dir, task="Task01_BrainTumour", transform=train_transform, section="training", download=True, num_workers=4, cache_num=cache_num)
val_ds   = DecathlonDataset(root_dir=root_dir, task="Task01_BrainTumour", transform=val_transform,   section="validation", download=False, num_workers=4, cache_num=cache_num)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True,  num_workers=2)
val_loader   = DataLoader(val_ds,   batch_size=2, shuffle=False, num_workers=2)

Visualization of a sample image and its three‑channel label

slice_id = 32
print(f"image shape: {val_ds[2]['image'].shape}")
plt.figure("image", (24, 6))
for i in range(4):
    plt.subplot(1, 4, i+1)
    plt.title(f"image channel {i}")
    plt.imshow(val_ds[2]["image"][i, :, :, slice_id].detach().cpu(), cmap="gray")
plt.show()
print(f"label shape: {val_ds[2]['label'].shape}")
plt.figure("label", (24, 6))
for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(f"label channel {i}")
    plt.imshow(val_ds[2]["label"][i, :, :, slice_id].detach().cpu())
plt.show()

UNETR model definition (Transformer‑based)

from self_attention_cv import UNETR
device = torch.device("cuda:0")
num_heads = 10
embed_dim = 512
model = UNETR(
    img_shape=tuple(roi_size),
    input_dim=4,
    output_dim=3,
    embed_dim=embed_dim,
    patch_size=16,
    num_heads=num_heads,
    ext_layers=[3, 6, 9, 12],
    norm='instance',
    base_filters=16,
    dim_linear_block=2048,
).to(device)

Reference UNet model for comparison

model_unet = UNet(
    dimensions=3,
    in_channels=4,
    out_channels=3,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)
print('Parameters in millions:', sum(p.numel() for p in model_unet.parameters())/1e6)

Training loop with DiceCE loss

loss_function = DiceCELoss(to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
max_epochs = 180
val_interval = 5
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch+1}/{max_epochs}")
    model.train()
    epoch_loss = 0.0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch+1} average loss: {epoch_loss:.4f}")
    if (epoch+1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            dice_metric = DiceMetric(include_background=True, reduction="mean")
            post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
            metric_sum = metric_sum_tc = metric_sum_wt = metric_sum_et = 0.0
            metric_count = metric_count_tc = metric_count_wt = metric_count_et = 0
            for val_data in val_loader:
                val_inputs, val_labels = val_data["image"].to(device), val_data["label"].to(device)
                val_outputs = post_trans(model(val_inputs))
                # overall dice
                value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels)
                not_nans = not_nans.mean().item()
                metric_count += not_nans
                metric_sum += value.mean().item() * not_nans
                # TC dice
                value_tc, _ = dice_metric(y_pred=val_outputs[:,0:1], y=val_labels[:,0:1])
                metric_count_tc += not_nans
                metric_sum_tc += value_tc.item() * not_nans
                # WT dice
                value_wt, _ = dice_metric(y_pred=val_outputs[:,1:2], y=val_labels[:,1:2])
                metric_count_wt += not_nans
                metric_sum_wt += value_wt.item() * not_nans
                # ET dice
                value_et, _ = dice_metric(y_pred=val_outputs[:,2:3], y=val_labels[:,2:3])
                metric_count_et += not_nans
                metric_sum_et += value_et.item() * not_nans
            metric = metric_sum / metric_count
            metric_tc = metric_sum_tc / metric_count_tc
            metric_wt = metric_sum_wt / metric_count_wt
            metric_et = metric_sum_et / metric_count_et
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
                print("saved new best metric model")
            print(f"current epoch: {epoch+1} mean dice: {metric:.4f} tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}")
            print(f"best mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}")
# Save final model
torch.save(model.state_dict(), "./last.pth")
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")

Plotting training loss and validation Dice scores

plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
plt.plot(range(1, len(epoch_loss_values)+1), epoch_loss_values, color="red")
plt.xlabel("epoch")

plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
plt.plot([val_interval * (i+1) for i in range(len(metric_values))], metric_values, color="green")
plt.xlabel("epoch")
plt.show()

plt.figure("train", (18, 6))
for i, (vals, title, col) in enumerate([
    (metric_values_tc, "Val Mean Dice TC", "blue"),
    (metric_values_wt, "Val Mean Dice WT", "brown"),
    (metric_values_et, "Val Mean Dice ET", "purple"),
]):
    plt.subplot(1, 3, i+1)
    plt.title(title)
    plt.plot([val_interval * (j+1) for j in range(len(vals))], vals, color=col)
    plt.xlabel("epoch")
plt.show()

Loading the best model and visualizing its output

model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
    val_input = val_ds[6]["image"].unsqueeze(0).to(device)
    val_output = model(val_input)
    # visualize input channels
    plt.figure("image", (24, 6))
    for i in range(4):
        plt.subplot(1, 4, i+1)
        plt.title(f"image channel {i}")
        plt.imshow(val_ds[6]["image"][i, :, :, 20].detach().cpu(), cmap="gray")
    plt.show()
    # visualize ground‑truth label channels
    plt.figure("label", (18, 6))
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(f"label channel {i}")
        plt.imshow(val_ds[6]["label"][i, :, :, 20].detach().cpu())
    plt.show()
    # visualize model output channels
    plt.figure("output", (18, 6))
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(f"output channel {i}")
        out_tensor = torch.sigmoid(val_output[0, i, :, :, 20]).detach().cpu()
        plt.imshow(out_tensor)
    plt.show()

The implementation follows the methodology described in UNETR: Transformers for 3D Medical Image Segmentation by Hatamizadeh et al. (2021) (arXiv:2103.10504v3).

Original Source

Signed-in readers can open the original source through BestHub's protected redirect.

Sign in to view source
Republication Notice

This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactadmin@besthub.devand we will review it promptly.

TransformerPyTorch3D segmentationBRATSMONAIUNETR
Code DAO
Written by

Code DAO

We deliver AI algorithm tutorials and the latest news, curated by a team of researchers from Peking University, Shanghai Jiao Tong University, Central South University, and leading AI companies such as Huawei, Kuaishou, and SenseTime. Join us in the AI alchemy—making life better!

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.