Back to Annotated Deep Learning Paper Implementations

Multi-Headed Attention (MHA)

docs/transformers/mha.html

latest6.9 KB
Original Source

hometransformers

View code on Github

#

Multi-Headed Attention (MHA)

This is a tutorial/implementation of multi-headed attention from paper Attention Is All You Need in PyTorch. The implementation is inspired from Annotated Transformer.

Here is the training code that uses a basic transformer with MHA for NLP auto-regression.

Here is an experiment implementation that trains a simple transformer.

24importmath25fromtypingimportOptional,List2627importtorch28fromtorchimportnn2930fromlabmlimporttracker

#

Prepare for multi-head attention

This module does a linear transformation and splits the vector into given number of heads for multi-head attention. This is used to transform key , query , and value vectors.

33classPrepareForMultiHeadAttention(nn.Module):

#

44def\_\_init\_\_(self,d\_model:int,heads:int,d\_k:int,bias:bool):45super().\_\_init\_\_()

#

Linear layer for linear transform

47self.linear=nn.Linear(d\_model,heads\*d\_k,bias=bias)

#

Number of heads

49self.heads=heads

#

Number of dimensions in vectors in each head

51self.d\_k=d\_k

#

53defforward(self,x:torch.Tensor):

#

Input has shape [seq_len, batch_size, d_model] or [batch_size, d_model] . We apply the linear transformation to the last dimension and split that into the heads.

57head\_shape=x.shape[:-1]

#

Linear transform

60x=self.linear(x)

#

Split last dimension into heads

63x=x.view(\*head\_shape,self.heads,self.d\_k)

#

Output has shape [seq_len, batch_size, heads, d_k] or [batch_size, heads, d_model]

66returnx

#

Multi-Head Attention Module

This computes scaled multi-headed attention for given query , key and value vectors.

Attention(Q,K,V)=seqsoftmax​(dk​​QK⊤​)V

In simple terms, it finds keys that matches the query, and gets the values of those keys.

It uses dot-product of query and key as the indicator of how matching they are. Before taking the softmax the dot-products are scaled by dk​​1​. This is done to avoid large dot-product values causing softmax to give very small gradients when dk​ is large.

Softmax is calculated along the axis of of the sequence (or time).

69classMultiHeadAttention(nn.Module):

#

  • heads is the number of heads.
  • d_model is the number of features in the query , key and value vectors.
90def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float=0.1,bias:bool=True):

#

96super().\_\_init\_\_()

#

Number of features per head

99self.d\_k=d\_model//heads

#

Number of heads

101self.heads=heads

#

These transform the query , key and value vectors for multi-headed attention.

104self.query=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=bias)105self.key=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=bias)106self.value=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=True)

#

Softmax for attention along the time dimension of key

109self.softmax=nn.Softmax(dim=1)

#

Output layer

112self.output=nn.Linear(d\_model,d\_model)

#

Dropout

114self.dropout=nn.Dropout(dropout\_prob)

#

Scaling factor before the softmax

116self.scale=1/math.sqrt(self.d\_k)

#

We store attentions so that it can be used for logging, or other computations if needed

119self.attn=None

#

Calculate scores between queries and keys

This method can be overridden for other variations like relative attention.

121defget\_scores(self,query:torch.Tensor,key:torch.Tensor):

#

Calculate QK⊤ or Sijbh​=∑d​Qibhd​Kjbhd​

129returntorch.einsum('ibhd,jbhd-\>ijbh',query,key)

#

mask has shape [seq_len_q, seq_len_k, batch_size] , where first dimension is the query dimension. If the query dimension is equal to 1 it will be broadcasted.

131defprepare\_mask(self,mask:torch.Tensor,query\_shape:List[int],key\_shape:List[int]):

#

137assertmask.shape[0]==1ormask.shape[0]==query\_shape[0]138assertmask.shape[1]==key\_shape[0]139assertmask.shape[2]==1ormask.shape[2]==query\_shape[1]

#

Same mask applied to all heads.

142mask=mask.unsqueeze(-1)

#

resulting mask has shape [seq_len_q, seq_len_k, batch_size, heads]

145returnmask

#

query , key and value are the tensors that store collection of query, key and value vectors. They have shape [seq_len, batch_size, d_model] .

mask has shape [seq_len, seq_len, batch_size] and mask[i, j, b] indicates whether for batch b , query at position i has access to key-value at position j .

147defforward(self,\*,148query:torch.Tensor,149key:torch.Tensor,150value:torch.Tensor,151mask:Optional[torch.Tensor]=None):

#

query , key and value have shape [seq_len, batch_size, d_model]

163seq\_len,batch\_size,\_=query.shape164165ifmaskisnotNone:166mask=self.prepare\_mask(mask,query.shape,key.shape)

#

Prepare query , key and value for attention computation. These will then have shape [seq_len, batch_size, heads, d_k] .

170query=self.query(query)171key=self.key(key)172value=self.value(value)

#

Compute attention scores QK⊤. This gives a tensor of shape [seq_len, seq_len, batch_size, heads] .

176scores=self.get\_scores(query,key)

#

Scale scores dk​​QK⊤​

179scores\*=self.scale

#

Apply mask

182ifmaskisnotNone:183scores=scores.masked\_fill(mask==0,float('-inf'))

#

softmax attention along the key sequence dimension seqsoftmax​(dk​​QK⊤​)

187attn=self.softmax(scores)

#

Save attentions if debugging

190tracker.debug('attn',attn)

#

Apply dropout

193attn=self.dropout(attn)

#

Multiply by values seqsoftmax​(dk​​QK⊤​)V

197x=torch.einsum("ijbh,jbhd-\>ibhd",attn,value)

#

Save attentions for any other calculations

200self.attn=attn.detach()

#

Concatenate multiple heads

203x=x.reshape(seq\_len,batch\_size,-1)

#

Output layer

206returnself.output(x)

labml.ai