Back to Annotated Deep Learning Paper Implementations

Compressive Transformer Experiment

docs/transformers/compressive/experiment.html

latest12.6 KB
Original Source

hometransformerscompressive

View code on Github

#

Compressive Transformer Experiment

This is an annotated PyTorch experiment to train a compressive transformer model.

11fromtypingimportList,Tuple,NamedTuple1213importtorch14importtorch.nnasnn15fromlabmlimportexperiment,tracker,monit,logger16fromlabml.configsimportoption17fromlabml.loggerimportText18fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs19fromlabml\_nn.helpers.metricsimportSimpleStateModule20fromlabml\_nn.helpers.trainerimportBatchIndex21fromlabml\_nn.transformers.compressiveimportCompressiveTransformer,AttentionReconstructionLoss,\22CompressiveTransformerLayer,Conv1dCompression

#

25classCompressedMemory(NamedTuple):26mem:List[torch.Tensor]27c\_mem:List[torch.Tensor]

#

Auto regressive model

30classAutoregressiveModel(nn.Module):

#

35def\_\_init\_\_(self,n\_vocab:int,d\_model:int,transformer:CompressiveTransformer):36super().\_\_init\_\_()

#

Token embedding module

38self.src\_embed=nn.Embedding(n\_vocab,d\_model)

#

Transformer

40self.transformer=transformer

#

Final layer

42self.generator=nn.Linear(d\_model,n\_vocab)

#

Masks

44self.mask\_x=None45self.mask\_mem=None

#

47defforward(self,x:torch.Tensor,mem:CompressedMemory):

#

Get memory and compressed memory

49ifmemisnotNone:50mem,c\_mem=mem.mem,mem.c\_mem51else:52mem=[]53c\_mem=[]

#

Total length of the memory and compressed memory (for masks)

56m\_len=len(mem[0])ifmemelse057ifc\_mem:58m\_len+=len(c\_mem[0])

#

Create a subsequent mask for tokens

61ifself.mask\_xisNoneorself.mask\_x.shape[0]\<len(x):62fromlabml\_nn.transformers.utilsimportsubsequent\_mask63self.mask\_x=subsequent\_mask(len(x)).to(x.device)

#

Create an all ones (full visibility) mask for memory

65ifself.mask\_memisNoneorself.mask\_mem.shape[1]\<m\_lenorself.mask\_mem.shape[0]\<len(x):66self.mask\_mem=self.mask\_x.new\_ones(len(x),m\_len,1)

#

Concatenate the masks if there is memory

69ifm\_len:70mask=torch.cat((self.mask\_mem[:len(x),:m\_len],self.mask\_x[:len(x),:len(x)]),dim=1)

#

Use only the subsequent mask otherwise

72else:73mask=self.mask\_x[:len(x),:len(x)]

#

Token embeddings

76x=self.src\_embed(x)

#

Run it through the transformer

78res,mem=self.transformer(x,mem,c\_mem,mask)

#

Generate logits of the next token

80res=self.generator(res)

#

82returnres,mem

#

Configurations

The default configurations can and will be overridden when we start the experiment.

85classConfigs(NLPAutoRegressionConfigs):

#

92model:AutoregressiveModel

#

Token embedding size

95d\_model:int=128

#

Number of attention heads

97heads:int=4

#

Dropout probability

99dropout:float=0.0

#

Number of features in FFN hidden layer

101d\_ff:int=256

#

Number of transformer layers

103n\_layers:int=6

#

Number of memories to keep

105mem\_len:int=8

#

State module to maintain memories when switching between training and validation

107memory=SimpleStateModule()

#

Attention Reconstruction Loss

109attention\_reconstruction\_loss:AttentionReconstructionLoss

#

Compression rate

111compression\_rate:int=4

#

Compressed memory length

113c\_mem\_len:int=128

#

115definit(self):

#

Set tracker configurations

117tracker.set\_scalar("accuracy.\*",True)118tracker.set\_scalar("loss.\*",True)

