Creating a Variational Autoencoder from Scratch Using PyTorch

Posted by


A variational autoencoder (VAE) is a type of neural network that can learn to generate new data samples by training on a dataset. VAEs are a type of generative model that learn a probability distribution over the data, allowing them to generate new samples that are similar to those in the training data. In this tutorial, we will walk through how to build a VAE from scratch using PyTorch.

First, let’s define the architecture of the VAE. The VAE is composed of an encoder and a decoder. The encoder takes an input data sample and encodes it into a latent space representation, while the decoder takes a sample from the latent space and reconstructs the original input data. The VAE is trained by maximizing the evidence lower bound (ELBO), which consists of two parts: the reconstruction loss and the KL divergence between the latent distribution and a standard normal distribution.

Let’s start by defining the encoder and decoder networks in PyTorch:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        mean = self.fc_mean(h)
        logvar = self.fc_logvar(h)

        return mean, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()

        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        x_recon = torch.sigmoid(self.fc2(h))

        return x_recon

Next, we will define the VAE model that combines the encoder and decoder:

class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()

        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        z = mean + eps*std
        return z

    def forward(self, x):
        mean, logvar = self.encoder(x)
        z = self.reparameterize(mean, logvar)
        x_recon = self.decoder(z)

        return x_recon, mean, logvar

Now, we will define the loss function for training the VAE:

def loss_function(x_recon, x, mean, logvar):
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
    kl_div = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())

    return recon_loss + kl_div

Finally, we can train the VAE on a dataset. Below is an example of training the VAE model on the MNIST dataset:

# Load the MNIST dataset
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)

# Initialize the VAE model
input_dim = 784
hidden_dim = 256
latent_dim = 20
vae = VAE(input_dim, hidden_dim, latent_dim)

# Define the optimizer
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

# Train the VAE
num_epochs = 10
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(train_loader):
        x = x.view(-1, 784)
        x_recon, mean, logvar = vae(x)
        loss = loss_function(x_recon, x, mean, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                epoch+1, num_epochs, i+1, len(train_loader), loss.item()))

This is a basic implementation of a VAE in PyTorch. You can further customize the VAE by modifying the network architecture, loss function, or training parameters. VAEs are a powerful tool for generating new data samples and can be applied to various types of datasets.

0 0 votes
Article Rating
39 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@TsiHang
3 months ago

I won't have passed my fyp and graduated without you. THANK YOU

@CrypticPulsar
3 months ago

You are incredible.. thanks for sharing all this knowledge and skill with the world..

@LucaBovelli
3 months ago

are you the son of notch (markus persson)?

@0liver19
3 months ago

you are awesome. thank you for this immensely valuable resource!!

@danyahhussein1073
3 months ago

Thanks Aladdin, you helped me a lot, thanks for the unique explanation, keep up the good!

@tode2227
3 months ago

Again an awesome from-scratch video! I have never seen programming videos in which it is so simple to follow what the person is coding, thank you.
Currently, there are no videos about stable diffusion from scratch, which include the training scripts.
It would be great to see a video on this!

@user-fb9zv9cf1s
3 months ago

Code from 15:05 so you don't need to type it all:
import torch
import torchvision.datasets as datasets
from tqdm import tqdm
from torch import nn, optim
from model import VariationalAutoEncoder
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

@marcel2711
3 months ago

mnist dataset lol. all samples/videos using the same DS. so boring. create your own dataset, implement something interesting

@sahhaf1234
3 months ago

First of all, thank you very much…
Secondly, in line 74, should'nt we have epsilon = torch.randn_like(1) instead of epsilon = torch.randn_like(sigma)? Because we want an epsilon distributed in N(0,1) and then the next line will generate z which will be distributed in N(sigma, epsilon).

@TsiHang
3 months ago

Had to learn about VAE with zero experience in coding or ML. Thank God I found this video 😅

@edgarromeroherrera2886
3 months ago

lovely video man, thankyou

@kl_moon
3 months ago

I love "from scratch" series, plz make more videos..!! and thank you so much!!!

@parthsoni1076
3 months ago

Thanks for the tutorial, it was simple yet insightful. Can you also make a video where you can combine different architecture such as Transformers or Residual blocks in Encoder-Decoder block of VAE.

@nathantrance7558
3 months ago

You are truly a life saver sir. Thank you for keeping everything simple instead of using programming shenanigans just to make it more complicated and unreadable.
Love your tutorials, I learned a lot from your line of thinking, including the ranting things.

@manolisnikolakakis7292
3 months ago

Thank you very much for your tutorials. They have been incredibly helpful and insightful.

@GoldenMunkee
3 months ago

I just have to say that, even as someone with a Master's in Data Science from a top university, I still use your tutorials for my work and my projects. Your stuff is incredibly helpful from a practical perspective. In school, they teach you theory with little to no instruction on how to actually build anything. Thank you so much for your hard work!!

@user-hq1jz5pb8w
3 months ago

Great tutorials!! I can understand how to work on VAE!! ☺☺☺☺

@user-cd2cu6dy6k
3 months ago

why machine learning is easy to learn? Because a lot of amazing guys are making videos about explaining papers and writing codes line by line.

@dr_rahmani_m
3 months ago

I like the thought process. So, thanks for the 'from scratch' tutorials.

@avp300
3 months ago

great explanation, thanks!