Part 2: Building a Deep Learning Model with PyTorch for the Iris Dataset

Posted by

Training a deep learning model with PyTorch for the Iris dataset (part 2)

Training a deep learning model with PyTorch for the Iris dataset (part 2)

In part 1 of this tutorial, we covered the basics of setting up a PyTorch environment and loading the Iris dataset. In this part, we will go over how to train a deep learning model using PyTorch for the Iris dataset.

Defining the model

First, we need to define our deep learning model. In this case, we will use a simple neural network with one hidden layer. Here’s the code for defining the model:


import torch
import torch.nn as nn

class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.fc1 = nn.Linear(4, 10)
self.fc2 = nn.Linear(10, 3)

def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x

Training the model

Next, we need to train our model on the Iris dataset. We will use the Mean Squared Error loss function and the Adam optimizer for training. Here’s the code for training the model:


model = NeuralNetwork()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

Testing the model

Finally, we can test our trained model on the test set and evaluate its performance. Here’s the code for testing the model:


with torch.no_grad():
correct = 0
total = 0
for inputs, labels in test_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f'Accuracy: {accuracy}')

Conclusion

Congratulations! You have successfully trained a deep learning model using PyTorch for the Iris dataset. You can now apply this knowledge to train more complex models on different datasets. Experiment with different network architectures, loss functions, and optimizers to improve the performance of your models. Happy coding!

0 0 votes
Article Rating

Leave a Reply

0 Comments
Inline Feedbacks
View all comments
0
Would love your thoughts, please comment.x
()
x