DCGAN (Deep Convolutional Generative Adversarial Networks) is a type of generative model that can generate high-quality images. In this tutorial, we will learn how to implement DCGAN using PyTorch.
Before we start, please make sure you have PyTorch installed on your system. You can install it using pip:
pip install torch torchvision
Now, let’s start by importing the necessary libraries:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
Next, we will define the generator and discriminator models. The generator takes a random noise vector as input and generates an image, while the discriminator takes an image as input and predicts whether it is real or fake.
class Generator(nn.Module):
def __init__(self, nz, ngf, nc):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(nz, ngf*4, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf*4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf*2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, input):
return self.main(input)
class Discriminator(nn.Module):
def __init__(self, nc, ndf):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf*2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf*4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf*4, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
Now, we will define the hyperparameters and create the models:
# Hyperparameters
nz = 100
ngf = 64
ndf = 64
nc = 1
# Create the models
generator = Generator(nz, ngf, nc)
discriminator = Discriminator(nc, ndf)
Next, we will define the loss function and optimization algorithm:
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
Now, we will load the dataset and prepare the data loaders:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
Finally, we will train the models:
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
for i, data in enumerate(dataloader):
# Update discriminator
discriminator.zero_grad()
real_images = data[0].to(device)
real_labels = torch.ones(real_images.size(0), 1).to(device)
output = discriminator(real_images)
loss_real = criterion(output, real_labels)
fake_images = generator(torch.randn(batch_size, nz, 1, 1).to(device))
fake_labels = torch.zeros(fake_images.size(0), 1).to(device)
output = discriminator(fake_images.detach())
loss_fake = criterion(output, fake_labels)
loss_D = loss_real + loss_fake
loss_D.backward()
optimizer_D.step()
# Update generator
generator.zero_grad()
output = discriminator(fake_images)
loss_G = criterion(output, real_labels)
loss_G.backward()
optimizer_G.step()
if i % 100 == 0:
print(f'Epoch [{epoch}/{num_epochs}], Step [{i}/{len(dataloader)}], Loss D: {loss_D.item()}, Loss G: {loss_G.item()}')
# Save the models
torch.save(generator.state_dict(), 'generator.pt')
torch.save(discriminator.state_dict(), 'discriminator.pt')
That’s it! You have successfully implemented a DCGAN using PyTorch. Feel free to experiment with different hyperparameters and datasets to generate high-quality images.
Hi, one question, why is output from generator 64×64 ? Mnist images are 28×28.
Awesome video. Thank you so much.
i cant not understand in line 70:
output = netD(fake.detach()).reshape(-1)
what is "detach"??
Thank you for the video. Well explained and engaging as well. Keep up the good work.
Great video!
Thanks for the amazing tutorial!!! What if we use kernel 3, how can we do this? It'll be great if you give ur valuable opinion in this regard! Thanks ♥
Great. But there's a question on labels. why labels of fake instances we've made are torch.ones(…); while the discriminator must classify them as fake images with label zero. Shouldn't we use torch .zeros(…) to label fake images???
There is a new video playlist entirely focused on GANs now, and I've remade this tutorial, do check out the links below. I think they are better than this video in terms of explanations and quality (although this is probably still pretty good 🙂
New DCGAN video:
https://youtu.be/IZtv9s_Wx9I
GAN Playlist:
https://www.youtube.com/playlist?list=PLhhyoLH6IjfwIp8bZnzX8QR30TRcHO8Va
Quick question here, if I use cifar10 and image size 128, I will get an error "Target and input must have the same number of elements.". What is the problem here?
Great video! However, I have a question: How can I make my own dataset of images which I would give to this network?
Thank you
I have a question!
When you're doing "lossG.backward" after the command output = netD(fake).reshape(-1),
aren't you also backpropagating the discriminator before even reaching the generator?
Isn't that going to decrease the learnings of the Discriminator a little bit?
Hi, i liked your tutorial. it was great! i just want to know if there is a way to not use the MNIST but use my own images ( my dataset is for floorplans) and i want to use what you did here just change the first part where i don't use this "dataset = datasets.MNIST( root="dataset/", train=True, transform=my_transforms, download=True)" . Thank you for the response
Dude perfect explanation, thanks 🙂
Amazing video, could you explain how can I use a custom dataset? I am little bit confused
Hello Aladdin, great tutorial! You use sigmoid with BCELoss. Wouldn't be beneficial to use BCEWithLogitsLoss and remove the sigmoid? According to pytorch docs the latter is more stable. Just curious as my understanding of the inner torch mechanics could be wrong 🙂
Video is great for me in learning PyTorch
Please I want to know if you can explain recommender systems like gan in this video. am a junior data scientist and I have some difficulties to find out how it real works and maths behind .
Hi, thank you very much, your video is very useful, and your explanation cleared so much thing in my mind.
Great Content ,Cover DeepLab, RCNN,Mask RCNN,Yolo Etc detection and segmentation algorithm please…