A Coding Implementation on MONAI for End-to-End 3D Spleen Segmentation Using UNet on Medical CT Volumes

A Coding Implementation on MONAI for End-to-End 3D Spleen Segmentation Using UNet on Medical CT Volumes


In this tutorial, we build an end-to-end 3D medical image segmentation pipeline using MONAI to segment the spleen on the Medical Segmentation Decathlon Task09 dataset. We work with volumetric CT scans, apply medical imaging transformations such as orientation alignment, voxel-spacing normalization, intensity windowing, foreground cropping, and patch-based sampling, and then train a 3D UNet model for binary organ segmentation. We also use mixed precision training, DiceCE loss, sliding-window inference, Dice-based validation, and qualitative visualization to understand how the model learns and how its predictions compare with the ground-truth masks. Also, we move from raw medical volumes to a complete train–validate–visualize segmentation system.

Copy CodeCopiedUse a different Browser

!pip install -q “monai[nibabel,tqdm,matplotlib]==1.5.2” 2>/dev/null
import os, time, glob, tempfile, warnings
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.amp import autocast, GradScaler
from monai.apps import DecathlonDataset
from monai.data import DataLoader, decollate_batch
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.utils import set_determinism
from monai.transforms import (
Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, Orientationd,
Spacingd, ScaleIntensityRanged, CropForegroundd, RandCropByPosNegLabeld,
RandFlipd, RandRotate90d, RandShiftIntensityd, AsDiscrete,
)
warnings.filterwarnings(“ignore”)

We start by installing MONAI with the required medical imaging and visualization dependencies. We then import PyTorch, NumPy, Matplotlib, and the main MONAI modules needed for datasets, transforms, model training, metrics, and inference. We also suppress warnings to keep the notebook output clean while we focus on the segmentation workflow.

Copy CodeCopiedUse a different Browser

QUICK_RUN = True
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
root_dir = tempfile.mkdtemp()
roi_size = (96, 96, 96)
num_samples = 4
batch_size = 2
max_epochs = 15 if QUICK_RUN else 200
val_every = 3
train_cache = 8 if QUICK_RUN else 24
val_cache = 2 if QUICK_RUN else 6
set_determinism(seed=0)
print(f”Device: {device} | epochs: {max_epochs} | data dir: {root_dir}”)
train_transforms = Compose(common + [
image_key=”image”, image_threshold=0),
RandFlipd(keys=[“image”, “label”], prob=0.2, spatial_axis=0),
RandFlipd(keys=[“image”, “label”], prob=0.2, spatial_axis=1),
RandFlipd(keys=[“image”, “label”], prob=0.2, spatial_axis=2),
RandRotate90d(keys=[“image”, “label”], prob=0.2, max_k=3),
RandShiftIntensityd(keys=[“image”], offsets=0.10, prob=0.5),
EnsureTyped(keys=[“image”, “label”]),
])
val_transforms = Compose(common + [EnsureTyped(keys=[“image”, “label”])])

We define the main configuration for the tutorial, including the device, dataset directory, patch size, batch size, number of epochs, and cache settings. We then create the preprocessing pipeline for CT volumes by loading images, aligning orientation, resampling voxel spacing, scaling intensities, and cropping the foreground. We also define the training and validation transforms, with the training pipeline including random crops, flips, rotations, and intensity shifts.

Copy CodeCopiedUse a different Browser

train_ds = DecathlonDataset(
root_dir=root_dir, task=”Task09_Spleen”, section=”training”,
transform=train_transforms, download=True, val_frac=0.2,
cache_num=train_cache, num_workers=2, seed=0)
val_ds = DecathlonDataset(
root_dir=root_dir, task=”Task09_Spleen”, section=”validation”,
transform=val_transforms, download=False, val_frac=0.2,
cache_num=val_cache, num_workers=2, seed=0)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
num_workers=2, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False,
num_workers=1, pin_memory=torch.cuda.is_available())
print(f”Train volumes: {len(train_ds)} | Val volumes: {len(val_ds)}”)
loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
scaler = GradScaler(“cuda”, enabled=torch.cuda.is_available())
dice_metric = DiceMetric(include_background=False, reduction=”mean”)
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])

We load the Medical Segmentation Decathlon Task09 Spleen dataset using MONAI’s DecathlonDataset. We split the data into training and validation sections, apply the appropriate transforms, and wrap both datasets with PyTorch-style data loaders. We then create a 3D UNet model, define the DiceCE loss, set up the AdamW optimizer, learning-rate scheduler, mixed-precision scaler, Dice metric, and post-processing steps.

Copy CodeCopiedUse a different Browser

