Back to Annotated Deep Learning Paper Implementations

Train a Graph Attention Network (GAT) on Cora dataset

docs/graphs/gat/experiment.html

latest10.9 KB
Original Source

homegraphsgat

View code on Github

#

Train a Graph Attention Network (GAT) on Cora dataset

11fromtypingimportDict1213importnumpyasnp14importtorch15fromtorchimportnn1617fromlabmlimportlab,monit,tracker,experiment18fromlabml.configsimportBaseConfigs,option,calculate19fromlabml.utilsimportdownload20fromlabml\_nn.helpers.deviceimportDeviceConfigs21fromlabml\_nn.graphs.gatimportGraphAttentionLayer22fromlabml\_nn.optimizers.configsimportOptimizerConfigs

#

Cora Dataset

Cora dataset is a dataset of research papers. For each paper we are given a binary feature vector that indicates the presence of words. Each paper is classified into one of 7 classes. The dataset also has the citation network.

The papers are the nodes of the graph and the edges are the citations.

The task is to classify the nodes to the 7 classes with feature vectors and citation network as input.

25classCoraDataset:

#

Labels for each node

40labels:torch.Tensor

#

Set of class names and an unique integer index

42classes:Dict[str,int]

#

Feature vectors for all nodes

44features:torch.Tensor

#

Adjacency matrix with the edge information. adj_mat[i][j] is True if there is an edge from i to j .

47adj\_mat:torch.Tensor

#

Download the dataset

49@staticmethod50def\_download():

#

54ifnot(lab.get\_data\_path()/'cora').exists():55download.download\_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',56lab.get\_data\_path()/'cora.tgz')57download.extract\_tar(lab.get\_data\_path()/'cora.tgz',lab.get\_data\_path())

#

Load the dataset

59def\_\_init\_\_(self,include\_edges:bool=True):

#

Whether to include edges. This is test how much accuracy is lost if we ignore the citation network.

66self.include\_edges=include\_edges

#

Download dataset

69self.\_download()

#

Read the paper ids, feature vectors, and labels

72withmonit.section('Read content file'):73content=np.genfromtxt(str(lab.get\_data\_path()/'cora/cora.content'),dtype=np.dtype(str))

#

Load the citations, it's a list of pairs of integers.

75withmonit.section('Read citations file'):76citations=np.genfromtxt(str(lab.get\_data\_path()/'cora/cora.cites'),dtype=np.int32)

#

Get the feature vectors

79features=torch.tensor(np.array(content[:,1:-1],dtype=np.float32))

#

Normalize the feature vectors

81self.features=features/features.sum(dim=1,keepdim=True)

#

Get the class names and assign an unique integer to each of them

84self.classes={s:ifori,sinenumerate(set(content[:,-1]))}

#

Get the labels as those integers

86self.labels=torch.tensor([self.classes[i]foriincontent[:,-1]],dtype=torch.long)

#

Get the paper ids

89paper\_ids=np.array(content[:,0],dtype=np.int32)

#

Map of paper id to index

91ids\_to\_idx={id\_:ifori,id\_inenumerate(paper\_ids)}

#

Empty adjacency matrix - an identity matrix

94self.adj\_mat=torch.eye(len(self.labels),dtype=torch.bool)

#

Mark the citations in the adjacency matrix

97ifself.include\_edges:98foreincitations:

#

The pair of paper indexes

100e1,e2=ids\_to\_idx[e[0]],ids\_to\_idx[e[1]]

#

We build a symmetrical graph, where if paper i referenced paper j we place an adge from i to j as well as an edge from j to i.

104self.adj\_mat[e1][e2]=True105self.adj\_mat[e2][e1]=True

#

Graph Attention Network (GAT)

This graph attention network has two graph attention layers.

108classGAT(nn.Module):

#

  • in_features is the number of features per node
  • n_hidden is the number of features in the first graph attention layer
  • n_classes is the number of classes
  • n_heads is the number of heads in the graph attention layers
  • dropout is the dropout probability
115def\_\_init\_\_(self,in\_features:int,n\_hidden:int,n\_classes:int,n\_heads:int,dropout:float):

#

123super().\_\_init\_\_()

#

First graph attention layer where we concatenate the heads

126self.layer1=GraphAttentionLayer(in\_features,n\_hidden,n\_heads,is\_concat=True,dropout=dropout)

#

Activation function after first graph attention layer

128self.activation=nn.ELU()

#

Final graph attention layer where we average the heads

130self.output=GraphAttentionLayer(n\_hidden,n\_classes,1,is\_concat=False,dropout=dropout)

#

Dropout

132self.dropout=nn.Dropout(dropout)

#

  • x is the features vectors of shape [n_nodes, in_features]
  • adj_mat is the adjacency matrix of the form [n_nodes, n_nodes, n_heads] or [n_nodes, n_nodes, 1]
