How to Build a 3D CNN for CT Scan Classification with TensorFlow

This tutorial walks through constructing, training, and evaluating a 3D convolutional neural network in TensorFlow to classify CT scans for viral pneumonia, covering data preprocessing, dynamic learning rates, early stopping, and single‑scan prediction with full code examples.

Python Crawling & Data Mining
Python Crawling & Data Mining
Python Crawling & Data Mining
How to Build a 3D CNN for CT Scan Classification with TensorFlow

CT Scan 3D Image Classification

In this tutorial we build a 3D convolutional neural network (CNN) using TensorFlow 2.4.1 to predict viral pneumonia from CT scans. The dataset consists of Nifti (.nii) files for normal and abnormal lungs, each pre‑processed by rotation, HU normalization and resizing to 128×128×64.

Environment

Python 3.6.5

Jupyter Notebook

TensorFlow 2.4.1

GPU: NVIDIA GeForce RTX 3080

Data Loading and Pre‑processing

import os, zipfile
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)
    tf.config.set_visible_devices([gpus[0]], "GPU")
print(gpus)

CT files are read with nibabel (install via pip install nibabel). The preprocessing steps are:

Rotate the volume 90° to fix orientation.

Clamp HU values to [-1000, 400] and scale to [0, 1].

Resize to a uniform shape (128, 128, 64).

import nibabel as nib
from scipy import ndimage

def read_nifti_file(filepath):
    scan = nib.load(filepath)
    return scan.get_fdata()

def normalize(volume):
    min = -1000
    max = 400
    volume[volume < min] = min
    volume[volume > max] = max
    volume = (volume - min) / (max - min)
    return volume.astype("float32")

def resize_volume(img):
    desired_depth, desired_width, desired_height = 64, 128, 128
    current_depth, current_width, current_height = img.shape[-1], img.shape[0], img.shape[1]
    depth_factor = desired_depth / current_depth
    width_factor = desired_width / current_width
    height_factor = desired_height / current_height
    img = ndimage.rotate(img, 90, reshape=False)
    img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
    return img

def process_scan(path):
    volume = read_nifti_file(path)
    volume = normalize(volume)
    volume = resize_volume(volume)
    return volume

Paths for normal and abnormal scans are collected, processed, and labeled (0 for normal, 1 for abnormal). The dataset is split 70/30 into training and validation sets.

# Split datasets
x_train = np.concatenate((abnormal_scans[:70], normal_scans[:70]), axis=0)
y_train = np.concatenate((abnormal_labels[:70], normal_labels[:70]), axis=0)
x_val = np.concatenate((abnormal_scans[70:], normal_scans[70:]), axis=0)
y_val = np.concatenate((abnormal_labels[70:], normal_labels[70:]), axis=0)
print("Number of samples in train and validation are %d and %d." % (x_train.shape[0], x_val.shape[0]))

Data Augmentation

During training, volumes are randomly rotated by angles chosen from [-20, -10, -5, 5, 10, 20] degrees. Values are clipped to [0, 1] after rotation.

import random
from scipy import ndimage
import tensorflow as tf

@tf.function
def rotate(volume):
    def scipy_rotate(volume):
        angles = [-20, -10, -5, 5, 10, 20]
        angle = random.choice(angles)
        vol = ndimage.rotate(volume, angle, reshape=False)
        vol[vol < 0] = 0
        vol[vol > 1] = 1
        return vol
    return tf.numpy_function(scipy_rotate, [volume], tf.float32)

def train_preprocessing(volume, label):
    volume = rotate(volume)
    volume = tf.expand_dims(volume, axis=3)
    return volume, label

def validation_preprocessing(volume, label):
    volume = tf.expand_dims(volume, axis=3)
    return volume, label

Data loaders are created with shuffling, batching (batch size = 2), and prefetching.

train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
validation_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))
batch_size = 2
train_dataset = (train_loader.shuffle(len(x_train))
                 .map(train_preprocessing)
                 .batch(batch_size)
                 .prefetch(2))
validation_dataset = (validation_loader.shuffle(len(x_val))
                    .map(validation_preprocessing)
                    .batch(batch_size)
                    .prefetch(2))

Model Construction

def get_model(width=128, height=128, depth=64):
    inputs = keras.Input((width, height, depth, 1))
    x = layers.Conv3D(64, 3, activation="relu")(inputs)
    x = layers.MaxPool3D(2)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv3D(64, 3, activation="relu")(x)
    x = layers.MaxPool3D(2)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv3D(128, 3, activation="relu")(x)
    x = layers.MaxPool3D(2)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv3D(256, 3, activation="relu")(x)
    x = layers.MaxPool3D(2)(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling3D()(x)
    x = layers.Dense(512, activation="relu")(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(1, activation="sigmoid")(x)
    model = keras.Model(inputs, outputs, name="3dcnn")
    return model

model = get_model()
model.summary()

Training

A dynamic learning rate with exponential decay and early stopping (patience = 15) are used. The model is trained for up to 100 epochs.

initial_learning_rate = 1e-4
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=30, decay_rate=0.96, staircase=True)
model.compile(loss="binary_crossentropy",
              optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
              metrics=["acc"])
checkpoint_cb = keras.callbacks.ModelCheckpoint("3d_image_classification.h5", save_best_only=True)
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)
model.fit(train_dataset, validation_data=validation_dataset, epochs=100,
          shuffle=True, verbose=2, callbacks=[checkpoint_cb, early_stopping_cb])

Because the dataset contains only 200 samples, results may vary significantly between runs.

Model Evaluation

Single Scan Prediction

# Load best weights
model.load_weights("3d_image_classification.h5")
prediction = model.predict(np.expand_dims(x_val[0], axis=0))[0]
scores = [1 - prediction[0], prediction[0]]
class_names = ["normal", "abnormal"]
for score, name in zip(scores, class_names):
    print("This model is %.2f percent confident that CT scan is %s" % (100 * score, name))

Example output: "This model is 27.88 percent confident that CT scan is normal" and "This model is 72.12 percent confident that CT scan is abnormal".

Note: The code is adapted from the official example with several modifications.

Original Source

Signed-in readers can open the original source through BestHub's protected redirect.

Sign in to view source
Republication Notice

This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactadmin@besthub.devand we will review it promptly.

Deep LearningTensorFlowmedical imaging3D CNNCT scan classification
Python Crawling & Data Mining
Written by

Python Crawling & Data Mining

Life's short, I code in Python. This channel shares Python web crawling, data mining, analysis, processing, visualization, automated testing, DevOps, big data, AI, cloud computing, machine learning tools, resources, news, technical articles, tutorial videos and learning materials. Join us!

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.