Generating Joy with PyTorch: Applying Arithmetic Operations in VAE Latent Space (Example Code)

Posted by


In this tutorial, we will explore how to use Variational Autoencoders (VAEs) in PyTorch to perform arithmetic operations in the latent space. Specifically, we will focus on a fun application of VAEs – making people smile in images. We will leverage the power of VAEs to generate new images of people smiling by manipulating the latent space vectors associated with their facial expressions.

  1. Setting up the Environment:
    Before we begin, make sure you have PyTorch installed. You can install PyTorch using pip:

    pip install torch torchvision

    Also, you will need to have the following libraries installed:

    matplotlib
    numpy
    PIL
  2. Loading the Dataset:
    For this tutorial, we will use the CelebA dataset, which contains over 200,000 celebrity images with annotations for various attributes such as "smiling". You can download the CelebA dataset from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html.

In this tutorial, we will only focus on the "smiling" attribute. We will load the dataset using PyTorch’s DataLoader and filter out the images of smiling faces.

from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load CelebA dataset
celeba_dataset = datasets.CelebA(root='data', download=True, transform=transform)

# Filter out images of smiling faces
smiling_indices = celeba_dataset.attr[:, celeba_dataset.attr_names.index('Smiling')].nonzero().squeeze()
smiling_dataset = torch.utils.data.Subset(celeba_dataset, smiling_indices)
  1. Building the VAE Model:
    Next, we will define the VAE model that will be used for generating new images of smiling faces. Here, we will use a simple VAE architecture with a fully-connected encoder and decoder.

    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

class VAE(nn.Module):
def init(self, latent_dim=100, hidden_dim=256):
super(VAE, self).init()
self.latent_dim = latent_dim

    # Encoder
    self.fc1 = nn.Linear(3 * 64 * 64, hidden_dim)
    self.fc21 = nn.Linear(hidden_dim, latent_dim)
    self.fc22 = nn.Linear(hidden_dim, latent_dim)

    # Decoder
    self.fc3 = nn.Linear(latent_dim, hidden_dim)
    self.fc4 = nn.Linear(hidden_dim, 3 * 64 * 64)

def encode(self, x):
    x = x.view(-1, 3 * 64 * 64)
    h = F.relu(self.fc1(x))
    return self.fc21(h), self.fc22(h)

def reparameterize(self, mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std

def decode(self, z):
    h = F.relu(self.fc3(z))
    return torch.sigmoid(self.fc4(h))

def forward(self, x):
    mu, logvar = self.encode(x)
    z = self.reparameterize(mu, logvar)
    return self.decode(z), mu, logvar

Instantiate the VAE model

vae = VAE()


4. Training the VAE Model:
Now, we will train the VAE model using the smiling faces dataset. We will define the loss function as the reconstruction loss plus the KL divergence term.
```python
import torch.optim as optim
from torch.utils.data import DataLoader

# Define the loss function
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 3 * 64 * 64), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Set up the optimizer
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

# Train the VAE model
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

# Set up the data loader
train_loader = DataLoader(smiling_dataset, batch_size=128, shuffle=True)

# Train the VAE model
for epoch in range(1, n_epochs + 1):
    train(epoch)
  1. Generating Smiling Faces:
    Once the VAE model is trained, we can generate new images of people smiling by manipulating the latent space vectors associated with their facial expressions. We can do this by performing arithmetic operations in the latent space.

    
    import matplotlib.pyplot as plt

Generate new images of smiling faces

def generate_images(n=10):
vae.eval()
with torch.no_grad():
z = torch.randn(n, vae.latent_dim)
images = vae.decode(z).view(-1, 3, 64, 64)
images = images.permute(0, 2, 3, 1).cpu().numpy()

    fig, axs = plt.subplots(1, n, figsize=(20, 10))
    for i in range(n):
        axs[i].imshow(images[i])
        axs[i].axis('off')

plt.show()

Generate new images of smiling faces

generate_images()



In this tutorial, we have explored how to use Variational Autoencoders (VAEs) to generate new images of people smiling by manipulating the latent space vectors associated with their facial expressions. By leveraging the power of VAEs, we can create fun and creative applications like making people smile in images. Experiment with different latent space arithmetic operations to create a wide range of facial expressions and have fun exploring the possibilities of VAEs in PyTorch!
0 0 votes
Article Rating

Leave a Reply

5 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@kushsheth4801
19 days ago

Hi Sir, I would like to tell you this is one of the finest videos on Deep Learning on YouTube i really really appreciate your hard work and would like to chat with you someday!

@masonholcombe3327
19 days ago

Absolutely loved this VAE series, currently in this course right now. Would also love to see a series on DDPMs

@sunnysolanki
19 days ago

Now, I know how people sharing posts on linked about changing face emotions were doing it. This is quite informative. Thanks for the wonderful work.

@c.nbhaskar4718
19 days ago

you are awesome sir ……

@MashrurMorshed
19 days ago

Really loved your series on VAEs, I understand them quite well now!

5
0
Would love your thoughts, please comment.x
()
x