docs/transformers/mlp_mixer/experiment.html
This is an annotated PyTorch experiment to train a MLP Mixer Model.
12fromlabmlimportexperiment13fromlabml.configsimportoption14fromlabml\_nn.transformersimportTransformerConfigs15fromlabml\_nn.transformers.configsimportFeedForwardConfigs16fromlabml\_nn.transformers.mlm.experimentimportTransformerMLM,ConfigsasMLMConfigs
This inherits from MLMConfigs where we define an experiment for Masked Language Models.
19classConfigs(MLMConfigs):
Configurable Feed-Forward Network for the MLP
29mix\_mlp:FeedForwardConfigs
The mixing MLP configurations
32@option(Configs.mix\_mlp)33def\_mix\_mlp\_configs(c:Configs):
38conf=FeedForwardConfigs()
Size of the MLP is the sequence length, because it is applied across tokens
40conf.d\_model=c.seq\_len
The paper suggests GELU activation
42conf.activation='GELU'
45returnconf
48@option(Configs.transformer)49def\_transformer\_configs(c:Configs):
We use our configurable transformer implementation
56conf=TransformerConfigs()
Set the vocabulary sizes for embeddings and generating logits
58conf.n\_src\_vocab=c.n\_tokens59conf.n\_tgt\_vocab=c.n\_tokens
Embedding size
61conf.d\_model=c.d\_model
Change attention module to MLPMixer
63fromlabml\_nn.transformers.mlp\_mixerimportMLPMixer64conf.encoder\_attn=MLPMixer(c.mix\_mlp.ffn)
67returnconf
70defmain():
Create experiment
72experiment.create(name="mlp\_mixer\_mlm")
Create configs
74conf=Configs()
Override configurations
76experiment.configs(conf,{
Batch size
78'batch\_size':64,
Sequence length of 32. We use a short sequence length to train faster. Otherwise MLM models take forever to train.
81'seq\_len':32,
Train for 1024 epochs.
84'epochs':1024,
Switch between training and validation for 1 times per epoch
87'inner\_iterations':1,
Transformer configurations
90'd\_model':128,91'transformer.ffn.d\_ff':256,92'transformer.n\_heads':8,93'transformer.n\_layers':6,94'transformer.ffn.activation':'GELU',
Mixer MLP hidden layer size
97'mix\_mlp.d\_ff':128,
Use Noam optimizer
100'optimizer.optimizer':'Noam',101'optimizer.learning\_rate':1.,102})
Set models for saving and loading
105experiment.add\_pytorch\_models({'model':conf.model})
Start the experiment
108withexperiment.start():
Run training
110conf.run()
114if\_\_name\_\_=='\_\_main\_\_':115main()