Saving and Loading Models in PyTorch Tutorial 17

Posted by


In this tutorial, we will learn how to save and load PyTorch models. Saving and loading models is an essential skill for machine learning practitioners as it allows you to save the state of your trained models and use them for inference or further training at a later time.

Before we proceed, make sure you have PyTorch installed in your environment. You can install PyTorch using pip by running the following command:

pip install torch torchvision

Now, let’s get started with saving and loading models in PyTorch.

Step 1: Define a Model
First, let’s define a simple neural network model using PyTorch. For this tutorial, we will create a basic feedforward neural network with two hidden layers.

import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

Step 2: Save a Model
To save a PyTorch model, we can use the state_dict() method. This method returns a dictionary containing the model’s parameters. We can save this dictionary to a file using the torch.save() function.

model = SimpleNN(input_size=784, hidden_size=128, num_classes=10)

# Save model
torch.save(model.state_dict(), 'simple_nn_model.pth')

Step 3: Load a Model
To load a saved model, we can instantiate the model class and load the saved state dictionary using the load_state_dict() method. This will load the saved parameters into the model.

# Create model instance
model = SimpleNN(input_size=784, hidden_size=128, num_classes=10)

# Load saved model
model.load_state_dict(torch.load('simple_nn_model.pth'))
model.eval()  # Set model to evaluation mode

Step 4: Use the Model for Inference
Once the model is loaded, we can use it for inference on new data. Here’s an example of how to make predictions using the loaded model:

# Assuming `data` is a tensor of input data
output = model(data)
predicted_class = torch.argmax(output, dim=1)
print(f'Predicted class: {predicted_class.item()}')

That’s it! You have successfully saved and loaded a PyTorch model. This is a fundamental skill in deep learning, and knowing how to save and load models will be invaluable as you work on more complex projects and datasets.

In this tutorial, we covered the basic steps to save and load PyTorch models. Remember to always save the model’s state_dict() and load it into a model instance appropriately. Additionally, make sure to set the model to evaluation mode (model.eval()) before using it for inference.

I hope you found this tutorial helpful. Happy coding!

0 0 votes
Article Rating
35 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@longdang7791
1 month ago

Hi, thank you for showing all the methods. I have a federated learning use-case where I train a single model using 32 clients. Currently, I have CUDA Out of memory issue. I want to know how each of the methods presented here affect the memory allocation and reservation on a single GPU. Any comment/idea is appreciate. Many thanks.

@npomfret
1 month ago

There are a couple of mistakes at around 14:30. You are using the state from the 'checkpoint' variable, where I think you should be using state from the 'loaded_checkpoint' variable. In general, I think this part of the tutorial would be clearer if you'd used two files, one for saving and one for loading. Also, a better IDE would highlight unused variables which might have helped avoid this mistake.

@ubaidghante8604
1 month ago

Great video, thanks man 👍🏻

@GurpreetSingh-si2gh
1 month ago

Thank you very much.🙏

@shivanshchakrawarti9289
1 month ago

How do you save losses and accuracy in a csv file? after each epoch?

@mohammadusamah819
1 month ago

Hi I love your videos and it has been very much help for me.
I have a question that how can we use pytorch custom models on different files

@amankushwaha8927
1 month ago

Very helpful video. Thanks

@lakeguy65616
1 month ago

HELP! I am confused. I have created and trained a model successfully. I have saved the model using the lazy method (torch.save(model, MODEL_NAME)) ** model_name is the file name and full path. Now, in a separate file(forecast), I want to load the trained model. So in this new forecast file, I've loaded the model from disk "model = torch.load("C:\Users\user\Desktop\datasets\rain\model.pth") " as per the lazy method. it generates an error message AttributeError: Can't get attribute 'NeuralNet' on <module '__main__' from …. NeuralNet is the class I created that inherited from NN.module. model is instantiated from the NeuralNet class.
When I torch.save my model, I'm saving an object instantiated from my class called NeuralNet. I don't understand the error message. Do I have to create the class again in my forecast file? It doesn't make sense to me.

any help will be greatly appreciated. (learning about Neural Networks is like sipping water from a fire hose!, its easy to get overwhelmed!)

@Dr.Bimechi
1 month ago

tnx

@MuhammadHussain-ws1xs
1 month ago

Amazing video, can i confirm that by using this method we can do fine tuning, for example train our model on one dataset, then save and load the model (with the weights), we can then train on a similar dataset but using the loaded model with the saved weights, right?

@omerbinali9472
1 month ago

How can we convert saved .pt model to .pb ?

@nguyenduydatnguyenduydat2640
1 month ago

Thank you so much for making the wonderful tutorial!
Would you like to correct again the part load checkpoint. I thought the command in line 32 33 should be model.load_state_dict(loaded_checkpoint['model_state'])
and
optimizer.load_state_dict(loaded_checkpoint['optim_state'])

@ashwiniaditya8445
1 month ago

Your voice is irritating. Giving knowledge is one thing but that voice is distracting me from concentrating on the topic (looks like you are putting on an accent). And what is with ending each sentence like you are asking a question?

@Sparklerated
1 month ago

could you make video about semantic segmentation with ensemble model with different encoders on pytorch, thank you!

@annemahoro9966
1 month ago

I need your help on how I can resolve below issues:

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

@mywayluna1715
1 month ago

Thank you so much for your lecture! I tried to save all my testing model result with nii(image) format. Could you give me a tip for me? I saved and loaded model to test. I did the image segmentation, my goal is to get a predicted segmentation images.

@geoffreyexoo
1 month ago

Mr. Python Engineer does a good job, moves along at a reasonable speed.

@newajsorifnishad9169
1 month ago

Why my model run again after load. Pth file? Can u plz help me

@amerel-samman9929
1 month ago

Every pytorch video: "and you can ignore this warning right here" 😛 I am going to be first to fall into that trap. Great series! Thank you so much.

@inesylla9706
1 month ago

thank you for the clear lesson! Really useful!