Utilizing PyTorch and Allegro Trains for Transfer Learning in Computer Vision

Posted by


Transfer learning is a machine learning technique where a model trained on one task is re-purposed on a second related task. This approach is particularly useful in computer vision tasks where there is a need to quickly train high-quality models on limited data.

In this tutorial, we will walk through how to implement transfer learning in computer vision using PyTorch and Allegro Trains. PyTorch is a powerful deep learning library that provides the flexibility and efficiency needed for building and training complex neural networks. Allegro Trains is a machine learning experiment and model management platform that helps track and monitor experiments, collaborate with team members, and deploy models to production.

To get started, make sure you have PyTorch and Allegro Trains installed on your local machine. You can install PyTorch using pip:

pip install torch torchvision

You can install Allegro Trains using pip as well:

pip install trains

Now that you have the necessary dependencies installed, let’s create a simple convolutional neural network (CNN) for image classification using PyTorch. We will use the popular CIFAR-10 dataset, which consists of 60,000 32×32 color images in 10 classes.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 16 * 16, 10)

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = x.view(-1, 64 * 16 * 16)
        x = self.fc1(x)
        return x

model = SimpleCNN()

Next, we will load the CIFAR-10 dataset and define data loaders for training and testing:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

Now we will train the model using the training data and evaluate its performance on the testing data:

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(10):
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Epoch {epoch + 1}, Accuracy: {accuracy}%')

torch.save(model.state_dict(), 'simple_cnn.pth')

At this point, we have trained a simple CNN model on the CIFAR-10 dataset. Next, we will demonstrate how to use transfer learning to improve the model’s performance by fine-tuning a pre-trained model on the same dataset.

We will use a pre-trained ResNet-18 model available in PyTorch’s torchvision library for transfer learning:

import torchvision.models as models

pretrained_model = models.resnet18(pretrained=True)
pretrained_model.fc = nn.Linear(pretrained_model.fc.in_features, 10)

We will freeze all the layers in the pre-trained model except for the final fully connected layer, which we will fine-tune on the CIFAR-10 dataset:

for param in pretrained_model.parameters():
    param.requires_grad = False

for param in pretrained_model.fc.parameters():
    param.requires_grad = True

We will define a new optimizer and criterion for fine-tuning the model:

optimizer = optim.SGD(pretrained_model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

We will train the fine-tuned model on the CIFAR-10 dataset:

for epoch in range(10):
    pretrained_model.train()
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = pretrained_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    pretrained_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = pretrained_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Epoch {epoch + 1}, Accuracy: {accuracy}%')

torch.save(pretrained_model.state_dict(), 'resnet18_finetuned.pth')

In this tutorial, we have demonstrated how to implement transfer learning in computer vision using PyTorch and Allegro Trains. By leveraging pre-trained models and fine-tuning them on specific datasets, we can achieve higher accuracy and faster convergence in training deep learning models. Allegro Trains provides a platform for tracking and managing machine learning experiments, making it easier to collaborate with team members and deploy models to production.

0 0 votes
Article Rating
5 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@rafailmahammadli8665
2 months ago

Hello, I get error when I run" task = Task.init(project_name='Pytorch Transfer Learning',task_name='Resnet18')"

Errors says : It is required that you pass in a value for the "algorithms" argument when calling decode().

Can you help me about it?

Thanks

@vasylcf
2 months ago

Thanks !

@EngineeringNibbles
2 months ago

Very useful

@sumneetkaurbamrah1982
2 months ago

Thank you for introducing Allegro. It is very useful and the explanation was excellent!

@vishalm2338
2 months ago

First Viewer:-)

Thank you so much for sharing !!