docs/transformers/xl/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/xl/ init.py)
This is an implementation of Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context in PyTorch.
Transformer has a limited attention span, equal to the length of the sequence trained in parallel. All these positions have a fixed positional encoding. Transformer XL increases this attention span by letting each of the positions pay attention to precalculated past embeddings. For instance if the context length is l, it will keep the embeddings of all layers for previous batch of length l and feed them to current step. If we use fixed-positional encodings these pre-calculated embeddings will have the same positions as the current context. They introduce relative positional encoding, where the positional encodings are introduced at the attention calculation.
Annotated implementation of relative multi-headed attention is in relative_mha.py.
Here's the training code and a notebook for training a transformer XL model on Tiny Shakespeare dataset.
35fromtypingimportList,Optional3637importtorch38importtorch.nnasnn3940fromlabml\_nn.utilsimportclone\_module\_list41from.relative\_mhaimportRelativeMultiHeadAttention42from..feed\_forwardimportFeedForward
The transformer XL model comprises of a number of these layers.
45classTransformerXLLayer(nn.Module):
d_model is the token embedding sizeself_attn is the self attention modulefeed_forward is the feed forward moduledropout_prob is the probability of dropping out after self attention and FFN51def\_\_init\_\_(self,\*,52d\_model:int,53self\_attn:RelativeMultiHeadAttention,54feed\_forward:FeedForward,55dropout\_prob:float):
62super().\_\_init\_\_()63self.size=d\_model64self.self\_attn=self\_attn65self.feed\_forward=feed\_forward66self.dropout=nn.Dropout(dropout\_prob)67self.norm\_self\_attn=nn.LayerNorm([d\_model])68self.norm\_ff=nn.LayerNorm([d\_model])
x is a tensor of the token level feature vectors of shape [seq_len, batch_size, d_model]mem is a tensor of the past token level feature vectors of shape [mem_len, batch_size, d_model]mask is a matrix of shape [seq_len, mem_len + seq_len, batch_size] or [seq_len, mem_len + seq_len, 1] . mask[i, j] is true if token at i can see token at j .70defforward(self,\*,71x:torch.Tensor,72mem:Optional[torch.Tensor],73mask:torch.Tensor):
Normalize the vectors before doing self attention
81z=self.norm\_self\_attn(x)
If there is memory
83ifmemisnotNone:
Normalize it
85mem=self.norm\_self\_attn(mem)
Concatenate with z
87m\_z=torch.cat((mem,z),dim=0)
Ignore if there is no memory
89else:90m\_z=z
Attention
92self\_attn=self.self\_attn(query=z,key=m\_z,value=m\_z,mask=mask)
Add the attention results
94x=x+self.dropout(self\_attn)
Normalize for feed-forward
97z=self.norm\_ff(x)
Pass through the feed-forward network
99ff=self.feed\_forward(z)
Add the feed-forward results back
101x=x+self.dropout(ff)
104returnx
This consists of multiple transformer XL layers
107classTransformerXL(nn.Module):
114def\_\_init\_\_(self,layer:TransformerXLLayer,n\_layers:int):115super().\_\_init\_\_()
Make copies of the transformer layer
117self.layers=clone\_module\_list(layer,n\_layers)
Final normalization layer
119self.norm=nn.LayerNorm([layer.size])
x is a tensor of the token embeddings vectors of shape [seq_len, batch_size, d_model]mem is a list of tensors of the past token level feature vectors of shape [mem_len, batch_size, d_model] for each layermask is the masking matrix121defforward(self,x:torch.Tensor,mem:List[torch.Tensor],mask:torch.Tensor):
List to store token level feature vectors, which will become the memories for the next sequential batch.
130new\_mem=[]
Run through each transformer layer
132fori,layerinenumerate(self.layers):
Add to the list of feature vectors
134new\_mem.append(x.detach())
Memory
136m=mem[i]ifmemelseNone
Run through the transformer XL layer
138x=layer(x=x,mem=m,mask=mask)
Finally, normalize the vectors
140returnself.norm(x),new\_mem