Back to Annotated Deep Learning Paper Implementations

Compressive Transformer

docs/transformers/compressive/index.html

latest12.9 KB
Original Source

hometransformerscompressive

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/compressive/ init.py)

#

Compressive Transformer

This is an implementation of Compressive Transformers for Long-Range Sequence Modelling in PyTorch.

This is an extension of Transformer XL where past memories are compressed to give a longer attention range. That is, the furthest ncm​c memories are compressed into ncm​ memories, where c is the compression rate.

Compression operation

The compression operation is defined as fc​:Rnc×d→Rn×d. The paper introduces multiple choices for fc​ and we have only implemented 1D convolution which seems to give the best results. Each layer has a separate compression operation fc​(i) where i is the layer number.

Training compression operation

Since training compression with BPTT requires maintaining a very large computational graph (many time steps), the paper proposes an auto-encoding loss and an attention reconstruction loss. The auto-encoding loss decodes the original memories from the compressed memories and calculates the loss. Attention reconstruction loss computes the multi-headed attention results on the compressed memory and on uncompressed memory and gets a mean squared error between them. We have implemented the latter here since it gives better results.

This implementation uses pre-layer normalization while the paper uses post-layer normalization. Pre-layer norm does the layer norm before FFN and self-attention, and the pass-through in the residual connection is not normalized. This is supposed to be more stable in standard transformer setups.

Here are the training code and a notebook for training a compressive transformer model on the Tiny Shakespeare dataset.

53fromtypingimportOptional,List5455importtorch56importtorch.nn.functionalasF57fromtorchimportnn5859fromlabml\_nn.transformers.feed\_forwardimportFeedForward60fromlabml\_nn.transformers.mhaimportPrepareForMultiHeadAttention61fromlabml\_nn.transformers.xl.relative\_mhaimportRelativeMultiHeadAttention62fromlabml\_nn.utilsimportclone\_module\_list

#

1D Convolution Compression fc​

This is a simple wrapper around nn.Conv1d with some tensor dimension permutations.

65classConv1dCompression(nn.Module):

#

  • compression_ratec
  • d_model is the embedding size
73def\_\_init\_\_(self,compression\_rate:int,d\_model:int):

#

78super().\_\_init\_\_()79self.conv=nn.Conv1d(d\_model,d\_model,kernel\_size=compression\_rate,stride=compression\_rate)

#

mem has shape [seq_len, batch, d_model]

81defforward(self,mem:torch.Tensor):

#

Permute the dimensions of mem so that we can run it through the convolution layer. The convolution layer accepts in the form [batch, features, sequence]

88mem=mem.permute(1,2,0)

#

Get compressed memory by running it through the convolution layer

90c\_mem=self.conv(mem)

#

Permute back to form [seq_len, batch, d_model]

92returnc\_mem.permute(2,0,1)

#

Compressive Transformer Layer

This is the implementation of a single compressive transformer layer

95classCompressiveTransformerLayer(nn.Module):

#

  • d_model is the token embedding size
  • self_attn is the self attention module
  • feed_forward is the feed forward module
  • dropout_prob is the probability of dropping out after self attention and FFN
  • compress is the compression function fc​
101def\_\_init\_\_(self,\*,102d\_model:int,103self\_attn:RelativeMultiHeadAttention,104feed\_forward:FeedForward,105dropout\_prob:float,106compress:Conv1dCompression):

#

114super().\_\_init\_\_()115self.compress=compress116self.size=d\_model117self.self\_attn=self\_attn118self.feed\_forward=feed\_forward119self.dropout=nn.Dropout(dropout\_prob)120self.norm\_self\_attn=nn.LayerNorm([d\_model])121self.norm\_ff=nn.LayerNorm([d\_model])

#

Concatenate the normalized token embeddings with memory and compressed memory.

  • z is layer normalized token embeddings.
  • mem and c_mem are memory and compressed memory (not normalized).
