Back to Annotated Deep Learning Paper Implementations

Rotary Positional Embeddings (RoPE)

docs/transformers/rope/index.html

latest6.7 KB
Original Source

hometransformersrope

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/rope/ init.py)

#

Rotary Positional Embeddings (RoPE)

This is an implementation of Rotary Positional Embeddings (RoPE) in PyTorch.

Rotary Positional Embeddings (RoPE) encode position information of tokens with a rotation matrix that naturally incorporates explicit relative position dependency.

Here's the training code for training a transformer model with RoPE on Tiny Shakespeare dataset.

23importtorch24fromtorchimportnn2526fromlabml.loggerimportinspect27fromlabml\_nn.transformers.mhaimportMultiHeadAttention

#

RoPE module

Rotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the d features as 2d​ pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token.

For a pair of features

Let xm(1)​ and xm(2)​ be two features of the key or query of any head at position m. Or for simplicity assume x has only two features. Then the transformation is,

RoPE(xm(1)​,xm(2)​,m)​=(cosmθsinmθ​−sinmθcosmθ​)(xm(1)​xm(2)​​)=(xm(1)​cosmθ−xm(2)​sinmθxm(2)​cosmθ+xm(1)​sinmθ​)​

where θ is a constant angle. The other pairs of features are transformed similarly.

Attention is relative

For a pair of features, dot-product attention score between two positions m and n would be

⟨RoPE(xm(1)​,xm(2)​,m),RoPE(xn(1)​,xn(2)​,n)⟩(xm(1)​cosmθ−xm(2)​sinmθ)(xn(1)​cosnθ−xn(2)​sinnθ)(xm(2)​cosmθ+xm(1)​sinmθ)(xn(2)​cosnθ+xn(1)​sinnθ)xm(1)​xn(1)​(cosmθcosnθ+sinmθsinnθ)xm(1)​xn(2)​(−cosmθsinnθ+sinmθcosnθ)xm(2)​xn(1)​(−sinmθcosnθ+cosmθsinnθ)xm(2)​xn(2)​(sinmθsinnθ+cosmθcosnθ)xm(1)​xn(1)​cos(m−n)θ+xm(1)​xn(2)​sin(m−n)θ−xm(2)​xn(1)​sin(m−n)θ+xm(2)​xn(2)​cos(m−n)θ(xm(1)​cos(m−n)θ−xm(2)​sin(m−n)θ)xn(1)​(xm(2)​cos(m−n)mθ+xm(1)​sin(m−n)θ)xn(2)​⟨RoPE(xm(1)​,xm(2)​,m−n),RoPE(xn(1)​,xn(2)​,0)⟩​=+=+++=+=+=​

This shows that for dot-production attention the rotary encodings gives relative attention.

For all features

The features are grouped into pairs and handled as above. They use a different θ for each pair.

The paper suggests using Θ=θi​=10000d2(i−1)​,i∈[1,2,...,2d​] for the 2d​ pairs of features.

We pair feature i with feature i+2d​. So for position m we transform

(xm(i)​xm(i+2d​)​​)​

to

(xm(i)​cosmθi​−xm(i+2d​)​sinmθi​xm(i+2d​)​cosmθi​+xm(i)​sinmθi​​)​

30classRotaryPositionalEmbeddings(nn.Module):

#

  • d is the number of features d
  • base is the constant used for calculating Θ
117def\_\_init\_\_(self,d:int,base:int=10\_000):

#

122super().\_\_init\_\_()123124self.base=base125self.d=d126self.cos\_cached=None127self.sin\_cached=None

#

Cache cos and sin values

129def\_build\_cache(self,x:torch.Tensor):

#

Return if cache is already built

134ifself.cos\_cachedisnotNoneandx.shape[0]\<=self.cos\_cached.shape[0]:135return

#

Get sequence length

138seq\_len=x.shape[0]

#

Θ=θi​=10000−d2(i−1)​,i∈[1,2,...,2d​]

141theta=1./(self.base\*\*(torch.arange(0,self.d,2).float()/self.d)).to(x.device)

#

Create position indexes [0, 1, ..., seq_len - 1]

144seq\_idx=torch.arange(seq\_len,device=x.device).float().to(x.device)

#

Calculate the product of position index and θi​

147idx\_theta=torch.einsum('n,d-\>nd',seq\_idx,theta)

#

Concatenate so that for row m we have [mθ0​,mθ1​,...,mθ2d​​,mθ0​,mθ1​,...,mθ2d​​]

151idx\_theta2=torch.cat([idx\_theta,idx\_theta],dim=1)

#

Cache them

154self.cos\_cached=idx\_theta2.cos()[:,None,None,:]155self.sin\_cached=idx\_theta2.sin()[:,None,None,:]

#

157def\_neg\_half(self,x:torch.Tensor):

#

2d​

159d\_2=self.d//2

#

Calculate [−x(2d​+1),−x(2d​+2),...,−x(d),x(1),x(2),...,x(2d​)]

162returntorch.cat([-x[:,:,:,d\_2:],x[:,:,:,:d\_2]],dim=-1)

#

  • x is the Tensor at the head of a key or a query with shape [seq_len, batch_size, n_heads, d]
164defforward(self,x:torch.Tensor):

#

Cache cos and sin values

169self.\_build\_cache(x)

#

Sequence length

172seq\_len=x.shape[0]

#

Split the features, we can choose to apply rotary embeddings only to a partial set of features.

175x\_rope,x\_pass=x[...,:self.d],x[...,self.d:]

#

Calculate [−x(2d​+1),−x(2d​+2),...,−x(d),x(1),x(2),...,x(2d​)]

179neg\_half\_x=self.\_neg\_half(x\_rope)

#

Calculate

(xm(i)​cosmθi​−xm(i+2d​)​sinmθi​xm(i+2d​)​cosmθi​+xm(i)​sinmθi​​)​

for i∈1,2,...,2d​

191x\_rope=(x\_rope\*self.cos\_cached[:seq\_len])+(neg\_half\_x\*self.sin\_cached[:seq\_len])

#

194returntorch.cat((x\_rope,x\_pass),dim=-1)

#

Multi-head attention with rotary positional embeddings

We override multi-head attention from original transformer.

197classRotaryPEMultiHeadAttention(MultiHeadAttention):

#

204def\_\_init\_\_(self,heads:int,d\_model:int,rope\_percentage:float=0.5,dropout\_prob:float=0.0):205super().\_\_init\_\_(heads,d\_model,dropout\_prob)

#

Rotary positional embedding layers

208d\_rope=int(self.d\_k\*rope\_percentage)209self.query\_rotary\_pe=RotaryPositionalEmbeddings(d\_rope)210self.key\_rotary\_pe=RotaryPositionalEmbeddings(d\_rope)

#

Calculate scores between queries and keys

212defget\_scores(self,query:torch.Tensor,key:torch.Tensor):

#

Calculate dot-product with RoPE

218returntorch.einsum('ibhd,jbhd-\>ijbh',self.query\_rotary\_pe(query),self.key\_rotary\_pe(key))

#

Testing RoPE with a simple example

221def\_test\_rotary():

#

225x=torch.tensor([[1,2,3,4],[4,5,6,7],[7,8,9,10]],dtype=torch.float)226x=x[:,None,None,:]227inspect(x)228229rotary\_pe=RotaryPositionalEmbeddings(4)230inspect(rotary\_pe(x))231232233if\_\_name\_\_=='\_\_main\_\_':234\_test\_rotary()

labml.ai