Skip to content

UNet for Medical Image Segmentation

INFO

Computer Vision, Medical Imaging, Image Segmentation

Training Process

UNet is a convolutional neural network (CNN)-based model for medical image segmentation, proposed by Ronneberger et al. in 2015. In this article, we will briefly introduce training a UNet model on a brain tumor medical image segmentation dataset using the PyTorch framework, while monitoring the training process with SwanLab to achieve intelligent localization of lesion areas or organ structures.

Code: Full code available in Section 5 or on Github
Training Logs: Unet-Medical-Segmentation - SwanLab
Model: UNet (implemented directly in PyTorch)
Dataset: brain-tumor-image-dataset-semantic-segmentation - Kaggle
SwanLab: https://swanlab.cn


1. Environment Setup

The environment setup consists of three steps:

  1. Ensure your computer has at least one NVIDIA GPU with CUDA installed.
  2. Install Python (version ≥ 3.8) and PyTorch with CUDA support.
  3. Install third-party libraries required for UNet fine-tuning using the following commands:
bash
git clone https://github.com/Zeyi-Lin/UNet-Medical.git  
cd UNet-Medical  
pip install -r requirements.txt

2. Dataset Preparation

This section uses the Brain Tumor Segmentation Dataset, which is specifically designed for medical image segmentation tasks.

Dataset Description: The Brain Tumor Segmentation Dataset is tailored for semantic segmentation in medical imaging, aiming to accurately identify brain tumor regions. It contains binary annotations (tumor/non-tumor) and enables pixel-level classification for fine-grained tumor segmentation. This dataset is suitable for training and evaluating medical image segmentation models, providing automated analysis support for brain tumor diagnosis.

For this task, we will download and extract the dataset for subsequent training.

Download and Extract the Dataset:

bash
python download.py  
unzip dataset/Brain_Tumor_Image_DataSet.zip -d dataset/

After completing these steps, you should see the following directory structure:

The folder contains training, validation, and test sets with image files (in jpg format) and annotation files (in json format). At this point, the dataset preparation is complete.

Below are some detailed code snippets. If you want to start training immediately, skip to Section 5.

3. Model Implementation

Here, we implement the UNet model in PyTorch (in net.py). The code is shown below:

python
import torch  
import torch.nn as nn  

# Define the downsampling block for the U-Net model  
class DownBlock(nn.Module):  
    def __init__(self, in_channels, out_channels, dropout_prob=0, max_pooling=True):  
        super(DownBlock, self).__init__()  
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)  
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)  
        self.relu = nn.ReLU(inplace=True)  
        self.maxpool = nn.MaxPool2d(2) if max_pooling else None  
        self.dropout = nn.Dropout(dropout_prob) if dropout_prob > 0 else None  

    def forward(self, x):  
        x = self.relu(self.conv1(x))  
        x = self.relu(self.conv2(x))  
        if self.dropout:  
            x = self.dropout(x)  
        skip = x  
        if self.maxpool:  
            x = self.maxpool(x)  
        return x, skip  

# Define the upsampling block for the U-Net model  
class UpBlock(nn.Module):  
    def __init__(self, in_channels, out_channels):  
        super(UpBlock, self).__init__()  
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)  
        self.conv1 = nn.Conv2d(out_channels * 2, out_channels, 3, padding=1)  
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)  
        self.relu = nn.ReLU(inplace=True)  

    def forward(self, x, skip):  
        x = self.up(x)  
        x = torch.cat([x, skip], dim=1)  
        x = self.relu(self.conv1(x))  
        x = self.relu(self.conv2(x))  
        return x  

