Unlocking the Beauty of Nature with Machine Learning: Bird Species Classification

Unlocking the Beauty of Nature with Machine Learning: Bird Species Classification

·

10 min read

Nature's diversity has a way of captivating our hearts, and few creatures do it as effortlessly as the dazzling array of bird species that grace our planet. From the vibrant plumage of tropical parrots to the majestic flight of soaring eagles, birds offer us a glimpse into the splendor of the natural world. But what if we could take our appreciation for these wonders to the next level, harnessing the power of PyTorch to identify and classify them?

Join us on this exhilarating journey through the intersection of artificial intelligence and ornithology as we embark on a quest to classify 525 unique bird species. Our dataset consists of 84,635 high-quality training images, each meticulously showcasing a single bird species in all its glory. But that's not all – we also have 2,625 test images and 2,625 validation images, offering a robust evaluation of our model's performance. With an average of five images per species, our dataset provides ample opportunities for your PyTorch-powered model to soar.

This blog is more than just numbers; it's an adventure through the realms of AI and ornithology, brimming with insights, surprises, and breathtaking visuals of our feathered friends. Together, we'll dive into the captivating universe of bird species classification, unravel the enchantment of our dataset, and unleash the potential of PyTorch to view nature in a whole new light.

So, fasten your seatbelts, grab your binoculars, and let's take flight on this remarkable journey of discovery!

Downloading the Dataset

To obtain this invaluable dataset, we'll make use of a well-known data repository, the Kaggle platform. If you haven't already, you'll need to install the Kaggle Python library and authenticate using your Kaggle API credentials. These credentials grant us access to download the dataset directly into our working environment.

Download the Bird Species Dataset

First, let's install the Kaggle Python library:

!pip install -q kaggle
from google.colab import drive
drive.mount('/content/drive')

Once the dataset is downloaded, it's time to unpack the floral treasures it contains.

!kaggle datasets download -d gpiosenka/100-bird-species
!unzip 100-bird-species.zip

Let's Set Up the Stage: Library Imports

Before we take off, let's set up our environment with the necessary libraries. The following code snippet showcases the imports that will lay the foundation for our journey:

# PyTorch for deep learning magic
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision import models

# Handling data with pandas
import pandas as pd

# Image processing with Pillow
from PIL import Image

# Plotting and visualization with matplotlib and seaborn
import matplotlib.pyplot as plt
import seaborn as sns

# Working with arrays using NumPy
import numpy as np

# For dealing with the dataset from Kaggle
import os
import tqdm

Now, let's break down the significance of each library in our toolkit:

  • PyTorch: The heart of our deep learning operations, providing tools for building, training, and evaluating neural networks.

  • Pandas: A versatile data manipulation library, crucial for handling and exploring our dataset.

  • Pillow: An image processing library that will help us work with the bird images.

  • Matplotlib and seaborn: For visualizing our data and model performance.

  • NumPy: Essential for numerical operations and array manipulations.

  • OS: These help us handle the dataset, allowing us to extract and organize the data.

Loading the Dataset

The first step in our data preprocessing journey is loading and structuring the "100-bird-species". This enables us to access the images and their associated labels, setting the stage for effective training and evaluation. Let's see how this is done:

# Load and structure the dataset
train_data_path = '/content/train'
valid_data_path = '/content/valid'
batch_size = 128
num_classes = 525

# Define the data transformation pipeline
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# Load the dataset from the folder structure
train_dataset = ImageFolder(train_data_path, transform=train_transforms)
val_dataset = ImageFolder(valid_data_path, transform=val_transforms)

Creating Data Loaders

To feed our model efficiently during training, we create data loaders. Data loaders handle data loading, batching, and shuffling. This simplifies the process of training our model and ensures it receives the right data at the right time:

# Create data loaders for batch processing
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

Building ResNet Model

In our quest for bird species classification, the choice of a robust model is pivotal. We turn to ResNet, a pre-trained convolutional neural network renowned for its prowess in image classification tasks. ResNet's architecture, with residual blocks, enables the training of deep networks without the risk of vanishing gradients.

Modifying the Final Act: Customizing the Fully Connected Layer

