docs/transformers/xl/relative_mha.html
This is an implementation of relative multi-headed attention from paper Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context in PyTorch.
16importtorch17fromtorchimportnn1819fromlabml.loggerimportinspect20fromlabml\_nn.transformers.mhaimportMultiHeadAttention
This method shifts ith row of a matrix by i columns.
If the input is [[1, 2 ,3], [4, 5 ,6], [7, 8, 9]] , the shifted result would be [[1, 2 ,3], [0, 4, 5], [6, 0, 7]] . Ideally we should mask out the lower triangle but it's ok for our purpose.
23defshift\_right(x:torch.Tensor):
Concatenate a column of zeros
33zero\_pad=x.new\_zeros(x.shape[0],1,\*x.shape[2:])34x\_padded=torch.cat([x,zero\_pad],dim=1)
Reshape and remove excess elements from the end
37x\_padded=x\_padded.view(x.shape[1]+1,x.shape[0],\*x.shape[2:])38x=x\_padded[:-1].view\_as(x)
41returnx
We override Multi-Head Attention module so we only need to write the get_scores method.
44classRelativeMultiHeadAttention(MultiHeadAttention):
52def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float=0.1):
The linear transformations do not need a bias since we explicitly include it when calculating scores. However having a bias for value might make sense.
56super().\_\_init\_\_(heads,d\_model,dropout\_prob,bias=False)
Number of relative positions
59self.P=2\*\*12
Relative positional embeddings for key relative to the query. We need 2P embeddings because the keys can be before or after the query.
63self.key\_pos\_embeddings=nn.Parameter(torch.zeros((self.P\*2,heads,self.d\_k)),requires\_grad=True)
Relative positional embedding bias for key relative to the query.
65self.key\_pos\_bias=nn.Parameter(torch.zeros((self.P\*2,heads)),requires\_grad=True)
Positional embeddings for the query is independent of the position of the query
67self.query\_pos\_bias=nn.Parameter(torch.zeros((heads,self.d\_k)),requires\_grad=True)
With absolute attention
Ajabs=linq(Xiq+Pi)⊤link(Xjk+Pj)=AQi⊤Kj+BQi⊤UjK+CUiQ⊤Kj+DUiQ⊤UjK
where Qi,Kj, are linear transformations of original embeddings Xiq,Xjk and UiQ,UjK are linear transformations of absolute positional encodings Pi,Pj.
They reason out that the attention to a given key should be the same regardless of the position of query. Hence replace CUiQ⊤Kj with a constant Cv⊤Kj.
For the second and third terms relative positional encodings are introduced. So BQi⊤UjK is replaced with BQi⊤Ri−j and DUiQ⊤UjK with DSi−j.
Ai,jrel=AQi⊤Kj+BQi⊤Ri−j+Cv⊤Kj+DSi−j
69defget\_scores(self,query:torch.Tensor,key:torch.Tensor):
Rk
108key\_pos\_emb=self.key\_pos\_embeddings[self.P-key.shape[0]:self.P+query.shape[0]]
Sk
110key\_pos\_bias=self.key\_pos\_bias[self.P-key.shape[0]:self.P+query.shape[0]]
v⊤
112query\_pos\_bias=self.query\_pos\_bias[None,None,:,:]
(A+C)i,j=Qi⊤Kj+v⊤Kj
117ac=torch.einsum('ibhd,jbhd-\>ijbh',query+query\_pos\_bias,key)
B′i,k=Qi⊤Rk
119b=torch.einsum('ibhd,jhd-\>ijbh',query,key\_pos\_emb)
D′i,k=Sk
121d=key\_pos\_bias[None,:,None,:]
Shift the rows of (B′+D′)i,k to get (B+D)i,j=(B′+D′)i,i−j
124bd=shift\_right(b+d)
Remove extra positions
126bd=bd[:,-key.shape[0]:]
Return the sum AQi⊤Kj+BQi⊤Ri−j+Cv⊤Kj+DSi−j
134returnac+bd
137def\_test\_shift\_right():138x=torch.tensor([[1,2,3],[4,5,6],[7,8,9]])139inspect(x)140inspect(shift\_right(x))141142x=torch.arange(1,6)[None,:,None,None].repeat(5,1,1,1)143inspect(x[:,:,0,0])144inspect(shift\_right(x)[:,:,0,0])145146x=torch.arange(1,6)[None,:,None,None].repeat(3,1,1,1)147inspect(x[:,:,0,0])148inspect(shift\_right(x)[:,:,0,0])149150151if\_\_name\_\_=='\_\_main\_\_':152\_test\_shift\_right()