docs/transformers/compressive/experiment.html
This is an annotated PyTorch experiment to train a compressive transformer model.
11fromtypingimportList,Tuple,NamedTuple1213importtorch14importtorch.nnasnn15fromlabmlimportexperiment,tracker,monit,logger16fromlabml.configsimportoption17fromlabml.loggerimportText18fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs19fromlabml\_nn.helpers.metricsimportSimpleStateModule20fromlabml\_nn.helpers.trainerimportBatchIndex21fromlabml\_nn.transformers.compressiveimportCompressiveTransformer,AttentionReconstructionLoss,\22CompressiveTransformerLayer,Conv1dCompression
25classCompressedMemory(NamedTuple):26mem:List[torch.Tensor]27c\_mem:List[torch.Tensor]
30classAutoregressiveModel(nn.Module):
35def\_\_init\_\_(self,n\_vocab:int,d\_model:int,transformer:CompressiveTransformer):36super().\_\_init\_\_()
Token embedding module
38self.src\_embed=nn.Embedding(n\_vocab,d\_model)
Transformer
40self.transformer=transformer
Final layer
42self.generator=nn.Linear(d\_model,n\_vocab)
Masks
44self.mask\_x=None45self.mask\_mem=None
47defforward(self,x:torch.Tensor,mem:CompressedMemory):
Get memory and compressed memory
49ifmemisnotNone:50mem,c\_mem=mem.mem,mem.c\_mem51else:52mem=[]53c\_mem=[]
Total length of the memory and compressed memory (for masks)
56m\_len=len(mem[0])ifmemelse057ifc\_mem:58m\_len+=len(c\_mem[0])
Create a subsequent mask for tokens
61ifself.mask\_xisNoneorself.mask\_x.shape[0]\<len(x):62fromlabml\_nn.transformers.utilsimportsubsequent\_mask63self.mask\_x=subsequent\_mask(len(x)).to(x.device)
Create an all ones (full visibility) mask for memory
65ifself.mask\_memisNoneorself.mask\_mem.shape[1]\<m\_lenorself.mask\_mem.shape[0]\<len(x):66self.mask\_mem=self.mask\_x.new\_ones(len(x),m\_len,1)
Concatenate the masks if there is memory
69ifm\_len:70mask=torch.cat((self.mask\_mem[:len(x),:m\_len],self.mask\_x[:len(x),:len(x)]),dim=1)
Use only the subsequent mask otherwise
72else:73mask=self.mask\_x[:len(x),:len(x)]
Token embeddings
76x=self.src\_embed(x)
Run it through the transformer
78res,mem=self.transformer(x,mem,c\_mem,mask)
Generate logits of the next token
80res=self.generator(res)
82returnres,mem
The default configurations can and will be overridden when we start the experiment.
85classConfigs(NLPAutoRegressionConfigs):
92model:AutoregressiveModel
Token embedding size
95d\_model:int=128
Number of attention heads
97heads:int=4
Dropout probability
99dropout:float=0.0
Number of features in FFN hidden layer
101d\_ff:int=256
Number of transformer layers
103n\_layers:int=6
Number of memories to keep
105mem\_len:int=8
State module to maintain memories when switching between training and validation
107memory=SimpleStateModule()
Attention Reconstruction Loss
109attention\_reconstruction\_loss:AttentionReconstructionLoss
Compression rate
111compression\_rate:int=4
Compressed memory length
113c\_mem\_len:int=128
115definit(self):
Set tracker configurations
117tracker.set\_scalar("accuracy.\*",True)118tracker.set\_scalar("loss.\*",True)
Do not print the attention reconstruction loss in the terminal
120tracker.set\_scalar("ar\_loss.\*",False)
This will keep the accuracy metric stats and memories separate for training and validation.
122self.state\_modules=[self.accuracy,self.memory]
Concatenate new memories and compress the oldest memories.
[email protected]\_grad()125defmerge\_compress\_memory(self,mem:CompressedMemory,new\_mem:List[torch.Tensor])\126-\>Tuple[CompressedMemory,List[torch.Tensor]]:
If the configurations specify not to use memory
132ifself.mem\_len==0andself.c\_mem\_len==0:133returnCompressedMemory([],[]),[]
Get memory and compressed memory
136ifmemisnotNone:137mem,c\_mem=mem.mem,mem.c\_mem138else:139mem,c\_mem=[],[]
Concatenate new memories with old memory
142ifmem:143mem=[torch.cat((m,x),dim=0)form,xinzip(mem,new\_mem)]144else:145mem=new\_mem
Compress the oldest memories if there are more memories than mem_len
148iflen(mem[0])\>self.mem\_len:
Calculate the number of compressed memories to make ncm=⌈cnm′−Nm⌉, where nm′ is the number of memories we have and Nm is the maximum number of memories we maintain (mem_len ).
152n\_c\_mem=(len(mem[0])-self.mem\_len+self.compression\_rate-1)//self.compression\_rate
Number of memories to compress cncm
154n\_old=n\_c\_mem\*self.compression\_rate
A list to keep memories that need to be compressed for each layer.
156mem\_to\_compress=[]
A list to keep the memories that do not get compressed for each layer.
158uncompressed\_mem=[]
Iterate through memories of each layer.
160forminmem:
Split the memories at cncm
162cm,m=torch.split(m,[n\_old,len(m)-n\_old])
Collect memories to compress
164mem\_to\_compress.append(cm)
Collect remaining memories
166uncompressed\_mem.append(m)
Update the memories
168mem=uncompressed\_mem
Compress the memories
171new\_c\_mem=[]172fori,layerinenumerate(self.model.transformer.layers):173new\_c\_mem.append(layer.compress(mem\_to\_compress[i]))
Concatenate newly compressed memories with old compressed memories
176ifc\_mem:177c\_mem=[torch.cat((m,nm),dim=0)form,nminzip(c\_mem,new\_c\_mem)]
If there are no old compressed memories
179else:180c\_mem=new\_c\_mem
Truncate old memories
183iflen(c\_mem[0])\>self.c\_mem\_len:184c\_mem=[m[-self.c\_mem\_len:]forminc\_mem]
No memories are compressed if the number of memories is less than mem_len
186else:187mem\_to\_compress=[]
Return memories and the memories that were compressed. Memories that were compressed are needed for the reconstruction loss computation.
191returnCompressedMemory(mem,c\_mem),mem\_to\_compress
193defstep(self,batch:any,batch\_idx:BatchIndex):
Move data to the device
199data,target=batch[0].to(self.device),batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
202ifself.mode.is\_train:203tracker.add\_global\_step(data.shape[0]\*data.shape[1])
Get memories
206mem=self.memory.get()
Run the model
208output,new\_mem=self.model(data,mem)
Merge and compress memory
210mem,mem\_to\_compress=self.merge\_compress\_memory(mem,new\_mem)
Update memories
212self.memory.set(mem)
Calculate and log cross entropy loss
215loss=self.loss\_func(output,target)216tracker.add("loss.",loss)
Calculate attention reconstruction loss if memories were compressed in this step
219ifmem\_to\_compress:
Get attention reconstruction loss
221ar\_loss=self.attention\_reconstruction\_loss(new\_mem,mem\_to\_compress)
Track attention reconstruction loss
223tracker.add("ar\_loss.",ar\_loss)
Add attention reconstruction loss to loss
225loss=loss+ar\_loss
Calculate and log accuracy
228self.accuracy(output,target)229self.accuracy.track()
Train the model
232ifself.mode.is\_train:
Calculate gradients
234loss.backward()
Clip gradients
236torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)
Take optimizer step
238self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
240ifbatch\_idx.is\_last:241tracker.add('model',self.model)
Clear the gradients
243self.optimizer.zero\_grad()
Save the tracked metrics
246tracker.save()
248defsample(self):
Starting prompt
254prompt=self.prompt
Collect output for printing
256log=[(prompt,Text.subtle)]
memory
258mem=CompressedMemory([],[])
Sample 25 tokens
260foriinmonit.iterate('Sample',25):
Tokenize the prompt
262data=self.text.text\_to\_i(prompt).unsqueeze(-1)
Move to device
264data=data.to(self.device)
Get the model output
266output,new\_mem=self.model(data,mem)
Get the model prediction (greedy)
268output=output.argmax(dim=-1).squeeze(1)
Add the prediction to prompt
270prompt+=self.prompt\_separator+self.text.itos[output[-1]]
Only feed the last character to model in next iteration, rest will go in as memories
272prompt=prompt[-1:]
Add the prediction for logging
274log+=[(self.prompt\_separator+self.text.itos[output[-1]],Text.value)]
Update and compress memory
276mem,\_=self.merge\_compress\_memory(mem,new\_mem)
Print the sampled output
279logger.log(log)
282@option(Configs.model)283defautoregressive\_model(c:Configs):
287fromlabml\_nn.transformers.xlimportRelativeMultiHeadAttention288fromlabml\_nn.transformers.feed\_forwardimportFeedForward289m=AutoregressiveModel(c.n\_tokens,c.d\_model,CompressiveTransformer(290CompressiveTransformerLayer(d\_model=c.d\_model,291self\_attn=RelativeMultiHeadAttention(c.heads,c.d\_model,c.dropout),292feed\_forward=FeedForward(c.d\_model,c.d\_ff,c.dropout),293dropout\_prob=c.dropout,294compress=Conv1dCompression(c.compression\_rate,c.d\_model)),c.n\_layers))295returnm.to(c.device)
298@option(Configs.attention\_reconstruction\_loss)299defattention\_reconstruction\_loss(c:Configs):
303returnAttentionReconstructionLoss(c.model.transformer.layers)
306defmain():
Create experiment
311experiment.create(name="compressive\_transformer",comment='')
Create configs
313conf=Configs()
Load configurations
315experiment.configs(conf,
A dictionary of configurations to override
317{'tokenizer':'character',318'text':'tiny\_shakespeare',319'optimizer.learning\_rate':2.5e-4,320'optimizer.optimizer':'AdamW',321'prompt':'It is',322'prompt\_separator':'',323324'train\_loader':'sequential\_train\_loader',325'valid\_loader':'sequential\_valid\_loader',326327'seq\_len':8,328'mem\_len':8,329'epochs':128,330'batch\_size':32,331'inner\_iterations':25,332'compression\_rate':2,333})
Set models for saving and loading
336experiment.add\_pytorch\_models({'model':conf.model})
Start the experiment
339withexperiment.start():
TrainValidConfigs.run
341conf.run()
345if\_\_name\_\_=='\_\_main\_\_':346main()