docs/transformers/xl/experiment.html
This is an annotated PyTorch experiment to train a transformer xl model.
11fromtypingimportList1213importtorch14importtorch.nnasnn15fromlabmlimportexperiment,tracker,monit,logger16fromlabml.configsimportoption17fromlabml.loggerimportText18fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs19fromlabml\_nn.helpers.metricsimportSimpleStateModule20fromlabml\_nn.helpers.trainerimportBatchIndex21fromlabml\_nn.transformers.xlimportTransformerXL,TransformerXLLayer
24classAutoregressiveModel(nn.Module):
29def\_\_init\_\_(self,n\_vocab:int,d\_model:int,transformer:TransformerXL):30super().\_\_init\_\_()
Token embedding module
32self.src\_embed=nn.Embedding(n\_vocab,d\_model)
Transformer
34self.transformer=transformer
Final layer
36self.generator=nn.Linear(d\_model,n\_vocab)
Masks
38self.mask\_x=None39self.mask\_mem=None
41defforward(self,x:torch.Tensor,mem:List[torch.Tensor]):
Length of the memory
43m\_len=len(mem[0])ifmemelse0
Create a subsequent mask for tokens
45ifself.mask\_xisNoneorself.mask\_x.shape[0]\<len(x):46fromlabml\_nn.transformers.utilsimportsubsequent\_mask47self.mask\_x=subsequent\_mask(len(x)).to(x.device)
Create an all ones (full visibility) mask for memory
49ifself.mask\_memisNoneorself.mask\_mem.shape[1]\<m\_lenorself.mask\_mem.shape[0]\<len(x):50self.mask\_mem=self.mask\_x.new\_ones(len(x),m\_len,1)
Concatenate the masks if there is memory
53ifm\_len:54mask=torch.cat((self.mask\_mem[:len(x),:m\_len],self.mask\_x[:len(x),:len(x)]),dim=1)
Use the subsequent mask otherwise
56else:57mask=self.mask\_x[:len(x),:len(x)]
Token embeddings
60x=self.src\_embed(x)
Run it through the transformer
62res,mem=self.transformer(x,mem,mask)
Generate logits of the next token
64res=self.generator(res)
66returnres,mem
The default configs can and will be over-ridden when we start the experiment
69classConfigs(NLPAutoRegressionConfigs):
76model:AutoregressiveModel
Token embedding size
79d\_model:int=128
Number of attention heads
81heads:int=4
Dropout probability
83dropout:float=0.0
Number of features in FFN hidden layer
85d\_ff:int=256
Number of transformer layers
87n\_layers:int=6
Number of memories to keep
89mem\_len:int=128
State module to maintain memories when switching between training and validation
91memory=SimpleStateModule()
93definit(self):
Set tracker configurations
95tracker.set\_scalar("accuracy.\*",True)96tracker.set\_scalar("loss.\*",True)
This will keep the accuracy metric stats and memories separate for training and validation.
98self.state\_modules=[self.accuracy,self.memory]
Concatenate memories and remove old memories to keep a maximum of mem_len memories.
100defmerge\_memory(self,old\_mem,new\_mem):
If it's configured not to use memory
107ifself.mem\_len==0:108return[]
Concatenate with old memory
111ifold\_mem:112mem=[torch.cat((m,x),dim=0)form,xinzip(old\_mem,new\_mem)]113else:114mem=new\_mem
Truncate old memories
117iflen(mem[0])\>self.mem\_len:118mem=[m[-self.mem\_len:]forminmem]
121returnmem
123defstep(self,batch:any,batch\_idx:BatchIndex):
Move data to the device
129data,target=batch[0].to(self.device),batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
132ifself.mode.is\_train:133tracker.add\_global\_step(data.shape[0]\*data.shape[1])
Get memories
136mem=self.memory.get()
Run the model
138output,new\_mem=self.model(data,mem)
Merge memory
140mem=self.merge\_memory(mem,new\_mem)
Update memories
142self.memory.set(mem)
Calculate and log cross entropy loss
145loss=self.loss\_func(output,target)146tracker.add("loss.",loss)
Calculate and log accuracy
149self.accuracy(output,target)150self.accuracy.track()
Train the model
153ifself.mode.is\_train:
Calculate gradients
155loss.backward()
Clip gradients
157torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)
Take optimizer step
159self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
161ifbatch\_idx.is\_last:162tracker.add('model',self.model)
Clear the gradients
164self.optimizer.zero\_grad()
Save the tracked metrics
167tracker.save()
169defsample(self):
Starting prompt
175prompt=self.prompt
Collect output for printing
177log=[(prompt,Text.subtle)]
memory
179mem=[]
Sample 25 tokens
181foriinmonit.iterate('Sample',25):
Tokenize the prompt
183data=self.text.text\_to\_i(prompt).unsqueeze(-1)
Move to device
185data=data.to(self.device)
Get the model output
187output,new\_mem=self.model(data,mem)
Get the model prediction (greedy)
189output=output.argmax(dim=-1).squeeze(1)
Add the prediction to prompt
191prompt+=self.prompt\_separator+self.text.itos[output[-1]]
Only feed the last character to model in next iteration, rest will go in as memories
193prompt=prompt[-1:]
Add the prediction for logging
195log+=[(self.prompt\_separator+self.text.itos[output[-1]],Text.value)]
Update memory
197mem=self.merge\_memory(mem,new\_mem)
Print the sampled output
200logger.log(log)
203@option(Configs.model)204defautoregressive\_model(c:Configs):
208fromlabml\_nn.transformers.xlimportRelativeMultiHeadAttention209fromlabml\_nn.transformers.feed\_forwardimportFeedForward210m=AutoregressiveModel(c.n\_tokens,c.d\_model,TransformerXL(211TransformerXLLayer(d\_model=c.d\_model,212self\_attn=RelativeMultiHeadAttention(c.heads,c.d\_model,c.dropout),213feed\_forward=FeedForward(c.d\_model,c.d\_ff,c.dropout),214dropout\_prob=c.dropout),c.n\_layers))215returnm.to(c.device)
218defmain():
Create experiment
223experiment.create(name="transformer\_xl",comment='')
Create configs
225conf=Configs()
Load configurations
227experiment.configs(conf,
A dictionary of configurations to override
229{'tokenizer':'character',230'text':'tiny\_shakespeare',231'optimizer.learning\_rate':1.,232'optimizer.optimizer':'Noam',233'prompt':'It is',234'prompt\_separator':'',235236'train\_loader':'sequential\_train\_loader',237'valid\_loader':'sequential\_valid\_loader',238239'seq\_len':2,240'mem\_len':32,241'epochs':128,242'batch\_size':32,243'inner\_iterations':25,244})
Set models for saving and loading
247experiment.add\_pytorch\_models({'model':conf.model})
Start the experiment
250withexperiment.start():
TrainValidConfigs.run
252conf.run()
256if\_\_name\_\_=='\_\_main\_\_':257main()