In PyTorch Lightning, a Lightning Module is a subclass of pl.LightningModule
that encapsulates all the logic for your model, training, validation, and testing. It simplifies the process of writing and managing your PyTorch code, making it more organized and readable. In this tutorial, we will discuss how to create a Lightning Module and implement its essential methods.
-
Install PyTorch Lightning:
Before we start, make sure you have PyTorch Lightning installed. You can install it via pip:pip install pytorch-lightning
- Create a Lightning Module:
To create a Lightning Module, you need to define a class that inherits frompl.LightningModule
and implement the necessary methods. Here’s a template for a basic Lightning Module:
import torch
import torch.nn as nn
import pytorch_lightning as pl
class MyLightningModule(pl.LightningModule):
def __init__(self):
super(MyLightningModule, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
In this template, we define a simple neural network model with two linear layers for a classification task. We implement the forward
method to define the forward pass of the model. The training_step
method is used to compute the loss for each batch during training. Finally, the configure_optimizers
method returns the optimizer used for training the model.
- Define Dataloaders:
To train the Lightning Module, you need to create PyTorch DataLoader objects to load the data. Here’s an example of how to define dataloaders for a PyTorch dataset:
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
# Load MNIST dataset
dataset = MNIST('./data', train=True, transform=ToTensor(), download=True)
# Split dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
- Training and Testing:
After defining the Lightning Module and dataloaders, you can train and test the model using aTrainer
object provided by PyTorch Lightning. Here’s an example of how to train the model:
# Create Lightning Module instance
model = MyLightningModule()
# Create Trainer object
trainer = pl.Trainer(max_epochs=10, gpus=1) # Use GPU for training
# Train the model
trainer.fit(model, train_loader, val_loader)
- Testing the Model:
You can test the trained model using thetrainer.test
method with a test dataloader:
test_loader = DataLoader(MNIST('./data', train=False, transform=ToTensor(), download=True), batch_size=32)
trainer.test(model, test_dataloaders=test_loader)
That’s it! You have now created a Lightning Module for your PyTorch model and trained it using PyTorch Lightning. This approach simplifies the codebase, making it easier to manage and scale your deep learning projects. Feel free to explore more advanced features of PyTorch Lightning to further enhance your training pipeline.
Thanks a lot man this is great help
What is your development environment
got introduced to pytorch lightning recently through my studies (no prior programming background) and was having quite some trouble getting the hang of it. this really helped it feel more approachable, thank you!
Thanks for making those videos. However, I just can't read what you type. Could you please, for your future videos, enlarge the code. It is nice to also see the presenter, since there is space on the screen I am not saying you should not appear. Again thank you.
Hello Aladdin, thanks for your contribution… Is it possible to make a tutorial for Inference on production and Transfer Learning / Fine tuning ?
I have a nn.Module class that is composed of multiple smaller nn.Module classes as parts of the forward process. I see that I could define the training,validation and test sets, as well as configuring the optimizer for it. But how do I define a Lightning module for these sub-component modules?
Awesome video. I'm quite new in Python. What IDE do you use? I like the autocomplete-functionality.
Dude what IDE are you using?
Thank you! I pretty much like your .zshrc file as well as .vimrc. Can you share these awesome themes & settings?
Short video but loved it.
Make more videos on Graph Neural Networks.
great lectures.
Thank you for the video. All your videos help in deeper understanding of the framework. I have been working with Pytorch TorchRec recently and would like to see a similar video on that. 🙂
Very few of your videos that I didn't had to pause to understand what's going on😅
Amazzzzzzzzzzzzing 🙂
Thank you for the clear example, I like how much cleaner lightning is, almost feels like Django (but hopefully a bit more flexible). Btw, can you recommand a good framework for ML research, for example to experimenting with slightly altered transformer architectures?