docs/transformers/models.html
13importmath1415importtorch16importtorch.nnasnn1718fromlabml\_nn.utilsimportclone\_module\_list19from.feed\_forwardimportFeedForward20from.mhaimportMultiHeadAttention21from.positional\_encodingimportget\_positional\_encoding
24classEmbeddingsWithPositionalEncoding(nn.Module):
31def\_\_init\_\_(self,d\_model:int,n\_vocab:int,max\_len:int=5000):32super().\_\_init\_\_()33self.linear=nn.Embedding(n\_vocab,d\_model)34self.d\_model=d\_model35self.register\_buffer('positional\_encodings',get\_positional\_encoding(d\_model,max\_len))
37defforward(self,x:torch.Tensor):38pe=self.positional\_encodings[:x.shape[0]].requires\_grad\_(False)39returnself.linear(x)\*math.sqrt(self.d\_model)+pe
42classEmbeddingsWithLearnedPositionalEncoding(nn.Module):
49def\_\_init\_\_(self,d\_model:int,n\_vocab:int,max\_len:int=5000):50super().\_\_init\_\_()51self.linear=nn.Embedding(n\_vocab,d\_model)52self.d\_model=d\_model53self.positional\_encodings=nn.Parameter(torch.zeros(max\_len,1,d\_model),requires\_grad=True)
55defforward(self,x:torch.Tensor):56pe=self.positional\_encodings[:x.shape[0]]57returnself.linear(x)\*math.sqrt(self.d\_model)+pe
This can act as an encoder layer or a decoder layer. We use pre-norm.
60classTransformerLayer(nn.Module):
d_model is the token embedding sizeself_attn is the self attention modulesrc_attn is the source attention module (when this is used in a decoder)feed_forward is the feed forward moduledropout_prob is the probability of dropping out after self attention and FFN69def\_\_init\_\_(self,\*,70d\_model:int,71self\_attn:MultiHeadAttention,72src\_attn:MultiHeadAttention=None,73feed\_forward:FeedForward,74dropout\_prob:float):
82super().\_\_init\_\_()83self.size=d\_model84self.self\_attn=self\_attn85self.src\_attn=src\_attn86self.feed\_forward=feed\_forward87self.dropout=nn.Dropout(dropout\_prob)88self.norm\_self\_attn=nn.LayerNorm([d\_model])89ifself.src\_attnisnotNone:90self.norm\_src\_attn=nn.LayerNorm([d\_model])91self.norm\_ff=nn.LayerNorm([d\_model])
Whether to save input to the feed forward layer
93self.is\_save\_ff\_input=False
95defforward(self,\*,96x:torch.Tensor,97mask:torch.Tensor,98src:torch.Tensor=None,99src\_mask:torch.Tensor=None):
Normalize the vectors before doing self attention
101z=self.norm\_self\_attn(x)
Run through self attention, i.e. keys and values are from self
103self\_attn=self.self\_attn(query=z,key=z,value=z,mask=mask)
Add the self attention results
105x=x+self.dropout(self\_attn)
If a source is provided, get results from attention to source. This is when you have a decoder layer that pays attention to encoder outputs
110ifsrcisnotNone:
Normalize vectors
112z=self.norm\_src\_attn(x)
Attention to source. i.e. keys and values are from source
114attn\_src=self.src\_attn(query=z,key=src,value=src,mask=src\_mask)
Add the source attention results
116x=x+self.dropout(attn\_src)
Normalize for feed-forward
119z=self.norm\_ff(x)
Save the input to the feed forward layer if specified
121ifself.is\_save\_ff\_input:122self.ff\_input=z.clone()
Pass through the feed-forward network
124ff=self.feed\_forward(z)
Add the feed-forward results back
126x=x+self.dropout(ff)127128returnx
131classEncoder(nn.Module):
138def\_\_init\_\_(self,layer:TransformerLayer,n\_layers:int):139super().\_\_init\_\_()
Make copies of the transformer layer
141self.layers=clone\_module\_list(layer,n\_layers)
Final normalization layer
143self.norm=nn.LayerNorm([layer.size])
145defforward(self,x:torch.Tensor,mask:torch.Tensor):
Run through each transformer layer
147forlayerinself.layers:148x=layer(x=x,mask=mask)
Finally, normalize the vectors
150returnself.norm(x)
153classDecoder(nn.Module):
160def\_\_init\_\_(self,layer:TransformerLayer,n\_layers:int):161super().\_\_init\_\_()
Make copies of the transformer layer
163self.layers=clone\_module\_list(layer,n\_layers)
Final normalization layer
165self.norm=nn.LayerNorm([layer.size])
167defforward(self,x:torch.Tensor,memory:torch.Tensor,src\_mask:torch.Tensor,tgt\_mask:torch.Tensor):
Run through each transformer layer
169forlayerinself.layers:170x=layer(x=x,mask=tgt\_mask,src=memory,src\_mask=src\_mask)
Finally, normalize the vectors
172returnself.norm(x)
This predicts the tokens and gives the lof softmax of those. You don't need this if you are using nn.CrossEntropyLoss .
175classGenerator(nn.Module):
185def\_\_init\_\_(self,n\_vocab:int,d\_model:int):186super().\_\_init\_\_()187self.projection=nn.Linear(d\_model,n\_vocab)
189defforward(self,x):190returnself.projection(x)
193classEncoderDecoder(nn.Module):
200def\_\_init\_\_(self,encoder:Encoder,decoder:Decoder,src\_embed:nn.Module,tgt\_embed:nn.Module,generator:nn.Module):201super().\_\_init\_\_()202self.encoder=encoder203self.decoder=decoder204self.src\_embed=src\_embed205self.tgt\_embed=tgt\_embed206self.generator=generator
This was important from their code. Initialize parameters with Glorot / fan_avg.
210forpinself.parameters():211ifp.dim()\>1:212nn.init.xavier\_uniform\_(p)
214defforward(self,src:torch.Tensor,tgt:torch.Tensor,src\_mask:torch.Tensor,tgt\_mask:torch.Tensor):
Run the source through encoder
216enc=self.encode(src,src\_mask)
Run encodings and targets through decoder
218returnself.decode(enc,src\_mask,tgt,tgt\_mask)
220defencode(self,src:torch.Tensor,src\_mask:torch.Tensor):221returnself.encoder(self.src\_embed(src),src\_mask)
223defdecode(self,memory:torch.Tensor,src\_mask:torch.Tensor,tgt:torch.Tensor,tgt\_mask:torch.Tensor):224returnself.decoder(self.tgt\_embed(tgt),memory,src\_mask,tgt\_mask)