# Define the complete U-Net model  
class UNet(nn.Module):  
    def __init__(self, n_channels=3, n_classes=1, n_filters=32):  
        super(UNet, self).__init__()  
        
        # Encoder path  
        self.down1 = DownBlock(n_channels, n_filters)  
        self.down2 = DownBlock(n_filters, n_filters * 2)  
        self.down3 = DownBlock(n_filters * 2, n_filters * 4)  
        self.down4 = DownBlock(n_filters * 4, n_filters * 8)  
        self.down5 = DownBlock(n_filters * 8, n_filters * 16)  
        
        # Bottleneck layer (remove final max-pooling)  
        self.bottleneck = DownBlock(n_filters * 16, n_filters * 32, dropout_prob=0.4, max_pooling=False)  
        
        # Decoder path  
        self.up1 = UpBlock(n_filters * 32, n_filters * 16)  
        self.up2 = UpBlock(n_filters * 16, n_filters * 8)  
        self.up3 = UpBlock(n_filters * 8, n_filters * 4)  
        self.up4 = UpBlock(n_filters * 4, n_filters * 2)  
        self.up5 = UpBlock(n_filters * 2, n_filters)  
        
        # Output layer  
        self.outc = nn.Conv2d(n_filters, n_classes, 1)  
        self.sigmoid = nn.Sigmoid()  

    def forward(self, x):  
        # Encoder path  
        x1, skip1 = self.down1(x)      # 128  
        x2, skip2 = self.down2(x1)     # 64  
        x3, skip3 = self.down3(x2)     # 32  
        x4, skip4 = self.down4(x3)     # 16  
        x5, skip5 = self.down5(x4)     # 8  
        
        # Bottleneck layer  
        x6, skip6 = self.bottleneck(x5)  # 8 (no downsampling)  
        
        # Decoder path  
        x = self.up1(x6, skip5)    # 16  
        x = self.up2(x, skip4)     # 32  
        x = self.up3(x, skip3)     # 64  
        x = self.up4(x, skip2)     # 128  
        x = self.up5(x, skip1)     # 256  
        
        x = self.outc(x)  
        x = self.sigmoid(x)  
        return x

The model is saved as a pth file, requiring approximately 124 MB.

4. Experiment Tracking with SwanLab

SwanLab is an open-source tool for training experiment tracking. Designed for AI researchers, SwanLab provides training visualization, automatic logging, hyperparameter recording, experiment comparison, and multi-user collaboration. With SwanLab, researchers can identify training issues through intuitive visualizations, compare multiple experiments for inspiration, and share results via online links to facilitate team collaboration.

For this training session, we configure SwanLab with the project name Unet-Medical-Segmentation, experiment name bs32-epoch40, and the following hyperparameters:

python
swanlab.init(  
    project="Unet-Medical-Segmentation",  
    experiment_name="bs32-epoch40",  
    config={  
        "batch_size": 32,  
        "learning_rate": 1e-4,  
        "num_epochs": 40,  
        "device": "cuda" if torch.cuda.is_available() else "cpu",  
    },  
)

Here, the batch size is 32, the learning rate is 1e-4, and the training runs for 40 epochs.

For first-time SwanLab users, register an account on the official website, copy your API Key from the user settings page, and paste it when prompted during training. Subsequent logins will not require this step:

5. Start Training

View the training visualization: Unet-Medical-Segmentation

This section accomplishes the following:

  1. Loads the UNet model.
  2. Prepares the dataset (training, validation, and test sets) with resizing to (256, 256) and normalization.
  3. Uses SwanLab to log training metrics, hyperparameters, and final model outputs.
  4. Trains for 40 epochs.
  5. Generates final prediction visualizations.

The directory structure before execution should be:

|———— dataset/  
|———————— train/  
|———————— val/  
|———————— test/  
|———— readme_files/  
|———— train.py  
|———— data.py  
|———— net.py  
|———— download.py  
|———— requirements.txt

Full Code

train.py:

python
import torch  
import torch.nn as nn  
import torch.optim as optim  
import matplotlib.pyplot as plt  
from pycocotools.coco import COCO  
from torch.utils.data import DataLoader  
import torchvision.transforms as transforms  
import random  
import swanlab  
from net import UNet  
from data import COCOSegmentationDataset  


