Building a Custom ResNet Model for Flower Classification

·

14 min read

In the world of deep learning, image classification is a fascinating problem. We're often amazed at the ability of computers to recognize objects in images. In this blog post, we'll embark on a journey to build a custom ResNet model for flower classification using the PyTorch framework and name it FlowerResNet.

Our mission revolves around a unique dataset known as "Flowers-299" from Kaggle, which comprises a staggering total of 115,944 images of various flowers. Our goal is to harness the power of deep learning to teach a computer to differentiate between these floral beauties.

Let's dive into the fascinating world of computer vision, where we'll explore data acquisition, model architecture, training processes, and model evaluation. By the end of this journey, you'll be equipped to create custom models for image recognition tasks, just like our custom ResNet model for flower classification.

Data Acquisition and Preprocessing

To embark on our journey into flower classification, we begin with the essential step of data acquisition and preprocessing. Our dataset, "Flowers-299," is a treasure trove of botanical beauty, consisting of a staggering total of 115,944 high-quality images of various flowers. Let's navigate through the process of obtaining, organizing, and preparing this rich dataset for our deep-learning adventure.

Downloading the Dataset

The first step is to acquire the "Flowers-299" dataset, and for this purpose, we'll leverage the Kaggle platform. If you haven't already, you'll need to install the Kaggle Python library and authenticate using your Kaggle API credentials. These credentials enable us to access and download the dataset directly into our environment.

!pip install -q kaggle
from google.colab import drive
drive.mount('/content/drive')

Once the dataset is downloaded, it's time to unpack the floral treasures it contains.

!kaggle datasets download -d bogdancretu/flower299
!unzip flower299.zip -d /content/Flowers299

Introduction and Setup

In the world of deep learning and image classification, harnessing the power of various libraries and packages is essential. Before we dive into the fascinating realm of creating a custom ResNet model for flower classification using PyTorch, let's take a moment to introduce the key tools we'll be using.

Libraries and Packages

  • Python: Our primary programming language, Python serves as the foundation for this project.

  • NumPy: The fundamental library for numerical computations, NumPy will help us work with multi-dimensional arrays, making data manipulation a breeze.

  • Matplotlib: For visualizing data and monitoring the progress of our model, Matplotlib comes in handy. We'll use it to create loss and accuracy curves.

  • PyTorch: This deep learning framework will be the cornerstone of our project, allowing us to create and train our custom ResNet model efficiently.

  • TorchVision: A PyTorch library for computer vision tasks, TorchVision will provide us with useful tools and pre-processing functions.

Additional Libraries

To assist us in training our model, we will be using the following libraries and packages:

  • TQDM: A handy progress bar that keeps us informed about the progress of various processes, such as training and validation.

With these tools at our disposal, we're well-equipped to tackle the task of building a flower classification model. Let's get started by importing the necessary packages and diving into the exciting world of deep learning.

import os
import time

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder

import tqdm

Now that we're all set up, let's proceed with our journey into the realm of custom ResNet models and flower classification.

Data Preprocessing

The raw dataset, while rich in botanical diversity, requires some essential preprocessing to make it suitable for deep learning. Here are the steps we follow:

Loading the Dataset

The first step in our data preprocessing journey is loading and structuring the "Flowers-299" dataset. This enables us to access the images and their associated labels, setting the stage for effective training and evaluation. Let's see how this is done:

# Load and structure the dataset
data_path = '/content/Flowers299'
batch_size = 128
num_classes = 299

# Define the data transformation pipeline
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the dataset from the folder structure
dataset = ImageFolder(data_path, transform=transform)

Splitting into Training and Validation Sets

For effective model training and evaluation, we split our dataset into two distinct sets: the training set and the validation set. The training set is typically used to train the model, while the validation set helps us assess the model's performance. In our case, we allocate 80% of the data to training and 20% to validation:

# Split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

Creating Data Loaders

To feed our model efficiently during training, we create data loaders. Data loaders handle data loading, batching, and shuffling. This simplifies the process of training our model and ensures it receives the right data at the right time:

# Create data loaders for batch processing
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

Building a Custom FlowerResNet Model

In this section, we'll delve into the heart of our flower classification project—the custom ResNet model. ResNet, short for Residual Network, is a powerful convolutional neural network architecture known for its ability to handle deep networks effectively. Our custom ResNet model will inherit the ResNet architecture.

Understanding ResNet

Before we dive into our code, let's grasp the core concepts of ResNet. ResNet introduced the concept of residual blocks, which allow for much deeper neural networks without the vanishing gradient problem. These residual blocks contain skip connections that enable the network to learn residual functions. In simpler terms, the network learns to predict the difference between the desired output and the current output.

Residual Block

# Code example of a basic residual block
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.shortcut(identity)
        out = self.relu(out)
        return out

Customizing for Flower Classification

