In this tutorial, we will learn how to save and load PyTorch models. Saving and loading models is an essential skill for machine learning practitioners as it allows you to save the state of your trained models and use them for inference or further training at a later time.
Before we proceed, make sure you have PyTorch installed in your environment. You can install PyTorch using pip by running the following command:
pip install torch torchvision
Now, let’s get started with saving and loading models in PyTorch.
Step 1: Define a Model
First, let’s define a simple neural network model using PyTorch. For this tutorial, we will create a basic feedforward neural network with two hidden layers.
import torch.nn as nn
class SimpleNN(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
Step 2: Save a Model
To save a PyTorch model, we can use the state_dict()
method. This method returns a dictionary containing the model’s parameters. We can save this dictionary to a file using the torch.save()
function.
model = SimpleNN(input_size=784, hidden_size=128, num_classes=10)
# Save model
torch.save(model.state_dict(), 'simple_nn_model.pth')
Step 3: Load a Model
To load a saved model, we can instantiate the model class and load the saved state dictionary using the load_state_dict()
method. This will load the saved parameters into the model.
# Create model instance
model = SimpleNN(input_size=784, hidden_size=128, num_classes=10)
# Load saved model
model.load_state_dict(torch.load('simple_nn_model.pth'))
model.eval() # Set model to evaluation mode
Step 4: Use the Model for Inference
Once the model is loaded, we can use it for inference on new data. Here’s an example of how to make predictions using the loaded model:
# Assuming `data` is a tensor of input data
output = model(data)
predicted_class = torch.argmax(output, dim=1)
print(f'Predicted class: {predicted_class.item()}')
That’s it! You have successfully saved and loaded a PyTorch model. This is a fundamental skill in deep learning, and knowing how to save and load models will be invaluable as you work on more complex projects and datasets.
In this tutorial, we covered the basic steps to save and load PyTorch models. Remember to always save the model’s state_dict()
and load it into a model instance appropriately. Additionally, make sure to set the model to evaluation mode (model.eval()
) before using it for inference.
I hope you found this tutorial helpful. Happy coding!
Hi, thank you for showing all the methods. I have a federated learning use-case where I train a single model using 32 clients. Currently, I have CUDA Out of memory issue. I want to know how each of the methods presented here affect the memory allocation and reservation on a single GPU. Any comment/idea is appreciate. Many thanks.
There are a couple of mistakes at around 14:30. You are using the state from the 'checkpoint' variable, where I think you should be using state from the 'loaded_checkpoint' variable. In general, I think this part of the tutorial would be clearer if you'd used two files, one for saving and one for loading. Also, a better IDE would highlight unused variables which might have helped avoid this mistake.
Great video, thanks man 👍🏻
Thank you very much.🙏
How do you save losses and accuracy in a csv file? after each epoch?
Hi I love your videos and it has been very much help for me.
I have a question that how can we use pytorch custom models on different files
Very helpful video. Thanks
HELP! I am confused. I have created and trained a model successfully. I have saved the model using the lazy method (torch.save(model, MODEL_NAME)) ** model_name is the file name and full path. Now, in a separate file(forecast), I want to load the trained model. So in this new forecast file, I've loaded the model from disk "model = torch.load("C:\Users\user\Desktop\datasets\rain\model.pth") " as per the lazy method. it generates an error message AttributeError: Can't get attribute 'NeuralNet' on <module '__main__' from …. NeuralNet is the class I created that inherited from NN.module. model is instantiated from the NeuralNet class.
When I torch.save my model, I'm saving an object instantiated from my class called NeuralNet. I don't understand the error message. Do I have to create the class again in my forecast file? It doesn't make sense to me.
any help will be greatly appreciated. (learning about Neural Networks is like sipping water from a fire hose!, its easy to get overwhelmed!)
tnx
Amazing video, can i confirm that by using this method we can do fine tuning, for example train our model on one dataset, then save and load the model (with the weights), we can then train on a similar dataset but using the loaded model with the saved weights, right?
How can we convert saved .pt model to .pb ?
Thank you so much for making the wonderful tutorial!
Would you like to correct again the part load checkpoint. I thought the command in line 32 33 should be model.load_state_dict(loaded_checkpoint['model_state'])
and
optimizer.load_state_dict(loaded_checkpoint['optim_state'])
Your voice is irritating. Giving knowledge is one thing but that voice is distracting me from concentrating on the topic (looks like you are putting on an accent). And what is with ending each sentence like you are asking a question?
could you make video about semantic segmentation with ensemble model with different encoders on pytorch, thank you!
I need your help on how I can resolve below issues:
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
Thank you so much for your lecture! I tried to save all my testing model result with nii(image) format. Could you give me a tip for me? I saved and loaded model to test. I did the image segmentation, my goal is to get a predicted segmentation images.
Mr. Python Engineer does a good job, moves along at a reasonable speed.
Why my model run again after load. Pth file? Can u plz help me
Every pytorch video: "and you can ignore this warning right here" 😛 I am going to be first to fall into that trap. Great series! Thank you so much.
thank you for the clear lesson! Really useful!