best_dice, best_epoch = -1.0, -1
loss_hist, dice_hist, dice_epochs = [], [], []
best_path = os.path.join(root_dir, “best_spleen_unet.pth”)
for epoch in range(1, max_epochs + 1):
model.train(); epoch_loss, t0 = 0.0, time.time()
for batch in train_loader:
x, y = batch[“image”].to(device), batch[“label”].to(device)
optimizer.zero_grad(set_to_none=True)
with autocast(“cuda”, enabled=torch.cuda.is_available()):
logits = model(x)
loss = loss_fn(logits, y)
scaler.scale(loss).backward()
scaler.step(optimizer); scaler.update()
epoch_loss += loss.item()
scheduler.step()
epoch_loss /= len(train_loader); loss_hist.append(epoch_loss)
print(f”[{epoch:3d}/{max_epochs}] loss={epoch_loss:.4f} ”
f”lr={scheduler.get_last_lr()[0]:.2e} ({time.time()-t0:.0f}s)”)
if epoch % val_every == 0 or epoch == max_epochs:
model.eval(); dice_metric.reset()
with torch.no_grad():
for vb in val_loader:
vx, vy = vb[“image”].to(device), vb[“label”].to(device)
with autocast(“cuda”, enabled=torch.cuda.is_available()):
vout = sliding_window_inference(vx, roi_size, 4, model,
overlap=0.5)
vout = [post_pred(o) for o in decollate_batch(vout)]
vlab = [post_label(o) for o in decollate_batch(vy)]
dice_metric(y_pred=vout, y=vlab)
d = dice_metric.aggregate().item()
dice_hist.append(d); dice_epochs.append(epoch)
if d > best_dice:
best_dice, best_epoch = d, epoch
torch.save(model.state_dict(), best_path)
print(f” >> val Dice={d:.4f} (best={best_dice:.4f} @ {best_epoch})”)
print(f”\nDone. Best mean Dice {best_dice:.4f} at epoch {best_epoch}.”)

We run the full training loop, where each epoch trains the 3D UNet on cropped volumetric patches from the spleen dataset. We use automatic mixed precision to reduce memory usage and speed up training when a GPU is available. We also validate the model at regular intervals using sliding-window inference, track the Dice score, and save the best-performing checkpoint.

Copy CodeCopiedUse a different Browser

fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(range(1, len(loss_hist)+1), loss_hist, “-o”, ms=3)
ax[0].set(title=”Training loss”, xlabel=”epoch”, ylabel=”DiceCE loss”)
ax[1].plot(dice_epochs, dice_hist, “-o”, color=”seagreen”, ms=4)
ax[1].set(title=”Validation mean Dice”, xlabel=”epoch”, ylabel=”Dice”); ax[1].set_ylim(0, 1)
plt.tight_layout(); plt.show()
model.load_state_dict(torch.load(best_path, map_location=device)); model.eval()
with torch.no_grad():
sample = next(iter(val_loader))
img = sample[“image”].to(device)
with autocast(“cuda”, enabled=torch.cuda.is_available()):
pred = sliding_window_inference(img, roi_size, 4, model, overlap=0.5)
pred = torch.argmax(pred, dim=1).cpu().numpy()[0]
img_np, lab_np = img.cpu().numpy()[0, 0], sample[“label”].numpy()[0, 0]
z = int(np.argmax(lab_np.sum(axis=(0, 1))))
fig, ax = plt.subplots(1, 3, figsize=(13, 5))
ax[0].imshow(img_np[:, :, z], cmap=”gray”); ax[0].set_title(“CT slice”)
ax[1].imshow(lab_np[:, :, z], cmap=”viridis”); ax[1].set_title(“Ground truth”)
ax[2].imshow(pred[:, :, z], cmap=”viridis”); ax[2].set_title(“Prediction”)
for a in ax: a.axis(“off”)
plt.tight_layout(); plt.show()

We first plot the training loss and validation Dice score to see how the model improves over time. We then reload the best-saved model checkpoint and run inference on a single validation volume using sliding-window prediction. We visualize the CT slice, ground-truth mask, and predicted segmentation side by side to inspect the model’s qualitative performance.

In conclusion, we finished a practical MONAI-based workflow for 3D spleen segmentation using a 3D UNet model. We prepared the Medical Segmentation Decathlon dataset, transformed and augmented the CT volumes, trained the model with DiceCE loss, validated it using sliding-window inference, and tracked both loss and Dice score over time. We also inspected the final prediction visually by comparing the CT slice, ground-truth label, and model output side by side. Now, we have a clear understanding of how MONAI supports medical segmentation tasks from data loading and preprocessing to model training, evaluation, checkpointing, and qualitative analysis.

Check out the Full Codes with Notebook. Also, feel free to follow us on Twitter and don’t forget to join our 150k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.

Need to partner with us for promoting your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar etc.? Connect with us

The post A Coding Implementation on MONAI for End-to-End 3D Spleen Segmentation Using UNet on Medical CT Volumes appeared first on MarkTechPost.



Source link

Leave a Reply

Your email address will not be published. Required fields are marked *

Pin It on Pinterest