To adapt ResNet for our flower classification task, we introduce specific modifications. For instance, we customize the number of classes to match the 299 flower categories in our dataset. Additionally, we configure the network to handle 3-channel input data.

Custom ResNet Model

# Code example of our custom ResNet model
class FlowerResNet(nn.Module):
    def __init__(self, num_classes):
        super(FlowerResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 64, 3)
        self.layer2 = self._make_layer(64, 128, 4, stride=2)
        self.layer3 = self._make_layer(128, 256, 6, stride=2)
        self.layer4 = self._make_layer(256, 512, 3, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, in_channels, out_channels, blocks, stride=1):
        layers = []
        layers.append(BasicBlock(in_channels, out_channels, stride))
        in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(BasicBlock(in_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

Our custom FlowerResNet model is designed to transform the raw pixel values of flower images into meaningful predictions about the flower species.

Training the FlowerResNet Model

With the FlowerResNet model in place, it's time to embark on the training journey. Training a deep learning model involves fine-tuning its parameters to make accurate predictions on our flower dataset. Let's break down the essential steps and explore how they're implemented in code.

Defining the Loss Function

The choice of a suitable loss function is crucial for training a classification model. In our flower classification task, we use the Cross-Entropy Loss, which is a common choice for multi-class classification problems. It measures the dissimilarity between the predicted class probabilities and the true class labels.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create an instance of the custom model
model = FlowerResNet(num_classes)
model.to(device)

with torch.no_grad():
    dataiter = iter(train_loader)
    images, labels = next(iter(train_loader))
    images = images.to(device)  # Move images to the same device as the model
    output = model(images)

# Code example of defining the loss function
criterion = nn.CrossEntropyLoss()

Selecting the Optimization Algorithm

The optimization algorithm determines how the model's parameters are adjusted during training to minimize the loss. We use the Stochastic Gradient Descent (SGD) optimizer with momentum, a popular choice for training deep neural networks.

# Code example of selecting the optimizer
learning_rate = 0.001
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

Monitoring and Visualizing Training Progress

As we embark on the training journey of our FlowerResNet model for flower classification, it's vital to keep a close eye on the model's performance and progress. We employ various techniques to ensure effective learning.

Learning Rate Scheduling

Learning rate scheduling is a technique that dynamically adjusts the learning rate during training. It can help fine-tune the model's convergence and improve training efficiency. In our code, we use the ReduceLROnPlateau scheduler from PyTorch to reduce the learning rate when the validation loss plateaus.

# Learning rate scheduling with ReduceLROnPlateau
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5)

Initialization and Checkpointing

Before diving into the training loop, it's essential to set up the initial conditions for training and optionally load previous training progress. Here's what we do:

# Initialize training variables
start_epoch = 0
best_val_loss = float('inf')
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
checkpoint_path = '/content/drive/MyDrive/project/flowerresnet_checkpoint.pth'

# If a checkpoint file exists, load previous training progress
if os.path.exists(checkpoint_path):
    try:
      checkpoint = torch.load(checkpoint_path)
    except Exception as e:
      checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_val_loss = checkpoint['val_loss']
    train_losses = checkpoint['train_losses']
    val_losses = checkpoint['val_losses']
    train_accuracies = checkpoint['train_accuracies']
    val_accuracies = checkpoint['val_accuracies']

Training Loop

The heart of the training process lies in the training loop. This loop iteratively updates the model's parameters using the defined loss and optimizer. It comprises multiple epochs, allowing the model to learn from the training data during each pass.

# Training loop
num_epochs = 40

for epoch in range(start_epoch, num_epochs):
    model.train()
    total_correct = 0
    total_samples = 0
    running_loss = 0.0

    start_time = time.time()  # Start time for the epoch

    for images, labels in tqdm.tqdm(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(outputs, 1)
        total_samples += labels.size(0)
        total_correct += (predicted == labels).sum().item()
        running_loss += loss.item()

    train_loss = running_loss / len(train_loader)
    train_accuracy = 100 * total_correct / total_samples

    model.eval()
    val_loss = 0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item()

    val_loss /= len(val_loader)
    val_accuracy = 100 * total_correct / total_samples

    scheduler.step(val_loss)

    end_time = time.time()  # End time for the epoch
    epoch_time = end_time - start_time  # Total time taken for the epoch

    # Store the losses and accuracies
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)

    print(f"Epoch [{epoch + 1}/{num_epochs}], Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
    print(f"Train Accuracy: {train_accuracy:.2f}%, Validation Accuracy: {val_accuracy:.2f}%")
    print(f"Time taken for Epoch {epoch + 1}: {epoch_time:.2f} seconds")

    # Save the model and optimizer state in the checkpoint file
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_accuracies': train_accuracies,
            'val_accuracies': val_accuracies
        }, checkpoint_path)

Each training loop iteration brings our model closer to its full potential. By saving the model and optimizer state at checkpoints, we ensure that we can resume training from where we left off and even select the best model based on validation loss.

