Understanding ResNet: Explanation and PyTorch Implementation

Posted by


In recent years, deep learning has gained a lot of popularity in the field of artificial intelligence. With the advancements in hardware and algorithms, researchers are able to train deep neural networks with thousands or even millions of parameters to achieve state-of-the-art performance on various tasks such as image classification, object detection, and natural language processing.

One of the most influential neural network architectures in the deep learning community is the Residual Network (ResNet). Introduced by Kaiming He et al. in their paper "Deep Residual Learning for Image Recognition" in 2015, ResNet has revolutionized the way deep neural networks are designed and trained.

In this tutorial, we will walk through the key concepts behind ResNet and provide a step-by-step guide on how to implement ResNet in PyTorch.

ResNet: Key Concepts

  1. Vanishing Gradient Problem: One of the main challenges in training deep neural networks is the vanishing gradient problem. As the network gets deeper, the gradients tend to become very small during backpropagation, making it difficult for the network to learn meaningful representations.

  2. Skip Connections: ResNet addresses the vanishing gradient problem by introducing skip connections, also known as identity mappings. Instead of having the output of each layer directly connected to the next layer, ResNet adds a skip connection that bypasses one or more layers by adding the input to the output.

  3. Residual Blocks: The basic building block of a ResNet is the residual block. A residual block consists of two convolutional layers with batch normalization and ReLU activation functions, along with a skip connection that adds the input to the output. This allows the network to learn residual functions instead of directly mapping input to output.

  4. Bottleneck Architecture: To reduce the computational cost of training deep networks, ResNet introduces a bottleneck architecture for deeper networks. The bottleneck architecture consists of three convolutional layers: a 1×1 convolutional layer followed by a 3×3 convolutional layer and another 1×1 convolutional layer. This helps reduce the number of parameters and computational cost while maintaining performance.

ResNet Implementation in PyTorch

Now, let’s move on to implementing ResNet in PyTorch. We will use the CIFAR-10 dataset for this tutorial, which consists of 60,000 32×32 color images in 10 classes.

  1. Import Libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
  1. Define Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.shortcut(x)
        out = self.relu(out)

        return out
  1. Define ResNet Model
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init()
        self.in_channels = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)

        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)

        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        return out
  1. Define Hyperparameters and Data Loaders
device = 'cuda' if torch.cuda.is_available() else 'cpu'

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  1. Training the Model
model = ResNet(ResidualBlock, [2, 2, 2])
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

def train_model(model, criterion, optimizer, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch {epoch+1}, Loss: {running_loss / len(trainloader)}')

    print('Finished Training')
  1. Evaluate the Model
def evaluate_model(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy on test set: %d %%' % (100 * correct / total))
  1. Training and Evaluating the Model
train_model(model, criterion, optimizer, num_epochs=20)
evaluate_model(model)

By following this tutorial, you should now have a better understanding of ResNet and how to implement it in PyTorch. ResNet has been proven to be an effective neural network architecture for various computer vision tasks, and with the power of PyTorch, you can easily experiment with different variations of ResNet and train your own state-of-the-art models. Happy coding!

0 0 votes
Article Rating

Leave a Reply

5 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@rahulnakka87
21 days ago

I think there is slight mistake in the ResidualBlock- you have fed x as input for self.c2 and self.c3 convlayers it should be f=self.relu(self.c2(f)); f=self.c3(f)

@mwont
21 days ago

Thank you for perfect explanation. It works very good. Maybe you could do a similar video about GANs?

@mohamedel-hadidy4844
21 days ago

great explanation, thanks

@gabriellafernandes8602
21 days ago

amazing! really helpful, thank you 🙂

@jak47l45
21 days ago

very helpful, thanks

5
0
Would love your thoughts, please comment.x
()
x