docs/transformers/fast_weights/experiment.html
This trains a fast weights transformer model for auto-regression.
Here’s a Colab notebook for training a fast weights transformer on Tiny Shakespeare dataset.
16importtorch17fromtorchimportnn1819fromlabmlimportexperiment20fromlabml.configsimportoption21fromlabml.utils.pytorchimportget\_modules22fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs
25classAutoregressiveModel(nn.Module):
30def\_\_init\_\_(self,n\_vocab:int,d\_model:int,transformer:nn.Module):31super().\_\_init\_\_()
Token embedding module
33self.src\_embed=nn.Embedding(n\_vocab,d\_model)34self.transformer=transformer35self.generator=nn.Linear(d\_model,n\_vocab)
37defforward(self,x:torch.Tensor):
Embed the tokens
39x=self.src\_embed(x)
Run it through the the transformer
41res=self.transformer(x)
Generate logits of the next token
43returnself.generator(res),None
The default configs can and will be over-ridden when we start the experiment
46classConfigs(NLPAutoRegressionConfigs):
53model:AutoregressiveModel5455d\_model:int=51256nu:int=157heads:int=858dropout:float=0.059d\_ff:int=204860n\_layers:int=6
Create fast weights transformer.
63@option(Configs.model)64deffast\_weights\_transformer(c:Configs):
68fromlabml\_nn.transformers.fast\_weightsimportFastWeightsAttentionTransformer,\69FastWeightsAttentionTransformerLayer,FastWeightsAttention,FeedForward7071fromlabml\_nn.transformers.fast\_weightsimportDPFP72returnAutoregressiveModel(73c.n\_tokens,c.d\_model,74FastWeightsAttentionTransformer(75FastWeightsAttentionTransformerLayer(d\_model=c.d\_model,76attn=FastWeightsAttention(c.heads,c.d\_model,c.dropout,DPFP(nu=c.nu)),77feed\_forward=FeedForward(c.d\_model,c.d\_ff,c.dropout),78dropout\_prob=c.dropout),79c.n\_layers)).to(c.device)
82defmain():
Create experiment
84experiment.create(name="fast\_weights\_transformer")
Create configs
86conf=Configs()
Load configurations
88experiment.configs(conf,
A dictionary of configurations to override
90{'tokenizer':'character',91'text':'tiny\_shakespeare',92'optimizer.learning\_rate':1.0,93'optimizer.optimizer':'Noam',94'prompt':'It is',95'prompt\_separator':'',9697'train\_loader':'shuffled\_train\_loader',98'valid\_loader':'shuffled\_valid\_loader',99100'seq\_len':128,101'epochs':128,102'batch\_size':16,103'inner\_iterations':25})
Set models for saving and loading
106experiment.add\_pytorch\_models(get\_modules(conf))
Start the experiment
109withexperiment.start():
Run the training loop
111conf.run()112113114if\_\_name\_\_=='\_\_main\_\_':115main()