Back to Annotated Deep Learning Paper Implementations

Transformer XL

docs/transformers/xl/index.html

latest5.0 KB
Original Source

hometransformersxl

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/xl/ init.py)

#

Transformer XL

This is an implementation of Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context in PyTorch.

Transformer has a limited attention span, equal to the length of the sequence trained in parallel. All these positions have a fixed positional encoding. Transformer XL increases this attention span by letting each of the positions pay attention to precalculated past embeddings. For instance if the context length is l, it will keep the embeddings of all layers for previous batch of length l and feed them to current step. If we use fixed-positional encodings these pre-calculated embeddings will have the same positions as the current context. They introduce relative positional encoding, where the positional encodings are introduced at the attention calculation.

Annotated implementation of relative multi-headed attention is in relative_mha.py.

Here's the training code and a notebook for training a transformer XL model on Tiny Shakespeare dataset.

35fromtypingimportList,Optional3637importtorch38importtorch.nnasnn3940fromlabml\_nn.utilsimportclone\_module\_list41from.relative\_mhaimportRelativeMultiHeadAttention42from..feed\_forwardimportFeedForward

#

Transformer XL Layer

The transformer XL model comprises of a number of these layers.

45classTransformerXLLayer(nn.Module):

#

  • d_model is the token embedding size
  • self_attn is the self attention module
  • feed_forward is the feed forward module
  • dropout_prob is the probability of dropping out after self attention and FFN
51def\_\_init\_\_(self,\*,52d\_model:int,53self\_attn:RelativeMultiHeadAttention,54feed\_forward:FeedForward,55dropout\_prob:float):

#

62super().\_\_init\_\_()63self.size=d\_model64self.self\_attn=self\_attn65self.feed\_forward=feed\_forward66self.dropout=nn.Dropout(dropout\_prob)67self.norm\_self\_attn=nn.LayerNorm([d\_model])68self.norm\_ff=nn.LayerNorm([d\_model])

#

  • x is a tensor of the token level feature vectors of shape [seq_len, batch_size, d_model]
  • mem is a tensor of the past token level feature vectors of shape [mem_len, batch_size, d_model]
  • mask is a matrix of shape [seq_len, mem_len + seq_len, batch_size] or [seq_len, mem_len + seq_len, 1] . mask[i, j] is true if token at i can see token at j .
70defforward(self,\*,71x:torch.Tensor,72mem:Optional[torch.Tensor],73mask:torch.Tensor):

#

Normalize the vectors before doing self attention

81z=self.norm\_self\_attn(x)

#

If there is memory

83ifmemisnotNone:

#

Normalize it

85mem=self.norm\_self\_attn(mem)

#

Concatenate with z

87m\_z=torch.cat((mem,z),dim=0)

#

Ignore if there is no memory

89else:90m\_z=z

#

Attention

92self\_attn=self.self\_attn(query=z,key=m\_z,value=m\_z,mask=mask)

#

Add the attention results

94x=x+self.dropout(self\_attn)

#

Normalize for feed-forward

97z=self.norm\_ff(x)

#

Pass through the feed-forward network

99ff=self.feed\_forward(z)

#

Add the feed-forward results back

101x=x+self.dropout(ff)

#

104returnx

#

Transformer XL Model

This consists of multiple transformer XL layers

107classTransformerXL(nn.Module):

#

114def\_\_init\_\_(self,layer:TransformerXLLayer,n\_layers:int):115super().\_\_init\_\_()

#

Make copies of the transformer layer

117self.layers=clone\_module\_list(layer,n\_layers)

#

Final normalization layer

119self.norm=nn.LayerNorm([layer.size])

#

  • x is a tensor of the token embeddings vectors of shape [seq_len, batch_size, d_model]
  • mem is a list of tensors of the past token level feature vectors of shape [mem_len, batch_size, d_model] for each layer
  • mask is the masking matrix
121defforward(self,x:torch.Tensor,mem:List[torch.Tensor],mask:torch.Tensor):

#

List to store token level feature vectors, which will become the memories for the next sequential batch.

130new\_mem=[]

#

Run through each transformer layer

132fori,layerinenumerate(self.layers):

#

Add to the list of feature vectors

134new\_mem.append(x.detach())

#

Memory

136m=mem[i]ifmemelseNone

#

Run through the transformer XL layer

138x=layer(x=x,mem=m,mask=mask)

#

Finally, normalize the vectors

140returnself.norm(x),new\_mem

labml.ai