Back to Annotated Deep Learning Paper Implementations

Rotary Positional Embeddings with Relative distance (RoPER)

docs/transformers/rope/value_pe/index.html

latest7.7 KB
Original Source

hometransformersropevalue_pe

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/rope/value_pe/ init.py)

#

RoPER is work by Georges Harik (@gharik), and this implementation is based on his original code.

Rotary Positional Embeddings with Relative distance (RoPER)

Rotary Positional Embeddings (RoPE) includes relative positions in attention score calculation. However, the embeddings themselves do not get any positional information , except what it can get implicitly from causal attention.

RoPER adds relative positional information explicitly to value embeddings. Specifically, it adds the relative positions of the tokens it paid attention to. We use same rotary positional embeddings to rotate the values in attention, Then, after taking the weighted sum, we rotate the final in the opposite direction. Which is equivalent to rotating each of the values (before attention) relative to the current position.

Here's the training code for training a transformer model with RoPER on an arithmetic addition where we can see significant improvement over RoPE.

Relative distances in embeddings

For any head, let an,i​ be the attention from position n to position i, and vi​ be the value embeddings at position i. Let's denote individual features as vi(1)​,vi(2)​,….

Normally, we would take the weight sum of value embeddings

on(j)​=i∑​an,i​vi(j)​

This doesn't explicitly add any distance information about the positions i to final result on(j)​.

RoPER pairs features like RoPE and transform them. For a pair vm(1)​ and vm(2)​ it transforms them by RoPE(vm(1)​,vm(2)​,m). Let us donate the transformed features with v^m(1)​,v^m(2)​. Then it rotates the weighted sum o^n(j)​ in the the reverse direction with RoPE(o^n(1)​,o^n(2)​,−n). Note the −n.

Note that,

RoPE(xm(1)​,xm(2)​,m)​=(cosmθsinmθ​−sinmθcosmθ​)(xm(1)​xm(2)​​)=(xm(1)​cosmθ−xm(2)​sinmθxm(2)​cosmθ+xm(1)​sinmθ​)​

Final output after with the transformations is,

RoPE(o^n(1)​,o^n(2)​,−n)(o^n(1)​cosnθ+o^n(2)​sinnθo^n(2)​cosnθ−o^n(1)​sinnθ​)​=

Note that sin(−nθ)=−sinnθ.

Let's expand the first term o^n(1)​cosnθ+o^n(2)​sinnθ,

o^n(1)​cosnθ+o^n(2)​sinnθi∑​an,i​v^i(1)​cosnθ+i∑​an,i​v^i(2)​sinnθi∑​an,i​(vi(1)​cosiθ−vi(2)​siniθ)cosnθi∑​an,i​(vi(2)​cosiθ+vi(1)​siniθ)sinmθi∑​an,i​vi(1)​(cosiθcosnθ+siniθsinnθ)i∑​an,i​vi(2)​(cosiθsinnθ−siniθcosnθ)i∑​an,i​vi(1)​cos(i−n)θ−i∑​an,i​vi(2)​sin(i−n)θi∑​an,i​vi(1)​cos(i−n)θ−i∑​an,i​vi(2)​sin(i−n)θ​==+=+==​

Simiarly we can show the second term is equal to,

i∑​an,i​vi(1)​cos(i−n)θ+i∑​an,i​vi(2)​sin(i−n)θ

Which gives,

RoPE(o^n(1)​,o^n(2)​,−n)(∑i​an,i​vi(1)​cos(i−n)θ−∑i​an,i​vi(2)​sin(i−n)θ∑i​an,i​vi(1)​cos(i−n)θ+∑i​an,i​vi(2)​sin(i−n)θ​)i∑​an,i​RoPE(vi(1)​,vi(1)​,(i−n)θ)​==​

That is, the weighted average of values rotated relative to current position.

Here's an experiment that uses RoPER on an arthmetic addition task.

118fromtypingimportOptional119120importtorch121122fromlabml\_nn.transformers.ropeimportRotaryPositionalEmbeddings,RotaryPEMultiHeadAttention

#