# Dataset paths  
train_dir = './dataset/train'  
val_dir = './dataset/valid'  
test_dir = './dataset/test'  

train_annotation_file = './dataset/train/_annotations.coco.json'  
test_annotation_file = './dataset/test/_annotations.coco.json'  
val_annotation_file = './dataset/valid/_annotations.coco.json'  

# Load COCO datasets  
train_coco = COCO(train_annotation_file)  
val_coco = COCO(val_annotation_file)  
test_coco = COCO(test_annotation_file)  

# Loss functions  
def dice_loss(pred, target, smooth=1e-6):  
    pred_flat = pred.view(-1)  
    target_flat = target.view(-1)  
    intersection = (pred_flat * target_flat).sum()  
    return 1 - ((2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth))  

def combined_loss(pred, target):  
    dice = dice_loss(pred, target)  
    bce = nn.BCELoss()(pred, target)  
    return 0.6 * dice + 0.4 * bce  

# Training function  
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):  
    best_val_loss = float('inf')  
    patience = 8  
    patience_counter = 0  

    for epoch in range(num_epochs):  
        model.train()  
        train_loss = 0  
        train_acc = 0  
        
        for images, masks in train_loader:  
            images, masks = images.to(device), masks.to(device)  
            
            optimizer.zero_grad()  
            outputs = model(images)  
            loss = criterion(outputs, masks)  
            
            loss.backward()  
            optimizer.step()  
            
            train_loss += loss.item()  
            train_acc += (outputs.round() == masks).float().mean().item()  

        train_loss /= len(train_loader)  
        train_acc /= len(train_loader)  
        
        # Validation  
        model.eval()  
        val_loss = 0  
        val_acc = 0  
        
        with torch.no_grad():  
            for images, masks in val_loader:  
                images, masks = images.to(device), masks.to(device)  
                outputs = model(images)  
                loss = criterion(outputs, masks)  
                
                val_loss += loss.item()  
                val_acc += (outputs.round() == masks).float().mean().item()  
        
        val_loss /= len(val_loader)  
        val_acc /= len(val_loader)  
        
        swanlab.log(  
            {  
                "train/loss": train_loss,  
                "train/acc": train_acc,  
                "train/epoch": epoch+1,  
                "val/loss": val_loss,  
                "val/acc": val_acc,  
            },  
            step=epoch+1)  
        
        print(f'Epoch {epoch+1}/{num_epochs}:')  
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')  
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')  
        
        # Early stopping  
        if val_loss < best_val_loss:  
            best_val_loss = val_loss  
            patience_counter = 0  
            torch.save(model.state_dict(), 'best_model.pth')  
        else:  
            patience_counter += 1  
            if patience_counter >= patience:  
                print("Early stopping triggered")  
                break  

