Implementing ResNet from Scratch using PyTorch

Posted by


ResNet (or Residual Network) is a deep neural network architecture that was introduced by Kaiming He et al. in their paper "Deep Residual Learning for Image Recognition." ResNet is known for its ability to train very deep networks more effectively by utilizing skip connections or shortcuts, which help in bypassing the vanishing gradient problem.

In this tutorial, we will implement a ResNet architecture using the PyTorch deep learning library from scratch. This tutorial assumes that you have some basic knowledge of deep learning concepts and are familiar with PyTorch.

  1. Import necessary libraries:
    First, we need to import the necessary libraries for our implementation. We will be using the torch and torch.nn libraries, which provide the tools for building and training deep neural networks.
import torch
import torch.nn as nn
import torch.nn.functional as F
  1. Define the basic building blocks:
    We will start by defining the basic building blocks of a ResNet architecture, which include the convolutional layer, batch normalization, and the residual block.
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In the BasicBlock class, we define the convolutional layers with batch normalization and the residual connections with skip connections. The expansion factor is set to 1 for the basic block. The forward method specifies how the input is passed through the block.

  1. Define the full ResNet architecture:
    Next, we will define the full ResNet architecture by stacking multiple BasicBlocks together. We will create different versions of ResNet, such as ResNet18, ResNet34, ResNet50, ResNet101, and ResNet152, based on the number of layers.
class ResNet(nn.Module):

    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In the ResNet class, we define the full ResNet architecture by stacking multiple BasicBlocks together in the _make_layer method. The forward method specifies how the input is passed through the layers.

  1. Create an instance of the ResNet model:
    Now, we can create an instance of the ResNet model with the desired parameters. For example, to create a ResNet18 model, we can use the following code:
def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

This code creates a ResNet18 model with 2 basic blocks per layer, resulting in a total of 18 layers.

  1. Train and test the ResNet model:
    Finally, we can train and test the ResNet model on a dataset. We can use the CIFAR-10 dataset as an example. Below is the code to train and test the model on the CIFAR-10 dataset:
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=2)

model = ResNet18()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

for epoch in range(10):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:    # print every 100 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

In this code snippet, we load the CIFAR-10 dataset, create an instance of the ResNet18 model, define the loss function and optimizer, and train the model for 10 epochs. We then test the model on the test set and report the accuracy.

And that’s it! You have now successfully implemented a ResNet architecture using PyTorch from scratch. You can further customize the model by adding more layers or experimenting with different hyperparameters. Happy coding!

0 0 votes
Article Rating
39 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@Bunny-eh4ji
1 month ago

Thank you for tutorial. You're a real mad lad for this.

@kamikamen_official
1 month ago

Will give this a look.

@lalithaevani5942
1 month ago

For the Stride part for down-sampling in each layer, in the paper, it is written to down-sample at conv3_1, conv4_1 and conv5_1. If I understand your code correctly does it mean that there are conv3_0, conv4_0, and conv5_0 and hence stride of 2 is applied to the second block in each layer?

@anthonydavid6578
1 month ago

i expect the whole reimplement that includes the dataset preprocessing, training code, visualization and so on; is there any of these videoes?

@ywy6810
1 month ago

Thanks sir you are so kind

@user-jq5rg9ue8g
1 month ago

NotImplementedError: Module [ResNet] is missing the required "forward" function
getting this error anyone can tell about it
when i use
def test():

net = ResNet152()

x = torch.randn(2, 3, 224, 224)

y = net(x).to(device)

print(y.shape)

test(

@krisnaargadewa4376
1 month ago

why bias not TRUE?

@Ssc2969
1 month ago

Hi, thanks a lot for this tutorial. This code is extremely helpful. If I use part of this code in my project and cite your GitHub link if my paper gets published, would that be, okay? Please let me know. Thanks!

@minhajuddinansari561
1 month ago

In the condition:
if stride != 1 or self.in_channels != out_channels*4

shouldn't it instead be self.out_channels != in_channels*4

EDIT: Oh you clarified that out_channels is out_channels * expansion

@science.20246
1 month ago

the final layer software or we keep fc and go forward ?

@abhisekpanigrahi1033
1 month ago

Hello Aladdin , Can you please make video explaining the concept of _make_layer function. It is really confusing.

@kostiantynbevzuk3807
1 month ago

Sry probably for stupid question, but dont we need to pass stride as parameter in `class.block.conv1` and set padding to 1, and `block.conv2` to stride=1 and padding to 0 instead? Or am I missing something from original paper?

@doggydoggy578
1 month ago

Omg I don't know what happens but no what why I try, the code return the same error :
<ipython-input-44-a6a501bb55e0> in forward(self, x)

33 print(x.shape,identity.shape)

34 print('is identity_downsamples none ?', self.identity_downsamples==None)

—> 35 x += identity

36 x = self.relu(x)

37

RuntimeError: The size of tensor a (256) must match the size of tensor b (64) at non-singleton dimension 1

Help please
I have re check my code multiple times and make sure it is exactly as yours but to no avail I can't make it to work. 🙁 I run on Colab btw

@amegatron07
1 month ago

I love the idea of residual layers. Not taking math into account, on a higher level it intuitively seems useful, because with usual layers, the low-level information gets lost from layer to layer. But with skip-connections, we keep track of lower-level information, sort of. Unfortunately, I can't now remember the IRL-example to depict this, but in general it is the same: while constructing something high-level, we don't only need to see what we have at this high-level, but also need to keep track of some lower-level steps we're performing.

@ArunKumar-sg6jf
1 month ago

l learnt how u applied pading 0 and padding = 1

@user-vn7fs1qp6i
1 month ago

Hi Aladdin! Thanks so much for a great content. I had a quick question at aroud 3:50 (calculating the padding). I'm looking at this formula [(W−K+2P)/S]+1 that people often use to calculate the output size, and tried letting W = 7, K = 3, S = 2 etc, but I just don't see how a P=3 would get us an output of 112. How can I calculate/estimate padding sizes from input and output sizes (+ kernel sizes, steps, etc)?

@aneekaazmat6653
1 month ago

Hello ,
Your video was very interesting for me as I am just using resnet first time. But I have a question about how we can use it for audio classification, I have to do boundary detection in music files. My mel spectograms shape is not actually same for all files it is (80,1789) , (80,3356) , and so on. means the 2nd dimension is changing at ever song. so how can I use this kind of mel spectograms for RESNET?
Can you pleas make a video for audio classification using RESNET?

@LalitPandeyontube
1 month ago

I am trying to prune the residual blocks such that my resnet will have 3 residual blocks.. but I keep on getting an error with mat dimensions.

@OlegKorsak
1 month ago

u can do super().__init__()

@sourodipkundu8421
1 month ago

I didn't understand where did you implement the skip connection?