Back to Annotated Deep Learning Paper Implementations

Transformer XL Experiment

docs/transformers/xl/experiment.html

latest8.6 KB
Original Source

hometransformersxl

View code on Github

#

Transformer XL Experiment

This is an annotated PyTorch experiment to train a transformer xl model.

11fromtypingimportList1213importtorch14importtorch.nnasnn15fromlabmlimportexperiment,tracker,monit,logger16fromlabml.configsimportoption17fromlabml.loggerimportText18fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs19fromlabml\_nn.helpers.metricsimportSimpleStateModule20fromlabml\_nn.helpers.trainerimportBatchIndex21fromlabml\_nn.transformers.xlimportTransformerXL,TransformerXLLayer

#

Auto regressive model

24classAutoregressiveModel(nn.Module):

#

29def\_\_init\_\_(self,n\_vocab:int,d\_model:int,transformer:TransformerXL):30super().\_\_init\_\_()

#

Token embedding module

32self.src\_embed=nn.Embedding(n\_vocab,d\_model)

#

Transformer

34self.transformer=transformer

#

Final layer

36self.generator=nn.Linear(d\_model,n\_vocab)

#

Masks

38self.mask\_x=None39self.mask\_mem=None

#

41defforward(self,x:torch.Tensor,mem:List[torch.Tensor]):

#

Length of the memory

43m\_len=len(mem[0])ifmemelse0

#

Create a subsequent mask for tokens

45ifself.mask\_xisNoneorself.mask\_x.shape[0]\<len(x):46fromlabml\_nn.transformers.utilsimportsubsequent\_mask47self.mask\_x=subsequent\_mask(len(x)).to(x.device)

#

Create an all ones (full visibility) mask for memory

49ifself.mask\_memisNoneorself.mask\_mem.shape[1]\<m\_lenorself.mask\_mem.shape[0]\<len(x):50self.mask\_mem=self.mask\_x.new\_ones(len(x),m\_len,1)

#

Concatenate the masks if there is memory

53ifm\_len:54mask=torch.cat((self.mask\_mem[:len(x),:m\_len],self.mask\_x[:len(x),:len(x)]),dim=1)

#

Use the subsequent mask otherwise

56else:57mask=self.mask\_x[:len(x),:len(x)]

#

Token embeddings

60x=self.src\_embed(x)

#

Run it through the transformer

62res,mem=self.transformer(x,mem,mask)

#

Generate logits of the next token

64res=self.generator(res)

#

66returnres,mem

#

Configurations

The default configs can and will be over-ridden when we start the experiment

69classConfigs(NLPAutoRegressionConfigs):

#

76model:AutoregressiveModel

#

Token embedding size

79d\_model:int=128

#

Number of attention heads

81heads:int=4

#

Dropout probability

83dropout:float=0.0

#

Number of features in FFN hidden layer

85d\_ff:int=256

#

Number of transformer layers

87n\_layers:int=6

#

Number of memories to keep

89mem\_len:int=128

#

State module to maintain memories when switching between training and validation

91memory=SimpleStateModule()

#

93definit(self):

#

Set tracker configurations

95tracker.set\_scalar("accuracy.\*",True)96tracker.set\_scalar("loss.\*",True)

#

This will keep the accuracy metric stats and memories separate for training and validation.

98self.state\_modules=[self.accuracy,self.memory]

#

Concatenate memories and remove old memories to keep a maximum of mem_len memories.

100defmerge\_memory(self,old\_mem,new\_mem):

#

If it's configured not to use memory

107ifself.mem\_len==0:108return[]

#

Concatenate with old memory

111ifold\_mem:112mem=[torch.cat((m,x),dim=0)form,xinzip(old\_mem,new\_mem)]113else:114mem=new\_mem

#

Truncate old memories

117iflen(mem[0])\>self.mem\_len:118mem=[m[-self.mem\_len:]forminmem]

#

121returnmem

#

Training/validation step

