Creating a Variational Autoencoder in Pytorch Without Using Pre-trained Models

Posted by


Variational Autoencoders (VAEs) are a type of neural network architecture that are used for generating data by learning a latent space representation of the input data. Unlike traditional autoencoders which simply compress and decompress data, VAEs learn a probabilistic distribution of the latent space which allows for generating new data points.

In this tutorial, we will implement a variational autoencoder from scratch using PyTorch. This tutorial assumes some familiarity with neural networks and PyTorch, so if you are new to these concepts, I recommend familiarizing yourself with them before continuing.

Step 1: Importing the necessary libraries
First, we need to import the necessary libraries for building our VAE. We will be using PyTorch for our neural network operations, and some other libraries for data processing and visualization.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

Step 2: Loading the data
For this tutorial, we will use the MNIST dataset which contains images of handwritten digits. We will use the torchvision library to load and preprocess the dataset.

# Define the transformation to be applied to the data
transform = transforms.Compose([transforms.ToTensor()])

# Load the dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)

Step 3: Building the VAE model
Next, we will define the architecture of our VAE model. Our VAE will consist of an encoder and a decoder, each with neural network layers.

class VAE(nn.Module):
    def __init__(self, input_size=784, latent_size=20):
        super(VAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_size, 400),
            nn.ReLU(),
            nn.Linear(400, 200),
            nn.ReLU(),
            nn.Linear(200, 100),
            nn.ReLU()
        )

        self.mu = nn.Linear(100, latent_size)
        self.log_var = nn.Linear(100, latent_size)

        self.decoder = nn.Sequential(
            nn.Linear(latent_size, 100),
            nn.ReLU(),
            nn.Linear(100, 200),
            nn.ReLU(),
            nn.Linear(200, 400),
            nn.ReLU(),
            nn.Linear(400, input_size),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        x = self.encoder(x)
        mu = self.mu(x)
        log_var = self.log_var(x)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decoder(z)
        return x_recon, mu, log_var

Step 4: Defining the loss function
In VAEs, the loss function is a combination of the reconstruction loss and the KL divergence term. We will define the loss function as follows:

def loss_function(recon_x, x, mu, log_var):
    BCE = nn.BCELoss(reduction='sum')(recon_x, x)
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

Step 5: Training the VAE
Now, we will train our VAE model using the training data. We will define the training loop and train the model.

vae = VAE()
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

for epoch in range(10):
    total_loss = 0
    for i, (x, _) in enumerate(train_loader):
        x = x.view(-1, 784)
        optimizer.zero_grad()
        x_recon, mu, log_var = vae(x)
        loss = loss_function(x_recon, x, mu, log_var)
        loss.backward()
        total_loss += loss.item()
        optimizer.step()

    print('Epoch:', epoch, 'Loss:', total_loss / len(train_loader))

Step 6: Generating new samples
After training the VAE, we can now generate new samples by sampling from the latent space and passing it through the decoder.

z = torch.randn(16, 20)
generated_samples = vae.decoder(z)

Step 7: Visualizing the results
Finally, we can visualize the original and reconstructed images, as well as the generated samples.

fig, axes = plt.subplots(2, 8, figsize=(16, 4))

for i in range(8):
    axes[0, i].imshow(x[i].view(28, 28).detach().numpy(), cmap='gray')
    axes[0, i].axis('off')

    axes[1, i].imshow(x_recon[i].view(28, 28).detach().numpy(), cmap='gray')
    axes[1, i].axis('off')

plt.show()

In this tutorial, we have implemented a variational autoencoder from scratch using PyTorch. VAEs are powerful tools for generating data and learning latent representations of complex data distributions. I hope this tutorial has been helpful in understanding and implementing VAEs.

0 0 votes
Article Rating

Leave a Reply

7 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@Explaining-AI
2 hours ago
@pratyanshvaibhav
2 hours ago

Hi sir, when i am trying to run these code of run_simple_vae.py i am getting the error ValueError: num_samples should be a positive integer value, but got num_samples=0 i am following the readme from github repository but i dont know from where this error is coming ..maybe it is not expecting the dataset in csv format.. please help

@ivanmateev
2 hours ago

Why the final layer of mean/logvar is 2 dimensional but not 1

@androidtech9388
2 hours ago

Hi sir,
I have  scanning electron microscope images which has some defects in it or some part of pattern is missing while printing on wafer. How can we use VAE to classify a sem image into fault and faultless category?
Please guide me.

@guillermovc
2 hours ago

Hi, i have two questions: Shouldn't the decoder_fcs have a final nn.Tanh() layer so that the output (-1, 1) matches the way the images are rescaled (-1, 1)? Another question I have is, how are we supposed to generate new data points? Should we need to feed the network and get the mean and log and generate vector z, or could we just navigate through vectors of the same shape as z and expect to get something?
Thank you!!

@dmitryplatonov
2 hours ago

Why 0.000001 multiplier applied to kl_loss?

@prateekpani9464
2 hours ago

When you calculated KL Loss, why did you use dim=-1, why torch.mean() after torch.sum()? Can you provide links of simplified KL Div that you did please?

Btw, awesome series. I might have some doubts here and there.

7
0
Would love your thoughts, please comment.x
()
x