docs/transformers/rope/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/rope/ init.py)
This is an implementation of Rotary Positional Embeddings (RoPE) in PyTorch.
Rotary Positional Embeddings (RoPE) encode position information of tokens with a rotation matrix that naturally incorporates explicit relative position dependency.
Here's the training code for training a transformer model with RoPE on Tiny Shakespeare dataset.
23importtorch24fromtorchimportnn2526fromlabml.loggerimportinspect27fromlabml\_nn.transformers.mhaimportMultiHeadAttention
Rotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the d features as 2d pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token.
Let xm(1) and xm(2) be two features of the key or query of any head at position m. Or for simplicity assume x has only two features. Then the transformation is,
RoPE(xm(1),xm(2),m)=(cosmθsinmθ−sinmθcosmθ)(xm(1)xm(2))=(xm(1)cosmθ−xm(2)sinmθxm(2)cosmθ+xm(1)sinmθ)
where θ is a constant angle. The other pairs of features are transformed similarly.
For a pair of features, dot-product attention score between two positions m and n would be
⟨RoPE(xm(1),xm(2),m),RoPE(xn(1),xn(2),n)⟩(xm(1)cosmθ−xm(2)sinmθ)(xn(1)cosnθ−xn(2)sinnθ)(xm(2)cosmθ+xm(1)sinmθ)(xn(2)cosnθ+xn(1)sinnθ)xm(1)xn(1)(cosmθcosnθ+sinmθsinnθ)xm(1)xn(2)(−cosmθsinnθ+sinmθcosnθ)xm(2)xn(1)(−sinmθcosnθ+cosmθsinnθ)xm(2)xn(2)(sinmθsinnθ+cosmθcosnθ)xm(1)xn(1)cos(m−n)θ+xm(1)xn(2)sin(m−n)θ−xm(2)xn(1)sin(m−n)θ+xm(2)xn(2)cos(m−n)θ(xm(1)cos(m−n)θ−xm(2)sin(m−n)θ)xn(1)(xm(2)cos(m−n)mθ+xm(1)sin(m−n)θ)xn(2)⟨RoPE(xm(1),xm(2),m−n),RoPE(xn(1),xn(2),0)⟩=+=+++=+=+=
This shows that for dot-production attention the rotary encodings gives relative attention.
The features are grouped into pairs and handled as above. They use a different θ for each pair.
The paper suggests using Θ=θi=10000d2(i−1),i∈[1,2,...,2d] for the 2d pairs of features.
We pair feature i with feature i+2d. So for position m we transform
(xm(i)xm(i+2d))
to
(xm(i)cosmθi−xm(i+2d)sinmθixm(i+2d)cosmθi+xm(i)sinmθi)
30classRotaryPositionalEmbeddings(nn.Module):
d is the number of features dbase is the constant used for calculating Θ117def\_\_init\_\_(self,d:int,base:int=10\_000):
122super().\_\_init\_\_()123124self.base=base125self.d=d126self.cos\_cached=None127self.sin\_cached=None
Cache cos and sin values
129def\_build\_cache(self,x:torch.Tensor):
Return if cache is already built
134ifself.cos\_cachedisnotNoneandx.shape[0]\<=self.cos\_cached.shape[0]:135return
Get sequence length
138seq\_len=x.shape[0]
Θ=θi=10000−d2(i−1),i∈[1,2,...,2d]
141theta=1./(self.base\*\*(torch.arange(0,self.d,2).float()/self.d)).to(x.device)
Create position indexes [0, 1, ..., seq_len - 1]
144seq\_idx=torch.arange(seq\_len,device=x.device).float().to(x.device)
Calculate the product of position index and θi
147idx\_theta=torch.einsum('n,d-\>nd',seq\_idx,theta)
Concatenate so that for row m we have [mθ0,mθ1,...,mθ2d,mθ0,mθ1,...,mθ2d]
151idx\_theta2=torch.cat([idx\_theta,idx\_theta],dim=1)
Cache them
154self.cos\_cached=idx\_theta2.cos()[:,None,None,:]155self.sin\_cached=idx\_theta2.sin()[:,None,None,:]
157def\_neg\_half(self,x:torch.Tensor):
2d
159d\_2=self.d//2
Calculate [−x(2d+1),−x(2d+2),...,−x(d),x(1),x(2),...,x(2d)]
162returntorch.cat([-x[:,:,:,d\_2:],x[:,:,:,:d\_2]],dim=-1)
x is the Tensor at the head of a key or a query with shape [seq_len, batch_size, n_heads, d]164defforward(self,x:torch.Tensor):
Cache cos and sin values
169self.\_build\_cache(x)
Sequence length
172seq\_len=x.shape[0]
Split the features, we can choose to apply rotary embeddings only to a partial set of features.
175x\_rope,x\_pass=x[...,:self.d],x[...,self.d:]
Calculate [−x(2d+1),−x(2d+2),...,−x(d),x(1),x(2),...,x(2d)]
179neg\_half\_x=self.\_neg\_half(x\_rope)
Calculate
(xm(i)cosmθi−xm(i+2d)sinmθixm(i+2d)cosmθi+xm(i)sinmθi)
for i∈1,2,...,2d
191x\_rope=(x\_rope\*self.cos\_cached[:seq\_len])+(neg\_half\_x\*self.sin\_cached[:seq\_len])
194returntorch.cat((x\_rope,x\_pass),dim=-1)
We override multi-head attention from original transformer.
197classRotaryPEMultiHeadAttention(MultiHeadAttention):
204def\_\_init\_\_(self,heads:int,d\_model:int,rope\_percentage:float=0.5,dropout\_prob:float=0.0):205super().\_\_init\_\_(heads,d\_model,dropout\_prob)
Rotary positional embedding layers
208d\_rope=int(self.d\_k\*rope\_percentage)209self.query\_rotary\_pe=RotaryPositionalEmbeddings(d\_rope)210self.key\_rotary\_pe=RotaryPositionalEmbeddings(d\_rope)
212defget\_scores(self,query:torch.Tensor,key:torch.Tensor):
Calculate dot-product with RoPE
218returntorch.einsum('ibhd,jbhd-\>ijbh',self.query\_rotary\_pe(query),self.key\_rotary\_pe(key))
Testing RoPE with a simple example
221def\_test\_rotary():
225x=torch.tensor([[1,2,3,4],[4,5,6,7],[7,8,9,10]],dtype=torch.float)226x=x[:,None,None,:]227inspect(x)228229rotary\_pe=RotaryPositionalEmbeddings(4)230inspect(rotary\_pe(x))231232233if\_\_name\_\_=='\_\_main\_\_':234\_test\_rotary()