123defconcat\_memory(self,z:torch.Tensor,mem:Optional[torch.Tensor],c\_mem:Optional[torch.Tensor]):

#

If there is no memory just return the token embeddings

132ifmemisNone:133returnz

#

If there are compressed memory concatenate that with memory

136ifc\_memisnotNone:137mem=torch.cat((c\_mem,mem),dim=0)

#

Run the memory through the normalization layer

140mem=self.norm\_self\_attn(mem)

#

Concatenate normalized memory and normalized token embeddings

142returntorch.cat((mem,z),dim=0)

#

  • x is a tensor of token level feature vectors of shape [seq_len, batch_size, d_model]
  • mem is a tensor of the past token level feature vectors (memory) of shape [mem_len, batch_size, d_model]
  • c_mem is a tensor of the compressed memory [c_mem_len, batch_size, d_model]
  • mask is a matrix of shape [seq_len, c_mem_len + mem_len + seq_len, batch_size] or [seq_len, c_mem_len + mem_len + seq_len, 1] . mask[i, j] is true if token at i can see token at j .
144defforward(self,\*,145x:torch.Tensor,146mem:Optional[torch.Tensor],147c\_mem:Optional[torch.Tensor],148mask:torch.Tensor):

#

Normalize the vectors before doing self attention

158z=self.norm\_self\_attn(x)

#

Normalize and concatenate memory and compressed memory

160m\_z=self.concat\_memory(z,mem,c\_mem)

#

Attention

162self\_attn=self.self\_attn(query=z,key=m\_z,value=m\_z,mask=mask)

#

Add the attention results

164x=x+self.dropout(self\_attn)

#

Normalize for feed-forward

167z=self.norm\_ff(x)

#

Pass through the feed-forward network

169ff=self.feed\_forward(z)

#

Add the feed-forward results back

171x=x+self.dropout(ff)

#

174returnx

#

Compressive Transformer Model

This consists of multiple compressive transformer layers

177classCompressiveTransformer(nn.Module):

#

184def\_\_init\_\_(self,layer:CompressiveTransformerLayer,n\_layers:int):185super().\_\_init\_\_()

#

Make copies of the transformer layer

187self.layers=clone\_module\_list(layer,n\_layers)

#

Final normalization layer

189self.norm=nn.LayerNorm([layer.size])

#

  • x is a tensor of the token embeddings vectors of shape [seq_len, batch_size, d_model]
  • mem is a list of tensors of the past token level feature vectors of shape [mem_len, batch_size, d_model] for each layer
  • c_mem is a list of tensors of the compressed memory [c_mem_len, batch_size, d_model] for each layer
  • mask is the masking matrix
191defforward(self,x:torch.Tensor,mem:List[torch.Tensor],c\_mem:List[torch.Tensor],mask:torch.Tensor):

#

List to store token level feature vectors, which will become the memories for the next sequential batch.

202new\_mem=[]

#

Run through each transformer layer

204fori,layerinenumerate(self.layers):

#

Add to the list of feature vectors

206new\_mem.append(x.detach())

#

Memory

208m=mem[i]ifmemelseNone

#

Compressed Memory

210cm=c\_mem[i]ifc\_memelseNone

#

Run through the transformer XL layer

212x=layer(x=x,mem=m,c\_mem=cm,mask=mask)

#

Finally, normalize the vectors

214returnself.norm(x),new\_mem

#

Attention Reconstruction Loss

Attention reconstruction loss recreates the self-attention output with uncompressed memory and with compressed memory and calculates the mean squared error between the two. It does this without positional encoding.

When calculating and training the compression function fc​ with attention reconstruction loss, all parameters but fc​ are frozen. This includes key/value projections and bias/scaling after normalization.

Since this loss can be computed independently of the cross-entropy-loss of the model you can have a separate optimizer that only updates fc​. However, we use the same optimizer to update fc​ so when calculating attention reconstruction loss, we detach all other parameters except fc​ from the gradient computation.