Now, let's dive into the code that brings ResNet into our fold and tailors it for our specific task. The following snippet showcases the loading of a pre-trained ResNet model and the modification of its final fully connected layer:

# Loading a pre-trained ResNet model
model = models.resnet50(pretrained=True)

# Extracting the number of input features for the fully connected layer
num_ftrs = model.fc.in_features

# Modifying the final fully connected layer to match the number of bird species
model.fc = nn.Linear(num_ftrs, len(train_dataset.classes))

Here's a breakdown of the steps:

  1. Loading ResNet: We acquire ResNet50, a variant of ResNet with 50 layers, pre-trained on ImageNet. This provides our model with a solid foundation for recognizing intricate patterns in images.

  2. Identifying Input Features: We extract the number of input features from the existing fully connected layer. This is crucial for the subsequent customization.

  3. Customizing the Fully Connected Layer: The final act involves modifying the fully connected layer of ResNet. By replacing it with a new layer, we ensure compatibility with the number of bird species in our dataset.

Setting the Stage: Device Deployment and Optimizing

Deploying to the Device:

To unleash the potential of our ResNet model, the first step is deploying it to the right device. We opt for the simplicity and efficiency of PyTorch, ensuring seamless compatibility with our chosen hardware. The following code snippet accomplishes this pivotal task:

# Sending the model to the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using {device} for inference')
model = model.to(device)

In this succinct move, we transfer the computational prowess of our model to the specified device, be it a CPU or GPU. This step sets the stage for efficient training and evaluation perfomed on the chosen hardware.

Loss Function and Optimizer

With our model in the spotlight, it's time to guide its learning process. The choice of a loss function and optimizer lays the foundation for a successful training journey. Here's the straightforward implementation:

# Defining the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
  1. Loss Function Selection: We employ the CrossEntropyLoss, a staple for classification tasks, to quantify the disparity between predicted and actual classes. This function propels our model towards accurate species identification.

  2. Optimizer Configuration: Adam, a popular optimization algorithm, takes the reins in steering our model's parameters towards minimizing the chosen loss. The learning rate, set at 0.001, governs the step size in this intricate dance of parameter adjustments.

Checkpoint Resilience: Safeguarding Model Progress

Creating Checkpoint Path

Before we embark on the dynamic journey of training our model, let's fortify our expedition with a safety net. The following code ensures the existence of a directory dedicated to housing checkpoints:

# Creating a directory for checkpoints if it doesn't exist
checkpoint_path = '/content/drive/MyDrive/project/BirdSpecies_pretrain_model_checkpoint.pth'
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)

This simple yet crucial step establishes a haven for our model's progress, safeguarding against unforeseen interruptions and empowering us to resume training seamlessly.

Saving and Loading Checkpoints

In the unpredictable realm of machine learning training, interruptions are inevitable. Fear not, for we've equipped our model with the ability to save and load checkpoints. Here's how:

Saving the Day: A Function for Checkpoint Preservation

# Defining a function to save model checkpoints
def save_checkpoint(epoch, model, optimizer, loss, path):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }
    torch.save(checkpoint, path)

This function encapsulates the model's vital signs at a particular epoch, ensuring that even in the face of adversity, we can revive and continue our training expedition.

Loading from Checkpoints

# Loading an existing checkpoint if available
def load_checkpoint(model, optimizer, path):
    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        print(f"Loaded checkpoint. Resuming training from epoch {epoch}.")
        return model, optimizer, epoch, loss
    else:
        print("No checkpoint found. Starting from epoch 0.")
        return model, optimizer, 0, float('inf')

This dynamic duo of functions ensures that our model can gracefully recover from setbacks, allowing us to pick up where we left off. Join us in the next section as we dive into the heart of training.

Training Loop

In the vast landscape of artificial intelligence, the training loop stands as a crucial terrain where our model evolves and refines its predictive prowess. This section unveils the meticulous process of training our model using PyTorch. Buckle up as we traverse through epochs, witness the dance of loss functions, and safeguard our progress with checkpoints.

Code Implementation:

