Decoding the Shadows: Pneumonia Classification in Chest X-rays πŸ«πŸ“Έ

Photo by CDC on Unsplash

Decoding the Shadows: Pneumonia Classification in Chest X-rays πŸ«πŸ“Έ

Β·

9 min read

Welcome to a journey into the realm of medical image classification! In this blog, we'll explore the fascinating task of classifying chest X-ray images to detect pneumonia using deep learning techniques. Our adventure will involve downloading a specialized dataset, setting up the necessary libraries, constructing a convolutional neural network (CNN) model, and training it to make accurate predictions. Along the way, we'll ensure the model's resilience through checkpoints, monitor its progress with informative plots, and finally, witness its predictive capabilities.

Downloading the Dataset

Before we embark on our coding adventure, let's obtain the essential dataset for our task. We'll utilize Kaggle to download the Chest X-ray dataset, specifically curated for pneumonia detection.

Let's Set Up the Stage: Library Imports

To kick things off, we'll import the libraries and dependencies required for our project. These include popular tools like PyTorch for deep learning, NumPy for numerical operations, and Matplotlib for visualization.

import os
import time

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchsummary import summary

import tqdm
from glob import glob

Visualizing Chest X-ray Images

Before diving into the intricacies of data preprocessing and model training, let's take a moment to visualize a sample chest X-ray image. This step not only provides us with an initial glimpse of the data but also helps us confirm that our dataset is loaded correctly.

Code for Image Visualization

# Define the path to a sample chest X-ray image
path_train = "/content/chest_xray/train"
img = glob(path_train + "/PNEUMONIA/*.jpeg")

# Read and display the image
img = np.asarray(plt.imread(img[0]))
plt.figure(figsize=(5, 5))
plt.imshow(img)
plt.title("Sample Chest X-ray Image")
plt.axis('off')
plt.show()

Loading the Dataset

Now that the dataset is at our fingertips, our next crucial step is to load it effectively and gain a comprehensive understanding of the chest X-ray images' structure. Visualization will be instrumental in unraveling insights from our data.

Load and Structure the Dataset

Let's navigate through the process of loading and structuring our dataset. The dataset is organized into training and validation sets, each residing in its respective folder within the designated data path.

# Load and structure the dataset
data_path = '/content/chest_xray'
batch_size = 68
num_classes = 2
learning_rate = 0.001
checkpoint_path = '/content/drive/MyDrive/project/chest_xray_checkpoint.pth'

train_path = os.path.join(data_path, 'train')
val_path = os.path.join(data_path, 'val')

# Define the data transformation pipeline
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.Normalize(mean=[0.485], std=[0.229])  # Normalize for grayscale
])

# Load the dataset from the folder structure
train_dataset = ImageFolder(train_path, transform=transform)
val_dataset = ImageFolder(val_path, transform=transform)

Creating Data Loaders

Now that we've acquired our chest X-ray dataset, the next critical step is to prepare it for efficient training. The process of loading images in batches is a fundamental aspect of optimizing the training performance of our deep learning model.

Batch Processing for Training and Validation:

We initiate the creation of our data loaders with the training set. Here's how we set up the data loader for the training phase:

# Create data loaders for batch processing
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

Building Model

Now, let's delve into the construction of our convolutional neural network (CNN) designed for the task of pneumonia detection through chest X-ray images. The ChestXrayCNN model is meticulously crafted, leveraging the potency of convolutional layers to extract intricate features from the X-ray data.

class ChestXrayCNN(nn.Module):
    def __init__(self, num_classes):
        super(ChestXrayCNN, self).__init__()

        # Convolutional Layer 1
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Convolutional Layer 2
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU(inplace=True)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Fully Connected Layers
        self.fc1 = nn.Linear(64 * 56 * 56, 128)
        self.relu3 = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(128, num_classes)

        # Sigmoid Activation for Binary Classification
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Convolutional Layer 1
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)

        # Convolutional Layer 2
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)

        # Flatten for Fully Connected Layers
        x = x.view(x.size(0), -1)

        # Fully Connected Layer 1
        x = self.fc1(x)
        x = self.relu3(x)

        # Fully Connected Layer 2
        x = self.fc2(x)

        return x

Setting the Stage: Device Deployment and Optimizing

Optimizing our model's performance involves deploying it on the appropriate hardware. We'll explore the use of GPU (if available) for faster computations and implement optimization techniques like the Stochastic Gradient Descent (SGD) optimizer.

# Sending the model to the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ChestXrayCNN(num_classes)
model.to(device)

with torch.no_grad():
    dataiter = iter(train_loader)
    images, labels = next(iter(train_loader))
    images = images.to(device)  # Move images to the same device as the model
    output = model(images)

# Defining the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

# Learning rate scheduling with ReduceLROnPlateau
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5)

Checkpoint Resilience: Safeguarding Model Progress

In the ever-evolving landscape of deep learning, the journey of training a model is not always a smooth, uninterrupted path. Unexpected interruptions can occur, ranging from system crashes to manual halts. To ensure the resilience of our training process, we implement a robust checkpoint mechanism. This mechanism serves a crucial purpose: to save the model's progress periodically, allowing us to seamlessly resume training from where we left off in the event of an interruption.

Initializing Progress Tracking Variables

Before diving into the intricacies of our checkpoint mechanism, we initialize several variables to keep tabs on the training progress:

  • start_epoch: The epoch from which we'll resume training.

  • best_val_loss: A placeholder for the best validation loss observed, initialized to infinity.

  • train_losses, val_losses Lists to track training and validation losses over epochs.

  • train_accuracies, val_accuracies: Lists to monitor training and validation accuracies.