RoPE module that rotates in the opposite direction

This inherits from RoPE rotation implementation and changes the direction.

125classReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):

#

  • x is the Tensor at the head of a key or a query with shape [seq_len, batch_size, n_heads, d]
132defforward(self,x:torch.Tensor):

#

Cache cos and sin values

137self.\_build\_cache(x)

#

Split the features, we can choose to apply rotary embeddings only to a partial set of features.

140x\_rope,x\_pass=x[...,:self.d],x[...,self.d:]

#

Calculate [−x(2d​+1),−x(2d​+2),...,−x(d),x(1),x(2),...,x(2d​)]

144neg\_half\_x=self.\_neg\_half(x\_rope)

#

Calculate

(xm(i)​cos−mθi​−xm(i+2d​)​sin−mθi​xm(i+2d​)​cos−mθi​+xm(i)​sin−mθi​​)=(xm(i)​cosmθi​+xm(i+2d​)​sinmθi​xm(i+2d​)​cosmθi​−xm(i)​sinmθi​​)​

for i∈1,2,...,2d​

160x\_rope=(x\_rope\*self.cos\_cached[:x.shape[0]])-(neg\_half\_x\*self.sin\_cached[:x.shape[0]])

#

163returntorch.cat((x\_rope,x\_pass),dim=-1)

#

Multi-head attention with rotary positional embeddings

We override multi-head attention from original transformer.

166classRotaryValuePEMultiHeadAttention(RotaryPEMultiHeadAttention):

#

173def\_\_init\_\_(self,heads:int,d\_model:int,174rope\_percentage:float=0.5,rope\_value\_percentage:float=0.5,175dropout\_prob:float=0.0):176super().\_\_init\_\_(heads,d\_model,rope\_percentage,dropout\_prob)

#

Rotary positional embedding layers

179d\_rope\_value=int(self.d\_k\*rope\_value\_percentage)180181self.value\_rotary\_pe=RotaryPositionalEmbeddings(d\_rope\_value)182self.value\_reverse\_rotary\_pe=ReverseRotaryPositionalEmbeddings(d\_rope\_value)

#

query , key and value are the tensors that store collection of query, key and value vectors. They have shape [seq_len, batch_size, d_model] .

mask has shape [seq_len, seq_len, batch_size] and mask[i, j, b] indicates whether for batch b , query at position i has access to key-value at position j .

184defforward(self,\*,185query:torch.Tensor,186key:torch.Tensor,187value:torch.Tensor,188mask:Optional[torch.Tensor]=None):

#

query , key and value have shape [seq_len, batch_size, d_model]

200seq\_len,batch\_size,\_=query.shape201202ifmaskisnotNone:203mask=self.prepare\_mask(mask,query.shape,key.shape)

#

Prepare query , key and value for attention computation. These will then have shape [seq_len, batch_size, heads, d_k] .

207query=self.query(query)208key=self.key(key)209value=self.value(value)

#

Compute attention scores QK⊤. This gives a tensor of shape [seq_len, seq_len, batch_size, heads] .

213scores=self.get\_scores(query,key)

#

Scale scores dk​​QK⊤​

216scores\*=self.scale

#

Apply mask

219ifmaskisnotNone:220scores=scores.masked\_fill(mask==0,float('-inf'))

#

softmax attention along the key sequence dimension seqsoftmax​(dk​​QK⊤​)

224attn=self.softmax(scores)

#

Apply dropout

227attn=self.dropout(attn)

#

Rotate value embeddings before taking the weighted sum so that they contain positional information

230value=self.value\_rotary\_pe(value)

#

Multiply by values seqsoftmax​(dk​​QK⊤​)V

234x=torch.einsum("ijbh,jbhd-\>ibhd",attn,value)

#

Rotate in the opposite direction so that each embedding hold the relative positions

237x=self.value\_reverse\_rotary\_pe(x)

#

Save attentions for any other calculations

240self.attn=attn.detach()

#

Concatenate multiple heads

243x=x.reshape(seq\_len,batch\_size,-1)

#

Output layer

246returnself.output(x)

labml.ai