docs/transformers/aft/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/aft/ init.py)
This is a PyTorch implementation of the paper An Attention Free Transformer.
This paper replaces the self-attention layer with a new efficient operation, that has memory complexity of O(Td), where T is the sequence length and d is the dimensionality of embeddings.
The paper introduces AFT along with AFT-local and AFT-conv. Here we have implemented AFT-local which pays attention to closeby tokens in an autoregressive model.
AFT (similar to MHA) first transforms the embeddings X into query Q=XWQ, key K=XWK and value V=XWV tensors with learned weights. The output for each position t∈[1,T] is calculated with the following operation.
Yt=σ(Qt)⊙∑t′=1Texp(Kt′+wt,t′)∑t′=1Texp(Kt′+wt,t′)⊙Vt′
, where ⊙ is element-wise product, σ is a non-linearity (sigmoid) and w∈RT×T is a learned matrix of pair-wise position biases.
This means that we take the weighted average of values and multiply them by the query. This eliminates the need to calculate the T×T attention matrix that MHA requires, and therefore reduce the memory requirement.
AFT Local only apply learned pair-wise position biases locally:
wt,t′′={wt,t′,0,for ∣t−t′∣<sotherwise
, where s≤T is the local window size.
Although wt,t′′ is 0 outside the local window the AFT operation still uses key-value pairs from other areas. This is different from local transformers where embeddings outside the local window are completely not visible.
Here is the training code for a AFT Local model.
59fromtypingimportOptional6061importtorch62fromtorchimportnn
Yt=σ(Qt)⊙∑t′=1Texp(Kt′+wt,t′)∑t′=1Texp(Kt′+wt,t′)⊙Vt′
where,
wt,t′′={wt,t′,0,for ∣t−t′∣<sotherwise
66classAFTLocal(nn.Module):
d_model is the number of features in the query , key and value vectors.seq_len is Tlocal_window_size is the local window size sbias is whether to have a bias parameter for transformations for Q, K and V.85def\_\_init\_\_(self,d\_model:int,seq\_len:int,local\_window\_size:int,bias:bool=True):
93super().\_\_init\_\_()
Local window size s
96self.local\_window\_size=local\_window\_size
These transform the query , key and value vectors.
98self.query=nn.Linear(d\_model,d\_model,bias=bias)99self.key=nn.Linear(d\_model,d\_model,bias=bias)100self.value=nn.Linear(d\_model,d\_model,bias=bias)
Pair-wise positional biases w∈RT×T
102self.pos\_bias=nn.Parameter(torch.zeros(seq\_len,seq\_len),requires\_grad=True)
Mask for wt,t′
104self.local\_mask=nn.Parameter(self.create\_local\_mask(seq\_len,local\_window\_size),requires\_grad=False)
Activation σ
106self.activation=nn.Sigmoid()
Output layer
108self.output=nn.Linear(d\_model,d\_model)
This creates a mask for
mt,t′={1,0,for ∣t−t′∣<sotherwise
110@staticmethod111defcreate\_local\_mask(seq\_len,local\_window\_size):
Initialize to ones
127local\_mask=torch.ones(seq\_len,seq\_len,dtype=torch.bool)
Make t′−t≥s zero
129local\_mask=torch.tril(local\_mask,local\_window\_size-1)
Make t−t′≥s zero
131local\_mask=torch.triu(local\_mask,-(local\_window\_size-1))
134returnlocal\_mask
query , key and value are the tensors that store collection of token embeddings for query, key and value. They have shape [seq_len, batch_size, d_model] .
mask has shape [seq_len, seq_len, batch_size] and mask[i, j, b] indicates whether for batch b , query at position i has access to key-value at position j .
136defforward(self,\*,137query:torch.Tensor,138key:torch.Tensor,139value:torch.Tensor,140mask:Optional[torch.Tensor]=None):
query , key and value have shape [seq_len, batch_size, d_model]
152seq\_len,\_,\_=query.shape153154ifmaskisnotNone:
mask has shape [seq_len_q, seq_len_k, batch_size] , where first dimension is the query dimension. If the query dimension is equal to 1 it will be broadcasted.
158assertmask.shape[0]==1ormask.shape[0]==query.shape[0]159assertmask.shape[1]==key.shape[0]160assertmask.shape[2]==1ormask.shape[2]==query.shape[1]
Transform query, key and value embeddings
163query=self.query(query)164key=self.key(key)165value=self.value(value)
Get
wt,t′′={wt,t′,0,for ∣t−t′∣<sotherwise
using the mask
178pos\_bias=self.pos\_bias[:seq\_len,:seq\_len]\*self.local\_mask[:seq\_len,:seq\_len]179pos\_bias=pos\_bias.unsqueeze(-1)180pos\_bias.masked\_fill\_(~mask,float('-inf'))
# Yt=σ(Qt)⊙∑t′=1Texp(Kt′+wt,t′)∑t′=1Texp(Kt′+wt,t′)⊙Vt′=σ(Qt)⊙∑t′=1Texp(wt,t′)⊙exp(Kt′)∑t′=1Texp(wt,t′)⊙exp(Kt′)⊙Vt′
We compute exp(wt,t′), exp(Kt′)⊙Vt′ and exp(Kt′) separately and do a matrix multiplication. We use einsum for clarity.
We subtract maxt′(Kt′) and maxt′(wt,t′) before calculating the exponents to stabilize the softmax calculation.
If xi is large exp(xi) becomes huge and the computation of ∑exp(xi)∑exp(xi)yibecomes unstable. Subtracting a constant before calculating the exponent from numerator and denominator will cancel out. and can help stabilize the computation. So we subtract max(xi) to stabilize the computation.
202max\_key=key.max(dim=0,keepdims=True)[0]203max\_pos\_bias=pos\_bias.max(dim=1,keepdims=True)[0]
exp(Kt′−maxt′(Kt′))
206exp\_key=torch.exp(key-max\_key)
exp(wt,t′−maxt′(wt,t′))
208exp\_pos\_bias=torch.exp(pos\_bias-max\_pos\_bias)
The numerator part ∑t′=1Texp(wt,t′)⊙exp(Kt′)⊙Vt′
211num=torch.einsum('ijb,jbd-\>ibd',exp\_pos\_bias,exp\_key\*value)
The denominator part ∑t′=1Texp(wt,t′)⊙exp(Kt′)
213den=torch.einsum('ijb,jbd-\>ibd',exp\_pos\_bias,exp\_key)
Output Yt=σ(Qt)⊙∑t′=1Texp(wt,t′)⊙exp(Kt′)∑t′=1Texp(wt,t′)⊙exp(Kt′)⊙Vt′
218y=self.activation(query)\*num/den
Output layer
221returnself.output(y)
Test local mask
224def\_test\_local\_mask():
228fromlabml.loggerimportinspect229inspect(AFTLocal.create\_local\_mask(10,4))
233if\_\_name\_\_=='\_\_main\_\_':234\_test\_local\_mask()