docs/transformers/mha.html
This is a tutorial/implementation of multi-headed attention from paper Attention Is All You Need in PyTorch. The implementation is inspired from Annotated Transformer.
Here is the training code that uses a basic transformer with MHA for NLP auto-regression.
Here is an experiment implementation that trains a simple transformer.
24importmath25fromtypingimportOptional,List2627importtorch28fromtorchimportnn2930fromlabmlimporttracker
This module does a linear transformation and splits the vector into given number of heads for multi-head attention. This is used to transform key , query , and value vectors.
33classPrepareForMultiHeadAttention(nn.Module):
44def\_\_init\_\_(self,d\_model:int,heads:int,d\_k:int,bias:bool):45super().\_\_init\_\_()
Linear layer for linear transform
47self.linear=nn.Linear(d\_model,heads\*d\_k,bias=bias)
Number of heads
49self.heads=heads
Number of dimensions in vectors in each head
51self.d\_k=d\_k
53defforward(self,x:torch.Tensor):
Input has shape [seq_len, batch_size, d_model] or [batch_size, d_model] . We apply the linear transformation to the last dimension and split that into the heads.
57head\_shape=x.shape[:-1]
Linear transform
60x=self.linear(x)
Split last dimension into heads
63x=x.view(\*head\_shape,self.heads,self.d\_k)
Output has shape [seq_len, batch_size, heads, d_k] or [batch_size, heads, d_model]
66returnx
This computes scaled multi-headed attention for given query , key and value vectors.
Attention(Q,K,V)=seqsoftmax(dkQK⊤)V
In simple terms, it finds keys that matches the query, and gets the values of those keys.
It uses dot-product of query and key as the indicator of how matching they are. Before taking the softmax the dot-products are scaled by dk1. This is done to avoid large dot-product values causing softmax to give very small gradients when dk is large.
Softmax is calculated along the axis of of the sequence (or time).
69classMultiHeadAttention(nn.Module):
heads is the number of heads.d_model is the number of features in the query , key and value vectors.90def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float=0.1,bias:bool=True):
96super().\_\_init\_\_()
Number of features per head
99self.d\_k=d\_model//heads
Number of heads
101self.heads=heads
These transform the query , key and value vectors for multi-headed attention.
104self.query=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=bias)105self.key=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=bias)106self.value=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=True)
Softmax for attention along the time dimension of key
109self.softmax=nn.Softmax(dim=1)
Output layer
112self.output=nn.Linear(d\_model,d\_model)
Dropout
114self.dropout=nn.Dropout(dropout\_prob)
Scaling factor before the softmax
116self.scale=1/math.sqrt(self.d\_k)
We store attentions so that it can be used for logging, or other computations if needed
119self.attn=None
This method can be overridden for other variations like relative attention.
121defget\_scores(self,query:torch.Tensor,key:torch.Tensor):
Calculate QK⊤ or Sijbh=∑dQibhdKjbhd
129returntorch.einsum('ibhd,jbhd-\>ijbh',query,key)
mask has shape [seq_len_q, seq_len_k, batch_size] , where first dimension is the query dimension. If the query dimension is equal to 1 it will be broadcasted.
131defprepare\_mask(self,mask:torch.Tensor,query\_shape:List[int],key\_shape:List[int]):
137assertmask.shape[0]==1ormask.shape[0]==query\_shape[0]138assertmask.shape[1]==key\_shape[0]139assertmask.shape[2]==1ormask.shape[2]==query\_shape[1]
Same mask applied to all heads.
142mask=mask.unsqueeze(-1)
resulting mask has shape [seq_len_q, seq_len_k, batch_size, heads]
145returnmask
query , key and value are the tensors that store collection of query, key and value vectors. They have shape [seq_len, batch_size, d_model] .
mask has shape [seq_len, seq_len, batch_size] and mask[i, j, b] indicates whether for batch b , query at position i has access to key-value at position j .
147defforward(self,\*,148query:torch.Tensor,149key:torch.Tensor,150value:torch.Tensor,151mask:Optional[torch.Tensor]=None):
query , key and value have shape [seq_len, batch_size, d_model]
163seq\_len,batch\_size,\_=query.shape164165ifmaskisnotNone:166mask=self.prepare\_mask(mask,query.shape,key.shape)
Prepare query , key and value for attention computation. These will then have shape [seq_len, batch_size, heads, d_k] .
170query=self.query(query)171key=self.key(key)172value=self.value(value)
Compute attention scores QK⊤. This gives a tensor of shape [seq_len, seq_len, batch_size, heads] .
176scores=self.get\_scores(query,key)
Scale scores dkQK⊤
179scores\*=self.scale
Apply mask
182ifmaskisnotNone:183scores=scores.masked\_fill(mask==0,float('-inf'))
softmax attention along the key sequence dimension seqsoftmax(dkQK⊤)
187attn=self.softmax(scores)
Save attentions if debugging
190tracker.debug('attn',attn)
Apply dropout
193attn=self.dropout(attn)
Multiply by values seqsoftmax(dkQK⊤)V
197x=torch.einsum("ijbh,jbhd-\>ibhd",attn,value)
Save attentions for any other calculations
200self.attn=attn.detach()
Concatenate multiple heads
203x=x.reshape(seq\_len,batch\_size,-1)
Output layer
206returnself.output(x)