In this tutorial, we will be covering how to implement image segmentation using PyTorch and the U-Net architecture from scratch. Image segmentation is a task where we try to partition an image into multiple segments, where each segment represents a distinct object or region within the image. U-Net is a popular architecture for image segmentation tasks due to its effectiveness and ease of use.
To begin with, make sure you have PyTorch installed on your machine. You can install it using the following command:
pip install torch torchvision
Next, we need to import the necessary libraries:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
Now, let’s define our dataset class. For simplicity, we will create a synthetic dataset with randomly generated images and corresponding segmentation masks. You can replace this with your own dataset later.
class CustomDataset(Dataset):
def __init__(self, num_samples, transform=None):
self.num_samples = num_samples
self.transform = transform
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Generate random image
image = np.random.rand(256, 256, 3)
image = (image * 255).astype(np.uint8)
image = Image.fromarray(image)
# Generate random segmentation mask
mask = np.random.randint(0, 2, size=(256, 256))
mask = Image.fromarray(mask)
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
return image, mask
Now, let’s define the U-Net architecture. The U-Net architecture consists of an encoder-decoder structure with skip connections to preserve high-resolution features. Here is the implementation of U-Net:
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
# Encoder
self.down1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU()
)
self.pool1 = nn.MaxPool2d(2)
self.down2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU()
)
self.pool2 = nn.MaxPool2d(2)
# Decoder
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.up1 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU()
)
self.upconv2 = nn.ConvTranspose2d(64, 1, kernel_size=2, stride=2)
def forward(self, x):
# Encoder
x1 = self.down1(x)
x = self.pool1(x1)
x2 = self.down2(x)
x = self.pool2(x2)
# Decoder
x = self.upconv1(x)
x = torch.cat([x, x2], dim=1)
x = self.up1(x)
x = self.upconv2(x)
return x
Now, let’s instantiate the dataset class, U-Net model, loss function, optimizer, and training loop:
# Instantiate dataset
dataset = CustomDataset(num_samples=1000, transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
# Instantiate U-Net model
model = UNet()
# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
model.train()
for images, masks in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks.unsqueeze(1))
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')
Finally, let’s visualize the results by passing some test images through the trained model:
# Visualize results
model.eval()
with torch.no_grad():
images, masks = next(iter(dataloader))
outputs = model(images)
outputs = torch.sigmoid(outputs)
for i in range(3):
plt.subplot(1, 3, i+1)
plt.imshow(outputs[i].squeeze().cpu().numpy(), cmap='gray')
plt.show()
And there you have it, a complete PyTorch image segmentation tutorial using the U-Net architecture from scratch. Feel free to experiment with different datasets, hyperparameters, and network architectures to further improve the performance of the model.
These from scratch videos & paper implementations take a lot of time for me to do, if you want to see me make more of these types of videos: please crush that like button and subscribe and I'll do it 🙂
Support the channel ❤️:
https://www.youtube.com/channel/UCkzW5JSFwvKRjXABI-UTAkQ/join
Original paper: https://arxiv.org/abs/1505.04597
Paper review: https://youtu.be/oLvmLJkmXuc
⌚️ Timestamps:
0:00 – Introduction
1:03 – Model from scratch
22:20 – Dataset from scratch
29:50 – Training from scratch
39:48 – Utils (almost) from scratch
50:10 – Evaluation and Ending
Man that was amazing! It was pure quality content. Keep it up!
Hi, I would like to understand for not applying transformations on mask data.
Can you give us your plugins details please. Most needed !
Why did you put 1 in unsqueeze targets ?
not a single confusion in this video, thanks
48:00 man you killed it , wow
Ey there, i know its been 3 years ago, but in the minute 46:15, your cam blocks the code. Thx anyway, its a great fully video
Hey bro, I know this video is from a long time ago. But thank you for teaching me and, most importantly, being an inspiration. I have now learned how to do the dataset, training loop, and Unet model, all from scratch in my head, just like you. I have also written a thesis on the subject as part of my bachelor's project at my university. Again, thank you, and I hope to learn more from you in the future.
big data please remember i like this video.
You are amazing! I have been struggling with this for 2 weeks and your video is so helpful. I can only imagine the amount of work you put into this. Thank you so much.
Hi, I noticed that they change from (572,572) to (570,570) and (568,568). Why you still keep padding and stride equal 1?(which mean that you keep the same H and W after cnn)?
Guys, how do I need to edit mask for multiclass? currently I have target mask shape (batch, 1, width, height) and result (batch, classes, width, height). And cross entropy won't work with different shapes
how can i get image files at kaggle?
{"error":{"message":"Unauthenticated"}}
great video…thanks for the guidance…but at the time of training, as the number of epochs increases…my loss also increases in negative…..i have tried changing the loss function to crossentropy but still the issue wont get resolved..would appreciate some help here..thanks anyways..heart emoji
could you explain more on "my_checkpoint.pth.tar"? I tried following your video, but I got an error: in _legacy_load
magic_number = pickle_module.load(f, **pickle_load_args)EOFError: Ran out of input. I think it is related to "my_checkpoint.pth.tar". Thanks
Thank you for the in-depth explanation of how to implement UNET. I would love to see you update GitHub to save the model and a separate display.py showing how to load the model and display the image segmentation predictions.
Thank you for video. Was wondering if anyone knows why I would be getting can’t find file errors ?
I'm very thankful for the video and great implementation too but I wish you could go into details of why you do certain things and perhaps explain stuff a bit more.
Would be super helpful !
excellent