A variational autoencoder (VAE) is a type of neural network that can learn to generate new data samples by training on a dataset. VAEs are a type of generative model that learn a probability distribution over the data, allowing them to generate new samples that are similar to those in the training data. In this tutorial, we will walk through how to build a VAE from scratch using PyTorch.
First, let’s define the architecture of the VAE. The VAE is composed of an encoder and a decoder. The encoder takes an input data sample and encodes it into a latent space representation, while the decoder takes a sample from the latent space and reconstructs the original input data. The VAE is trained by maximizing the evidence lower bound (ELBO), which consists of two parts: the reconstruction loss and the KL divergence between the latent distribution and a standard normal distribution.
Let’s start by defining the encoder and decoder networks in PyTorch:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mean = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
def forward(self, x):
h = F.relu(self.fc1(x))
mean = self.fc_mean(h)
logvar = self.fc_logvar(h)
return mean, logvar
class Decoder(nn.Module):
def __init__(self, latent_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.fc1 = nn.Linear(latent_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, z):
h = F.relu(self.fc1(z))
x_recon = torch.sigmoid(self.fc2(h))
return x_recon
Next, we will define the VAE model that combines the encoder and decoder:
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
def reparameterize(self, mean, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
z = mean + eps*std
return z
def forward(self, x):
mean, logvar = self.encoder(x)
z = self.reparameterize(mean, logvar)
x_recon = self.decoder(z)
return x_recon, mean, logvar
Now, we will define the loss function for training the VAE:
def loss_function(x_recon, x, mean, logvar):
recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
kl_div = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
return recon_loss + kl_div
Finally, we can train the VAE on a dataset. Below is an example of training the VAE model on the MNIST dataset:
# Load the MNIST dataset
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
# Initialize the VAE model
input_dim = 784
hidden_dim = 256
latent_dim = 20
vae = VAE(input_dim, hidden_dim, latent_dim)
# Define the optimizer
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
# Train the VAE
num_epochs = 10
for epoch in range(num_epochs):
for i, (x, _) in enumerate(train_loader):
x = x.view(-1, 784)
x_recon, mean, logvar = vae(x)
loss = loss_function(x_recon, x, mean, logvar)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
epoch+1, num_epochs, i+1, len(train_loader), loss.item()))
This is a basic implementation of a VAE in PyTorch. You can further customize the VAE by modifying the network architecture, loss function, or training parameters. VAEs are a powerful tool for generating new data samples and can be applied to various types of datasets.
I won't have passed my fyp and graduated without you. THANK YOU
You are incredible.. thanks for sharing all this knowledge and skill with the world..
are you the son of notch (markus persson)?
you are awesome. thank you for this immensely valuable resource!!
Thanks Aladdin, you helped me a lot, thanks for the unique explanation, keep up the good!
Again an awesome from-scratch video! I have never seen programming videos in which it is so simple to follow what the person is coding, thank you.
Currently, there are no videos about stable diffusion from scratch, which include the training scripts.
It would be great to see a video on this!
Code from 15:05 so you don't need to type it all:
import torch
import torchvision.datasets as datasets
from tqdm import tqdm
from torch import nn, optim
from model import VariationalAutoEncoder
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
mnist dataset lol. all samples/videos using the same DS. so boring. create your own dataset, implement something interesting
First of all, thank you very much…
Secondly, in line 74, should'nt we have epsilon = torch.randn_like(1) instead of epsilon = torch.randn_like(sigma)? Because we want an epsilon distributed in N(0,1) and then the next line will generate z which will be distributed in N(sigma, epsilon).
Had to learn about VAE with zero experience in coding or ML. Thank God I found this video 😅
lovely video man, thankyou
I love "from scratch" series, plz make more videos..!! and thank you so much!!!
Thanks for the tutorial, it was simple yet insightful. Can you also make a video where you can combine different architecture such as Transformers or Residual blocks in Encoder-Decoder block of VAE.
You are truly a life saver sir. Thank you for keeping everything simple instead of using programming shenanigans just to make it more complicated and unreadable.
Love your tutorials, I learned a lot from your line of thinking, including the ranting things.
Thank you very much for your tutorials. They have been incredibly helpful and insightful.
I just have to say that, even as someone with a Master's in Data Science from a top university, I still use your tutorials for my work and my projects. Your stuff is incredibly helpful from a practical perspective. In school, they teach you theory with little to no instruction on how to actually build anything. Thank you so much for your hard work!!
Great tutorials!! I can understand how to work on VAE!! ☺☺☺☺
why machine learning is easy to learn? Because a lot of amazing guys are making videos about explaining papers and writing codes line by line.
I like the thought process. So, thanks for the 'from scratch' tutorials.
great explanation, thanks!