Back to Annotated Deep Learning Paper Implementations

FNet Experiment

docs/transformers/fnet/experiment.html

latest4.6 KB
Original Source

hometransformersfnet

View code on Github

#

FNet Experiment

This is an annotated PyTorch experiment to train a FNet model.

This is based on general training loop and configurations for AG News classification task.

15importtorch16fromtorchimportnn1718fromlabmlimportexperiment19fromlabml.configsimportoption20fromlabml\_nn.experiments.nlp\_classificationimportNLPClassificationConfigs21fromlabml\_nn.transformersimportEncoder22fromlabml\_nn.transformersimportTransformerConfigs

#

Transformer based classifier model

25classTransformerClassifier(nn.Module):

#

29def\_\_init\_\_(self,encoder:Encoder,src\_embed:nn.Module,generator:nn.Linear):

#

36super().\_\_init\_\_()37self.src\_embed=src\_embed38self.encoder=encoder39self.generator=generator

#

41defforward(self,x:torch.Tensor):

#

Get the token embeddings with positional encodings

43x=self.src\_embed(x)

#

Transformer encoder

45x=self.encoder(x,None)

#

Get logits for classification.

We set the [CLS] token at the last position of the sequence. This is extracted by x[-1] , where x is of shape [seq_len, batch_size, d_model]

51x=self.generator(x[-1])

#

Return results (second value is for state, since our trainer is used with RNNs also)

55returnx,None

#

Configurations

This inherits from NLPClassificationConfigs

58classConfigs(NLPClassificationConfigs):

#

Classification model

67model:TransformerClassifier

#

Transformer

69transformer:TransformerConfigs

#

Transformer configurations

72@option(Configs.transformer)73def\_transformer\_configs(c:Configs):

#

We use our configurable transformer implementation

80conf=TransformerConfigs()

#

Set the vocabulary sizes for embeddings and generating logits

82conf.n\_src\_vocab=c.n\_tokens83conf.n\_tgt\_vocab=c.n\_tokens

#

86returnconf

#

Create FNetMix module that can replace the self-attention in transformer encoder layer .

89@option(TransformerConfigs.encoder\_attn)90deffnet\_mix():

#

96fromlabml\_nn.transformers.fnetimportFNetMix97returnFNetMix()

#

Create classification model

100@option(Configs.model)101def\_model(c:Configs):

#

105m=TransformerClassifier(c.transformer.encoder,106c.transformer.src\_embed,107nn.Linear(c.d\_model,c.n\_classes)).to(c.device)108109returnm

#

112defmain():

#

Create experiment

114experiment.create(name="fnet")

#

Create configs

116conf=Configs()

#

Override configurations

118experiment.configs(conf,{

#

Use world level tokenizer

120'tokenizer':'basic\_english',

#

Train for 32 epochs

123'epochs':32,

#

Switch between training and validation for 10 times per epoch

126'inner\_iterations':10,

#

Transformer configurations (same as defaults)

129'transformer.d\_model':512,130'transformer.ffn.d\_ff':2048,131'transformer.n\_heads':8,132'transformer.n\_layers':6,

#

Use FNet instead of self-a ttention

136'transformer.encoder\_attn':'fnet\_mix',

#

Use Noam optimizer

139'optimizer.optimizer':'Noam',140'optimizer.learning\_rate':1.,141})

#

Set models for saving and loading

144experiment.add\_pytorch\_models({'model':conf.model})

#

Start the experiment

147withexperiment.start():

#

Run training

149conf.run()

#

153if\_\_name\_\_=='\_\_main\_\_':154main()

labml.ai