docs/transformers/fnet/experiment.html
This is an annotated PyTorch experiment to train a FNet model.
This is based on general training loop and configurations for AG News classification task.
15importtorch16fromtorchimportnn1718fromlabmlimportexperiment19fromlabml.configsimportoption20fromlabml\_nn.experiments.nlp\_classificationimportNLPClassificationConfigs21fromlabml\_nn.transformersimportEncoder22fromlabml\_nn.transformersimportTransformerConfigs
25classTransformerClassifier(nn.Module):
encoder is the transformer Encodersrc_embed is the token embedding module (with positional encodings)generator is the final fully connected layer that gives the logits.29def\_\_init\_\_(self,encoder:Encoder,src\_embed:nn.Module,generator:nn.Linear):
36super().\_\_init\_\_()37self.src\_embed=src\_embed38self.encoder=encoder39self.generator=generator
41defforward(self,x:torch.Tensor):
Get the token embeddings with positional encodings
43x=self.src\_embed(x)
Transformer encoder
45x=self.encoder(x,None)
Get logits for classification.
We set the [CLS] token at the last position of the sequence. This is extracted by x[-1] , where x is of shape [seq_len, batch_size, d_model]
51x=self.generator(x[-1])
Return results (second value is for state, since our trainer is used with RNNs also)
55returnx,None
This inherits from NLPClassificationConfigs
58classConfigs(NLPClassificationConfigs):
Classification model
67model:TransformerClassifier
Transformer
69transformer:TransformerConfigs
72@option(Configs.transformer)73def\_transformer\_configs(c:Configs):
We use our configurable transformer implementation
80conf=TransformerConfigs()
Set the vocabulary sizes for embeddings and generating logits
82conf.n\_src\_vocab=c.n\_tokens83conf.n\_tgt\_vocab=c.n\_tokens
86returnconf
Create FNetMix module that can replace the self-attention in transformer encoder layer .
89@option(TransformerConfigs.encoder\_attn)90deffnet\_mix():
96fromlabml\_nn.transformers.fnetimportFNetMix97returnFNetMix()
Create classification model
100@option(Configs.model)101def\_model(c:Configs):
105m=TransformerClassifier(c.transformer.encoder,106c.transformer.src\_embed,107nn.Linear(c.d\_model,c.n\_classes)).to(c.device)108109returnm
112defmain():
Create experiment
114experiment.create(name="fnet")
Create configs
116conf=Configs()
Override configurations
118experiment.configs(conf,{
Use world level tokenizer
120'tokenizer':'basic\_english',
Train for 32 epochs
123'epochs':32,
Switch between training and validation for 10 times per epoch
126'inner\_iterations':10,
Transformer configurations (same as defaults)
129'transformer.d\_model':512,130'transformer.ffn.d\_ff':2048,131'transformer.n\_heads':8,132'transformer.n\_layers':6,
Use FNet instead of self-a ttention
136'transformer.encoder\_attn':'fnet\_mix',
Use Noam optimizer
139'optimizer.optimizer':'Noam',140'optimizer.learning\_rate':1.,141})
Set models for saving and loading
144experiment.add\_pytorch\_models({'model':conf.model})
Start the experiment
147withexperiment.start():
Run training
149conf.run()
153if\_\_name\_\_=='\_\_main\_\_':154main()