Applying UNETR Transformer for 3D Medical Image Segmentation
This article walks through using the UNETR transformer architecture to segment 3D brain MRI scans from the BRATS dataset, detailing environment setup, data preprocessing with MONAI, model construction, training with DiceCE loss, validation metrics, and visualizing the best‑performing model outputs.
The UNETR model, the first Transformer‑based architecture for 3D medical image segmentation, is applied to the multi‑modal BRATS brain tumor MRI dataset. The goal is to match the performance of a conventional UNet while demonstrating the transformer workflow.
Environment setup
!pip install monai tqdm
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
!pip install self-attention-cv==1.2.3Library imports
import os, shutil, tempfile
import matplotlib.pyplot as plt
import numpy as np
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader
from monai.losses import DiceLoss, DiceCELoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.transforms import (
Activations, AsChannelFirstd, AsDiscrete, CenterSpatialCropd,
Compose, LoadImaged, MapTransform, NormalizeIntensityd,
Orientationd, RandFlipd, RandScaleIntensityd, RandShiftIntensityd,
RandSpatialCropd, Spacingd, ToTensord,
)
from monai.utils import set_determinism
import torch
print_config()Custom transform for BRATS label conversion
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
"""Convert BRATS labels to three channels: TC, WT, ET.
label 1 → edema, label 2 → enhancing tumor, label 3 → necrotic core.
"""
def __call__(self, data):
d = dict(data)
for key in self.keys:
result = []
# TC = label 2 ∪ label 3
result.append(np.logical_or(d[key] == 2, d[key] == 3))
# WT = label 1 ∪ label 2 ∪ label 3
result.append(np.logical_or(np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1))
# ET = label 2 only
result.append(d[key] == 2)
d[key] = np.stack(result, axis=0).astype(np.float32)
return dData transforms for training and validation
roi_size = [128, 128, 64]
pixdim = (1.5, 1.5, 2.0)
train_transform = Compose([
LoadImaged(keys=["image", "label"]),
AsChannelFirstd(keys="image"),
ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
Spacingd(keys=["image", "label"], pixdim=pixdim, mode=("bilinear", "nearest")),
Orientationd(keys=["image", "label"], axcodes="RAS"),
RandSpatialCropd(keys=["image", "label"], roi_size=roi_size, random_size=False),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
ToTensord(keys=["image", "label"]),
])
val_transform = Compose([
LoadImaged(keys=["image", "label"]),
AsChannelFirstd(keys="image"),
ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
Spacingd(keys=["image", "label"], pixdim=pixdim, mode=("bilinear", "nearest")),
Orientationd(keys=["image", "label"], axcodes="RAS"),
CenterSpatialCropd(keys=["image", "label"], roi_size=roi_size),
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
ToTensord(keys=["image", "label"]),
])Dataset loading with MONAI's DecathlonDataset
root_dir = "./"
cache_num = 8
train_ds = DecathlonDataset(root_dir=root_dir, task="Task01_BrainTumour", transform=train_transform, section="training", download=True, num_workers=4, cache_num=cache_num)
val_ds = DecathlonDataset(root_dir=root_dir, task="Task01_BrainTumour", transform=val_transform, section="validation", download=False, num_workers=4, cache_num=cache_num)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=2)Visualization of a sample image and its three‑channel label
slice_id = 32
print(f"image shape: {val_ds[2]['image'].shape}")
plt.figure("image", (24, 6))
for i in range(4):
plt.subplot(1, 4, i+1)
plt.title(f"image channel {i}")
plt.imshow(val_ds[2]["image"][i, :, :, slice_id].detach().cpu(), cmap="gray")
plt.show()
print(f"label shape: {val_ds[2]['label'].shape}")
plt.figure("label", (24, 6))
for i in range(3):
plt.subplot(1, 3, i+1)
plt.title(f"label channel {i}")
plt.imshow(val_ds[2]["label"][i, :, :, slice_id].detach().cpu())
plt.show()UNETR model definition (Transformer‑based)
from self_attention_cv import UNETR
device = torch.device("cuda:0")
num_heads = 10
embed_dim = 512
model = UNETR(
img_shape=tuple(roi_size),
input_dim=4,
output_dim=3,
embed_dim=embed_dim,
patch_size=16,
num_heads=num_heads,
ext_layers=[3, 6, 9, 12],
norm='instance',
base_filters=16,
dim_linear_block=2048,
).to(device)Reference UNet model for comparison
model_unet = UNet(
dimensions=3,
in_channels=4,
out_channels=3,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
).to(device)
print('Parameters in millions:', sum(p.numel() for p in model_unet.parameters())/1e6)Training loop with DiceCE loss
loss_function = DiceCELoss(to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
max_epochs = 180
val_interval = 5
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
for epoch in range(max_epochs):
print("-" * 10)
print(f"epoch {epoch+1}/{max_epochs}")
model.train()
epoch_loss = 0.0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch+1} average loss: {epoch_loss:.4f}")
if (epoch+1) % val_interval == 0:
model.eval()
with torch.no_grad():
dice_metric = DiceMetric(include_background=True, reduction="mean")
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
metric_sum = metric_sum_tc = metric_sum_wt = metric_sum_et = 0.0
metric_count = metric_count_tc = metric_count_wt = metric_count_et = 0
for val_data in val_loader:
val_inputs, val_labels = val_data["image"].to(device), val_data["label"].to(device)
val_outputs = post_trans(model(val_inputs))
# overall dice
value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels)
not_nans = not_nans.mean().item()
metric_count += not_nans
metric_sum += value.mean().item() * not_nans
# TC dice
value_tc, _ = dice_metric(y_pred=val_outputs[:,0:1], y=val_labels[:,0:1])
metric_count_tc += not_nans
metric_sum_tc += value_tc.item() * not_nans
# WT dice
value_wt, _ = dice_metric(y_pred=val_outputs[:,1:2], y=val_labels[:,1:2])
metric_count_wt += not_nans
metric_sum_wt += value_wt.item() * not_nans
# ET dice
value_et, _ = dice_metric(y_pred=val_outputs[:,2:3], y=val_labels[:,2:3])
metric_count_et += not_nans
metric_sum_et += value_et.item() * not_nans
metric = metric_sum / metric_count
metric_tc = metric_sum_tc / metric_count_tc
metric_wt = metric_sum_wt / metric_count_wt
metric_et = metric_sum_et / metric_count_et
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
print("saved new best metric model")
print(f"current epoch: {epoch+1} mean dice: {metric:.4f} tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}")
print(f"best mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}")
# Save final model
torch.save(model.state_dict(), "./last.pth")
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")Plotting training loss and validation Dice scores
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
plt.plot(range(1, len(epoch_loss_values)+1), epoch_loss_values, color="red")
plt.xlabel("epoch")
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
plt.plot([val_interval * (i+1) for i in range(len(metric_values))], metric_values, color="green")
plt.xlabel("epoch")
plt.show()
plt.figure("train", (18, 6))
for i, (vals, title, col) in enumerate([
(metric_values_tc, "Val Mean Dice TC", "blue"),
(metric_values_wt, "Val Mean Dice WT", "brown"),
(metric_values_et, "Val Mean Dice ET", "purple"),
]):
plt.subplot(1, 3, i+1)
plt.title(title)
plt.plot([val_interval * (j+1) for j in range(len(vals))], vals, color=col)
plt.xlabel("epoch")
plt.show()Loading the best model and visualizing its output
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
val_input = val_ds[6]["image"].unsqueeze(0).to(device)
val_output = model(val_input)
# visualize input channels
plt.figure("image", (24, 6))
for i in range(4):
plt.subplot(1, 4, i+1)
plt.title(f"image channel {i}")
plt.imshow(val_ds[6]["image"][i, :, :, 20].detach().cpu(), cmap="gray")
plt.show()
# visualize ground‑truth label channels
plt.figure("label", (18, 6))
for i in range(3):
plt.subplot(1, 3, i+1)
plt.title(f"label channel {i}")
plt.imshow(val_ds[6]["label"][i, :, :, 20].detach().cpu())
plt.show()
# visualize model output channels
plt.figure("output", (18, 6))
for i in range(3):
plt.subplot(1, 3, i+1)
plt.title(f"output channel {i}")
out_tensor = torch.sigmoid(val_output[0, i, :, :, 20]).detach().cpu()
plt.imshow(out_tensor)
plt.show()The implementation follows the methodology described in UNETR: Transformers for 3D Medical Image Segmentation by Hatamizadeh et al. (2021) (arXiv:2103.10504v3).
Signed-in readers can open the original source through BestHub's protected redirect.
This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactand we will review it promptly.
Code DAO
We deliver AI algorithm tutorials and the latest news, curated by a team of researchers from Peking University, Shanghai Jiao Tong University, Central South University, and leading AI companies such as Huawei, Kuaishou, and SenseTime. Join us in the AI alchemy—making life better!
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
