Back to Annotated Deep Learning Paper Implementations

Train Fast Weights Transformer

docs/transformers/fast_weights/experiment.html

latest3.3 KB
Original Source

hometransformersfast_weights

View code on Github

#

Train Fast Weights Transformer

This trains a fast weights transformer model for auto-regression.

Here’s a Colab notebook for training a fast weights transformer on Tiny Shakespeare dataset.

16importtorch17fromtorchimportnn1819fromlabmlimportexperiment20fromlabml.configsimportoption21fromlabml.utils.pytorchimportget\_modules22fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs

#

Auto regressive model

25classAutoregressiveModel(nn.Module):

#

30def\_\_init\_\_(self,n\_vocab:int,d\_model:int,transformer:nn.Module):31super().\_\_init\_\_()

#

Token embedding module

33self.src\_embed=nn.Embedding(n\_vocab,d\_model)34self.transformer=transformer35self.generator=nn.Linear(d\_model,n\_vocab)

#

37defforward(self,x:torch.Tensor):

#

Embed the tokens

39x=self.src\_embed(x)

#

Run it through the the transformer

41res=self.transformer(x)

#

Generate logits of the next token

43returnself.generator(res),None

#

Configurations

The default configs can and will be over-ridden when we start the experiment

46classConfigs(NLPAutoRegressionConfigs):

#

53model:AutoregressiveModel5455d\_model:int=51256nu:int=157heads:int=858dropout:float=0.059d\_ff:int=204860n\_layers:int=6

#

Create fast weights transformer.

63@option(Configs.model)64deffast\_weights\_transformer(c:Configs):

#

68fromlabml\_nn.transformers.fast\_weightsimportFastWeightsAttentionTransformer,\69FastWeightsAttentionTransformerLayer,FastWeightsAttention,FeedForward7071fromlabml\_nn.transformers.fast\_weightsimportDPFP72returnAutoregressiveModel(73c.n\_tokens,c.d\_model,74FastWeightsAttentionTransformer(75FastWeightsAttentionTransformerLayer(d\_model=c.d\_model,76attn=FastWeightsAttention(c.heads,c.d\_model,c.dropout,DPFP(nu=c.nu)),77feed\_forward=FeedForward(c.d\_model,c.d\_ff,c.dropout),78dropout\_prob=c.dropout),79c.n\_layers)).to(c.device)

#

82defmain():

#

Create experiment

84experiment.create(name="fast\_weights\_transformer")

#

Create configs

86conf=Configs()

#

Load configurations

88experiment.configs(conf,

#

A dictionary of configurations to override

90{'tokenizer':'character',91'text':'tiny\_shakespeare',92'optimizer.learning\_rate':1.0,93'optimizer.optimizer':'Noam',94'prompt':'It is',95'prompt\_separator':'',9697'train\_loader':'shuffled\_train\_loader',98'valid\_loader':'shuffled\_valid\_loader',99100'seq\_len':128,101'epochs':128,102'batch\_size':16,103'inner\_iterations':25})

#

Set models for saving and loading

106experiment.add\_pytorch\_models(get\_modules(conf))

#

Start the experiment

109withexperiment.start():

#

Run the training loop

111conf.run()112113114if\_\_name\_\_=='\_\_main\_\_':115main()

labml.ai