PyTorch Lightning: Understanding the Lightning Module

Posted by


In PyTorch Lightning, a Lightning Module is a subclass of pl.LightningModule that encapsulates all the logic for your model, training, validation, and testing. It simplifies the process of writing and managing your PyTorch code, making it more organized and readable. In this tutorial, we will discuss how to create a Lightning Module and implement its essential methods.

  1. Install PyTorch Lightning:
    Before we start, make sure you have PyTorch Lightning installed. You can install it via pip:

    pip install pytorch-lightning
  2. Create a Lightning Module:
    To create a Lightning Module, you need to define a class that inherits from pl.LightningModule and implement the necessary methods. Here’s a template for a basic Lightning Module:
import torch
import torch.nn as nn
import pytorch_lightning as pl

class MyLightningModule(pl.LightningModule):
    def __init__(self):
        super(MyLightningModule, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In this template, we define a simple neural network model with two linear layers for a classification task. We implement the forward method to define the forward pass of the model. The training_step method is used to compute the loss for each batch during training. Finally, the configure_optimizers method returns the optimizer used for training the model.

  1. Define Dataloaders:
    To train the Lightning Module, you need to create PyTorch DataLoader objects to load the data. Here’s an example of how to define dataloaders for a PyTorch dataset:
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

# Load MNIST dataset
dataset = MNIST('./data', train=True, transform=ToTensor(), download=True)

# Split dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
  1. Training and Testing:
    After defining the Lightning Module and dataloaders, you can train and test the model using a Trainer object provided by PyTorch Lightning. Here’s an example of how to train the model:
# Create Lightning Module instance
model = MyLightningModule()

# Create Trainer object
trainer = pl.Trainer(max_epochs=10, gpus=1)  # Use GPU for training

# Train the model
trainer.fit(model, train_loader, val_loader)
  1. Testing the Model:
    You can test the trained model using the trainer.test method with a test dataloader:
test_loader = DataLoader(MNIST('./data', train=False, transform=ToTensor(), download=True), batch_size=32)
trainer.test(model, test_dataloaders=test_loader)

That’s it! You have now created a Lightning Module for your PyTorch model and trained it using PyTorch Lightning. This approach simplifies the codebase, making it easier to manage and scale your deep learning projects. Feel free to explore more advanced features of PyTorch Lightning to further enhance your training pipeline.

0 0 votes
Article Rating
15 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@SteveWithers-cs6vq
1 month ago

Thanks a lot man this is great help

@mohammadrezababaee4701
1 month ago

What is your development environment

@freebabe4real
1 month ago

got introduced to pytorch lightning recently through my studies (no prior programming background) and was having quite some trouble getting the hang of it. this really helped it feel more approachable, thank you!

@michaelrockinger
1 month ago

Thanks for making those videos. However, I just can't read what you type. Could you please, for your future videos, enlarge the code. It is nice to also see the presenter, since there is space on the screen I am not saying you should not appear. Again thank you.

@user-lg9fu3qs5j
1 month ago

Hello Aladdin, thanks for your contribution… Is it possible to make a tutorial for Inference on production and Transfer Learning / Fine tuning ?

@tfaktas
1 month ago

I have a nn.Module class that is composed of multiple smaller nn.Module classes as parts of the forward process. I see that I could define the training,validation and test sets, as well as configuring the optimizer for it. But how do I define a Lightning module for these sub-component modules?

@kaspertoftbraun7324
1 month ago

Awesome video. I'm quite new in Python. What IDE do you use? I like the autocomplete-functionality.

@thepresistence5935
1 month ago

Dude what IDE are you using?

@maximklechshev6675
1 month ago

Thank you! I pretty much like your .zshrc file as well as .vimrc. Can you share these awesome themes & settings?

@anaydongre1226
1 month ago

Short video but loved it.
Make more videos on Graph Neural Networks.

@frankrobert9199
1 month ago

great lectures.

@correct_me_if_i_am_wrong
1 month ago

Thank you for the video. All your videos help in deeper understanding of the framework. I have been working with Pytorch TorchRec recently and would like to see a similar video on that. 🙂

@balramagnihotri3675
1 month ago

Very few of your videos that I didn't had to pause to understand what's going on😅

@hamzawi2752
1 month ago

Amazzzzzzzzzzzzing 🙂

@holthuizenoemoet591
1 month ago

Thank you for the clear example, I like how much cleaner lightning is, almost feels like Django (but hopefully a bit more flexible). Btw, can you recommand a good framework for ML research, for example to experimenting with slightly altered transformer architectures?