docs/graphs/gat/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/graphs/gat/ init.py)
This is a PyTorch implementation of the paper Graph Attention Networks.
GATs work on graph data. A graph consists of nodes and edges connecting nodes. For example, in Cora dataset the nodes are research papers and the edges are citations that connect the papers.
GAT uses masked self-attention, kind of similar to transformers. GAT consists of graph attention layers stacked on top of each other. Each graph attention layer gets node embeddings as inputs and outputs transformed embeddings. The node embeddings pay attention to the embeddings of other nodes it's connected to. The details of graph attention layers are included alongside the implementation.
Here is the training code for training a two-layer GAT on Cora dataset.
28importtorch29fromtorchimportnn
This is a single graph attention layer. A GAT is made up of multiple such layers.
It takes h={h1,h2,…,hN}, where hi∈RF as input and outputs h′={h1′,h2′,…,hN′}, where hi′∈RF′.
32classGraphAttentionLayer(nn.Module):
in_features , F, is the number of input features per nodeout_features , F′, is the number of output features per noden_heads , K, is the number of attention headsis_concat whether the multi-head results should be concatenated or averageddropout is the dropout probabilityleaky_relu_negative_slope is the negative slope for leaky relu activation46def\_\_init\_\_(self,in\_features:int,out\_features:int,n\_heads:int,47is\_concat:bool=True,48dropout:float=0.6,49leaky\_relu\_negative\_slope:float=0.2):
58super().\_\_init\_\_()5960self.is\_concat=is\_concat61self.n\_heads=n\_heads
Calculate the number of dimensions per head
64ifis\_concat:65assertout\_features%n\_heads==0
If we are concatenating the multiple heads
67self.n\_hidden=out\_features//n\_heads68else:
If we are averaging the multiple heads
70self.n\_hidden=out\_features
Linear layer for initial transformation; i.e. to transform the node embeddings before self-attention
74self.linear=nn.Linear(in\_features,self.n\_hidden\*n\_heads,bias=False)
Linear layer to compute attention score eij
76self.attn=nn.Linear(self.n\_hidden\*2,1,bias=False)
The activation for attention score eij
78self.activation=nn.LeakyReLU(negative\_slope=leaky\_relu\_negative\_slope)
Softmax to compute attention αij
80self.softmax=nn.Softmax(dim=1)
Dropout layer to be applied for attention
82self.dropout=nn.Dropout(dropout)
h , h is the input node embeddings of shape [n_nodes, in_features] .adj_mat is the adjacency matrix of shape [n_nodes, n_nodes, n_heads] . We use shape [n_nodes, n_nodes, 1] since the adjacency is the same for each head.Adjacency matrix represent the edges (or connections) among nodes. adj_mat[i][j] is True if there is an edge from node i to node j .
84defforward(self,h:torch.Tensor,adj\_mat:torch.Tensor):
Number of nodes
95n\_nodes=h.shape[0]
The initial transformation, gik=Wkhi for each head. We do single linear transformation and then split it up for each head.
100g=self.linear(h).view(n\_nodes,self.n\_heads,self.n\_hidden)
We calculate these for each head k. We have omitted ⋅k for simplicity.
eij=a(Whi,Whj)=a(gi,gj)
eij is the attention score (importance) from node j to node i. We calculate this for each head.
a is the attention mechanism, that calculates the attention score. The paper concatenates gi, gj and does a linear transformation with a weight vector a∈R2F′ followed by a LeakyReLU.
eij=LeakyReLU(a⊤[gi∥gj])
First we calculate [gi∥gj] for all pairs of i,j.
g_repeat gets {g1,g2,…,gN,g1,g2,…,gN,...} where each node embedding is repeated n_nodes times.
131g\_repeat=g.repeat(n\_nodes,1,1)
g_repeat_interleave gets {g1,g1,…,g1,g2,g2,…,g2,...} where each node embedding is repeated n_nodes times.
136g\_repeat\_interleave=g.repeat\_interleave(n\_nodes,dim=0)
Now we concatenate to get {g1∥g1,g1∥g2,…,g1∥gN,g2∥g1,g2∥g2,…,g2∥gN,...}
144g\_concat=torch.cat([g\_repeat\_interleave,g\_repeat],dim=-1)
Reshape so that g_concat[i, j] is gi∥gj
146g\_concat=g\_concat.view(n\_nodes,n\_nodes,self.n\_heads,2\*self.n\_hidden)
Calculate eij=LeakyReLU(a⊤[gi∥gj]) e is of shape [n_nodes, n_nodes, n_heads, 1]
154e=self.activation(self.attn(g\_concat))
Remove the last dimension of size 1
156e=e.squeeze(-1)
The adjacency matrix should have shape [n_nodes, n_nodes, n_heads] or[n_nodes, n_nodes, 1]
160assertadj\_mat.shape[0]==1oradj\_mat.shape[0]==n\_nodes161assertadj\_mat.shape[1]==1oradj\_mat.shape[1]==n\_nodes162assertadj\_mat.shape[2]==1oradj\_mat.shape[2]==self.n\_heads
Mask eij based on adjacency matrix. eij is set to −∞ if there is no edge from i to j.
165e=e.masked\_fill(adj\_mat==0,float('-inf'))
We then normalize attention scores (or coefficients) αij=softmaxj(eij)=∑k∈Niexp(eik)exp(eij)
where Ni is the set of nodes connected to i.
We do this by setting unconnected eij to −∞ which makes exp(eij)∼0 for unconnected pairs.
175a=self.softmax(e)
Apply dropout regularization
178a=self.dropout(a)
Calculate final output for each head hi′k=j∈Ni∑αijkgjk
Note: The paper includes the final activation σ in hi We have omitted this from the Graph Attention Layer implementation and use it on the GAT model to match with how other PyTorch modules are defined - activation as a separate layer.
187attn\_res=torch.einsum('ijh,jhf-\>ihf',a,g)
Concatenate the heads
190ifself.is\_concat:
hi′=∥∥k=1Khi′k
192returnattn\_res.reshape(n\_nodes,self.n\_heads\*self.n\_hidden)
Take the mean of the heads
194else:
hi′=K1k=1∑Khi′k
196returnattn\_res.mean(dim=1)