#

Do not print the attention reconstruction loss in the terminal

120tracker.set\_scalar("ar\_loss.\*",False)

#

This will keep the accuracy metric stats and memories separate for training and validation.

122self.state\_modules=[self.accuracy,self.memory]

#

Concatenate new memories and compress the oldest memories.

[email protected]\_grad()125defmerge\_compress\_memory(self,mem:CompressedMemory,new\_mem:List[torch.Tensor])\126-\>Tuple[CompressedMemory,List[torch.Tensor]]:

#

If the configurations specify not to use memory

132ifself.mem\_len==0andself.c\_mem\_len==0:133returnCompressedMemory([],[]),[]

#

Get memory and compressed memory

136ifmemisnotNone:137mem,c\_mem=mem.mem,mem.c\_mem138else:139mem,c\_mem=[],[]

#

Concatenate new memories with old memory

142ifmem:143mem=[torch.cat((m,x),dim=0)form,xinzip(mem,new\_mem)]144else:145mem=new\_mem

#

Compress the oldest memories if there are more memories than mem_len

148iflen(mem[0])\>self.mem\_len:

#

Calculate the number of compressed memories to make ncm​=⌈cnm′​−Nm​​⌉, where nm′​ is the number of memories we have and Nm​ is the maximum number of memories we maintain (mem_len ).

152n\_c\_mem=(len(mem[0])-self.mem\_len+self.compression\_rate-1)//self.compression\_rate

#

Number of memories to compress cncm​

154n\_old=n\_c\_mem\*self.compression\_rate

#

A list to keep memories that need to be compressed for each layer.

156mem\_to\_compress=[]

#

A list to keep the memories that do not get compressed for each layer.

158uncompressed\_mem=[]

#

Iterate through memories of each layer.

160forminmem:

#

Split the memories at cncm​

162cm,m=torch.split(m,[n\_old,len(m)-n\_old])

#

Collect memories to compress

164mem\_to\_compress.append(cm)

#

Collect remaining memories

166uncompressed\_mem.append(m)

#

Update the memories

168mem=uncompressed\_mem

#

Compress the memories

171new\_c\_mem=[]172fori,layerinenumerate(self.model.transformer.layers):173new\_c\_mem.append(layer.compress(mem\_to\_compress[i]))

#

Concatenate newly compressed memories with old compressed memories

176ifc\_mem:177c\_mem=[torch.cat((m,nm),dim=0)form,nminzip(c\_mem,new\_c\_mem)]

#

If there are no old compressed memories

179else:180c\_mem=new\_c\_mem

#

Truncate old memories

183iflen(c\_mem[0])\>self.c\_mem\_len:184c\_mem=[m[-self.c\_mem\_len:]forminc\_mem]

#

No memories are compressed if the number of memories is less than mem_len

186else:187mem\_to\_compress=[]

#

Return memories and the memories that were compressed. Memories that were compressed are needed for the reconstruction loss computation.

191returnCompressedMemory(mem,c\_mem),mem\_to\_compress

#

Training/validation step

193defstep(self,batch:any,batch\_idx:BatchIndex):

#

Move data to the device

199data,target=batch[0].to(self.device),batch[1].to(self.device)

#

Update global step (number of tokens processed) when in training mode

202ifself.mode.is\_train:203tracker.add\_global\_step(data.shape[0]\*data.shape[1])

#

Get memories

206mem=self.memory.get()

#

Run the model

208output,new\_mem=self.model(data,mem)

#

Merge and compress memory

210mem,mem\_to\_compress=self.merge\_compress\_memory(mem,new\_mem)

#

Update memories

212self.memory.set(mem)

#

Calculate and log cross entropy loss

215loss=self.loss\_func(output,target)216tracker.add("loss.",loss)

#

Calculate attention reconstruction loss if memories were compressed in this step

219ifmem\_to\_compress:

#

Get attention reconstruction loss

