Back to Annotated Deep Learning Paper Implementations

Relative Multi-Headed Attention

docs/transformers/xl/relative_mha.html

latest4.8 KB
Original Source

hometransformersxl

View code on Github

#

Relative Multi-Headed Attention

This is an implementation of relative multi-headed attention from paper Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context in PyTorch.

16importtorch17fromtorchimportnn1819fromlabml.loggerimportinspect20fromlabml\_nn.transformers.mhaimportMultiHeadAttention

#

This method shifts ith row of a matrix by i columns.

If the input is [[1, 2 ,3], [4, 5 ,6], [7, 8, 9]] , the shifted result would be [[1, 2 ,3], [0, 4, 5], [6, 0, 7]] . Ideally we should mask out the lower triangle but it's ok for our purpose.

23defshift\_right(x:torch.Tensor):

#

Concatenate a column of zeros

33zero\_pad=x.new\_zeros(x.shape[0],1,\*x.shape[2:])34x\_padded=torch.cat([x,zero\_pad],dim=1)

#

Reshape and remove excess elements from the end

37x\_padded=x\_padded.view(x.shape[1]+1,x.shape[0],\*x.shape[2:])38x=x\_padded[:-1].view\_as(x)

#

41returnx

#

Relative Multi-Head Attention Module

We override Multi-Head Attention module so we only need to write the get_scores method.

44classRelativeMultiHeadAttention(MultiHeadAttention):

#

52def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float=0.1):

#

The linear transformations do not need a bias since we explicitly include it when calculating scores. However having a bias for value might make sense.

56super().\_\_init\_\_(heads,d\_model,dropout\_prob,bias=False)

#

Number of relative positions

59self.P=2\*\*12

#

Relative positional embeddings for key relative to the query. We need 2P embeddings because the keys can be before or after the query.

63self.key\_pos\_embeddings=nn.Parameter(torch.zeros((self.P\*2,heads,self.d\_k)),requires\_grad=True)

#

Relative positional embedding bias for key relative to the query.

65self.key\_pos\_bias=nn.Parameter(torch.zeros((self.P\*2,heads)),requires\_grad=True)

#

Positional embeddings for the query is independent of the position of the query

67self.query\_pos\_bias=nn.Parameter(torch.zeros((heads,self.d\_k)),requires\_grad=True)

#

Get relative attention scores

With absolute attention

Ajabs​​=linq​(Xiq​+Pi​)⊤link​(Xjk​+Pj​)=AQi⊤​Kj​​+BQi⊤​UjK​​+CUiQ​⊤Kj​​+DUiQ​⊤UjK​​​

where Qi​,Kj​, are linear transformations of original embeddings Xiq​,Xjk​ and UiQ​,UjK​ are linear transformations of absolute positional encodings Pi​,Pj​.

They reason out that the attention to a given key should be the same regardless of the position of query. Hence replace CUiQ​⊤Kj​​ with a constant Cv⊤Kj​​.

For the second and third terms relative positional encodings are introduced. So BQi⊤​UjK​​ is replaced with BQi⊤​Ri−j​​ and DUiQ​⊤UjK​​ with DSi−j​​.

Ai,jrel​​=AQi⊤​Kj​​+BQi⊤​Ri−j​​+Cv⊤Kj​​+DSi−j​​​

69defget\_scores(self,query:torch.Tensor,key:torch.Tensor):

#

Rk​

108key\_pos\_emb=self.key\_pos\_embeddings[self.P-key.shape[0]:self.P+query.shape[0]]

#

Sk​

110key\_pos\_bias=self.key\_pos\_bias[self.P-key.shape[0]:self.P+query.shape[0]]

#

v⊤

112query\_pos\_bias=self.query\_pos\_bias[None,None,:,:]

#

(A+C)i,j​=Qi⊤​Kj​+v⊤Kj​

117ac=torch.einsum('ibhd,jbhd-\>ijbh',query+query\_pos\_bias,key)

#

B′i,k​=Qi⊤​Rk​

119b=torch.einsum('ibhd,jhd-\>ijbh',query,key\_pos\_emb)

#

D′i,k​=Sk​

121d=key\_pos\_bias[None,:,None,:]

#

Shift the rows of (B′+D′)i,k​ to get (B+D)i,j​=(B′+D′)i,i−j​

124bd=shift\_right(b+d)

#

Remove extra positions

126bd=bd[:,-key.shape[0]:]

#

Return the sum AQi⊤​Kj​​+BQi⊤​Ri−j​​+Cv⊤Kj​​+DSi−j​​

134returnac+bd

#

137def\_test\_shift\_right():138x=torch.tensor([[1,2,3],[4,5,6],[7,8,9]])139inspect(x)140inspect(shift\_right(x))141142x=torch.arange(1,6)[None,:,None,None].repeat(5,1,1,1)143inspect(x[:,:,0,0])144inspect(shift\_right(x)[:,:,0,0])145146x=torch.arange(1,6)[None,:,None,None].repeat(3,1,1,1)147inspect(x[:,:,0,0])148inspect(shift\_right(x)[:,:,0,0])149150151if\_\_name\_\_=='\_\_main\_\_':152\_test\_shift\_right()

labml.ai