134defforward(self,x:torch.Tensor,adj\_mat:torch.Tensor):

#

Apply dropout to the input

141x=self.dropout(x)

#

First graph attention layer

143x=self.layer1(x,adj\_mat)

#

Activation function

145x=self.activation(x)

#

Dropout

147x=self.dropout(x)

#

Output layer (without activation) for logits

149returnself.output(x,adj\_mat)

#

A simple function to calculate the accuracy

152defaccuracy(output:torch.Tensor,labels:torch.Tensor):

#

156returnoutput.argmax(dim=-1).eq(labels).sum().item()/len(labels)

#

Configurations

159classConfigs(BaseConfigs):

#

Model

165model:GAT

#

Number of nodes to train on

167training\_samples:int=500

#

Number of features per node in the input

169in\_features:int

#

Number of features in the first graph attention layer

171n\_hidden:int=64

#

Number of heads

173n\_heads:int=8

#

Number of classes for classification

175n\_classes:int

#

Dropout probability

177dropout:float=0.6

#

Whether to include the citation network

179include\_edges:bool=True

#

Dataset

181dataset:CoraDataset

#

Number of training iterations

183epochs:int=1\_000

#

Loss function

185loss\_func=nn.CrossEntropyLoss()

#

Device to train on

This creates configs for device, so that we can change the device by passing a config value

190device:torch.device=DeviceConfigs()

#

Optimizer

192optimizer:torch.optim.Adam

#

Training loop

We do full batch training since the dataset is small. If we were to sample and train we will have to sample a set of nodes for each training step along with the edges that span across those selected nodes.

194defrun(self):

#

Move the feature vectors to the device

204features=self.dataset.features.to(self.device)

#

Move the labels to the device

206labels=self.dataset.labels.to(self.device)

#

Move the adjacency matrix to the device

208edges\_adj=self.dataset.adj\_mat.to(self.device)

#

Add an empty third dimension for the heads

210edges\_adj=edges\_adj.unsqueeze(-1)

#

Random indexes

213idx\_rand=torch.randperm(len(labels))

#

Nodes for training

215idx\_train=idx\_rand[:self.training\_samples]

#

Nodes for validation

217idx\_valid=idx\_rand[self.training\_samples:]

#

Training loop

220forepochinmonit.loop(self.epochs):

#

Set the model to training mode

222self.model.train()

#

Make all the gradients zero

224self.optimizer.zero\_grad()

#

Evaluate the model

226output=self.model(features,edges\_adj)

#

Get the loss for training nodes

228loss=self.loss\_func(output[idx\_train],labels[idx\_train])

#

Calculate gradients

230loss.backward()

#

Take optimization step

232self.optimizer.step()

#

Log the loss

234tracker.add('loss.train',loss)

#

Log the accuracy

236tracker.add('accuracy.train',accuracy(output[idx\_train],labels[idx\_train]))

#

Set mode to evaluation mode for validation

239self.model.eval()

#

No need to compute gradients

242withtorch.no\_grad():

#

Evaluate the model again

244output=self.model(features,edges\_adj)

#

Calculate the loss for validation nodes

246loss=self.loss\_func(output[idx\_valid],labels[idx\_valid])

#

Log the loss

248tracker.add('loss.valid',loss)

#

Log the accuracy

250tracker.add('accuracy.valid',accuracy(output[idx\_valid],labels[idx\_valid]))

#

Save logs

253tracker.save()

#

Create Cora dataset

256@option(Configs.dataset)257defcora\_dataset(c:Configs):

#

261returnCoraDataset(c.include\_edges)

#

Get the number of classes

265calculate(Configs.n\_classes,lambdac:len(c.dataset.classes))

#

Number of features in the input

267calculate(Configs.in\_features,lambdac:c.dataset.features.shape[1])

#

Create GAT model

270@option(Configs.model)271defgat\_model(c:Configs):

#

275returnGAT(c.in\_features,c.n\_hidden,c.n\_classes,c.n\_heads,c.dropout).to(c.device)

#

Create configurable optimizer

278@option(Configs.optimizer)279def\_optimizer(c:Configs):

#

283opt\_conf=OptimizerConfigs()284opt\_conf.parameters=c.model.parameters()285returnopt\_conf

#

288defmain():

#

Create configurations

290conf=Configs()

#

Create an experiment

292experiment.create(name='gat')

#

Calculate configurations.

294experiment.configs(conf,{

#

Adam optimizer

296'optimizer.optimizer':'Adam',297'optimizer.learning\_rate':5e-3,298'optimizer.weight\_decay':5e-4,299})

#

Start and watch the experiment

302withexperiment.start():

#

Run the training

304conf.run()

#

308if\_\_name\_\_=='\_\_main\_\_':309main()

labml.ai