221ar\_loss=self.attention\_reconstruction\_loss(new\_mem,mem\_to\_compress)

#

Track attention reconstruction loss

223tracker.add("ar\_loss.",ar\_loss)

#

Add attention reconstruction loss to loss

225loss=loss+ar\_loss

#

Calculate and log accuracy

228self.accuracy(output,target)229self.accuracy.track()

#

Train the model

232ifself.mode.is\_train:

#

Calculate gradients

234loss.backward()

#

Clip gradients

236torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)

#

Take optimizer step

238self.optimizer.step()

#

Log the model parameters and gradients on last batch of every epoch

240ifbatch\_idx.is\_last:241tracker.add('model',self.model)

#

Clear the gradients

243self.optimizer.zero\_grad()

#

Save the tracked metrics

246tracker.save()

#

Sampling function to generate samples periodically while training

248defsample(self):

#

Starting prompt

254prompt=self.prompt

#

Collect output for printing

256log=[(prompt,Text.subtle)]

#

memory

258mem=CompressedMemory([],[])

#

Sample 25 tokens

260foriinmonit.iterate('Sample',25):

#

Tokenize the prompt

262data=self.text.text\_to\_i(prompt).unsqueeze(-1)

#

Move to device

264data=data.to(self.device)

#

Get the model output

266output,new\_mem=self.model(data,mem)

#

Get the model prediction (greedy)

268output=output.argmax(dim=-1).squeeze(1)

#

Add the prediction to prompt

270prompt+=self.prompt\_separator+self.text.itos[output[-1]]

#

Only feed the last character to model in next iteration, rest will go in as memories

272prompt=prompt[-1:]

#

Add the prediction for logging

274log+=[(self.prompt\_separator+self.text.itos[output[-1]],Text.value)]

#

Update and compress memory

276mem,\_=self.merge\_compress\_memory(mem,new\_mem)

#

Print the sampled output

279logger.log(log)

#

Initialize the auto-regressive model

282@option(Configs.model)283defautoregressive\_model(c:Configs):

#

287fromlabml\_nn.transformers.xlimportRelativeMultiHeadAttention288fromlabml\_nn.transformers.feed\_forwardimportFeedForward289m=AutoregressiveModel(c.n\_tokens,c.d\_model,CompressiveTransformer(290CompressiveTransformerLayer(d\_model=c.d\_model,291self\_attn=RelativeMultiHeadAttention(c.heads,c.d\_model,c.dropout),292feed\_forward=FeedForward(c.d\_model,c.d\_ff,c.dropout),293dropout\_prob=c.dropout,294compress=Conv1dCompression(c.compression\_rate,c.d\_model)),c.n\_layers))295returnm.to(c.device)

#

Initialize the attention reconstruction loss

298@option(Configs.attention\_reconstruction\_loss)299defattention\_reconstruction\_loss(c:Configs):

#

303returnAttentionReconstructionLoss(c.model.transformer.layers)

#

Run the experiment

306defmain():

#

Create experiment

311experiment.create(name="compressive\_transformer",comment='')

#

Create configs

313conf=Configs()

#

Load configurations

315experiment.configs(conf,

#

A dictionary of configurations to override

317{'tokenizer':'character',318'text':'tiny\_shakespeare',319'optimizer.learning\_rate':2.5e-4,320'optimizer.optimizer':'AdamW',321'prompt':'It is',322'prompt\_separator':'',323324'train\_loader':'sequential\_train\_loader',325'valid\_loader':'sequential\_valid\_loader',326327'seq\_len':8,328'mem\_len':8,329'epochs':128,330'batch\_size':32,331'inner\_iterations':25,332'compression\_rate':2,333})

#

Set models for saving and loading

336experiment.add\_pytorch\_models({'model':conf.model})

#

Start the experiment

339withexperiment.start():

#

TrainValidConfigs.run

341conf.run()

#

345if\_\_name\_\_=='\_\_main\_\_':346main()

labml.ai