docs/transformers/compressive/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/compressive/ init.py)
This is an implementation of Compressive Transformers for Long-Range Sequence Modelling in PyTorch.
This is an extension of Transformer XL where past memories are compressed to give a longer attention range. That is, the furthest ncmc memories are compressed into ncm memories, where c is the compression rate.
The compression operation is defined as fc:Rnc×d→Rn×d. The paper introduces multiple choices for fc and we have only implemented 1D convolution which seems to give the best results. Each layer has a separate compression operation fc(i) where i is the layer number.
Since training compression with BPTT requires maintaining a very large computational graph (many time steps), the paper proposes an auto-encoding loss and an attention reconstruction loss. The auto-encoding loss decodes the original memories from the compressed memories and calculates the loss. Attention reconstruction loss computes the multi-headed attention results on the compressed memory and on uncompressed memory and gets a mean squared error between them. We have implemented the latter here since it gives better results.
This implementation uses pre-layer normalization while the paper uses post-layer normalization. Pre-layer norm does the layer norm before FFN and self-attention, and the pass-through in the residual connection is not normalized. This is supposed to be more stable in standard transformer setups.
Here are the training code and a notebook for training a compressive transformer model on the Tiny Shakespeare dataset.
53fromtypingimportOptional,List5455importtorch56importtorch.nn.functionalasF57fromtorchimportnn5859fromlabml\_nn.transformers.feed\_forwardimportFeedForward60fromlabml\_nn.transformers.mhaimportPrepareForMultiHeadAttention61fromlabml\_nn.transformers.xl.relative\_mhaimportRelativeMultiHeadAttention62fromlabml\_nn.utilsimportclone\_module\_list
This is a simple wrapper around nn.Conv1d with some tensor dimension permutations.
65classConv1dCompression(nn.Module):
compression_ratecd_model is the embedding size73def\_\_init\_\_(self,compression\_rate:int,d\_model:int):
78super().\_\_init\_\_()79self.conv=nn.Conv1d(d\_model,d\_model,kernel\_size=compression\_rate,stride=compression\_rate)
mem has shape [seq_len, batch, d_model]
81defforward(self,mem:torch.Tensor):
Permute the dimensions of mem so that we can run it through the convolution layer. The convolution layer accepts in the form [batch, features, sequence]
88mem=mem.permute(1,2,0)
Get compressed memory by running it through the convolution layer
90c\_mem=self.conv(mem)
Permute back to form [seq_len, batch, d_model]
92returnc\_mem.permute(2,0,1)
This is the implementation of a single compressive transformer layer
95classCompressiveTransformerLayer(nn.Module):
d_model is the token embedding sizeself_attn is the self attention modulefeed_forward is the feed forward moduledropout_prob is the probability of dropping out after self attention and FFNcompress is the compression function fc101def\_\_init\_\_(self,\*,102d\_model:int,103self\_attn:RelativeMultiHeadAttention,104feed\_forward:FeedForward,105dropout\_prob:float,106compress:Conv1dCompression):
114super().\_\_init\_\_()115self.compress=compress116self.size=d\_model117self.self\_attn=self\_attn118self.feed\_forward=feed\_forward119self.dropout=nn.Dropout(dropout\_prob)120self.norm\_self\_attn=nn.LayerNorm([d\_model])121self.norm\_ff=nn.LayerNorm([d\_model])
Concatenate the normalized token embeddings with memory and compressed memory.
z is layer normalized token embeddings.mem and c_mem are memory and compressed memory (not normalized).123defconcat\_memory(self,z:torch.Tensor,mem:Optional[torch.Tensor],c\_mem:Optional[torch.Tensor]):
If there is no memory just return the token embeddings
132ifmemisNone:133returnz
If there are compressed memory concatenate that with memory
136ifc\_memisnotNone:137mem=torch.cat((c\_mem,mem),dim=0)
Run the memory through the normalization layer
140mem=self.norm\_self\_attn(mem)
Concatenate normalized memory and normalized token embeddings
142returntorch.cat((mem,z),dim=0)
x is a tensor of token level feature vectors of shape [seq_len, batch_size, d_model]mem is a tensor of the past token level feature vectors (memory) of shape [mem_len, batch_size, d_model]c_mem is a tensor of the compressed memory [c_mem_len, batch_size, d_model]mask is a matrix of shape [seq_len, c_mem_len + mem_len + seq_len, batch_size] or [seq_len, c_mem_len + mem_len + seq_len, 1] . mask[i, j] is true if token at i can see token at j .144defforward(self,\*,145x:torch.Tensor,146mem:Optional[torch.Tensor],147c\_mem:Optional[torch.Tensor],148mask:torch.Tensor):
Normalize the vectors before doing self attention
158z=self.norm\_self\_attn(x)
Normalize and concatenate memory and compressed memory
160m\_z=self.concat\_memory(z,mem,c\_mem)
Attention
162self\_attn=self.self\_attn(query=z,key=m\_z,value=m\_z,mask=mask)
Add the attention results
164x=x+self.dropout(self\_attn)
Normalize for feed-forward
167z=self.norm\_ff(x)
Pass through the feed-forward network
169ff=self.feed\_forward(z)
Add the feed-forward results back
171x=x+self.dropout(ff)
174returnx
This consists of multiple compressive transformer layers
177classCompressiveTransformer(nn.Module):
184def\_\_init\_\_(self,layer:CompressiveTransformerLayer,n\_layers:int):185super().\_\_init\_\_()
Make copies of the transformer layer
187self.layers=clone\_module\_list(layer,n\_layers)
Final normalization layer
189self.norm=nn.LayerNorm([layer.size])
x is a tensor of the token embeddings vectors of shape [seq_len, batch_size, d_model]mem is a list of tensors of the past token level feature vectors of shape [mem_len, batch_size, d_model] for each layerc_mem is a list of tensors of the compressed memory [c_mem_len, batch_size, d_model] for each layermask is the masking matrix191defforward(self,x:torch.Tensor,mem:List[torch.Tensor],c\_mem:List[torch.Tensor],mask:torch.Tensor):
List to store token level feature vectors, which will become the memories for the next sequential batch.
202new\_mem=[]
Run through each transformer layer
204fori,layerinenumerate(self.layers):
Add to the list of feature vectors
206new\_mem.append(x.detach())
Memory
208m=mem[i]ifmemelseNone
Compressed Memory
210cm=c\_mem[i]ifc\_memelseNone
Run through the transformer XL layer
212x=layer(x=x,mem=m,c\_mem=cm,mask=mask)
Finally, normalize the vectors
214returnself.norm(x),new\_mem
Attention reconstruction loss recreates the self-attention output with uncompressed memory and with compressed memory and calculates the mean squared error between the two. It does this without positional encoding.
When calculating and training the compression function fc with attention reconstruction loss, all parameters but fc are frozen. This includes key/value projections and bias/scaling after normalization.
Since this loss can be computed independently of the cross-entropy-loss of the model you can have a separate optimizer that only updates fc. However, we use the same optimizer to update fc so when calculating attention reconstruction loss, we detach all other parameters except fc from the gradient computation.
217classAttentionReconstructionLoss:
layers is the list of Compressive Transformer layers
235def\_\_init\_\_(self,layers:nn.ModuleList):
239self.layers=layers240self.loss\_func=nn.MSELoss()
This is a reimplementation of 'PrepareForMultiHeadAttention' where the projections are done with the parameters detached from gradient computation.
pmha is the 'PrepareForMultiHeadAttention' modulex is tensor with the token embeddings242defprepare\_for\_attn(self,pmha:PrepareForMultiHeadAttention,x:torch.Tensor):
Shape of the input except embedding dimension; [seq_len, batch_size] .
252head\_shape=x.shape[:-1]
Detach projection weights and bias
255weight=pmha.linear.weight.detach()256bias=pmha.linear.bias.detach()ifpmha.linear.biasisnotNoneelseNone
Linear transform
258x=F.linear(x,weight,bias)
Split last dimension into heads
261x=x.view(\*head\_shape,pmha.heads,pmha.d\_k)
Output has shape [seq_len, batch_size, heads, d_k] or [batch_size, d_model]
264returnx
This is a reimplementation of 'Multi-Head Attention' which calls prepare_for_attn instead of 'PrepareForMultiHeadAttention' to detach projection parameters.
266defattn(self,layer:RelativeMultiHeadAttention,query:torch.Tensor,key:torch.Tensor,value:torch.Tensor):
Calculate query, key and value projections
273query=self.prepare\_for\_attn(layer.query,query)274key=self.prepare\_for\_attn(layer.key,key)275value=self.prepare\_for\_attn(layer.value,value)
Compute attention scores QK⊤. This gives a tensor of shape [seq_len, seq_len, batch_size, heads] .
279scores=torch.einsum('ibhd,jbhd-\>ijbh',query,key)
Scale scores dkQK⊤
282scores\*=layer.scale
softmax attention along the key sequence dimension seqsoftmax(dkQK⊤)
286attn=layer.softmax(scores)
Multiply by values seqsoftmax(dkQK⊤)V
290returntorch.einsum("ijbh,jbhd-\>ibhd",attn,value)
Perform layer normalization with shift and scale parameters detached.
292defnorm(self,ln:nn.LayerNorm,x:torch.Tensor):
Detach shift(bias ) and scaling(weight ) parameters
298weight=ln.weight.detach()ifln.weightisnotNoneelseNone299bias=ln.bias.detach()ifln.biasisnotNoneelseNone
Layer normalization
302returnF.layer\_norm(x,ln.normalized\_shape,weight,bias,ln.eps)
This calculates the loss for a layer
304defcalc\_loss(self,layer:CompressiveTransformerLayer,h:torch.Tensor,mem:torch.Tensor):
Detach the token embeddings and memory.
310h=h.detach()311mem=mem.detach()
Compress the memory with fc(i). The parameters of fc(i) are the only parameters not detached from gradient computation.
315c\_mem=layer.compress(mem)
Normalize the embeddings and memories
318h=self.norm(layer.norm\_self\_attn,h)319mem=self.norm(layer.norm\_self\_attn,mem)320c\_mem=self.norm(layer.norm\_self\_attn,c\_mem)
Calculate the attention with uncompressed memory
323attn\_mem=self.attn(layer.self\_attn,h,mem,mem)
Calculate the attention with compressed memory
325attn\_cmem=self.attn(layer.self\_attn,h,c\_mem,c\_mem)
Calculate the mean square error
328returnself.loss\_func(attn\_cmem,attn\_mem)
330def\_\_call\_\_(self,h:List[torch.Tensor],mem:List[torch.Tensor]):
Calculate the losses for each layer
332losses=[self.calc\_loss(layer,h[n],mem[n])forn,layerinenumerate(self.layers)]
Sum of the losses
334returnsum(losses)