Object detection is a crucial task in computer vision that involves identifying and locating objects in an image. One popular algorithm for object detection is Faster R-CNN (Region-based Convolutional Neural Networks), which combines a region proposal network (RPN) to generate region proposals and a detection network to classify and refine these proposals.
In this tutorial, we will walk through the process of finetuning a Faster R-CNN model using PyTorch on a custom dataset for wheat detection. This tutorial assumes you have a basic understanding of PyTorch and deep learning concepts.
Step 1: Setting up the Environment
First, you need to install PyTorch and torchvision. You can install them using pip:
pip install torch torchvision
Next, install the Torchvision Object Detection library using the following command:
pip install torchvision
Step 2: Prepare the Dataset
For this tutorial, we will be using the Global Wheat Detection dataset from Kaggle. You can download the dataset from here: https://www.kaggle.com/c/global-wheat-detection/data
After downloading the dataset, extract the files and organize them into a folder structure like this:
- custom_dataset
- train
- image1.jpg
- image2.jpg
- train.csv
The train.csv file contains annotations for each image, including bounding box coordinates and class labels. You will need to preprocess the dataset to convert it into a format that PyTorch can use for training.
Step 3: Define Custom Datasets and Dataloaders
In PyTorch, you need to create a custom dataset class to load and preprocess the dataset. Here’s an example implementation for the wheat dataset:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
class WheatDataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None):
self.annotations = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
img_path = os.path.join(self.root_dir, self.annotations.iloc[idx, 0])
image = Image.open(img_path).convert("RGB")
boxes = self.annotations.iloc[idx, 1:5].values
boxes = torch.tensor([boxes], dtype=torch.float32)
labels = torch.tensor([1], dtype=torch.int64) # We have only one class label for wheat
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
target = {
"boxes": boxes,
"labels": labels,
"image_id": torch.tensor([idx]),
"area": area,
"iscrowd": torch.tensor([0])
if self.transform:
image, target = self.transform(image, target)
return image, target
Next, create dataloaders for training and validation using the defined dataset class:
train_dataset = WheatDataset(csv_file='train.csv', root_dir='custom_dataset/train')
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4)
valid_dataset = WheatDataset(csv_file='valid.csv', root_dir='custom_dataset/valid')
valid_dataloader = DataLoader(valid_dataset, batch_size=2, shuffle=False, num_workers=4)
Step 4: Define the Faster R-CNN Model
Next, you need to define the Faster R-CNN model for finetuning. You can use a pre-trained model from torchvision.models as a backbone and replace the classification and regression heads with new ones for your custom dataset.
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 2 # Wheat and background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
Step 5: Finetune the Model
Now it’s time to finetune the model on your custom dataset. Set up the optimizer, learning rate scheduler, and loss function for training:
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.transforms as T
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=5, eta_min=0.00001)
loss_fn = torch.nn.CrossEntropyLoss()
Now, train the model over multiple epochs and evaluate on the validation set:
num_epochs = 10
for epoch in range(num_epochs):
total_loss = 0
for images, targets in train_dataloader:
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
loss = model(images, targets)
total_loss += loss.item()
print(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss / len(train_dataloader)}")
# Evaluate on the validation set
Step 6: Inference
Once the model is trained and evaluated, you can use it for inference on new images:
import numpy as np
import matplotlib.pyplot as plt
def plot_image(image_tensor, target):
image = image_tensor.cpu().numpy().transpose((1, 2, 0)) * 255
image = np.clip(np.asarray(image, np.uint8), 0, 255)
for box in target["boxes"]:
x, y, w, h = box.cpu().numpy()
rect = plt.Rectangle((x, y), w-x, h-y, fill=False, edgecolor='red', linewidth=2)
images, targets = next(iter(valid_dataloader))
outputs = model(images)
for i in range(len(images)):
plot_image(images[i], targets[i])
And that’s it! You have successfully finetuned a Faster R-CNN model on a custom wheat detection dataset using PyTorch. Feel free to explore different hyperparameters, loss functions, and model architectures to improve the performance further.