def main():  
    swanlab.init(  
        project="Unet-Medical-Segmentation",  
        experiment_name="bs32-epoch40",  
        config={  
            "batch_size": 32,  
            "learning_rate": 1e-4,  
            "num_epochs": 40,  
            "device": "cuda" if torch.cuda.is_available() else "cpu",  
        },  
    )  
    
    # Device setup  
    device = torch.device(swanlab.config["device"])  
    
    # Data preprocessing  
    transform = transforms.Compose([  
        transforms.ToTensor(),  
        transforms.Resize((256, 256)),  
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
    ])  
    
    # Create datasets  
    train_dataset = COCOSegmentationDataset(train_coco, train_dir, transform=transform)  
    val_dataset = COCOSegmentationDataset(val_coco, val_dir, transform=transform)  
    test_dataset = COCOSegmentationDataset(test_coco, test_dir, transform=transform)  
    
    # Data loaders  
    BATCH_SIZE = swanlab.config["batch_size"]  
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)  
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)  
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)  
    
    # Initialize model  
    model = UNet(n_filters=32).to(device)  
    
    # Optimizer and learning rate  
    optimizer = optim.Adam(model.parameters(), lr=swanlab.config["learning_rate"])  
    
    # Train model  
    train_model(  
        model=model,  
        train_loader=train_loader,  
        val_loader=val_loader,  
        criterion=combined_loss,  
        optimizer=optimizer,  
        num_epochs=swanlab.config["num_epochs"],  
        device=device,  
    )  
    
    # Test set evaluation  
    model.eval()  
    test_loss = 0  
    test_acc = 0  
    
    with torch.no_grad():  
        for images, masks in test_loader:  
            images, masks = images.to(device), masks.to(device)  
            outputs = model(images)  
            loss = combined_loss(outputs, masks)  
            test_loss += loss.item()  
            test_acc += (outputs.round() == masks).float().mean().item()  
    
    test_loss /= len(test_loader)  
    test_acc /= len(test_loader)  
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")  
    swanlab.log({"test/loss": test_loss, "test/acc": test_acc})  
    
    # Visualization  
    visualize_predictions(model, test_loader, device, num_samples=10)  
    

def visualize_predictions(model, test_loader, device, num_samples=5, threshold=0.5):  
    model.eval()  
    with torch.no_grad():  
        # Get a batch of data  
        images, masks = next(iter(test_loader))  
        images, masks = images.to(device), masks.to(device)  
        predictions = model(images)  
        
        # Convert predictions to binary masks  
        binary_predictions = (predictions > threshold).float()  
        
        # Select random samples  
        indices = random.sample(range(len(images)), min(num_samples, len(images)))  
        indices = indices[:8]  
        
        # Create a large figure  
        plt.figure(figsize=(15, 8))  
        plt.suptitle(f'Epoch {swanlab.config["num_epochs"]} Predictions (Random 6 samples)')  
        
        for i, idx in enumerate(indices):  
            # Original image  
            plt.subplot(4, 8, i*4 + 1)  
            img = images[idx].cpu().numpy().transpose(1, 2, 0)  
            img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]).clip(0, 1)  
            plt.imshow(img)  
            plt.title('Original Image')  
            plt.axis('off')  
            
            # Ground truth mask  
            plt.subplot(4, 8, i*4 + 2)  
            plt.imshow(masks[idx].cpu().squeeze(), cmap='gray')  
            plt.title('True Mask')  
            plt.axis('off')  
            
            # Predicted mask  
            plt.subplot(4, 8, i*4 + 3)  
            plt.imshow(binary_predictions[idx].cpu().squeeze(), cmap='gray')  
            plt.title('Predicted Mask')  
            plt.axis('off')  

            # Overlay  
            plt.subplot(4, 8, i*4 + 4)  
            plt.imshow(img)  
            plt.imshow(binary_predictions[idx].cpu().squeeze(), cmap='Reds', alpha=0.3)  
            plt.title('Overlay')  
            plt.axis('off')  
        
        # Log to SwanLab  
        swanlab.log({"predictions": swanlab.Image(plt)})  

if __name__ == '__main__':  
    main()

Run Training

bash
python train.py

The following output indicates training has started:

6. Training Results

View detailed training progress here: Unet-Medical-Segmentation

From the SwanLab charts, we observe that both training and validation losses decrease with epochs, while accuracies increase. The final test accuracy reaches 97.93%.

The prediction chart displays the model's segmentation results on the test set, showing relatively accurate tumor localization:

This tutorial primarily aims to introduce medical image segmentation training. For improved performance, consider experimenting with more complex architectures or data augmentation. Share your results on the SwanLab Benchmark Community!

7. Model Inference

Load the trained model (best_model.pth) and perform inference:

bash
python predict.py

predict.py:

python
import torch  
import torchvision.transforms as transforms  
from PIL import Image  
import matplotlib.pyplot as plt  
from net import UNet  
import numpy as np  
import os  

