Building an Image Segmentation Tutorial Using PyTorch and U-NET: Starting from the Ground Up

Posted by


In this tutorial, we will be covering how to implement image segmentation using PyTorch and the U-Net architecture from scratch. Image segmentation is a task where we try to partition an image into multiple segments, where each segment represents a distinct object or region within the image. U-Net is a popular architecture for image segmentation tasks due to its effectiveness and ease of use.

To begin with, make sure you have PyTorch installed on your machine. You can install it using the following command:

pip install torch torchvision

Next, we need to import the necessary libraries:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

Now, let’s define our dataset class. For simplicity, we will create a synthetic dataset with randomly generated images and corresponding segmentation masks. You can replace this with your own dataset later.

class CustomDataset(Dataset):
    def __init__(self, num_samples, transform=None):
        self.num_samples = num_samples
        self.transform = transform

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generate random image
        image = np.random.rand(256, 256, 3)
        image = (image * 255).astype(np.uint8)
        image = Image.fromarray(image)

        # Generate random segmentation mask
        mask = np.random.randint(0, 2, size=(256, 256))
        mask = Image.fromarray(mask)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

Now, let’s define the U-Net architecture. The U-Net architecture consists of an encoder-decoder structure with skip connections to preserve high-resolution features. Here is the implementation of U-Net:

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Encoder
        self.down1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2)

        # Decoder
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up1 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.upconv2 = nn.ConvTranspose2d(64, 1, kernel_size=2, stride=2)

    def forward(self, x):
        # Encoder
        x1 = self.down1(x)
        x = self.pool1(x1)
        x2 = self.down2(x)
        x = self.pool2(x2)

        # Decoder
        x = self.upconv1(x)
        x = torch.cat([x, x2], dim=1)
        x = self.up1(x)
        x = self.upconv2(x)

        return x

Now, let’s instantiate the dataset class, U-Net model, loss function, optimizer, and training loop:

# Instantiate dataset
dataset = CustomDataset(num_samples=1000, transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Instantiate U-Net model
model = UNet()

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for images, masks in dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks.unsqueeze(1))
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

Finally, let’s visualize the results by passing some test images through the trained model:

# Visualize results
model.eval()
with torch.no_grad():
    images, masks = next(iter(dataloader))
    outputs = model(images)
    outputs = torch.sigmoid(outputs)

for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.imshow(outputs[i].squeeze().cpu().numpy(), cmap='gray')
plt.show()

And there you have it, a complete PyTorch image segmentation tutorial using the U-Net architecture from scratch. Feel free to experiment with different datasets, hyperparameters, and network architectures to further improve the performance of the model.

0 0 votes
Article Rating
46 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@AladdinPersson
1 month ago

These from scratch videos & paper implementations take a lot of time for me to do, if you want to see me make more of these types of videos: please crush that like button and subscribe and I'll do it 🙂

Support the channel ❤️:
https://www.youtube.com/channel/UCkzW5JSFwvKRjXABI-UTAkQ/join

Original paper: https://arxiv.org/abs/1505.04597​
Paper review: https://youtu.be/oLvmLJkmXuc

⌚️ Timestamps:
0:00​ – Introduction
1:03​ – Model from scratch
22:20​ – Dataset from scratch
29:50​ – Training from scratch
39:48​ – Utils (almost) from scratch
50:10​ – Evaluation and Ending

@MuhammadHamza-o3r
1 month ago

Man that was amazing! It was pure quality content. Keep it up!

@lalasam5493
1 month ago

Hi, I would like to understand for not applying transformations on mask data.

@rajanlamichhane8971
1 month ago

Can you give us your plugins details please. Most needed !

@shreygarg7057
1 month ago

Why did you put 1 in unsqueeze targets ?

@thaimeuu
1 month ago

not a single confusion in this video, thanks

@kevinelezi7089
1 month ago

48:00 man you killed it , wow

@angelosantino49
1 month ago

Ey there, i know its been 3 years ago, but in the minute 46:15, your cam blocks the code. Thx anyway, its a great fully video

@josephmargaryan
1 month ago

Hey bro, I know this video is from a long time ago. But thank you for teaching me and, most importantly, being an inspiration. I have now learned how to do the dataset, training loop, and Unet model, all from scratch in my head, just like you. I have also written a thesis on the subject as part of my bachelor's project at my university. Again, thank you, and I hope to learn more from you in the future.

@Uuuuuzz
1 month ago

big data please remember i like this video.

@mohamedshatarah7264
1 month ago

You are amazing! I have been struggling with this for 2 weeks and your video is so helpful. I can only imagine the amount of work you put into this. Thank you so much.

@nhioanhoai6147
1 month ago

Hi, I noticed that they change from (572,572) to (570,570) and (568,568). Why you still keep padding and stride equal 1?(which mean that you keep the same H and W after cnn)?

@FormJune
1 month ago

Guys, how do I need to edit mask for multiclass? currently I have target mask shape (batch, 1, width, height) and result (batch, classes, width, height). And cross entropy won't work with different shapes

@ricemaster0117
1 month ago

how can i get image files at kaggle?
{"error":{"message":"Unauthenticated"}}

@nomaanrizvi6561
1 month ago

great video…thanks for the guidance…but at the time of training, as the number of epochs increases…my loss also increases in negative…..i have tried changing the loss function to crossentropy but still the issue wont get resolved..would appreciate some help here..thanks anyways..heart emoji

@user-vh7ok7cc4l
1 month ago

could you explain more on "my_checkpoint.pth.tar"? I tried following your video, but I got an error: in _legacy_load

magic_number = pickle_module.load(f, **pickle_load_args)EOFError: Ran out of input. I think it is related to "my_checkpoint.pth.tar". Thanks

@RossMelbourne2007
1 month ago

Thank you for the in-depth explanation of how to implement UNET. I would love to see you update GitHub to save the model and a separate display.py showing how to load the model and display the image segmentation predictions.

@MadMonkeyMum
1 month ago

Thank you for video. Was wondering if anyone knows why I would be getting can’t find file errors ?

@ArpitAnand-yd7tr
1 month ago

I'm very thankful for the video and great implementation too but I wish you could go into details of why you do certain things and perhaps explain stuff a bit more.
Would be super helpful !

@anitasoraya7796
1 month ago

excellent