Back to Annotated Deep Learning Paper Implementations

Attention with Linear Biases (ALiBi)

docs/transformers/alibi/index.html

latest6.9 KB
Original Source

hometransformersalibi

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/alibi/ init.py)

#

Attention with Linear Biases (ALiBi)

This is an implementation of Attention with Linear Biases (ALiBi) from the paper Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation.

This replaces positional encodings with biases added to attention scores (attention logits, before the softmax). This is a relative scheme tested on autoregressive tasks, and the bias is higher for closeby tokens and lower for far-away tokens. The biases decrease linearly in the log scale (because it's before the softmax) and each head has a different slope.

Here's the attention formula for i-th token,

ai​​=softmax(qi​K⊤+m⋅[−(i−1),…,−1,0])=softmax(qi​K⊤+m⋅[0,1,…,(i−1)])​

where qi​∈Rd is the query of the i-th token, K∈Ri×d are the keys up to i, and d the number of features per head. Note that the above equality halts because softmax is invariant to translations (you can add any constant to all elements without changing the result).

Here is the training code for a ALiBi model.

33importmath34fromtypingimportOptional3536importtorch37fromtorchimportnn3839fromlabml.loggerimportinspect40fromlabml\_nn.transformers.mhaimportMultiHeadAttention

#

Get head-specific slope m for each head

  • n_heads is the number of heads in the attention layer n

The slope for first head is

2n8​1​=2−n8​

The slopes for the rest of the heads are in a geometric series with a ratio same as above.

For instance when the number of heads is 8 the slopes are 211​,221​,…,281​

43defget\_slopes(n\_heads:int):

#

Get the closest power of 2 to n_heads . If n_heads is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2, and then add the remaining slopes.

62n=2\*\*math.floor(math.log2(n\_heads))

#

2−n8​

64m\_0=2.0\*\*(-8.0/n)

#

2−1n8​,2−2n8​,2−3n8​,…

66m=torch.pow(m\_0,torch.arange(1,1+n))

#

If n_heads is not a power of 2, then we add the remaining slopes. We calculate the remaining slopes for n∗2 (avoiding slopes added previously). And pick the slopes upto n_heads .

71ifn\<n\_heads:

#

2−2n8​

73m\_hat\_0=2.0\*\*(-4.0/n)

#

2−12n8​,2−32n8​,2−52n8​,… Note that we take steps by 2 to avoid slopes added previously.

76m\_hat=torch.pow(m\_hat\_0,torch.arange(1,1+2\*(n\_heads-n),2))

#

Concatenate the slopes with the remaining slopes.

78m=torch.cat([m,m\_hat])7980returnm

#

Calculate the attention biases matrix

  • n_heads is the number of heads in the attention layer
  • mask is the attention mask of shape [seq_len_q, seq_len_k]

This returns a matrix of shape [seq_len_q, seq_len_k, n_heads,] with ALiBi attention biases.

[email protected]\_grad()84defget\_alibi\_biases(n\_heads:int,mask:torch.Tensor):

#

Get slopes m for each head

95m=get\_slopes(n\_heads).to(mask.device)

#

Calculate distances [0,1,…,N] Here we calculate the distances using the mask.

Since it's causal mask we can just use [0,1,…,N] too. distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]

102distance=mask.cumsum(dim=-1)

#

Multiply them pair-wise to get the AliBi bias matrix

105returndistance[:,:,None]\*m[None,None,:]

#

Attention with Linear Biases (ALiBi)

We override Multi-Head Attention.

108classAlibiMultiHeadAttention(MultiHeadAttention):

#

115def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float=0.1):116super().\_\_init\_\_(heads,d\_model,dropout\_prob)

#

To cache AliBi the biases

119self.alibi\_biases=None

#

query , key and value are the tensors that store collection of query, key and value vectors. They have shape [seq_len, batch_size, d_model] .

mask has shape [seq_len, seq_len, batch_size] and mask[i, j, b] indicates whether for batch b , query at position i has access to key-value at position j .

121defforward(self,\*,122query:torch.Tensor,123key:torch.Tensor,124value:torch.Tensor,125mask:Optional[torch.Tensor]=None):

#

ALiBi only works with causal masks.

137assertmaskisnotNone138assertmask.shape[0]==mask.shape[1]andmask.shape[2]==1

#

query , key and value have shape [seq_len, batch_size, d_model]

141seq\_len,batch\_size,\_=query.shape

#

Add head dimension to mask and check its shape.

144mask=self.prepare\_mask(mask,query.shape,key.shape)

#

Prepare query , key and value for attention computation. These will then have shape [seq_len, batch_size, heads, d_k] .

148query=self.query(query)149key=self.key(key)150value=self.value(value)

#

Compute attention scores QK⊤. This gives a tensor of shape [seq_len, seq_len, batch_size, heads] .

154scores=self.get\_scores(query,key)

#

Scale scores dk​​QK⊤​

157scores\*=self.scale

#

Create AliBi biases if it's not cached

160ifself.alibi\_biasesisNoneorself.alibi\_biases.shape[1]\<seq\_len:

#

mask has shape [seq_len, seq_len, 1, 1]

162self.alibi\_biases=get\_alibi\_biases(scores.shape[-1],mask[:,:,0,0])

#

Add AliBi biases to attention scores. ALiBi biases has shape [seq_len, seq_len, n_heads] and scores has shape [seq_len, seq_len, batch_size, n_heads]

167scores+=self.alibi\_biases[:seq\_len,:seq\_len,None,:]

#

Apply mask

170scores=scores.masked\_fill(mask==0,float('-inf'))

#

softmax attention along the key sequence dimension seqsoftmax​(dk​​QK⊤​)

174attn=self.softmax(scores)

#

Apply dropout

177attn=self.dropout(attn)

#

Multiply by values seqsoftmax​(dk​​QK⊤​)V

181x=torch.einsum("ijbh,jbhd-\>ibhd",attn,value)

#

Concatenate multiple heads

184x=x.reshape(seq\_len,batch\_size,-1)

#

Output layer

187returnself.output(x)

#

Simple test function to see the slopes.

190def\_test\_alibi():

#

194inspect(get\_slopes(12).tolist(),\_n=-1)195fromlabml\_nn.transformers.utilsimportsubsequent\_mask196197mask=subsequent\_mask(8)[:,:,0]198inspect(mask)199200inspect(get\_alibi\_biases(12,mask)[:,:,3],\_n=-1)

#

204if\_\_name\_\_=='\_\_main\_\_':205\_test\_alibi()

labml.ai