Back to Annotated Deep Learning Paper Implementations

Train Autoregressive Transformer

docs/transformers/knn/train_model.html

latest4.5 KB
Original Source

hometransformersknn

View code on Github

#

Train Autoregressive Transformer

This trains a simple transformer model for auto-regression.

12importtorch13fromtorchimportnn14fromlabmlimportexperiment15fromlabml.configsimportoption16fromlabml.utils.pytorchimportget\_modules1718fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs19fromlabml\_nn.transformersimportEncoder,Generator,TransformerConfigs20fromlabml\_nn.transformers.utilsimportsubsequent\_mask

#

Auto regressive model

23classAutoregressiveModel(nn.Module):

#

28def\_\_init\_\_(self,src\_embed:nn.Module,encoder:Encoder,generator:Generator,\*,29is\_save\_ff\_input:bool=False):30super().\_\_init\_\_()

#

Token embedding module

32self.src\_embed=src\_embed

#

Transformer based encoder

34self.encoder=encoder

#

Whether the last layer of the encoder should save the input to the feed-forward layer. This is out f(ct​), the embedding of the context.

38self.encoder.layers[-1].is\_save\_ff\_input=is\_save\_ff\_input

#

Next token generation layer; this give logits of the the next token

41self.generator=generator

#

This will be initialized on the first call

43self.src\_mask=None

#

Retrieve saved f(ct​)

45@property46defff\_input(self)-\>torch.Tensor:

#

50returnself.encoder.layers[-1].ff\_input

#

52defforward(self,src:torch.Tensor):

#

Create subsequent mask, so that the transformer can only pay attention to past tokens.

54ifself.src\_maskisNoneorself.src\_mask.size(0)!=len(src):55self.src\_mask=subsequent\_mask(len(src)).to(src.device)

#

Embed the tokens (src ) and run it through the the transformer

57res=self.encoder(self.src\_embed(src),self.src\_mask)

#

Generate logits of the next token

59returnself.generator(res),None

#

Configurations

The default configs can and will be over-ridden when we start the experiment

62classConfigs(NLPAutoRegressionConfigs):

#

69transformer:TransformerConfigs70model:AutoregressiveModel7172is\_save\_ff\_input=False

#

Initialize the auto-regressive model

75@option(Configs.model)76defautoregressive\_model(c:Configs):

#

80m=AutoregressiveModel(

#

Get the source token embedding layer, encoder and final token generator from configurable transformer

83src\_embed=c.transformer.src\_embed,84encoder=c.transformer.encoder,85generator=c.transformer.generator,

#

Whether to save f(ct​)

87is\_save\_ff\_input=c.is\_save\_ff\_input)88returnm.to(c.device)

#

Initialize the configurable transformer encoder for our autoregressive model

91@option(Configs.transformer)92deftransformer\_c(c:Configs):

#

96tc=TransformerConfigs()97tc.n\_src\_vocab=c.n\_tokens98tc.n\_tgt\_vocab=c.n\_tokens99100returntc

#

103defmain():

#

Create experiment

105experiment.create(name="knn\_lm")

#

Create configs

107conf=Configs()

#

Load configurations

109experiment.configs(conf,

#

A dictionary of configurations to override

111{'tokenizer':'character',112'prompt\_separator':'',113'prompt':'It is ',114'text':'tiny\_shakespeare',115116'optimizer.optimizer':'Noam',117'optimizer.learning\_rate':1.,118'optimizer.d\_model':256,119120'seq\_len':1024,121'epochs':128,122'batch\_size':6,123'inner\_iterations':10,

#

Transformer configurations

126'transformer.d\_model':256,127'transformer.ffn.d\_ff':1024,128'transformer.n\_heads':8,129'transformer.n\_layers':6})

#

This is needed to initialize models

132conf.n\_tokens=conf.text.n\_tokens

#

Set models for saving and loading

135experiment.add\_pytorch\_models(get\_modules(conf))

#

Start the experiment

138withexperiment.start():

#

TrainValidConfigs.run

140conf.run()141142143if\_\_name\_\_=='\_\_main\_\_':144main()

labml.ai