Callbacks in PyTorch Lightning #7

Posted by


In PyTorch Lightning, callbacks are used to customize the behavior of your training loop. Callbacks are classes that can be attached to the Trainer class and can be used to perform various tasks before, after, or during training. In this tutorial, we will explore how to create and use callbacks in PyTorch Lightning and how they can help you to improve your training process.

Step 1: Creating a Callback Class
To create a custom callback, you need to define a new class that inherits from the Callback base class provided by PyTorch Lightning. Here is an example of a simple callback class that prints a message before each epoch:

from pytorch_lightning.callbacks import Callback

class CustomCallback(Callback):
    def on_epoch_start(self, trainer, pl_module):
        print("Starting epoch...")

In this example, the CustomCallback class overrides the on_epoch_start method, which is called at the beginning of each epoch. You can override other methods like on_train_start, on_train_end, on_batch_start, on_batch_end, etc., to perform different tasks at different stages of the training process.

Step 2: Attaching the Callback to the Trainer
Once you have created your custom callback class, you can attach it to the Trainer class using the callbacks argument. You can pass a list of callback instances to the callbacks argument to attach multiple callbacks to the Trainer. Here is an example of how to attach the CustomCallback to the Trainer:

from pytorch_lightning import Trainer

model = MyModel()
trainer = Trainer(callbacks=[CustomCallback()])
trainer.fit(model)

In this example, we create an instance of MyModel and attach the CustomCallback to the Trainer. Now, whenever we call the fit method on the Trainer, the CustomCallback will be executed at the beginning of each epoch.

Step 3: Using Built-in Callbacks
PyTorch Lightning also provides a number of built-in callbacks that you can use without creating custom callback classes. Some of the most commonly used built-in callbacks include ModelCheckpoint, EarlyStopping, and LearningRateLogger.

The ModelCheckpoint callback saves the model weights at the end of each epoch, while the EarlyStopping callback stops training early if a certain metric does not improve for a number of epochs. The LearningRateLogger callback logs the learning rate at the end of each epoch.

You can use these built-in callbacks by passing them as strings to the callbacks argument of the Trainer class. For example:

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateLogger

model = MyModel()
trainer = Trainer(callbacks=[ModelCheckpoint(), EarlyStopping(), LearningRateLogger()])
trainer.fit(model)

In this example, we attach the ModelCheckpoint, EarlyStopping, and LearningRateLogger callbacks to the Trainer without creating custom callback classes.

Step 4: Modifying Callback Behavior
You can modify the behavior of a callback by passing arguments to the constructor of the callback class. For example, you can set the filename for the ModelCheckpoint callback to save the model weights to a specific file:

model_checkpoint = ModelCheckpoint(filename='model-{epoch:02d}-{val_loss:.2f}')
trainer = Trainer(callbacks=[model_checkpoint])

In this example, the ModelCheckpoint callback will save the model weights to a file with the format ‘model-.pth’ at the end of each epoch.

You can also modify the behavior of a callback by overriding methods or properties of the callback class. For example, you can modify the behavior of the ModelCheckpoint callback by overriding the on_epoch_end method:

class CustomModelCheckpoint(ModelCheckpoint):
    def on_epoch_end(self, trainer, pl_module):
        print(f"Saving model weights to {self.filename}")
        super().on_epoch_end(trainer, pl_module)

In this example, the CustomModelCheckpoint class overrides the on_epoch_end method of the ModelCheckpoint class to print a message before saving the model weights.

Conclusion
Callbacks are a powerful feature of PyTorch Lightning that allow you to customize the behavior of your training loop. You can create custom callbacks by defining new callback classes, attach them to the Trainer using the callbacks argument, and use built-in callbacks provided by PyTorch Lightning to improve your training process. By using callbacks effectively, you can implement advanced training techniques, monitor training progress, and save model weights at specific checkpoints during training.

0 0 votes
Article Rating
8 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@pradeepbansal23
1 month ago

Can we use PEFT or LORA with Pytorch Lightning ?

@starultra2863
1 month ago

which IDE do u use?

@Amitkumar-yh2uw
1 month ago

Contrary to many comments, I believe this series have been very helpful and very clearly articulated in transitioning from Pytorch to Lightning (Y)

@karmelsalah3401
1 month ago

Thnx

@thecaptain2000
1 month ago

The video seems to have just scratched the surface of what is possible. This in a situation wehre Ligthning's own documentation is quite poor, at least on their website. The Documentation quality actually brings down what is quite a useful tool as it encapsulates quite well repetitive tasks and the ability to "spam" the model training across several GPUs/ TPUs. I have checked and Alladin seems not to have ever followed up with the idea of an in depth video on Callbacks. Too bad!

@ensabinha
1 month ago

It felt, and still feels lightning is waste of time.

@ranimsaidi9564
1 month ago

Very nice content! Please when are you going to upload the tensorboard video?

@alessandromondin5069
1 month ago

Does the king reply?