How to Build a 3D CNN for CT Scan Classification with TensorFlow
This tutorial walks through constructing, training, and evaluating a 3D convolutional neural network in TensorFlow to classify CT scans for viral pneumonia, covering data preprocessing, dynamic learning rates, early stopping, and single‑scan prediction with full code examples.
CT Scan 3D Image Classification
In this tutorial we build a 3D convolutional neural network (CNN) using TensorFlow 2.4.1 to predict viral pneumonia from CT scans. The dataset consists of Nifti (.nii) files for normal and abnormal lungs, each pre‑processed by rotation, HU normalization and resizing to 128×128×64.
Environment
Python 3.6.5
Jupyter Notebook
TensorFlow 2.4.1
GPU: NVIDIA GeForce RTX 3080
Data Loading and Pre‑processing
import os, zipfile
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
tf.config.experimental.set_memory_growth(gpus[0], True)
tf.config.set_visible_devices([gpus[0]], "GPU")
print(gpus)CT files are read with nibabel (install via pip install nibabel). The preprocessing steps are:
Rotate the volume 90° to fix orientation.
Clamp HU values to [-1000, 400] and scale to [0, 1].
Resize to a uniform shape (128, 128, 64).
import nibabel as nib
from scipy import ndimage
def read_nifti_file(filepath):
scan = nib.load(filepath)
return scan.get_fdata()
def normalize(volume):
min = -1000
max = 400
volume[volume < min] = min
volume[volume > max] = max
volume = (volume - min) / (max - min)
return volume.astype("float32")
def resize_volume(img):
desired_depth, desired_width, desired_height = 64, 128, 128
current_depth, current_width, current_height = img.shape[-1], img.shape[0], img.shape[1]
depth_factor = desired_depth / current_depth
width_factor = desired_width / current_width
height_factor = desired_height / current_height
img = ndimage.rotate(img, 90, reshape=False)
img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
return img
def process_scan(path):
volume = read_nifti_file(path)
volume = normalize(volume)
volume = resize_volume(volume)
return volumePaths for normal and abnormal scans are collected, processed, and labeled (0 for normal, 1 for abnormal). The dataset is split 70/30 into training and validation sets.
# Split datasets
x_train = np.concatenate((abnormal_scans[:70], normal_scans[:70]), axis=0)
y_train = np.concatenate((abnormal_labels[:70], normal_labels[:70]), axis=0)
x_val = np.concatenate((abnormal_scans[70:], normal_scans[70:]), axis=0)
y_val = np.concatenate((abnormal_labels[70:], normal_labels[70:]), axis=0)
print("Number of samples in train and validation are %d and %d." % (x_train.shape[0], x_val.shape[0]))Data Augmentation
During training, volumes are randomly rotated by angles chosen from [-20, -10, -5, 5, 10, 20] degrees. Values are clipped to [0, 1] after rotation.
import random
from scipy import ndimage
import tensorflow as tf
@tf.function
def rotate(volume):
def scipy_rotate(volume):
angles = [-20, -10, -5, 5, 10, 20]
angle = random.choice(angles)
vol = ndimage.rotate(volume, angle, reshape=False)
vol[vol < 0] = 0
vol[vol > 1] = 1
return vol
return tf.numpy_function(scipy_rotate, [volume], tf.float32)
def train_preprocessing(volume, label):
volume = rotate(volume)
volume = tf.expand_dims(volume, axis=3)
return volume, label
def validation_preprocessing(volume, label):
volume = tf.expand_dims(volume, axis=3)
return volume, labelData loaders are created with shuffling, batching (batch size = 2), and prefetching.
train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
validation_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))
batch_size = 2
train_dataset = (train_loader.shuffle(len(x_train))
.map(train_preprocessing)
.batch(batch_size)
.prefetch(2))
validation_dataset = (validation_loader.shuffle(len(x_val))
.map(validation_preprocessing)
.batch(batch_size)
.prefetch(2))Model Construction
def get_model(width=128, height=128, depth=64):
inputs = keras.Input((width, height, depth, 1))
x = layers.Conv3D(64, 3, activation="relu")(inputs)
x = layers.MaxPool3D(2)(x)
x = layers.BatchNormalization()(x)
x = layers.Conv3D(64, 3, activation="relu")(x)
x = layers.MaxPool3D(2)(x)
x = layers.BatchNormalization()(x)
x = layers.Conv3D(128, 3, activation="relu")(x)
x = layers.MaxPool3D(2)(x)
x = layers.BatchNormalization()(x)
x = layers.Conv3D(256, 3, activation="relu")(x)
x = layers.MaxPool3D(2)(x)
x = layers.BatchNormalization()(x)
x = layers.GlobalAveragePooling3D()(x)
x = layers.Dense(512, activation="relu")(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs, name="3dcnn")
return model
model = get_model()
model.summary()Training
A dynamic learning rate with exponential decay and early stopping (patience = 15) are used. The model is trained for up to 100 epochs.
initial_learning_rate = 1e-4
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate, decay_steps=30, decay_rate=0.96, staircase=True)
model.compile(loss="binary_crossentropy",
optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
metrics=["acc"])
checkpoint_cb = keras.callbacks.ModelCheckpoint("3d_image_classification.h5", save_best_only=True)
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)
model.fit(train_dataset, validation_data=validation_dataset, epochs=100,
shuffle=True, verbose=2, callbacks=[checkpoint_cb, early_stopping_cb])Because the dataset contains only 200 samples, results may vary significantly between runs.
Model Evaluation
Single Scan Prediction
# Load best weights
model.load_weights("3d_image_classification.h5")
prediction = model.predict(np.expand_dims(x_val[0], axis=0))[0]
scores = [1 - prediction[0], prediction[0]]
class_names = ["normal", "abnormal"]
for score, name in zip(scores, class_names):
print("This model is %.2f percent confident that CT scan is %s" % (100 * score, name))Example output: "This model is 27.88 percent confident that CT scan is normal" and "This model is 72.12 percent confident that CT scan is abnormal".
Note: The code is adapted from the official example with several modifications.
Signed-in readers can open the original source through BestHub's protected redirect.
This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactand we will review it promptly.
Python Crawling & Data Mining
Life's short, I code in Python. This channel shares Python web crawling, data mining, analysis, processing, visualization, automated testing, DevOps, big data, AI, cloud computing, machine learning tools, resources, news, technical articles, tutorial videos and learning materials. Join us!
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
