Implementing a Variational Autoencoder for Handwritten Digits in PyTorch: A Code Example

Posted by


In this tutorial, we will explore a Variational Autoencoder (VAE) for generating handwritten digits using PyTorch. VAEs are a type of generative model that learn to encode and decode data in a continuous latent space. They are particularly useful for generating new samples from a dataset and are widely used in fields such as image generation and machine learning.

We will use the MNIST dataset, which consists of 28×28 pixel grayscale images of handwritten digits ranging from 0 to 9. We will build a VAE model that can generate new handwritten digits based on this dataset.

First, let’s install the necessary libraries:

pip install torch torchvision matplotlib

Next, let’s import the required libraries in our Python script:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
import matplotlib.pyplot as plt

Now, let’s define the VAE model. The model consists of an encoder network, a decoder network, and a reparameterization layer:

class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU()
        )

        self.mu = nn.Linear(256, latent_dim)
        self.log_var = nn.Linear(256, latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        x = self.encoder(x)
        mu = self.mu(x)
        log_var = self.log_var(x)

        z = self.reparameterize(mu, log_var)
        x_reconstructed = self.decoder(z)

        return x_reconstructed, mu, log_var

Next, let’s define the loss function for the VAE, which consists of a reconstruction loss and a KL divergence loss:

def vae_loss(x, x_reconstructed, mu, log_var):
    # Reconstruction Loss
    reconstruction_loss = F.binary_cross_entropy(x_reconstructed, x, reduction='sum')

    # KL Divergence Loss
    kl_divergence_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    return reconstruction_loss + kl_divergence_loss

Now, let’s define the training loop for the VAE model:

def train_vae(model, dataloader, optimizer, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for i, (x, _) in enumerate(dataloader):
            x = x.view(-1, 28*28)
            x_reconstructed, mu, log_var = model(x)

            loss = vae_loss(x, x_reconstructed, mu, log_var)
            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f'Epoch {epoch+1}, Loss: {total_loss/len(dataloader.dataset)}')

Finally, let’s load the MNIST dataset, initialize the VAE model, and train the model:

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor()
])

mnist_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist_dataset, batch_size=128, shuffle=True)

# Initialize VAE model
input_dim = 28*28
latent_dim = 64
model = VAE(input_dim, latent_dim)

# Set optimizer and number of epochs
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 10

# Train the model
train_vae(model, dataloader, optimizer, num_epochs)

After training the VAE model, we can generate new handwritten digits by sampling from the latent space:

def generate_samples(model, num_samples):
    model.eval()

    with torch.no_grad():
        z = torch.randn(num_samples, latent_dim)
        generated_samples = model.decoder(z)

    return generated_samples

# Generate new samples
num_samples = 10
generated_samples = generate_samples(model, num_samples)

# Display generated samples
fig, axs = plt.subplots(1, num_samples, figsize=(20, 4))
for i in range(num_samples):
    axs[i].imshow(generated_samples[i].view(28, 28).numpy(), cmap='gray')
    axs[i].axis('off')

plt.show()

And that’s it! We have successfully built a Variational Autoencoder for generating handwritten digits using PyTorch. Feel free to experiment with different hyperparameters, network architectures, and datasets to further improve the model’s performance.

0 0 votes
Article Rating

Leave a Reply

12 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@yogendra-yatnalkar
2 hours ago

Thanks a lot for the VAE series. A small question: Since we need a encoder output to be as close to standard distribution as possible, why dont we enforce activation function on the encoder linear layer ? –> The mean layer will have sigmoid activation fcn and variance layer will have tanh …something like this ?

@vineetgundecha7872
2 hours ago

Thanks for the explanation! Unlike the reconstruction loss which is interpretable, how should we interpret the KL divergence loss? What is an acceptable value? How would the sampled images look if we have a low reconstruction error but high KL divergence ?

@jalv1499
2 hours ago

thank you for the video! What's the formula of backpropagation? I did not see the code of backward propagation part.

@cricketcricket20
2 hours ago

Hello, at 13:53 you said that you are summing over the latent dimension. But aren't the z_mean and z_log_var tensors of the shape (batch size, channels, latent dimension)? In that case wouldn't you sum over axis = 2? Thanks a lot for the videos!

@prashantjaiswal5260
2 hours ago

running the code on google colab it shows error in model.to(DEVICE ) part how it can be corrected???

set_all_seeds(RANDOM_SEED)

model = VAE()

model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

@kadrimufti4295
2 hours ago

Thank you for the lecture. If we sample from a multivariate random normal distribution (to decode and see what numbers we can get), then will we be more likely to decode some digits over others due to the nature of the distribution we are sampling from? And so based on your 2-D plot, would we get the digits at the center more often than the ones on the periphery?

@hillarykavagi7349
2 hours ago

Hi Sebastian, I like your Videos, I has helped me, but am working on a personal project on Variational Autoencoders using Dirichlet distribution, and am stuck at the point of calculating Binary cross Entropy loss, I would kindly like to request for assistance

@MrBeefSlapper
2 hours ago

5:00 could you please explain how just using a linear layer nn.Linear is able to calculate the mean and log variance of the latent space for z_mean and z_log_var? It looks like z_mean and z_log_var compresses the space into 2 latent dimensions, but shouldn't there be an additional step to explicitly compute the mean and variance of the 2 latent dimensions before sampling?

@siddhantverma532
2 hours ago

First of all, thanks a lot! The scatter plot really gives a nice intuition about latent space.But it got me thinking that will every 2d space trained will look like this, or will it depend on how someone has made architecture or trained it.Then I saw your plot it was different from mine so I guess its not universal then. If it was universal it would be like a huge thing!
Another thing that we are trying to learn the probability distribution if I'm not wrong I wanna know and visualise the distribution that our network has learnt how can we know that, its in 2d so it can be visualised in 3d graph.

@736939
2 hours ago

3:59 In the first decoder's linear layer you have only 2 neurons, – I mean, if you have 2 neurons from z_mean and 2 neurons from z_log_var then, the decoder's linear layer must contains 4 neurons instead of 2. I don't get it.

@raghavamorusupalli7557
2 hours ago

Thank you for hand holding the DL aspirants to reach new destinations, Great Service to the Knowledge

@MohitGupta-zf8kx
2 hours ago

Your video is really amazing. Thank you very much for giving us so much knowledge. Can you please tell us how can we get the validation loss evaluation curves?
Thanks 🙂

12
0
Would love your thoughts, please comment.x
()
x