docs/transformers/configs.html
9importcopy1011importtorch.nnasnn1213fromlabml.configsimportBaseConfigs,option,calculate,aggregate14from.feed\_forwardimportFeedForward15from.mhaimportMultiHeadAttention16from.modelsimportEmbeddingsWithPositionalEncoding,EmbeddingsWithLearnedPositionalEncoding,TransformerLayer,\17Encoder,Decoder,Generator,EncoderDecoder
Creates a Position-wise FeedForward Network defined in feed_forward.py.
20classFeedForwardConfigs(BaseConfigs):
Position-wise feedforward layer
30ffn:FeedForward
Number of features in the embedding
32d\_model:int
Number of features in in the hidden layer
34d\_ff:int=2048
Dropout probability
36dropout:float=0.1
Activation in position-wise feedforward layer
38activation:nn.Module='ReLU'
Whether the FFN layer should be gated
40is\_gated:bool=False
Whether the first fully connected layer should have a learnable bias
42bias1:bool=True
Whether the second fully connected layer should have a learnable bias
44bias2:bool=True
Whether the fully connected layer for the gate should have a learnable bias
46bias\_gate:bool=False
Predefined GLU variants
48glu\_variant:str='none'
max(0,x)
51@option(FeedForwardConfigs.activation,'ReLU')52def\_ffn\_activation\_relu():
58returnnn.ReLU()
xΦ(x) where Φ(x)=P(X≤x),X∼N(0,1)
It was introduced in paper Gaussian Error Linear Units.
61@option(FeedForwardConfigs.activation,'GELU')62def\_ffn\_activation\_gelu():
70returnnn.GELU()
Initialize a feed forward network
73@option(FeedForwardConfigs.ffn,'default')74def\_feed\_forward(c:FeedForwardConfigs):
78returnFeedForward(c.d\_model,c.d\_ff,79dropout=c.dropout,80activation=c.activation,81is\_gated=c.is\_gated,82bias1=c.bias1,83bias2=c.bias2,84bias\_gate=c.bias\_gate)
These are variants with gated hidden layers for the FFN as introduced in paper GLU Variants Improve Transformer. We have omitted the bias terms as specified in the paper.
FFNGLU(x)(x,W1,V,W2)=(σ(xW1)⊗xV)W2
94aggregate(FeedForwardConfigs.glu\_variant,'GLU',95(FeedForwardConfigs.is\_gated,True),96(FeedForwardConfigs.bias1,False),97(FeedForwardConfigs.bias2,False),98(FeedForwardConfigs.bias\_gate,False),99(FeedForwardConfigs.activation,nn.Sigmoid()))
FFNBilinear(x)(x,W1,V,W2)=(xW1⊗xV)W2
104aggregate(FeedForwardConfigs.glu\_variant,'Bilinear',105(FeedForwardConfigs.is\_gated,True),106(FeedForwardConfigs.bias1,False),107(FeedForwardConfigs.bias2,False),108(FeedForwardConfigs.bias\_gate,False),109(FeedForwardConfigs.activation,nn.Identity()))
FFNReGLU(x)(x,W1,V,W2)=(max(0,xW1)⊗xV)W2
114aggregate(FeedForwardConfigs.glu\_variant,'ReGLU',115(FeedForwardConfigs.is\_gated,True),116(FeedForwardConfigs.bias1,False),117(FeedForwardConfigs.bias2,False),118(FeedForwardConfigs.bias\_gate,False),119(FeedForwardConfigs.activation,nn.ReLU()))
FFNGEGLU(x)(x,W1,V,W2)=(GELU(xW1)⊗xV)W2
124aggregate(FeedForwardConfigs.glu\_variant,'GEGLU',125(FeedForwardConfigs.is\_gated,True),126(FeedForwardConfigs.bias1,False),127(FeedForwardConfigs.bias2,False),128(FeedForwardConfigs.bias\_gate,False),129(FeedForwardConfigs.activation,nn.GELU()))
FFNSwiGLU(x)(x,W1,V,W2)=(Swish1(xW1)⊗xV)W2 where Swishβ(x)=xσ(βx)
135aggregate(FeedForwardConfigs.glu\_variant,'SwiGLU',136(FeedForwardConfigs.is\_gated,True),137(FeedForwardConfigs.bias1,False),138(FeedForwardConfigs.bias2,False),139(FeedForwardConfigs.bias\_gate,False),140(FeedForwardConfigs.activation,nn.SiLU()))
This defines configurations for a transformer. The configurations are calculate using option functions. These are lazy loaded and therefore only the necessary modules are calculated.
143classTransformerConfigs(BaseConfigs):
Number of attention heads
155n\_heads:int=8
Transformer embedding size
157d\_model:int=512
Number of layers
159n\_layers:int=6
Dropout probability
161dropout:float=0.1
Number of tokens in the source vocabulary (for token embeddings)
163n\_src\_vocab:int
Number of tokens in the target vocabulary (to generate logits for prediction)
165n\_tgt\_vocab:int
The encoder self attention
168encoder\_attn:MultiHeadAttention='mha'
The decoder self attention
170decoder\_attn:MultiHeadAttention='mha'
The decoder memory attention
172decoder\_mem\_attn:MultiHeadAttention='mha'
Configurable Feedforward Layer
175ffn:FeedForwardConfigs
Encoder layer
178encoder\_layer:TransformerLayer='default'
Decoder layer
180decoder\_layer:TransformerLayer='default'
Encoder consisting of multiple encoder layers
183encoder:Encoder='default'
Encoder consisting of multiple decoder layers
185decoder:Decoder='default'
Embedding layer for source
188src\_embed:nn.Module='fixed\_pos'
Embedding layer for target (for decoder)
190tgt\_embed:nn.Module='fixed\_pos'
Logit generator for prediction
193generator:Generator='default'
Encoder-decoder
196encoder\_decoder:EncoderDecoder
200def\_mha(c:TransformerConfigs):201returnMultiHeadAttention(c.n\_heads,c.d\_model,dropout\_prob=c.dropout)202203204calculate(TransformerConfigs.encoder\_attn,'mha',\_mha)205calculate(TransformerConfigs.decoder\_attn,'mha',\_mha)206calculate(TransformerConfigs.decoder\_mem\_attn,'mha',\_mha)
210def\_relative\_mha(c:TransformerConfigs):211fromlabml\_nn.transformers.xl.relative\_mhaimportRelativeMultiHeadAttention212returnRelativeMultiHeadAttention(c.n\_heads,c.d\_model)213214215calculate(TransformerConfigs.encoder\_attn,'relative',\_relative\_mha)216calculate(TransformerConfigs.decoder\_attn,'relative',\_relative\_mha)217calculate(TransformerConfigs.decoder\_mem\_attn,'relative',\_relative\_mha)
Create feedforward layer configurations
220@option(TransformerConfigs.ffn,'default')221def\_feed\_forward(c:TransformerConfigs):
225conf=FeedForwardConfigs()226conf.set\_default(FeedForwardConfigs.d\_model,func=lambda:c.d\_model)227conf.set\_default(FeedForwardConfigs.dropout,func=lambda:c.dropout)228returnconf
Encoder layer
231@option(TransformerConfigs.encoder\_layer,'default')232def\_encoder\_layer(c:TransformerConfigs):
236returnTransformerLayer(d\_model=c.d\_model,self\_attn=c.encoder\_attn,237src\_attn=None,feed\_forward=copy.deepcopy(c.ffn.ffn),238dropout\_prob=c.dropout)
Decoder layer
241@option(TransformerConfigs.decoder\_layer,'default')242def\_decoder\_layer(c:TransformerConfigs):
246returnTransformerLayer(d\_model=c.d\_model,self\_attn=c.decoder\_attn,247src\_attn=c.decoder\_mem\_attn,feed\_forward=copy.deepcopy(c.ffn.ffn),248dropout\_prob=c.dropout)
Encoder
251@option(TransformerConfigs.encoder,'default')252def\_encoder(c:TransformerConfigs):
256returnEncoder(c.encoder\_layer,c.n\_layers)
Decoder
259@option(TransformerConfigs.decoder,'default')260def\_decoder(c:TransformerConfigs):
264returnDecoder(c.decoder\_layer,c.n\_layers)
Logit generator
267@option(TransformerConfigs.generator,'default')268def\_generator(c:TransformerConfigs):
272returnGenerator(c.n\_tgt\_vocab,c.d\_model)
Source embedding with fixed positional encodings
276@option(TransformerConfigs.src\_embed,'fixed\_pos')277def\_src\_embed\_with\_positional(c:TransformerConfigs):
281returnEmbeddingsWithPositionalEncoding(c.d\_model,c.n\_src\_vocab)
Target embedding with fixed positional encodings
284@option(TransformerConfigs.tgt\_embed,'fixed\_pos')285def\_tgt\_embed\_with\_positional(c:TransformerConfigs):
289returnEmbeddingsWithPositionalEncoding(c.d\_model,c.n\_tgt\_vocab)
Source embedding with learned positional encodings
293@option(TransformerConfigs.src\_embed,'learned\_pos')294def\_src\_embed\_with\_learned\_positional(c:TransformerConfigs):
298returnEmbeddingsWithLearnedPositionalEncoding(c.d\_model,c.n\_src\_vocab)
Target embedding with learned positional encodings
301@option(TransformerConfigs.tgt\_embed,'learned\_pos')302def\_tgt\_embed\_with\_learned\_positional(c:TransformerConfigs):
306returnEmbeddingsWithLearnedPositionalEncoding(c.d\_model,c.n\_tgt\_vocab)
Source embedding without positional encodings
310@option(TransformerConfigs.src\_embed,'no\_pos')311def\_src\_embed\_without\_positional(c:TransformerConfigs):
315returnnn.Embedding(c.n\_src\_vocab,c.d\_model)
318@option(TransformerConfigs.tgt\_embed,'no\_pos')319def\_tgt\_embed\_without\_positional(c:TransformerConfigs):320returnnn.Embedding(c.n\_tgt\_vocab,c.d\_model)321322323@option(TransformerConfigs.encoder\_decoder,'default')324def\_encoder\_decoder(c:TransformerConfigs):325returnEncoderDecoder(c.encoder,c.decoder,c.src\_embed,c.tgt\_embed,c.generator)