Back to Annotated Deep Learning Paper Implementations

Transformer Encoder and Decoder Models

docs/transformers/models.html

latest7.2 KB
Original Source

hometransformers

View code on Github

#

Transformer Encoder and Decoder Models

13importmath1415importtorch16importtorch.nnasnn1718fromlabml\_nn.utilsimportclone\_module\_list19from.feed\_forwardimportFeedForward20from.mhaimportMultiHeadAttention21from.positional\_encodingimportget\_positional\_encoding

#

Embed tokens and add fixed positional encoding

24classEmbeddingsWithPositionalEncoding(nn.Module):

#

31def\_\_init\_\_(self,d\_model:int,n\_vocab:int,max\_len:int=5000):32super().\_\_init\_\_()33self.linear=nn.Embedding(n\_vocab,d\_model)34self.d\_model=d\_model35self.register\_buffer('positional\_encodings',get\_positional\_encoding(d\_model,max\_len))

#

37defforward(self,x:torch.Tensor):38pe=self.positional\_encodings[:x.shape[0]].requires\_grad\_(False)39returnself.linear(x)\*math.sqrt(self.d\_model)+pe

#

Embed tokens and add parameterized positional encodings

42classEmbeddingsWithLearnedPositionalEncoding(nn.Module):

#

49def\_\_init\_\_(self,d\_model:int,n\_vocab:int,max\_len:int=5000):50super().\_\_init\_\_()51self.linear=nn.Embedding(n\_vocab,d\_model)52self.d\_model=d\_model53self.positional\_encodings=nn.Parameter(torch.zeros(max\_len,1,d\_model),requires\_grad=True)

#

55defforward(self,x:torch.Tensor):56pe=self.positional\_encodings[:x.shape[0]]57returnself.linear(x)\*math.sqrt(self.d\_model)+pe

#

Transformer Layer

This can act as an encoder layer or a decoder layer. We use pre-norm.

60classTransformerLayer(nn.Module):

#

  • d_model is the token embedding size
  • self_attn is the self attention module
  • src_attn is the source attention module (when this is used in a decoder)
  • feed_forward is the feed forward module
  • dropout_prob is the probability of dropping out after self attention and FFN
69def\_\_init\_\_(self,\*,70d\_model:int,71self\_attn:MultiHeadAttention,72src\_attn:MultiHeadAttention=None,73feed\_forward:FeedForward,74dropout\_prob:float):

#

82super().\_\_init\_\_()83self.size=d\_model84self.self\_attn=self\_attn85self.src\_attn=src\_attn86self.feed\_forward=feed\_forward87self.dropout=nn.Dropout(dropout\_prob)88self.norm\_self\_attn=nn.LayerNorm([d\_model])89ifself.src\_attnisnotNone:90self.norm\_src\_attn=nn.LayerNorm([d\_model])91self.norm\_ff=nn.LayerNorm([d\_model])

#

Whether to save input to the feed forward layer

93self.is\_save\_ff\_input=False

#

95defforward(self,\*,96x:torch.Tensor,97mask:torch.Tensor,98src:torch.Tensor=None,99src\_mask:torch.Tensor=None):

#

Normalize the vectors before doing self attention

101z=self.norm\_self\_attn(x)

#

Run through self attention, i.e. keys and values are from self

103self\_attn=self.self\_attn(query=z,key=z,value=z,mask=mask)

#

Add the self attention results

105x=x+self.dropout(self\_attn)

#

If a source is provided, get results from attention to source. This is when you have a decoder layer that pays attention to encoder outputs

110ifsrcisnotNone:

#

Normalize vectors

112z=self.norm\_src\_attn(x)

#

Attention to source. i.e. keys and values are from source

114attn\_src=self.src\_attn(query=z,key=src,value=src,mask=src\_mask)

#

Add the source attention results

116x=x+self.dropout(attn\_src)

#

Normalize for feed-forward

119z=self.norm\_ff(x)

#

Save the input to the feed forward layer if specified

121ifself.is\_save\_ff\_input:122self.ff\_input=z.clone()

#

Pass through the feed-forward network

124ff=self.feed\_forward(z)

#

Add the feed-forward results back

126x=x+self.dropout(ff)127128returnx

#

Transformer Encoder

131classEncoder(nn.Module):

#

138def\_\_init\_\_(self,layer:TransformerLayer,n\_layers:int):139super().\_\_init\_\_()

#

Make copies of the transformer layer

141self.layers=clone\_module\_list(layer,n\_layers)

#

Final normalization layer

143self.norm=nn.LayerNorm([layer.size])

#

145defforward(self,x:torch.Tensor,mask:torch.Tensor):

#

Run through each transformer layer

147forlayerinself.layers:148x=layer(x=x,mask=mask)

#

Finally, normalize the vectors

150returnself.norm(x)

#

Transformer Decoder

153classDecoder(nn.Module):

#

160def\_\_init\_\_(self,layer:TransformerLayer,n\_layers:int):161super().\_\_init\_\_()

#

Make copies of the transformer layer

163self.layers=clone\_module\_list(layer,n\_layers)

#

Final normalization layer

165self.norm=nn.LayerNorm([layer.size])

#

167defforward(self,x:torch.Tensor,memory:torch.Tensor,src\_mask:torch.Tensor,tgt\_mask:torch.Tensor):

#

Run through each transformer layer

169forlayerinself.layers:170x=layer(x=x,mask=tgt\_mask,src=memory,src\_mask=src\_mask)

#

Finally, normalize the vectors

172returnself.norm(x)

#

Generator

This predicts the tokens and gives the lof softmax of those. You don't need this if you are using nn.CrossEntropyLoss .

175classGenerator(nn.Module):

#

185def\_\_init\_\_(self,n\_vocab:int,d\_model:int):186super().\_\_init\_\_()187self.projection=nn.Linear(d\_model,n\_vocab)

#

189defforward(self,x):190returnself.projection(x)

#

Combined Encoder-Decoder

193classEncoderDecoder(nn.Module):

#

200def\_\_init\_\_(self,encoder:Encoder,decoder:Decoder,src\_embed:nn.Module,tgt\_embed:nn.Module,generator:nn.Module):201super().\_\_init\_\_()202self.encoder=encoder203self.decoder=decoder204self.src\_embed=src\_embed205self.tgt\_embed=tgt\_embed206self.generator=generator

#

This was important from their code. Initialize parameters with Glorot / fan_avg.

210forpinself.parameters():211ifp.dim()\>1:212nn.init.xavier\_uniform\_(p)

#

214defforward(self,src:torch.Tensor,tgt:torch.Tensor,src\_mask:torch.Tensor,tgt\_mask:torch.Tensor):

#

Run the source through encoder

216enc=self.encode(src,src\_mask)

#

Run encodings and targets through decoder

218returnself.decode(enc,src\_mask,tgt,tgt\_mask)

#

220defencode(self,src:torch.Tensor,src\_mask:torch.Tensor):221returnself.encoder(self.src\_embed(src),src\_mask)

#

223defdecode(self,memory:torch.Tensor,src\_mask:torch.Tensor,tgt:torch.Tensor,tgt\_mask:torch.Tensor):224returnself.decoder(self.tgt\_embed(tgt),memory,src\_mask,tgt\_mask)

labml.ai