docs/graphs/gatv2/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/graphs/gatv2/ init.py)
This is a PyTorch implementation of the GATv2 operator from the paper How Attentive are Graph Attention Networks?.
GATv2s work on graph data similar to GAT. A graph consists of nodes and edges connecting nodes. For example, in Cora dataset the nodes are research papers and the edges are citations that connect the papers.
The GATv2 operator fixes the static attention problem of the standard GAT. Static attention is when the attention to the key nodes has the same rank (order) for any query node. GAT computes attention from query node i to key node j as,
eij=LeakyReLU(a⊤[Whi∥Whj])=LeakyReLU(a1⊤Whi+a2⊤Whj)
Note that for any query node i, the attention rank (argsort) of keys depends only on a2⊤Whj. Therefore the attention rank of keys remains the same (static) for all queries.
GATv2 allows dynamic attention by changing the attention mechanism,
eij=a⊤LeakyReLU(W[hi∥hj])=a⊤LeakyReLU(Wlhi+Wrhj)
The paper shows that GATs static attention mechanism fails on some graph problems with a synthetic dictionary lookup dataset. It's a fully connected bipartite graph where one set of nodes (query nodes) have a key associated with it and the other set of nodes have both a key and a value associated with it. The goal is to predict the values of query nodes. GAT fails on this task because of its limited static attention.
Here is the training code for training a two-layer GATv2 on Cora dataset.
57importtorch58fromtorchimportnn
This is a single graph attention v2 layer. A GATv2 is made up of multiple such layers. It takes h={h1,h2,…,hN}, where hi∈RF as input and outputs h′={h1′,h2′,…,hN′}, where hi′∈RF′.
62classGraphAttentionV2Layer(nn.Module):
in_features , F, is the number of input features per nodeout_features , F′, is the number of output features per noden_heads , K, is the number of attention headsis_concat whether the multi-head results should be concatenated or averageddropout is the dropout probabilityleaky_relu_negative_slope is the negative slope for leaky relu activationshare_weights if set to True , the same matrix will be applied to the source and the target node of every edge75def\_\_init\_\_(self,in\_features:int,out\_features:int,n\_heads:int,76is\_concat:bool=True,77dropout:float=0.6,78leaky\_relu\_negative\_slope:float=0.2,79share\_weights:bool=False):
89super().\_\_init\_\_()9091self.is\_concat=is\_concat92self.n\_heads=n\_heads93self.share\_weights=share\_weights
Calculate the number of dimensions per head
96ifis\_concat:97assertout\_features%n\_heads==0
If we are concatenating the multiple heads
99self.n\_hidden=out\_features//n\_heads100else:
If we are averaging the multiple heads
102self.n\_hidden=out\_features
Linear layer for initial source transformation; i.e. to transform the source node embeddings before self-attention
106self.linear\_l=nn.Linear(in\_features,self.n\_hidden\*n\_heads,bias=False)
If share_weights is True the same linear layer is used for the target nodes
108ifshare\_weights:109self.linear\_r=self.linear\_l110else:111self.linear\_r=nn.Linear(in\_features,self.n\_hidden\*n\_heads,bias=False)
Linear layer to compute attention score eij
113self.attn=nn.Linear(self.n\_hidden,1,bias=False)
The activation for attention score eij
115self.activation=nn.LeakyReLU(negative\_slope=leaky\_relu\_negative\_slope)
Softmax to compute attention αij
117self.softmax=nn.Softmax(dim=1)
Dropout layer to be applied for attention
119self.dropout=nn.Dropout(dropout)
h , h is the input node embeddings of shape [n_nodes, in_features] .adj_mat is the adjacency matrix of shape [n_nodes, n_nodes, n_heads] . We use shape [n_nodes, n_nodes, 1] since the adjacency is the same for each head. Adjacency matrix represent the edges (or connections) among nodes. adj_mat[i][j] is True if there is an edge from node i to node j .121defforward(self,h:torch.Tensor,adj\_mat:torch.Tensor):
Number of nodes
131n\_nodes=h.shape[0]
The initial transformations, glik=Wlkhi grik=Wrkhi for each head. We do two linear transformations and then split it up for each head.
137g\_l=self.linear\_l(h).view(n\_nodes,self.n\_heads,self.n\_hidden)138g\_r=self.linear\_r(h).view(n\_nodes,self.n\_heads,self.n\_hidden)
We calculate these for each head k. We have omitted ⋅k for simplicity.
eij=a(Wlhi,Wrhj)=a(gli,grj)
eij is the attention score (importance) from node j to node i. We calculate this for each head.
a is the attention mechanism, that calculates the attention score. The paper sums gli, grj followed by a LeakyReLU and does a linear transformation with a weight vector a∈RF′
eij=a⊤LeakyReLU([gli+grj]) Note: The paper desrcibes eij as eij=a⊤LeakyReLU(W[hi∥hj]) which is equivalent to the definition we use here.
First we calculate [gli+grj] for all pairs of i,j.
g_l_repeat gets {gl1,gl2,…,glN,gl1,gl2,…,glN,...} where each node embedding is repeated n_nodes times.
176g\_l\_repeat=g\_l.repeat(n\_nodes,1,1)
g_r_repeat_interleave gets {gr1,gr1,…,gr1,gr2,gr2,…,gr2,...} where each node embedding is repeated n_nodes times.
181g\_r\_repeat\_interleave=g\_r.repeat\_interleave(n\_nodes,dim=0)
Now we add the two tensors to get {gl1+gr1,gl1+gr2,…,gl1+grN,gl2+gr1,gl2+gr2,…,gl2+grN,...}
189g\_sum=g\_l\_repeat+g\_r\_repeat\_interleave
Reshape so that g_sum[i, j] is gli+grj
191g\_sum=g\_sum.view(n\_nodes,n\_nodes,self.n\_heads,self.n\_hidden)
Calculate eij=a⊤LeakyReLU([gli+grj]) e is of shape [n_nodes, n_nodes, n_heads, 1]
199e=self.attn(self.activation(g\_sum))
Remove the last dimension of size 1
201e=e.squeeze(-1)
The adjacency matrix should have shape [n_nodes, n_nodes, n_heads] or[n_nodes, n_nodes, 1]
205assertadj\_mat.shape[0]==1oradj\_mat.shape[0]==n\_nodes206assertadj\_mat.shape[1]==1oradj\_mat.shape[1]==n\_nodes207assertadj\_mat.shape[2]==1oradj\_mat.shape[2]==self.n\_heads
Mask eij based on adjacency matrix. eij is set to −∞ if there is no edge from i to j.
210e=e.masked\_fill(adj\_mat==0,float('-inf'))
We then normalize attention scores (or coefficients) αij=softmaxj(eij)=∑j′∈Niexp(eij′)exp(eij)
where Ni is the set of nodes connected to i.
We do this by setting unconnected eij to −∞ which makes exp(eij)∼0 for unconnected pairs.
220a=self.softmax(e)
Apply dropout regularization
223a=self.dropout(a)
Calculate final output for each head hi′k=j∈Ni∑αijkgrj,k
227attn\_res=torch.einsum('ijh,jhf-\>ihf',a,g\_r)
Concatenate the heads
230ifself.is\_concat:
hi′=∥∥k=1Khi′k
232returnattn\_res.reshape(n\_nodes,self.n\_heads\*self.n\_hidden)
Take the mean of the heads
234else:
hi′=K1k=1∑Khi′k
236returnattn\_res.mean(dim=1)