Regularization Techniques in Neural Networks: L2 Regularization and Weight Decay – Theory and Implementation in PyTorch

Posted by


Introduction:
In this tutorial, we will discuss L2 regularization, also known as weight decay, in the context of neural networks. L2 regularization is a commonly used technique to prevent overfitting in neural networks by adding a penalty term to the loss function that discourages large weights. We will explain the theory behind L2 regularization and how to implement it in PyTorch.

Theory:
L2 regularization adds a term to the loss function that penalizes large weights by adding the squared sum of all weight values. The total loss function is then the sum of the original loss function and the L2 regularization term scaled by a hyperparameter lambda:

Loss_total = Loss_original + lambda * ||w||^2

Where ||w||^2 is the L2 norm of the weights matrix w. By adding this term to the loss function, the optimization process is encouraged to find smaller weights that generalize better to unseen data.

Implementation in PyTorch:
Now, let’s implement L2 regularization in PyTorch. Below is a simple neural network example using L2 regularization:

import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple neural network class
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Load the MNIST dataset
train_loader, test_loader = load_mnist()

# Initialize the model and optimizer
model = SimpleNN()
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.01)

# Define the loss function
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

In this code snippet, we define a simple neural network with two fully connected layers. We then load the MNIST dataset and initialize the model, optimizer, and loss function. In the optimizer definition, we specify the weight_decay parameter, which is equivalent to the lambda parameter in the theory section. This parameter controls the strength of the L2 regularization term.

In the training loop, we calculate the loss and perform backpropagation as usual. The weight_decay parameter in the optimizer takes care of adding the L2 regularization term to the loss function automatically.

Conclusion:
In this tutorial, we have discussed L2 regularization, its theory, and how to implement it in PyTorch. L2 regularization is a powerful technique to prevent overfitting in neural networks by penalizing large weights. By incorporating L2 regularization into your model, you can improve generalization performance and make your neural network more robust to unseen data.

0 0 votes
Article Rating
5 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@christianondo9637
1 month ago

Very well explained, thank you sir

@user-vl2fp9ye5n
1 month ago

Great information. I have doubt i want to understand what happens if I introduce l1 and l2 regularization in parallel training of model? Does it make any difference? As I am working on pipeline parallel for autoencoder and i want to add these l1 and l2 to parallel

@leevroko
1 month ago

Great content, you deserve much more views!

@caiyu538
1 month ago

Great. It looks PyTorch does not provide build in l1 l2 norm as keras does.

@liamgaeuman8518
1 month ago

👌