Creating a Variational Autoencoder from Scratch Using PyTorch

Posted by


A variational autoencoder (VAE) is a type of neural network that can learn to generate new data samples by training on a dataset. VAEs are a type of generative model that learn a probability distribution over the data, allowing them to generate new samples that are similar to those in the training data. In this tutorial, we will walk through how to build a VAE from scratch using PyTorch.

First, let’s define the architecture of the VAE. The VAE is composed of an encoder and a decoder. The encoder takes an input data sample and encodes it into a latent space representation, while the decoder takes a sample from the latent space and reconstructs the original input data. The VAE is trained by maximizing the evidence lower bound (ELBO), which consists of two parts: the reconstruction loss and the KL divergence between the latent distribution and a standard normal distribution.

Let’s start by defining the encoder and decoder networks in PyTorch:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        mean = self.fc_mean(h)
        logvar = self.fc_logvar(h)

        return mean, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()

        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        x_recon = torch.sigmoid(self.fc2(h))

        return x_recon

Next, we will define the VAE model that combines the encoder and decoder:

class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()

        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        z = mean + eps*std
        return z

    def forward(self, x):
        mean, logvar = self.encoder(x)
        z = self.reparameterize(mean, logvar)
        x_recon = self.decoder(z)

        return x_recon, mean, logvar

Now, we will define the loss function for training the VAE:

def loss_function(x_recon, x, mean, logvar):
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
    kl_div = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())

    return recon_loss + kl_div

Finally, we can train the VAE on a dataset. Below is an example of training the VAE model on the MNIST dataset:

# Load the MNIST dataset
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)

# Initialize the VAE model
input_dim = 784
hidden_dim = 256
latent_dim = 20
vae = VAE(input_dim, hidden_dim, latent_dim)

# Define the optimizer
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

# Train the VAE
num_epochs = 10
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(train_loader):
        x = x.view(-1, 784)
        x_recon, mean, logvar = vae(x)
        loss = loss_function(x_recon, x, mean, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                epoch+1, num_epochs, i+1, len(train_loader), loss.item()))

This is a basic implementation of a VAE in PyTorch. You can further customize the VAE by modifying the network architecture, loss function, or training parameters. VAEs are a powerful tool for generating new data samples and can be applied to various types of datasets.

0 0 votes
Article Rating

Leave a Reply

39 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@TsiHang
24 days ago

I won't have passed my fyp and graduated without you. THANK YOU

@CrypticPulsar
24 days ago

You are incredible.. thanks for sharing all this knowledge and skill with the world..

@LucaBovelli
24 days ago

are you the son of notch (markus persson)?

@0liver19
24 days ago

you are awesome. thank you for this immensely valuable resource!!

@danyahhussein1073
24 days ago

Thanks Aladdin, you helped me a lot, thanks for the unique explanation, keep up the good!

@tode2227
24 days ago

Again an awesome from-scratch video! I have never seen programming videos in which it is so simple to follow what the person is coding, thank you.
Currently, there are no videos about stable diffusion from scratch, which include the training scripts.
It would be great to see a video on this!

@user-fb9zv9cf1s
24 days ago

Code from 15:05 so you don't need to type it all:
import torch
import torchvision.datasets as datasets
from tqdm import tqdm
from torch import nn, optim
from model import VariationalAutoEncoder
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

@marcel2711
24 days ago

mnist dataset lol. all samples/videos using the same DS. so boring. create your own dataset, implement something interesting

@sahhaf1234
24 days ago

First of all, thank you very much…
Secondly, in line 74, should'nt we have epsilon = torch.randn_like(1) instead of epsilon = torch.randn_like(sigma)? Because we want an epsilon distributed in N(0,1) and then the next line will generate z which will be distributed in N(sigma, epsilon).

@TsiHang
24 days ago

Had to learn about VAE with zero experience in coding or ML. Thank God I found this video 😅

@edgarromeroherrera2886
24 days ago

lovely video man, thankyou

@kl_moon
24 days ago

I love "from scratch" series, plz make more videos..!! and thank you so much!!!

@parthsoni1076
24 days ago

Thanks for the tutorial, it was simple yet insightful. Can you also make a video where you can combine different architecture such as Transformers or Residual blocks in Encoder-Decoder block of VAE.

@nathantrance7558
24 days ago

You are truly a life saver sir. Thank you for keeping everything simple instead of using programming shenanigans just to make it more complicated and unreadable.
Love your tutorials, I learned a lot from your line of thinking, including the ranting things.

@manolisnikolakakis7292
24 days ago

Thank you very much for your tutorials. They have been incredibly helpful and insightful.

@GoldenMunkee
24 days ago

I just have to say that, even as someone with a Master's in Data Science from a top university, I still use your tutorials for my work and my projects. Your stuff is incredibly helpful from a practical perspective. In school, they teach you theory with little to no instruction on how to actually build anything. Thank you so much for your hard work!!

@user-hq1jz5pb8w
24 days ago

Great tutorials!! I can understand how to work on VAE!! ☺☺☺☺

@user-cd2cu6dy6k
24 days ago

why machine learning is easy to learn? Because a lot of amazing guys are making videos about explaining papers and writing codes line by line.

@dr_rahmani_m
24 days ago

I like the thought process. So, thanks for the 'from scratch' tutorials.

@avp300
24 days ago

great explanation, thanks!

39
0
Would love your thoughts, please comment.x
()
x