Enhancing Faster-RCNN with PyTorch for Object Detection: Custom Dataset for Wheat Detection

Posted by


Object detection is a crucial task in computer vision that involves identifying and locating objects in an image. One popular algorithm for object detection is Faster R-CNN (Region-based Convolutional Neural Networks), which combines a region proposal network (RPN) to generate region proposals and a detection network to classify and refine these proposals.

In this tutorial, we will walk through the process of finetuning a Faster R-CNN model using PyTorch on a custom dataset for wheat detection. This tutorial assumes you have a basic understanding of PyTorch and deep learning concepts.

Step 1: Setting up the Environment
First, you need to install PyTorch and torchvision. You can install them using pip:

pip install torch torchvision

Next, install the Torchvision Object Detection library using the following command:

pip install torchvision

Step 2: Prepare the Dataset
For this tutorial, we will be using the Global Wheat Detection dataset from Kaggle. You can download the dataset from here: https://www.kaggle.com/c/global-wheat-detection/data

After downloading the dataset, extract the files and organize them into a folder structure like this:

- custom_dataset
  - train
    - image1.jpg
    - image2.jpg
    ...
  - train.csv

The train.csv file contains annotations for each image, including bounding box coordinates and class labels. You will need to preprocess the dataset to convert it into a format that PyTorch can use for training.

Step 3: Define Custom Datasets and Dataloaders
In PyTorch, you need to create a custom dataset class to load and preprocess the dataset. Here’s an example implementation for the wheat dataset:

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd

class WheatDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.annotations.iloc[idx, 0])
        image = Image.open(img_path).convert("RGB")
        boxes = self.annotations.iloc[idx, 1:5].values
        boxes = torch.tensor([boxes], dtype=torch.float32)
        labels = torch.tensor([1], dtype=torch.int64)  # We have only one class label for wheat
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([idx]),
            "area": area,
            "iscrowd": torch.tensor([0])
        }

        if self.transform:
            image, target = self.transform(image, target)

        return image, target

Next, create dataloaders for training and validation using the defined dataset class:

train_dataset = WheatDataset(csv_file='train.csv', root_dir='custom_dataset/train')
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4)

valid_dataset = WheatDataset(csv_file='valid.csv', root_dir='custom_dataset/valid')
valid_dataloader = DataLoader(valid_dataset, batch_size=2, shuffle=False, num_workers=4)

Step 4: Define the Faster R-CNN Model
Next, you need to define the Faster R-CNN model for finetuning. You can use a pre-trained model from torchvision.models as a backbone and replace the classification and regression heads with new ones for your custom dataset.

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

num_classes = 2  # Wheat and background
in_features = model.roi_heads.box_predictor.cls_score.in_features

model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

Step 5: Finetune the Model
Now it’s time to finetune the model on your custom dataset. Set up the optimizer, learning rate scheduler, and loss function for training:

import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.transforms as T

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=5, eta_min=0.00001)
loss_fn = torch.nn.CrossEntropyLoss()

Now, train the model over multiple epochs and evaluate on the validation set:

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for images, targets in train_dataloader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()
        loss = model(images, targets)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()

    lr_scheduler.step()

    print(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss / len(train_dataloader)}")

# Evaluate on the validation set
model.eval()

Step 6: Inference
Once the model is trained and evaluated, you can use it for inference on new images:

import numpy as np
import matplotlib.pyplot as plt

def plot_image(image_tensor, target):
    image = image_tensor.cpu().numpy().transpose((1, 2, 0)) * 255
    image = np.clip(np.asarray(image, np.uint8), 0, 255)
    plt.imshow(image)

    for box in target["boxes"]:
        x, y, w, h = box.cpu().numpy()
        rect = plt.Rectangle((x, y), w-x, h-y, fill=False, edgecolor='red', linewidth=2)
        plt.gca().add_patch(rect)

    plt.show()

images, targets = next(iter(valid_dataloader))

outputs = model(images)
for i in range(len(images)):
    plot_image(images[i], targets[i])

And that’s it! You have successfully finetuned a Faster R-CNN model on a custom wheat detection dataset using PyTorch. Feel free to explore different hyperparameters, loss functions, and model architectures to improve the performance further.

0 0 votes
Article Rating

Leave a Reply

22 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@darshandushing5985
3 days ago

Sir i want to use RCNN for number plate detection, how to do this ?

Please reply

@rohithgamer7340
3 days ago

I mean come on, where are links?

@_maha_
3 days ago

how you evaluate the model

@studshelper9901
3 days ago

now found the diamond from coal field,
i used to detect text from image
Thanks a lot😀😀

@luisleal4169
3 days ago

great video! chan you share the code and dataset? do you have any github repo?

@SumanMondal-qz4qd
3 days ago

in the last cell where you have drawn the output image you didn't draw the boxes form the output you used the prior bounding boxes which are meant to use for training and validation…. I found this misleading

@دانهالسبيعي-ن8ه
3 days ago

How can I contact you? I want to clarify some points

@AbhiramMirthipati
3 days ago

How to get the accuracy of the model get printed ?

@nra6925
3 days ago

can you make tutorial for retinet model?

@nasimaislambithi9510
3 days ago

uploads the code link pls

@JahnaviYadrami
3 days ago

hi

@geekroy1728
3 days ago

bro posts only videos, but not github links🤣🤣🤣

@vijayalaxmiise1504
3 days ago

Sir, video has clarified many doubts. Thanks a lot

@tilakrajchoubey5534
3 days ago

will it work if my images are of different lengths??

@raehanfelda8956
3 days ago

can you make video about how to tuning yolov8 hyperparameter using raytune?

@brunonogueirarenzo2306
3 days ago

Great content, Datum!

@atamsujiwanto1840
3 days ago

how if we use 3 labels, sir? can you explain the code for 3 labels?

@karlm9584
3 days ago

I receive an error during setting up the targets for the model training in the for loop – for d in data: … targ["boxes"] = d[1]["boxes"].to(device) causes an error. On inspection, len(d) returns 1 so the index of d[1] is invalid. How to solve this error please? Any attempts to solve it results in the input data being incompatible with the model

@bholuramgurjar900
3 days ago

Good job! can you share the code?

@rv0_0
3 days ago

link to the code please

22
0
Would love your thoughts, please comment.x
()
x