docs/transformers/aft/experiment.html
This is an annotated PyTorch experiment to train a AFT model.
This is based on general training loop and configurations for auto-regressive NLP task.
14importtorch15fromlabmlimportexperiment16fromlabml.configsimportoption17fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs18fromlabml\_nn.transformersimportTransformerConfigs,Encoder19fromlabml\_nn.transformers.utilsimportsubsequent\_mask20fromtorchimportnn
This consists of a token embedding layer, transformer encoder, and a final linear layer that gives token logits.
23classAutoregressiveTransformer(nn.Module):
encoder is the transformer Encodersrc_embed is the token embedding module (with positional encodings)generator is the final fully connected layer that gives the logits.31def\_\_init\_\_(self,encoder:Encoder,src\_embed:nn.Module,generator:nn.Module):
38super().\_\_init\_\_()39self.src\_embed=src\_embed40self.encoder=encoder41self.generator=generator
The mask will be initialized on the first call
44self.mask=None
46defforward(self,x:torch.Tensor):
Create subsequent mask if mask is not initialized or if the size of the mask is different
49ifself.maskisNoneorself.mask.size(0)!=len(x):
Subsequent mask, will mask out tokens from seeing future tokens
51self.mask=subsequent\_mask(len(x)).to(x.device)
Get the token embeddings with positional encodings
54x=self.src\_embed(x)
Transformer encoder
56x=self.encoder(x,self.mask)
Get logits
58x=self.generator(x)
Return results (second value is for state, since our trainer is used with RNNs also)
62returnx,None
This inherits from NLPAutoRegressionConfigs
65classConfigs(NLPAutoRegressionConfigs):
GPT model
74model:AutoregressiveTransformer
Transformer
76transformer:TransformerConfigs7778local\_window\_size:int=32
81@option(Configs.transformer,'Transformer')82def\_transformer\_configs(c:Configs):
We use our configurable transformer implementation
89conf=TransformerConfigs()
Set the vocabulary sizes for embeddings and generating logits
91conf.n\_src\_vocab=c.n\_tokens92conf.n\_tgt\_vocab=c.n\_tokens
Set the embedding size
94conf.d\_model=c.d\_model
Replace self-attention with an AFT Local Module
96fromlabml\_nn.transformers.aftimportAFTLocal97conf.encoder\_attn=AFTLocal(c.d\_model,c.seq\_len,c.local\_window\_size)
100returnconf
Create an auto-regressive model
103@option(Configs.model)104def\_model(c:Configs):
108m=AutoregressiveTransformer(c.transformer.encoder,109c.transformer.src\_embed,110c.transformer.generator).to(c.device)111112returnm
115defmain():
Create experiment
117experiment.create(name="aft")
Create configs
119conf=Configs()
Override configurations
121experiment.configs(conf,{
Use character level tokenizer
123'tokenizer':'character',
Prompt separator is blank
125'prompt\_separator':'',
Starting prompt for sampling
127'prompt':'It is ',
Use Tiny Shakespeare dataset
129'text':'tiny\_shakespeare',
Use a context size of 128
132'seq\_len':256,
Train for 32 epochs
134'epochs':128,
Batch size 128
136'batch\_size':32,
Switch between training and validation for 10 times per epoch
139'inner\_iterations':10,
Embedding size
142'd\_model':128,
FFN hidden dimension size
144'transformer.ffn.d\_ff':256,
Optimizer
147'optimizer.optimizer':'Noam',148'optimizer.learning\_rate':1.,149})
Set models for saving and loading
152experiment.add\_pytorch\_models({'model':conf.model})
Start the experiment
155withexperiment.start():
Run training
157conf.run()
161if\_\_name\_\_=='\_\_main\_\_':162main()