Generating Images with GANs in PyTorch: From Beginner to Advanced | Part 6 of 6

Posted by


In this tutorial, we will be learning about Image Generation using Generative Adversarial Networks (GANs) in PyTorch. GANs are a type of deep learning model that can generate realistic images by learning from a dataset of real images. In this tutorial, we will be using PyTorch, a popular deep learning framework, to implement a simple GAN for generating images.

This tutorial is part 6 of a series called "Deep Learning with PyTorch: Zero to GANs". If you haven’t already, I recommend checking out the previous parts of the series to get a good understanding of the basics of deep learning and PyTorch.

Before we begin, make sure you have Python installed on your machine along with the required libraries like PyTorch, NumPy, and Matplotlib. You can install these libraries using pip:

pip install torch numpy matplotlib

Let’s get started with the tutorial.

Introduction to GANs

Generative Adversarial Networks (GANs) are a type of deep learning model composed of two neural networks, the generator and the discriminator. The generator network takes random noise as input and generates images, while the discriminator network tries to distinguish between real images from the dataset and fake images generated by the generator.

During training, the generator and discriminator networks are trained simultaneously in a competitive framework. The generator tries to generate realistic images that can fool the discriminator, while the discriminator tries to correctly classify real and fake images. As a result, the generator gets better at generating realistic images over time.

Implementing a GAN in PyTorch

To implement a GAN in PyTorch, we will define the generator and discriminator networks using convolutional neural networks (CNNs). We will use the MNIST dataset, which contains handwritten digits, as our training dataset.

First, let’s import the required libraries and define some hyperparameters for our GAN:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Hyperparameters
batch_size = 64
noise_dim = 100
num_epochs = 100

Next, let’s load the MNIST dataset and create data loaders for training:

# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

Now, let’s define the generator and discriminator networks:

# Generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 784),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# Discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

Next, let’s define the training loop for our GAN:

# Initialize the generator and discriminator networks
generator = Generator()
discriminator = Discriminator()

# Loss function and optimizer
criterion = nn.BCELoss()
gen_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)

# Training loop
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(train_loader):
        batch_size = real_images.size(0)

        # Train the discriminator
        real = real_images.view(batch_size, -1)
        real_labels = torch.ones(batch_size, 1)

        noise = torch.randn(batch_size, noise_dim)
        fake = generator(noise)
        fake_labels = torch.zeros(batch_size, 1)

        disc_optimizer.zero_grad()

        disc_real = discriminator(real)
        disc_fake = discriminator(fake.detach())

        disc_loss_real = criterion(disc_real, real_labels)
        disc_loss_fake = criterion(disc_fake, fake_labels)

        disc_loss = disc_loss_real + disc_loss_fake
        disc_loss.backward()

        disc_optimizer.step()

        # Train the generator
        gen_optimizer.zero_grad()

        fake = generator(noise)
        disc_fake = discriminator(fake)

        gen_loss = criterion(disc_fake, real_labels)
        gen_loss.backward()

        gen_optimizer.step()

        print(f'Epoch [{epoch}/{num_epochs}], Step [{i}/{len(train_loader)}], '
              f'Discriminator Loss: {disc_loss.item()}, Generator Loss: {gen_loss.item()}')

Generating Images

Finally, let’s generate some images using the trained generator network:

# Generate images
with torch.no_grad():
    noise = torch.randn(8, noise_dim)
    fake_images = generator(noise)

    fake_images = fake_images.view(-1, 28, 28)

    fig, axs = plt.subplots(2, 4, figsize=(12, 6))
    for i in range(8):
        axs[i//4, i%4].imshow(fake_images[i].cpu().numpy(), cmap='gray')
        axs[i//4, i%4].axis('off')

    plt.show()

That’s it! We have successfully implemented a simple GAN in PyTorch for generating images. Feel free to experiment with different architectures, hyperparameters, and datasets to improve the quality of generated images. GANs are a powerful tool for image generation and can be used for various applications like image synthesis, data augmentation, and image manipulation. Have fun experimenting with GANs and deep learning!

0 0 votes
Article Rating

Leave a Reply

21 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@dhruvvaidh8843
2 hours ago

I am not able to download the notebook from the Jovian platform

@AbdrazakSharif
2 hours ago

Finshding video

@AbdrazakSharif
2 hours ago

❤❤

@javlontursunov6527
2 hours ago

Is the course free ?

@akshatzz_iitb6052
2 hours ago

Hey Aakash bhaiya!
I'm pursuing mechanical engiennering from IIT Bombay 2nd year(roll no 22B4513) and I know you are our alumni. I'm really interested in the field in data science and learned a lot from this series. Is there a way I can learn more from you. I'd love to know more.

@BASANIKHILESH
2 hours ago

This video was really very helpful and way of teaching is too good.

@daddyyyyyyyyyyyy
2 hours ago

hi pls can you segregate your videos for beginners like me its very confusing what to study what not to coz there are multiple videos on one topics so it's really confusing pls make a systematic playlist

@augustinestephens471
2 hours ago

Hey akash , i have a query.
When i was trying to create the generator architecture , i didn't include BatchNorm2d layers like your architecture , and the generator always gave me a gray image.
Adding BatchNorm2d solved that problem. Can you explain why that is?

@Carbon-XII
2 hours ago

Finished the course with great interest.
Definitely recommend it to other ML learners.
Also, will experiment with the source code and different datasets soon.

@improvementbeyond2974
2 hours ago

Thank you man i learned alot from you. Respect for your job hope to achive many things in life. Really thank you

@sandro5535
2 hours ago

About anime figures. You can clearly tell when a female but male characters are usually quite feminine. Rule of thumb if unsure it is a male.

@animeshsarkar295
2 hours ago

Sir, I am an undergraduate student with Math(hons.) and i am working on my coding skills as I want to get in IT sector .

So, I am confused that I SHOULD OR SHOULD NOT pursue M.sc (Mathematics and Computing) in ISM Dhanbad ?

@tanvirashraf728
2 hours ago

@freecodecamp.org, how about introducing a tutorial on enterprise app, i believe it wilk be exciting and Allah knows best

@abishekkirupa7898
2 hours ago

Can I start ethical hacking course after your Java course for beginners ?

@leorinaldi4931
2 hours ago

This is a great course I strongly recommend it to anyone interested in machine learning and Pytorch.

@pratikkulkar8128
2 hours ago

Much Waited!

@payal8159
2 hours ago

Hey coding lovers..!!
Wanna learn basic coding and computer concepts!! Then hurry up!
https://youtube.com/channel/UCPGZJkh1b2dUt6M5kBJ26fA
Visit the channel "Junior Master Mind"
And start learning

@حليموالشبح-ن8ي
2 hours ago

It is possible to send you an error what you correct this error please

@digigoliath
2 hours ago

Awesome!!! TQVM!!

21
0
Would love your thoughts, please comment.x
()
x