docs/transformers/hour_glass/experiment.html
This is an annotated PyTorch experiment to train a hourglass.
This is based on training loop and configurations for a simple transformer auto-regressive NLP task.
14importmath15fromtypingimportList1617importtorch18fromtorchimportnn1920fromlabmlimportexperiment21fromlabml.configsimportoption22fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs23fromlabml\_nn.transformers.hour\_glassimportHourGlass24fromlabml\_nn.transformers.positional\_encodingimportPositionalEncoding
27classAutoregressiveTransformer(nn.Module):
n_tokens is the vocabulary sized_model is the size of the token embeddingsdropout is the dropout probabilityhour_glass is the hourglass model32def\_\_init\_\_(self,n\_tokens:int,d\_model:int,dropout:float,hour\_glass:HourGlass):
39super().\_\_init\_\_()
Token embeddings
41self.embedding=nn.Embedding(n\_tokens,d\_model)
š The official paper implementation use relative attention
47self.pos\_embedding=PositionalEncoding(d\_model,dropout)
49self.hour\_glass=hour\_glass
To normalize the final embeddings
51self.norm=nn.LayerNorm([d\_model])
Embedding size
53self.d\_model=d\_model
Final linear layer to predict the logits
55self.output=nn.Linear(d\_model,n\_tokens)
x is the tensor with token indexes of shape [seq_len, batch_size]57def\_\_call\_\_(self,x:torch.Tensor):
Get embeddings
62x=self.embedding(x)
65ifself.pos\_embeddingisnotNone:66x=self.pos\_embedding(x\*math.sqrt(self.d\_model))
Hourglass
69x=self.hour\_glass(x)
Get logits
72output=self.output(self.norm(x))
Return the logits
75returnoutput,None
This inherits from training loop and configurations for a simple transformer auto-regressive NLP task.
78classConfigs(NLPAutoRegressionConfigs):
Model
86model:AutoregressiveTransformer
Number of attention heads
88n\_heads:int=8
Dropout probability
90dropout:float=0.1
Size of feed-forward hidden layer
92d\_ff:int=512
Token embedding size
94d\_model:int=256
Shortening factors
96shortening\_factors:List[int]=[8,4]
Create the model
99@option(Configs.model)100def\_model(c:Configs):
Create hourglass model
106hour\_glass=HourGlass(c.n\_heads,c.d\_model,c.dropout,c.d\_ff,c.shortening\_factors)
Create the auto-regressive wrapper
108m=AutoregressiveTransformer(c.n\_tokens,c.d\_model,c.dropout,hour\_glass).to(c.device)
111returnm
114defmain():
Create experiment
116experiment.create(name="hour\_glass")
Create configs
118conf=Configs()
Override configurations
120experiment.configs(conf,{
Use character level tokenizer
122'tokenizer':'character',
Prompt separator is blank
124'prompt\_separator':'',
Starting prompt for sampling
126'prompt':'It is ',
Use Tiny Shakespeare dataset
128'text':'tiny\_shakespeare',
Use a context size of 256
131'seq\_len':256,
Train for 128 epochs
133'epochs':128,
Batch size 32
135'batch\_size':32,
Switch between training and validation for 10 times per epoch
138'inner\_iterations':10,
Use Noam optimizer
141'optimizer.optimizer':'Noam',142'optimizer.learning\_rate':1.,
144})
Set models for saving and loading
147experiment.add\_pytorch\_models({'model':conf.model})
Start the experiment
150withexperiment.start():
Run training
152conf.run()
156if\_\_name\_\_=='\_\_main\_\_':157main()