Checking for a Previous Checkpoint

The first step is to check if a previous checkpoint exists. If so, we load its content, including the model and optimizer states, to seamlessly pick up training.

if os.path.exists(checkpoint_path):
    try:
        # Load the checkpoint
        checkpoint = torch.load(checkpoint_path)
    except Exception as e:
        # Load checkpoint on CPU if an exception occurs
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

    # Load model and optimizer state from the checkpoint
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    # Update start_epoch, best_val_loss, and progress tracking variables
    start_epoch = checkpoint['epoch'] + 1
    best_val_loss = checkpoint['val_loss']
    train_losses = checkpoint['train_losses']
    val_losses = checkpoint['val_losses']
    train_accuracies = checkpoint['train_accuracies']
    val_accuracies = checkpoint['val_accuracies']

This piece of code ensures that if a checkpoint is available, we seamlessly continue our deep-learning expedition.

Training Loop with Checkpointing

Our training loop spans multiple epochs, where we train the model on the training dataset and evaluate its performance on the validation dataset. Let's delve into the details:

Training Phase

# Set the number of epochs for training
num_epochs = 5

# Loop through epochs for training
for epoch in range(start_epoch, num_epochs):
    # Set the model to training mode
    model.train()
    total_correct = 0
    total_samples = 0
    running_loss = 0.0

    # Record the start time for the epoch
    start_time = time.time()

    # Iterate over batches in the training data
    for images, labels in tqdm.tqdm(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Calculate accuracy and update running loss
        _, predicted = torch.max(outputs, 1)
        total_samples += labels.size(0)
        total_correct += (predicted == labels).sum().item()
        running_loss += loss.item()

    # Calculate average training loss and accuracy for the epoch
    train_loss = running_loss / len(train_loader)
    train_accuracy = 100 * total_correct / total_samples

    # Set the model to evaluation mode
    model.eval()
    val_loss = 0
    total_correct = 0
    total_samples = 0

    # Evaluate the model on the validation data
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item()

    # Calculate average validation loss and accuracy for the epoch
    val_loss /= len(val_loader)
    val_accuracy = 100 * total_correct / total_samples

    # Adjust learning rate based on validation loss
    scheduler.step(val_loss)

    # Record the end time for the epoch
    end_time = time.time()

    # Store the losses and accuracies for visualization
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)

    # Print and display progress information
    print(f"Epoch [{epoch + 1}/{num_epochs}], Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
    print(f"Train Accuracy: {train_accuracy:.2f}%, Validation Accuracy: {val_accuracy:.2f}%")
    print(f"Time taken for Epoch {epoch + 1}: {end_time - start_time:.2f} seconds")

    # Save the model and optimizer state if validation loss improves
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_accuracies': train_accuracies,
            'val_accuracies': val_accuracies
        }, checkpoint_path)

In this snippet, after each epoch, we check if the current validation loss is better than the previously recorded best validation loss. If it is, we update the best validation loss and save the model and optimizer states along with the training progress.

Monitoring Training Progress with Plots

In this crucial segment of our journey, we delve into the heart of model training – monitoring its progression over multiple epochs. We'll meticulously craft plots that serve as visual indicators of the model's learning curve. These plots will vividly showcase the evolution of both training and validation metrics, providing valuable insights into the performance of our ChestXrayCNN model.

Plotting the Training and Validation Loss Over Epochs

Our first stop involves visualizing the loss function – a fundamental metric guiding the model towards optimal performance. We create a detailed plot that unfolds the story of how the training and validation losses change with each epoch. This insightful graph becomes a compass, guiding us toward the convergence of our model.

plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()
plt.show()

Plotting the Training and Validation Accuracy Over Epochs

Accompanying the loss plot, we turn our attention to accuracy, a metric that measures the model's precision. The training and validation accuracy plots unfold the narrative of how our model's predictive prowess evolves with each epoch. A rise in accuracy is a testament to the model's learning journey.

plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Train Accuracy')
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy Over Epochs')
plt.legend()
plt.show()

Making Predictions with the Model

Now comes the exciting part β€” witnessing our trained model in action! We will delve into the code that performs image predictions and visualize the outcomes. Let's break down the process step by step:

def predict_image(image_path, model_path):
    # Image Transformation
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
        transforms.ToTensor(),
        transforms.Normalize((0.485,), (0.229,))  # Normalize for grayscale
    ])

    # Load and Transform the Image
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)

    # Load the Pre-trained Model
    model = ChestXrayCNN(num_classes=2)
    model.load_state_dict(torch.load(model_path)['model_state_dict'])  # Load the model state_dict from the checkpoint_path

    # Set the Model to Evaluation Mode
    model.eval()

    # Make Predictions
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)

    # Retrieve Predicted Class Name
    class_name = class_names[predicted.item()]  # Get the class name corresponding to the predicted class index

    # Visualize the Prediction
    image = Image.open(image_path)
    plt.imshow(image)
    plt.axis('off')
    plt.title(f'Predicted Class: {class_name}')
    plt.show()

# Specify the Image and Model Paths
image_path = '/content/image.jpeg'
model_path = '/content/drive/MyDrive/project/chest_xray_checkpoint.pth'

# Make Prediction
predict_image(image_path, model_path)

Conclusion

As our deep learning journey concludes, we'll reflect on the key steps taken, insights gained, and the potential impact of our model in the domain of pneumonia detection through chest X-raysπŸš€.

Β