123defstep(self,batch:any,batch\_idx:BatchIndex):

#

Move data to the device

129data,target=batch[0].to(self.device),batch[1].to(self.device)

#

Update global step (number of tokens processed) when in training mode

132ifself.mode.is\_train:133tracker.add\_global\_step(data.shape[0]\*data.shape[1])

#

Get memories

136mem=self.memory.get()

#

Run the model

138output,new\_mem=self.model(data,mem)

#

Merge memory

140mem=self.merge\_memory(mem,new\_mem)

#

Update memories

142self.memory.set(mem)

#

Calculate and log cross entropy loss

145loss=self.loss\_func(output,target)146tracker.add("loss.",loss)

#

Calculate and log accuracy

149self.accuracy(output,target)150self.accuracy.track()

#

Train the model

153ifself.mode.is\_train:

#

Calculate gradients

155loss.backward()

#

Clip gradients

157torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)

#

Take optimizer step

159self.optimizer.step()

#

Log the model parameters and gradients on last batch of every epoch

161ifbatch\_idx.is\_last:162tracker.add('model',self.model)

#

Clear the gradients

164self.optimizer.zero\_grad()

#

Save the tracked metrics

167tracker.save()

#

Sampling function to generate samples periodically while training

169defsample(self):

#

Starting prompt

175prompt=self.prompt

#

Collect output for printing

177log=[(prompt,Text.subtle)]

#

memory

179mem=[]

#

Sample 25 tokens

181foriinmonit.iterate('Sample',25):

#

Tokenize the prompt

183data=self.text.text\_to\_i(prompt).unsqueeze(-1)

#

Move to device

185data=data.to(self.device)

#

Get the model output

187output,new\_mem=self.model(data,mem)

#

Get the model prediction (greedy)

189output=output.argmax(dim=-1).squeeze(1)

#

Add the prediction to prompt

191prompt+=self.prompt\_separator+self.text.itos[output[-1]]

#

Only feed the last character to model in next iteration, rest will go in as memories

193prompt=prompt[-1:]

#

Add the prediction for logging

195log+=[(self.prompt\_separator+self.text.itos[output[-1]],Text.value)]

#

Update memory

197mem=self.merge\_memory(mem,new\_mem)

#

Print the sampled output

200logger.log(log)

#

Initialize the auto-regressive model

203@option(Configs.model)204defautoregressive\_model(c:Configs):

#

208fromlabml\_nn.transformers.xlimportRelativeMultiHeadAttention209fromlabml\_nn.transformers.feed\_forwardimportFeedForward210m=AutoregressiveModel(c.n\_tokens,c.d\_model,TransformerXL(211TransformerXLLayer(d\_model=c.d\_model,212self\_attn=RelativeMultiHeadAttention(c.heads,c.d\_model,c.dropout),213feed\_forward=FeedForward(c.d\_model,c.d\_ff,c.dropout),214dropout\_prob=c.dropout),c.n\_layers))215returnm.to(c.device)

#

Run the experiment

218defmain():

#

Create experiment

223experiment.create(name="transformer\_xl",comment='')

#

Create configs

225conf=Configs()

#

Load configurations

227experiment.configs(conf,

#

A dictionary of configurations to override

229{'tokenizer':'character',230'text':'tiny\_shakespeare',231'optimizer.learning\_rate':1.,232'optimizer.optimizer':'Noam',233'prompt':'It is',234'prompt\_separator':'',235236'train\_loader':'sequential\_train\_loader',237'valid\_loader':'sequential\_valid\_loader',238239'seq\_len':2,240'mem\_len':32,241'epochs':128,242'batch\_size':32,243'inner\_iterations':25,244})

#

Set models for saving and loading

247experiment.add\_pytorch\_models({'model':conf.model})

#

Start the experiment

250withexperiment.start():

#

TrainValidConfigs.run

252conf.run()

#

256if\_\_name\_\_=='\_\_main\_\_':257main()

labml.ai