217classAttentionReconstructionLoss:

#

layers is the list of Compressive Transformer layers

235def\_\_init\_\_(self,layers:nn.ModuleList):

#

239self.layers=layers240self.loss\_func=nn.MSELoss()

#

This is a reimplementation of 'PrepareForMultiHeadAttention' where the projections are done with the parameters detached from gradient computation.

242defprepare\_for\_attn(self,pmha:PrepareForMultiHeadAttention,x:torch.Tensor):

#

Shape of the input except embedding dimension; [seq_len, batch_size] .

252head\_shape=x.shape[:-1]

#

Detach projection weights and bias

255weight=pmha.linear.weight.detach()256bias=pmha.linear.bias.detach()ifpmha.linear.biasisnotNoneelseNone

#

Linear transform

258x=F.linear(x,weight,bias)

#

Split last dimension into heads

261x=x.view(\*head\_shape,pmha.heads,pmha.d\_k)

#

Output has shape [seq_len, batch_size, heads, d_k] or [batch_size, d_model]

264returnx

#

This is a reimplementation of 'Multi-Head Attention' which calls prepare_for_attn instead of 'PrepareForMultiHeadAttention' to detach projection parameters.

266defattn(self,layer:RelativeMultiHeadAttention,query:torch.Tensor,key:torch.Tensor,value:torch.Tensor):

#

Calculate query, key and value projections

273query=self.prepare\_for\_attn(layer.query,query)274key=self.prepare\_for\_attn(layer.key,key)275value=self.prepare\_for\_attn(layer.value,value)

#

Compute attention scores QK⊤. This gives a tensor of shape [seq_len, seq_len, batch_size, heads] .

279scores=torch.einsum('ibhd,jbhd-\>ijbh',query,key)

#

Scale scores dk​​QK⊤​

282scores\*=layer.scale

#

softmax attention along the key sequence dimension seqsoftmax​(dk​​QK⊤​)

286attn=layer.softmax(scores)

#

Multiply by values seqsoftmax​(dk​​QK⊤​)V

290returntorch.einsum("ijbh,jbhd-\>ibhd",attn,value)

#

Perform layer normalization with shift and scale parameters detached.

292defnorm(self,ln:nn.LayerNorm,x:torch.Tensor):

#

Detach shift(bias ) and scaling(weight ) parameters

298weight=ln.weight.detach()ifln.weightisnotNoneelseNone299bias=ln.bias.detach()ifln.biasisnotNoneelseNone

#

Layer normalization

302returnF.layer\_norm(x,ln.normalized\_shape,weight,bias,ln.eps)

#

This calculates the loss for a layer

304defcalc\_loss(self,layer:CompressiveTransformerLayer,h:torch.Tensor,mem:torch.Tensor):

#

Detach the token embeddings and memory.

310h=h.detach()311mem=mem.detach()

#

Compress the memory with fc​(i). The parameters of fc​(i) are the only parameters not detached from gradient computation.

315c\_mem=layer.compress(mem)

#

Normalize the embeddings and memories

318h=self.norm(layer.norm\_self\_attn,h)319mem=self.norm(layer.norm\_self\_attn,mem)320c\_mem=self.norm(layer.norm\_self\_attn,c\_mem)

#

Calculate the attention with uncompressed memory

323attn\_mem=self.attn(layer.self\_attn,h,mem,mem)

#

Calculate the attention with compressed memory

325attn\_cmem=self.attn(layer.self\_attn,h,c\_mem,c\_mem)

#

Calculate the mean square error

328returnself.loss\_func(attn\_cmem,attn\_mem)

#

330def\_\_call\_\_(self,h:List[torch.Tensor],mem:List[torch.Tensor]):

#

Calculate the losses for each layer

332losses=[self.calc\_loss(layer,h[n],mem[n])forn,layerinenumerate(self.layers)]

#

Sum of the losses

334returnsum(losses)

labml.ai