Tutorial on Image Captioning Using PyTorch

Posted by


In this tutorial, we will guide you through building an image captioning system using PyTorch. Image captioning is a challenging task that involves generating a textual description of an image. This tutorial will help you understand the implementation details of an image captioning system and train a model to generate captions for images.

Step 1: Install the necessary libraries
Before we begin, make sure you have PyTorch and torchvision installed on your system. You can install PyTorch and torchvision using pip:

pip install torch torchvision

Step 2: Download the dataset
For this tutorial, we will use the Flickr30k dataset, which consists of 31,783 images with 158,915 captions. You can download the dataset from the following link: https://www.kaggle.com/marisakamisha/flicker30k

Step 3: Preprocess the data
Next, we will preprocess the data by resizing the images and tokenizing the captions. We will use the torchvision library to load and preprocess the images, and the NLTK library to tokenize the captions. Make sure you have NLTK installed on your system:

pip install nltk

Here is the code to preprocess the data:

import torchvision.transforms as transforms
from torchvision import datasets
from nltk.tokenize import word_tokenize
import nltk
nltk.download('punkt')

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_data = datasets.ImageFolder('path/to/train/images/', transform=transform)
val_data = datasets.ImageFolder('path/to/val/images/', transform=transform)

# Tokenize captions
train_captions = preprocess_captions(train_data)
val_captions = preprocess_captions(val_data)

def preprocess_captions(data):
    captions = []
    for idx in range(len(data)):
        captions.extend(data[idx]['captions'])
    return captions

Step 4: Build the model
We will build a CNN-LSTM model for image captioning. The CNN will extract features from the image, and the LSTM will generate captions based on these features. Here is the code to build the model:

import torch
import torch.nn as nn
import torchvision.models as models

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 512)

    def forward(self, x):
        return self.resnet(x)

class LSTM(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers):
        super(LSTM, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs

Step 5: Train the model
Now we will train the CNN-LSTM model on the dataset. We will use the Adam optimizer and the CrossEntropyLoss function to train the model. Here is the code to train the model:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cnn = CNN().to(device)
lstm = LSTM(vocab_size, embed_size, hidden_size, num_layers).to(device)

criterion = nn.CrossEntropyLoss()
params = list(cnn.parameters()) + list(lstm.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate)

for epoch in range(num_epochs):
    for images, captions in dataloader:
        images = images.to(device)
        captions = captions.to(device)

        features = cnn(images)
        outputs = lstm(features, captions)

        optimizer.zero_grad()
        loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))
        loss.backward()
        optimizer.step()

        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

Step 6: Generate captions
Finally, we will use the trained model to generate captions for test images. Here is the code to generate captions:

def generate_caption(image):
    with torch.no_grad():
        image = transform(image).unsqueeze(0).to(device)
        features = cnn(image)
        captions = torch.LongTensor([[token2idx['<start>']]]).to(device)

        caption = []

        for _ in range(max_len):
            outputs = lstm(features, captions)
            predicted = outputs.argmax(2)[-1].item()
            caption.append(idx2token[predicted])

            if predicted == token2idx['<end>']:
                break

            captions = torch.cat((captions, torch.LongTensor([[predicted]]).to(device)), dim=1)

        return ' '.join(caption)

Congratulations! You have successfully built an image captioning system using PyTorch. By following this tutorial, you have learned how to preprocess data, build a model, train the model, and generate captions for images. Feel free to experiment with different architectures and hyperparameters to improve the performance of the image captioning system.

0 0 votes
Article Rating
30 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@HARIS-q3n
2 months ago

since you feed the feature vector at timestamp-0 so at inference time we also only feed the feature-vector at timestamp-0 we not have to provide the start token in the test phase

@HARIS-q3n
2 months ago

i HAVE A QUESTION THAT IF WE ARE TRAINING THE MODEL IS BATCHES THEN WE CANNOT USE THE LOGIC OF BREAKING THE LOOP IF IT PREDICT THE END TOKEN SINCE THE END TOKEN POISTION MAY VARIES FOR EACH CAPTION WITH IN THE BATCH SO WHAT THE SOL FOR THAT
THANKS

@muntazirmehdi5250
2 months ago

Where i can get the loader file

@rish77778
2 months ago

can you show us how to inference with this model. you did not show the code

@andsoehd277
2 months ago

make video about automated hair removal from dermoscopy image please

@user-ss2ss3rk3u
2 months ago

Thanks alot! one important question:
In the training loop the loss is calculated from scores and the captions which are the target.
there is no shifting to the right of the target captions. Without doing so how does the model still knows to learn the next word? Is there an internal pytorch method that does so implicitly? I tried to look and i dont understand how in this way the loss can be calculated in a way such the model would learn to predict the next word

@sairajdas6692
2 months ago

Where & when is the caption_image method getting called ?

@ashkankhademian600
2 months ago

How is it that you are so good at explaining?
Keep up the good work champ.

@arpitajain6747
2 months ago

how to execute this in colab

@KunalSaini97
2 months ago

kyun chutiya bna raha bhai ye nahi chal rha

@AR22001144
2 months ago

Why is the image and captions concatenated and sent to LSTM ?

@gyanratna7357
2 months ago

How to get the datasets?

@oskarjung6738
2 months ago

That was a very Aladdin tutorial, thank you!

@NutSorting
2 months ago

Awesome tutorial, followed it till the end. I have a question, where do we split the training and test set? and how as there are image data and caption data too. Can you help me with that?

@abhisekpanigrahi1033
2 months ago

Hello Aladdin, I have a question. The number of images in the dataset is around 8090 and the batch size I selected as 32. So the total number of batches in each epoch should be 253. but the when I load the data and check length of data loader it shows 1265. I don't understand this. Can you please explain if you have any idea. I have never seen this.

@Glitch40417
2 months ago

Hey, did anyone get a good model
Cause I did like 40 epochs and the print_examples are giving me the same answer again and again. If anyone did get a good model pls do reply how many epochs did you run to get a good model.
BTW awesome video, really helpful

@minhct2511
2 months ago

hello sir can you guide me how to run this code. I'm new to Python so still bad at it :(( (nice video anyway)

@Bobobhehe
2 months ago

Awesome complete tutorial, thank you.

@pedramkhoshnevis
2 months ago

Thank you. Can you add the requirment.txt file so we know what are the versions of each library?

@kenyanow8274
2 months ago

HI Aladdin, thanks for the video and GitHub link. I've gone through your code and have entered it into Jupyter. The program gets through just about everything, and then right before it trains it stops and gives me the following error message: "ValueError: too many values to unpack (expected 2)". I'm really at a lose here. Just wondering if you could provide a recommendation? I know this has happened to a few other people, but it's odd that it's not a universal issue. I am using my own dataset. Thanks a lot in advance.