Tutorial-4: Building Graph Convolutional Networks (GCN) and Graph Attention Networks (GAT) from scratch in PyTorch and PyTorch Geometric

Posted by


In this tutorial, we will focus on implementing Graph Convolutional Networks (GCNs) and Graph Attention Networks (GATs) using PyTorch from scratch, as well as using PyTorch Geometric, which provides powerful tools and pre-implemented modules for implementing graph neural networks.

Graph Neural Networks (GNNs) have gained popularity in recent years due to their ability to learn representations of graph-structured data. GCNs and GATs are two popular architectures for GNNs that have been shown to achieve state-of-the-art performance on various graph-based tasks.

We will first implement a simple version of GCN from scratch using PyTorch. Then, we will implement a GAT model from scratch. Finally, we will use PyTorch Geometric to implement both GCN and GAT and compare the results.

Let’s get started with the implementation of GCN from scratch:

  1. Implementing GCN from scratch:

First, we need to define the GCN layer. Here’s the code to implement the GCN layer:

import torch
import torch.nn as nn

class GCNLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x, adj_matrix):
        x = self.linear(x)
        x = torch.matmul(adj_matrix, x)
        return x

In the code above, we define a GCNLayer class that takes the input dimension and output dimension as input and initializes a linear layer. In the forward method, we apply the linear transformation followed by the graph convolution operation by multiplying the input features with the adjacency matrix.

Next, we will define the GCN model using multiple GCN layers. Here’s the code to implement the GCN model:

class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.gcn1 = GCNLayer(input_dim, hidden_dim)
        self.gcn2 = GCNLayer(hidden_dim, output_dim)

    def forward(self, x, adj_matrix):
        x = self.gcn1(x, adj_matrix)
        x = torch.relu(x)
        x = self.gcn2(x, adj_matrix)
        return x

In the code above, we define a GCN class that takes the input dimension, hidden dimension, and output dimension as input and initializes two GCN layers. In the forward method, we apply two GCN layers sequentially followed by a ReLU activation function.

Now, let’s move on to implementing GAT from scratch:

  1. Implementing GAT from scratch:

First, we need to define the GAT layer. Here’s the code to implement the GAT layer:

class GATLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_heads):
        super(GATLayer, self).__init__()
        self.num_heads = num_heads
        self.heads = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_heads)])

    def forward(self, x, adj_matrix):
        heads_out = [head(x) for head in self.heads]
        x = torch.stack(heads_out, dim=1)
        x = torch.mean(x, dim=1)
        x = torch.matmul(adj_matrix, x)
        return x

In the code above, we define a GATLayer class that takes the input dimension, output dimension, and number of heads as input and initializes multiple linear layers as attention heads. In the forward method, we apply multiple attention heads and aggregate the outputs using mean pooling.

Next, we will define the GAT model using multiple GAT layers. Here’s the code to implement the GAT model:

class GAT(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_heads):
        super(GAT, self).__init__()
        self.gat1 = GATLayer(input_dim, hidden_dim, num_heads)
        self.gat2 = GATLayer(hidden_dim, output_dim, num_heads)

    def forward(self, x, adj_matrix):
        x = self.gat1(x, adj_matrix)
        x = torch.relu(x)
        x = self.gat2(x, adj_matrix)
        return x

In the code above, we define a GAT class that takes the input dimension, hidden dimension, output dimension, and number of heads as input and initializes two GAT layers. In the forward method, we apply two GAT layers sequentially followed by a ReLU activation function.

Now, we will use PyTorch Geometric to implement GCN and GAT:

  1. Using PyTorch Geometric:

PyTorch Geometric provides a wide range of tools and pre-implemented modules for working with graph data. We can use PyTorch Geometric to easily implement GCN and GAT models without having to build them from scratch.

Here’s an example of how to implement a GCN model using PyTorch Geometric:

import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv

class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x

In the code above, we define a GCN class that takes the input dimension, hidden dimension, and output dimension as input and initializes two GCNConv layers from PyTorch Geometric. In the forward method, we apply two GCNConv layers sequentially followed by a ReLU activation function.

Similarly, we can implement a GAT model using PyTorch Geometric as follows:

from torch_geometric.nn import GATConv

class GAT(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_heads):
        super(GAT, self).__init__()
        self.conv1 = GATConv(input_dim, hidden_dim, heads=num_heads)
        self.conv2 = GATConv(hidden_dim*num_heads, output_dim, heads=1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x

In the code above, we define a GAT class that takes the input dimension, hidden dimension, output dimension, and number of heads as input and initializes two GATConv layers from PyTorch Geometric. In the forward method, we apply two GATConv layers sequentially followed by a ReLU activation function.

In conclusion, we have implemented Graph Convolutional Networks (GCNs) and Graph Attention Networks (GATs) from scratch using PyTorch and PyTorch Geometric. PyTorch Geometric provides a convenient way to implement GNN models and makes it easier to work with graph data. You can further explore different variations and extensions of GCNs and GATs to enhance the performance of your graph neural networks. Happy coding!

0 0 votes
Article Rating

Leave a Reply

6 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
@truongsonnguyen144
8 days ago

Can you send me the code in this video, thank you so much.

@parthjoshi5045
8 days ago

what happens if we keep num_heads == 1 , einsum is giving some sort of error for that, any help is appreciated?

@sanjaykrish8719
8 days ago

Excellent implementation and explanation. Can we get the notebook? It would have been great if you could have run. the final tasks in the custom model which you built instead of on PyG

@florianellsaesser4165
8 days ago

Nice explanation. Thank you so much. Which original implementation are you referring to in the video. Thank you!

@PN-eq8oe
8 days ago

Thank you for sharing such an informative video. I've learned a lot from this video. I noticed that you used data.train_mask to compute test_acc (27:52). I just wanted to ask if this was what you intended or if it might have been a mistake. Also, I noticed that you added 'heads' argument in the GAT class but I don't see where it was used. Maybe further modification is required? Thank you again for taking the time to create such great content, and I look forward to seeing more from you in the future.

@Sashx123
8 days ago

it's hard to find GNNs videos for beginners so this is nice ! i've been trying to apply GCNs for text classification, i don't quiet understand how to make train/test split of text nodes (since we can also have word nodes), can you also make videos of GCN for textual data ?

6
0
Would love your thoughts, please comment.x
()
x