docs/hypernetworks/experiment.html
1importtorch2importtorch.nnasnn3fromlabmlimportexperiment4fromlabml.configsimportoption5fromlabml.utils.pytorchimportget\_modules67fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs8fromlabml\_nn.hypernetworks.hyper\_lstmimportHyperLSTM9fromlabml\_nn.lstmimportLSTM
12classAutoregressiveModel(nn.Module):
17def\_\_init\_\_(self,n\_vocab:int,d\_model:int,rnn\_model:nn.Module):18super().\_\_init\_\_()
Token embedding module
20self.src\_embed=nn.Embedding(n\_vocab,d\_model)21self.lstm=rnn\_model22self.generator=nn.Linear(d\_model,n\_vocab)
24defforward(self,x:torch.Tensor):25x=self.src\_embed(x)
Embed the tokens (src ) and run it through the the transformer
27res,state=self.lstm(x)
Generate logits of the next token
29returnself.generator(res),state
The default configs can and will be over-ridden when we start the experiment
32classConfigs(NLPAutoRegressionConfigs):
39model:AutoregressiveModel40rnn\_model:nn.Module4142d\_model:int=51243n\_rhn:int=1644n\_z:int=16
Initialize the auto-regressive model
47@option(Configs.model)48defautoregressive\_model(c:Configs):
52m=AutoregressiveModel(c.n\_tokens,c.d\_model,c.rnn\_model)53returnm.to(c.device)
56@option(Configs.rnn\_model)57defhyper\_lstm(c:Configs):58returnHyperLSTM(c.d\_model,c.d\_model,c.n\_rhn,c.n\_z,1)596061@option(Configs.rnn\_model)62deflstm(c:Configs):63returnLSTM(c.d\_model,c.d\_model,1)646566defmain():
Create experiment
68experiment.create(name="hyper\_lstm",comment='')
Create configs
70conf=Configs()
Load configurations
72experiment.configs(conf,
A dictionary of configurations to override
74{'tokenizer':'character',75'text':'tiny\_shakespeare',76'optimizer.learning\_rate':2.5e-4,77'optimizer.optimizer':'Adam',78'prompt':'It is',79'prompt\_separator':'',8081'rnn\_model':'hyper\_lstm',8283'train\_loader':'shuffled\_train\_loader',84'valid\_loader':'shuffled\_valid\_loader',8586'seq\_len':512,87'epochs':128,88'batch\_size':2,89'inner\_iterations':25})
Set models for saving and loading
92experiment.add\_pytorch\_models(get\_modules(conf))
Start the experiment
95withexperiment.start():
TrainValidConfigs.run
97conf.run()9899100if\_\_name\_\_=='\_\_main\_\_':101main()