Monitoring Training Progress with Plots

Now that we've gained insights into the training process and saved our model's progress, it's time to visualize how well our custom ResNet model for flower classification is learning. We'll create plots to track training and validation metrics.

Training and Validation Loss

The first set of plots will showcase the loss values. Loss is a fundamental metric that quantifies how far off our model's predictions are from the true labels. Lower loss values indicate better model performance.

# Plotting Training and Validation Loss
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

In the graph, you'll notice two lines: one representing the training loss and the other for the validation loss. If everything goes well, you should see the training loss decreasing and hopefully the validation loss following a similar trend without significant deviations. This indicates that our model is learning effectively and not overfitting the training data.

Training and Validation Accuracy

Accuracy is another crucial metric. It measures the model's ability to make correct predictions. Higher accuracy values are indicative of a model that can classify flower images more accurately.

# Plotting Training and Validation Accuracy
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Train Accuracy')
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.show()

Similar to the loss plots, these graphs display training and validation accuracy. A well-behaved model should see both training and validation accuracy increasing as training progresses. However, if the validation accuracy starts to decrease or plateaus while training accuracy continues to rise, it could be a sign of overfitting.

By regularly examining these plots, we can make informed decisions about our model's training and potentially take corrective actions if necessary. In the next section, we'll delve into evaluating our trained model on new flower images and using it for predictions.

Making Predictions with Your FlowerResNet Model

The training journey for your custom ResNet model has been exciting, but the real magic happens when we use it to make predictions on new and unseen flower images. In this section, we'll explore how to load your trained model, preprocess an image, and make predictions.

Preparing the Model and Image

Before making predictions, we need to load our trained model and preprocess the input image. Here's how we do it:

from PIL import Image
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Function to get class names
def get_class_names(folder_path):
    class_names = []
    subfolders = [f.name for f in os.scandir(folder_path) if f.is_dir()]
    class_names = sorted(subfolders)
    return class_names

# Define the folder path containing your flower class subdirectories
folder_path = '/content/Flowers299'
class_names = get_class_names(folder_path)

def predict_image(image_path, model_path):
    # Define the image transformation
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load and transform the image
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)

    # Load the trained model
    model = FlowerResNet(num_classes=299)  # Assuming 'FlowerResNet' is the name of your custom model
    model.load_state_dict(torch.load(model_path)['model_state_dict'])  # Load the model state_dict from the checkpoint_path

    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)

    # Get the predicted flower class name
    flower_name = class_names[predicted.item()]  # 'class_names' should be defined as shown earlier

    # Display the image and its predicted label
    image = Image.open(image_path)
    plt.imshow(image)
    plt.axis('off')
    plt.title(flower_name)
    plt.show()

# Paths to example images
sunflower_image_path = '/content/sunflower_image.jpg'
acacia_image_path = '/content/acacia_image.jpg'
zenobia_image_path = '/content/Zenobia_image.jpg'

# Path to your trained model checkpoint
model_path = '/content/drive/MyDrive/project/flowerresnet_checkpoint.pth'

# Make predictions for the example images
predict_image(sunflower_image_path, model_path)
predict_image(acacia_image_path, model_path)
predict_image(zenobia_image_path, model_path)

Viewing Predictions

The code above loads the trained model, preprocesses the input image, and predicts the flower class. It then displays the original image along with the predicted label using Matplotlib.

You can use this approach to make predictions on any flower image by specifying the image path and the path to your trained model checkpoint. It's a satisfying moment to see your model in action, correctly classifying flowers from new and unseen data.

Conclusion

In this blog post, we embarked on an exciting journey into the world of deep learning and image classification. Using the PyTorch framework, we built a custom ResNet model for flower classification, allowing our computer to recognize and categorize flowers from images.

We began by acquiring the "flower299" dataset from Kaggle, which comprises a whopping 115,944 images spanning numerous flower species. We discussed essential data preprocessing steps, including dataset loading, splitting, and data augmentation, to ensure our model's effectiveness.

The heart of our journey lay in the training process. We carefully monitored the training progress, visualizing loss and accuracy curves, and employed techniques such as learning rate scheduling to fine-tune our model. Witnessing the learning journey of our custom ResNet model was both thrilling and enlightening.

As we made predictions on new and unseen flower images, our model showcased its proficiency in identifying various flower species, making our project truly come to life. Remember, you can use this model to classify flowers in any image you encounter!

The road to deep learning is never-ending, and there's always room for improvement. You might consider experimenting with different architectures, optimizing hyperparameters, or applying techniques such as transfer learning to achieve even better results.

Lastly, we encourage you to explore the complete code and experiment further with this project. You can access the Jupyter Notebook and the code used in this blog post in the GitHub repository associated with this project.

Thank you for joining us on this journey into the world of deep learning and flower classification. We hope this blog post has inspired you to delve deeper into the fascinating field of computer vision and machine learning.

Keep blooming with your AI projects!