Back to Annotated Deep Learning Paper Implementations

experiment.py

docs/hypernetworks/experiment.html

latest2.9 KB
Original Source

homehypernetworks

View code on Github

#

1importtorch2importtorch.nnasnn3fromlabmlimportexperiment4fromlabml.configsimportoption5fromlabml.utils.pytorchimportget\_modules67fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs8fromlabml\_nn.hypernetworks.hyper\_lstmimportHyperLSTM9fromlabml\_nn.lstmimportLSTM

#

Auto regressive model

12classAutoregressiveModel(nn.Module):

#

17def\_\_init\_\_(self,n\_vocab:int,d\_model:int,rnn\_model:nn.Module):18super().\_\_init\_\_()

#

Token embedding module

20self.src\_embed=nn.Embedding(n\_vocab,d\_model)21self.lstm=rnn\_model22self.generator=nn.Linear(d\_model,n\_vocab)

#

24defforward(self,x:torch.Tensor):25x=self.src\_embed(x)

#

Embed the tokens (src ) and run it through the the transformer

27res,state=self.lstm(x)

#

Generate logits of the next token

29returnself.generator(res),state

#

Configurations

The default configs can and will be over-ridden when we start the experiment

32classConfigs(NLPAutoRegressionConfigs):

#

39model:AutoregressiveModel40rnn\_model:nn.Module4142d\_model:int=51243n\_rhn:int=1644n\_z:int=16

#

Initialize the auto-regressive model

47@option(Configs.model)48defautoregressive\_model(c:Configs):

#

52m=AutoregressiveModel(c.n\_tokens,c.d\_model,c.rnn\_model)53returnm.to(c.device)

#

56@option(Configs.rnn\_model)57defhyper\_lstm(c:Configs):58returnHyperLSTM(c.d\_model,c.d\_model,c.n\_rhn,c.n\_z,1)596061@option(Configs.rnn\_model)62deflstm(c:Configs):63returnLSTM(c.d\_model,c.d\_model,1)646566defmain():

#

Create experiment

68experiment.create(name="hyper\_lstm",comment='')

#

Create configs

70conf=Configs()

#

Load configurations

72experiment.configs(conf,

#

A dictionary of configurations to override

74{'tokenizer':'character',75'text':'tiny\_shakespeare',76'optimizer.learning\_rate':2.5e-4,77'optimizer.optimizer':'Adam',78'prompt':'It is',79'prompt\_separator':'',8081'rnn\_model':'hyper\_lstm',8283'train\_loader':'shuffled\_train\_loader',84'valid\_loader':'shuffled\_valid\_loader',8586'seq\_len':512,87'epochs':128,88'batch\_size':2,89'inner\_iterations':25})

#

Set models for saving and loading

92experiment.add\_pytorch\_models(get\_modules(conf))

#

Start the experiment

95withexperiment.start():

#

TrainValidConfigs.run

97conf.run()9899100if\_\_name\_\_=='\_\_main\_\_':101main()

labml.ai