def train_model_with_checkpoint(model, train_loader, val_loader, criterion, optimizer, start_epoch, num_epochs, device, checkpoint_path):
    train_losses = []
    val_losses = []

    for epoch in range(start_epoch, num_epochs):
        start_time = time.time()  # Starting time for the epoch
        model.train()
        running_loss = 0.0

        for inputs, labels in tqdm.tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        epoch_time = time.time() - start_time  # Calculating epoch duration
        avg_train_loss = running_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, labels in tqdm.tqdm(val_loader):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)

        # Saving checkpoint at the end of each epoch
        save_checkpoint(epoch, model, optimizer, avg_val_loss, checkpoint_path)

        print(f"|Epoch {epoch + 1}/{num_epochs}|Train Loss: {avg_train_loss:.4f}|Val Loss: {avg_val_loss:.4f}|Time: {epoch_time:.2f} seconds|")

    return train_losses, val_losses

# Loading the model and optimizer, and resuming training from the last checkpoint
model, optimizer, start_epoch, prev_val_loss = load_checkpoint(model, optimizer, checkpoint_path)

# Training the model and storing loss values
num_epochs = 15
train_losses, val_losses = train_model_with_checkpoint(model, train_loader, val_loader, criterion, optimizer, start_epoch, num_epochs, device, checkpoint_path)

Monitoring Training Progress with Plots

Now that we've gained insights into the training process and saved our model's progress, it's time to visualize how well our model is learning. We'll create plots to track training and validation metrics.

Training and Validation Loss

The first set of plots will showcase the loss values. Loss is a fundamental metric that quantifies how far off our model's predictions are from the true labels. Lower loss values indicate better model performance.

# Set Seaborn style and color palette
sns.set_style('darkgrid')
sns.set_palette('pastel')

# Defineing the function to plot losses
def plot_losses(train_losses, val_losses):
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(10, 5))
    plt.plot(epochs, train_losses, label='Training Loss', marker='o')
    plt.plot(epochs, val_losses, label='Validation Loss', marker='o')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()


plot_losses(train_losses, val_losses)

In the graph, you'll notice two lines: one representing the training loss and the other for the validation loss. If everything goes well, you should see the training loss decreasing and hopefully, the validation loss following a similar trend without significant deviations. This indicates that our model is learning effectively and not overfitting the training data.

Making Predictions with Model

Embark on the final stretch of our journey into the world of wonders as we unravel the art of making predictions with our fine-tuned model. This section demystifies the process of loading saved models, preprocessing images, and predicting the species of our feathered friends.

Code Implementation:

# Loading the saved model
saved_model_path = '/content/drive/MyDrive/project/BirdSpecies_pretrain_model_checkpoint.pth'
checkpoint = torch.load(saved_model_path)

# Loading the state_dict into the model
model = models.resnet50()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(train_dataset.classes))

# Matching the keys from the state_dict with the keys in the current model
model_dict = model.state_dict()
state_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if k in model_dict}

# Loading the matched state_dict into the model
model.load_state_dict(state_dict)

# Setting the model to evaluation mode
model.eval()

# Defining a function to predict the bird species
def predict_bird(image_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    image = Image.open(image_path)
    image = transform(image)
    image = image.unsqueeze(0)  # Adding a batch dimension

    # Making the prediction
    with torch.no_grad():
        model.eval()
        outputs = model(image)
        _, predicted_idx = torch.max(outputs, 1)

    predicted_label = train_dataset.classes[predicted_idx]
    return predicted_label

# Example: Predicting bird species from an image
image_path = 'ABBOTTS_BABBLER.jpg'
predicted_label = predict_bird(image_path)
print('Predicted bird species:', predicted_label)

This incantation of code breathes life into our trained model, allowing it to decipher the identity of a bird from an image. The predict_bird function serves as our magical portal into the realm of Bird recognition, unveiling the predicted species with a simple invocation.

Conclusion

In this exhilarating journey through the fusion of artificial intelligence and ornithology, we've witnessed the power of PyTorch in classifying 525 unique bird species. Our model showcases the seamless integration of machine learning into the realm of nature exploration, born from cutting-edge technology and a dataset brimming with the ResNet Model.