Skip to content

CIFAR10 Image Classification

INFO

Introduction to Image Classification and Machine Learning

Overview

CIFAR-10 is a classic image classification dataset comprising 60,000 32×32-pixel color images divided into 10 categories (e.g., airplane, automobile, bird, etc.), with 50,000 for training and 10,000 for testing.

CIFAR-10 is widely used for image classification tasks. The objective is to build a model that performs 10-class classification on input images, outputting probability scores for each category. Due to its low resolution, complex backgrounds, and limited data volume, this dataset is often employed to test model generalization and feature extraction capabilities, serving as a benchmark for deep learning beginners. Typical approaches include CNNs (e.g., ResNet, AlexNet) with data augmentation and cross-entropy loss optimization, achieving top accuracy above 95%. Its lightweight nature makes CIFAR-10 popular for education and research, spawning more complex variants like CIFAR-100.

CIFAR-10 includes images from the following 10 classes:

  • Airplane
  • Automobile
  • Bird
  • Cat
  • Deer
  • Dog
  • Frog
  • Horse
  • Ship
  • Truck

This case study focuses on:

  • Implementing, training, and evaluating a ResNet50 (Residual Neural Network) using PyTorch.
  • Tracking hyperparameters, logging metrics, and visualizing training progress with SwanLab.

Environment Setup

This example requires Python>=3.8. Ensure Python is installed on your system.

Dependencies:

torch
torchvision
swanlab

Quick installation:

bash
pip install torch torchvision swanlab

Full Code

python
import os
import random
import numpy as np
import torch
from torch import nn, optim, utils
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Compose, Resize, Lambda
import swanlab

def set_seed(seed=42):
    """Set all random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def log_images(loader, num_images=16):
    """Capture and visualize the first N images."""
    images_logged = 0
    logged_images = []
    for images, labels in loader:
        for i in range(images.shape[0]):
            if images_logged < num_images:
                logged_images.append(swanlab.Image(images[i], caption=f"Label: {labels[i]}", size=(128, 128)))
                images_logged += 1
            else:
                break
        if images_logged >= num_images:
            break
    swanlab.log({"Preview/CIFAR10": logged_images})

if __name__ == "__main__":
    # Set random seed
    set_seed(42)

    # Configure device
    try:
        use_mps = torch.backends.mps.is_available()
    except AttributeError:
        use_mps = False

    device = "cuda" if torch.cuda.is_available() else "mps" if use_mps else "cpu"

    # Initialize SwanLab
    run = swanlab.init(
        project="CIFAR10",
        experiment_name="resnet50-pretrained",
        config={
            "model": "Resnet50",
            "optim": "Adam",
            "lr": 1e-4,
            "batch_size": 32,
            "num_epochs": 5,
            "train_dataset_num": 45000,
            "val_dataset_num": 5000,
            "device": device,
            "num_classes": 10,
        },
    )

    # Define transforms: resize and convert to 3 channels
    transform = Compose([
        ToTensor(),
        Resize((224, 224), antialias=True),  # ResNet expects 224x224 input
    ])

    # Load datasets
    dataset = CIFAR10(os.getcwd(), train=True, download=True, transform=transform)
    train_dataset, val_dataset = utils.data.random_split(
        dataset,
        [run.config.train_dataset_num, run.config.val_dataset_num],
        generator=torch.Generator().manual_seed(42)
    )

    train_loader = utils.data.DataLoader(train_dataset, batch_size=run.config.batch_size, shuffle=True)
    val_loader = utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False)

    # Initialize model, loss, and optimizer
    if run.config.model == "Resnet18":
        from torchvision.models import resnet18
        model = resnet18(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, run.config.num_classes)
    elif run.config.model == "Resnet34":
        from torchvision.models import resnet34
        model = resnet34(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, run.config.num_classes)
    elif run.config.model == "Resnet50":
        from torchvision.models import resnet50
        model = resnet50(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, run.config.num_classes)
    elif run.config.model == "Resnet101":
        from torchvision.models import resnet101
        model = resnet101(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, run.config.num_classes)
    elif run.config.model == "Resnet152":
        from torchvision.models import resnet152
        model = resnet152(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, run.config.num_classes)

    model.to(torch.device(device))
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=run.config.lr)

    # Optional: Preview dataset images
    log_images(train_loader, 8)

    # Training loop
    for epoch in range(1, run.config.num_epochs+1):
        swanlab.log({"train/epoch": epoch}, step=epoch)
        model.train()
        train_correct = 0
        train_total = 0

        for iter, batch in enumerate(train_loader):
            x, y = batch
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(output, 1)
            train_total += y.size(0)
            train_correct += (predicted == y).sum().item()

            if iter % 40 == 0:
                print(f"Epoch [{epoch}/{run.config.num_epochs}], Iteration [{iter + 1}/{len(train_loader)}], Loss: {loss.item()}")
                swanlab.log({"train/loss": loss.item()}, step=(epoch - 1) * len(train_loader) + iter)

        train_accuracy = train_correct / train_total
        swanlab.log({"train/acc": train_accuracy}, step=(epoch - 1) * len(train_loader) + iter)

        # Validation
        model.eval()
        correct = 0
        total = 0
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                output = model(x)
                val_loss += criterion(output, y).item()
                _, predicted = torch.max(output, 1)
                total += y.size(0)
                correct += (predicted == y).sum().item()

        accuracy = correct / total
        avg_val_loss = val_loss / len(val_loader)
        swanlab.log({"val/acc": accuracy, "val/loss": avg_val_loss}, step=(epoch - 1) * len(train_loader) + iter)

Switching ResNet Models

The code supports the following ResNet variants:

  • ResNet18
  • ResNet34
  • ResNet50
  • ResNet101
  • ResNet152

To switch models, modify the model parameter in config:

python
    run = swanlab.init(
        ...
        config={
            "model": "Resnet50",
        ...
        },
    )

Demo