def load_model(model_path='best_model.pth', device='cuda'):  
    """Load trained model"""  
    try:  
        # Check if file exists  
        if not os.path.exists(model_path):  
            raise FileNotFoundError(f"Model file not found at {model_path}")  
            
        model = UNet(n_filters=32).to(device)  
        state_dict = torch.load(model_path, map_location=device, weights_only=True)  
        model.load_state_dict(state_dict)  
        model.eval()  
        print(f"Model loaded successfully from {model_path}")  
        return model  
    except Exception as e:  
        print(f"Error loading model: {str(e)}")  
        raise  

def preprocess_image(image_path):  
    """Preprocess input image"""  
    image = Image.open(image_path).convert('RGB')  
    display_image = image.resize((256, 256), Image.Resampling.BILINEAR)  
    
    transform = transforms.Compose([  
        transforms.ToTensor(),  
        transforms.Resize((256, 256)),  
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
    ])  
    
    image_tensor = transform(image)  
    return image_tensor.unsqueeze(0), display_image  

def predict_mask(model, image_tensor, device='cuda', threshold=0.5):  
    """Generate segmentation mask"""  
    with torch.no_grad():  
        image_tensor = image_tensor.to(device)  
        prediction = model(image_tensor)  
        prediction = (prediction > threshold).float()  
    return prediction  

def visualize_result(original_image, predicted_mask):  
    """Visualize predictions"""  
    plt.figure(figsize=(12, 6))  
    plt.suptitle('Predictions')  
    
    # Original image  
    plt.subplot(131)  
    plt.imshow(original_image)  
    plt.title('Original Image')  
    plt.axis('off')  
    
    # Predicted mask  
    plt.subplot(132)  
    plt.imshow(predicted_mask.squeeze(), cmap='gray')  
    plt.title('Predicted Mask')  
    plt.axis('off')  
    
    # Overlay  
    plt.subplot(133)  
    plt.imshow(np.array(original_image))  
    plt.imshow(predicted_mask.squeeze(), cmap='Reds', alpha=0.3)  
    plt.title('Overlay')  
    plt.axis('off')  
        
    plt.tight_layout()  
    plt.savefig('./predictions.png')  
    print("Visualization saved as predictions.png")  

def main():  
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
    print(f"Using device: {device}")  
    
    try:  
        model_path = "/Users/zeyilin/Desktop/Coding/UNet-Medical/best_model.pth"  
        print(f"Attempting to load model from: {model_path}")  
        model = load_model(model_path, device)  
        
        image_path = "dataset/test/27_jpg.rf.b2a2b9811786cc32a23c46c560f04d07.jpg"  
        if not os.path.exists(image_path):  
            raise FileNotFoundError(f"Image file not found at {image_path}")  
            
        print(f"Processing image: {image_path}")  
        image_tensor, original_image = preprocess_image(image_path)  
        
        predicted_mask = predict_mask(model, image_tensor, device)  
        predicted_mask = predicted_mask.cpu().numpy()  
        
        print("Generating visualization...")  
        visualize_result(original_image, predicted_mask)  
        print("Results saved to predictions.png")  
        
    except Exception as e:  
        print(f"Error during prediction: {str(e)}")  
        raise  

if __name__ == '__main__':  
    main()

Additional Notes

Hardware Specifications and Parameters

Training was conducted on an NVIDIA vGPU-32GB, completing 40 epochs in 13 minutes 22 seconds.

GPU memory usage was 6.124 GB, meaning any GPU with ≥6GB VRAM can run this task. To reduce memory requirements, decrease the batch size.


References

Code: Full code in Section 5 or on Github
Training Logs: Unet-Medical-Segmentation - SwanLab
Model: UNet (PyTorch implementation)
Dataset: brain-tumor-image-dataset-semantic-segmentation - Kaggle
SwanLab: https://swanlab.cn