Back to Annotated Deep Learning Paper Implementations

An Attention Free Transformer

docs/transformers/aft/index.html

latest7.4 KB
Original Source

hometransformersaft

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/aft/ init.py)

#

An Attention Free Transformer

This is a PyTorch implementation of the paper An Attention Free Transformer.

This paper replaces the self-attention layer with a new efficient operation, that has memory complexity of O(Td), where T is the sequence length and d is the dimensionality of embeddings.

The paper introduces AFT along with AFT-local and AFT-conv. Here we have implemented AFT-local which pays attention to closeby tokens in an autoregressive model.

Attention Free Transformer

AFT (similar to MHA) first transforms the embeddings X into query Q=XWQ, key K=XWK and value V=XWV tensors with learned weights. The output for each position t∈[1,T] is calculated with the following operation.

Yt​=σ(Qt​)⊙∑t′=1T​exp(Kt′​+wt,t′​)∑t′=1T​exp(Kt′​+wt,t′​)⊙Vt′​​

, where ⊙ is element-wise product, σ is a non-linearity (sigmoid) and w∈RT×T is a learned matrix of pair-wise position biases.

This means that we take the weighted average of values and multiply them by the query. This eliminates the need to calculate the T×T attention matrix that MHA requires, and therefore reduce the memory requirement.

AFT Local

AFT Local only apply learned pair-wise position biases locally:

wt,t′′​={wt,t′​,0,​for ∣t−t′∣<sotherwise​​

, where s≤T is the local window size.

Although wt,t′′​ is 0 outside the local window the AFT operation still uses key-value pairs from other areas. This is different from local transformers where embeddings outside the local window are completely not visible.

Here is the training code for a AFT Local model.

59fromtypingimportOptional6061importtorch62fromtorchimportnn

#

AFT Local Operation

Yt​=σ(Qt​)⊙∑t′=1T​exp(Kt′​+wt,t′​)∑t′=1T​exp(Kt′​+wt,t′​)⊙Vt′​​

where,

wt,t′′​={wt,t′​,0,​for ∣t−t′∣<sotherwise​​

66classAFTLocal(nn.Module):

#

  • d_model is the number of features in the query , key and value vectors.
  • seq_len is T
  • local_window_size is the local window size s
  • bias is whether to have a bias parameter for transformations for Q, K and V.
85def\_\_init\_\_(self,d\_model:int,seq\_len:int,local\_window\_size:int,bias:bool=True):

#

93super().\_\_init\_\_()

#

Local window size s

96self.local\_window\_size=local\_window\_size

#

These transform the query , key and value vectors.

98self.query=nn.Linear(d\_model,d\_model,bias=bias)99self.key=nn.Linear(d\_model,d\_model,bias=bias)100self.value=nn.Linear(d\_model,d\_model,bias=bias)

#

Pair-wise positional biases w∈RT×T

102self.pos\_bias=nn.Parameter(torch.zeros(seq\_len,seq\_len),requires\_grad=True)

#

Mask for wt,t′​

104self.local\_mask=nn.Parameter(self.create\_local\_mask(seq\_len,local\_window\_size),requires\_grad=False)

#

Activation σ

106self.activation=nn.Sigmoid()

#

Output layer

108self.output=nn.Linear(d\_model,d\_model)

#

Create local mask

This creates a mask for

mt,t′​={1,0,​for ∣t−t′∣<sotherwise​​

110@staticmethod111defcreate\_local\_mask(seq\_len,local\_window\_size):

#

Initialize to ones

127local\_mask=torch.ones(seq\_len,seq\_len,dtype=torch.bool)

#

Make t′−t≥s zero

129local\_mask=torch.tril(local\_mask,local\_window\_size-1)

#

Make t−t′≥s zero

131local\_mask=torch.triu(local\_mask,-(local\_window\_size-1))

#

134returnlocal\_mask

#

query , key and value are the tensors that store collection of token embeddings for query, key and value. They have shape [seq_len, batch_size, d_model] .

mask has shape [seq_len, seq_len, batch_size] and mask[i, j, b] indicates whether for batch b , query at position i has access to key-value at position j .

136defforward(self,\*,137query:torch.Tensor,138key:torch.Tensor,139value:torch.Tensor,140mask:Optional[torch.Tensor]=None):

#

query , key and value have shape [seq_len, batch_size, d_model]

152seq\_len,\_,\_=query.shape153154ifmaskisnotNone:

#

mask has shape [seq_len_q, seq_len_k, batch_size] , where first dimension is the query dimension. If the query dimension is equal to 1 it will be broadcasted.

158assertmask.shape[0]==1ormask.shape[0]==query.shape[0]159assertmask.shape[1]==key.shape[0]160assertmask.shape[2]==1ormask.shape[2]==query.shape[1]

#

Transform query, key and value embeddings

163query=self.query(query)164key=self.key(key)165value=self.value(value)

#

Get

wt,t′′​={wt,t′​,0,​for ∣t−t′∣<sotherwise​​

using the mask

178pos\_bias=self.pos\_bias[:seq\_len,:seq\_len]\*self.local\_mask[:seq\_len,:seq\_len]179pos\_bias=pos\_bias.unsqueeze(-1)180pos\_bias.masked\_fill\_(~mask,float('-inf'))

# Yt​​=σ(Qt​)⊙∑t′=1T​exp(Kt′​+wt,t′​)∑t′=1T​exp(Kt′​+wt,t′​)⊙Vt′​​=σ(Qt​)⊙∑t′=1T​exp(wt,t′​)⊙exp(Kt′​)∑t′=1T​exp(wt,t′​)⊙exp(Kt′​)⊙Vt′​​​

We compute exp(wt,t′​), exp(Kt′​)⊙Vt′​ and exp(Kt′​) separately and do a matrix multiplication. We use einsum for clarity.

#

We subtract maxt′​(Kt′​) and maxt′​(wt,t′​) before calculating the exponents to stabilize the softmax calculation.

If xi​ is large exp(xi​) becomes huge and the computation of ∑exp(xi​)∑exp(xi​)yi​​becomes unstable. Subtracting a constant before calculating the exponent from numerator and denominator will cancel out. and can help stabilize the computation. So we subtract max(xi​) to stabilize the computation.

202max\_key=key.max(dim=0,keepdims=True)[0]203max\_pos\_bias=pos\_bias.max(dim=1,keepdims=True)[0]

#

exp(Kt′​−maxt′​(Kt′​))

206exp\_key=torch.exp(key-max\_key)

#

exp(wt,t′​−maxt′​(wt,t′​))

208exp\_pos\_bias=torch.exp(pos\_bias-max\_pos\_bias)

#

The numerator part ∑t′=1T​exp(wt,t′​)⊙exp(Kt′​)⊙Vt′​

211num=torch.einsum('ijb,jbd-\>ibd',exp\_pos\_bias,exp\_key\*value)

#

The denominator part ∑t′=1T​exp(wt,t′​)⊙exp(Kt′​)

213den=torch.einsum('ijb,jbd-\>ibd',exp\_pos\_bias,exp\_key)

#

Output Yt​=σ(Qt​)⊙∑t′=1T​exp(wt,t′​)⊙exp(Kt′​)∑t′=1T​exp(wt,t′​)⊙exp(Kt′​)⊙Vt′​​

218y=self.activation(query)\*num/den

#

Output layer

221returnself.output(y)

#

Test local mask

224def\_test\_local\_mask():

#

228fromlabml.loggerimportinspect229inspect(AFTLocal.create\_local\_mask(10,4))

#

233if\_\_name\_\_=='\_\_main\_\_':234\_test\_local\_mask()

labml.ai