docs/transformers/rope/experiment.html
This is an annotated PyTorch experiment to train a transformer model with Rotary Positional Embeddings (RoPE).
12fromlabmlimportexperiment13fromlabml.configsimportoption,calculate14fromlabml\_nn.transformersimportTransformerConfigs15fromlabml\_nn.transformers.basic.autoregressive\_experimentimportAutoregressiveTransformer,Configs
19def\_rotary\_pe\_mha(c:TransformerConfigs):20fromlabml\_nn.transformers.ropeimportRotaryPEMultiHeadAttention21returnRotaryPEMultiHeadAttention(c.n\_heads,c.d\_model,1.)
Configuration options
25calculate(TransformerConfigs.encoder\_attn,'rotary',\_rotary\_pe\_mha)26calculate(TransformerConfigs.decoder\_attn,'rotary',\_rotary\_pe\_mha)27calculate(TransformerConfigs.decoder\_mem\_attn,'rotary',\_rotary\_pe\_mha)
Create an autoregressive model and initialize weights
30@option(Configs.model,'rotary\_pe\_transformer')31def\_model(c:Configs):
35m=AutoregressiveTransformer(c.transformer.encoder,36c.transformer.src\_embed,37c.transformer.generator).to(c.device)3839returnm
42defmain():
Create experiment
44experiment.create(name="rotary\_pe\_transformer",writers={'screen'})
Create configs
46conf=Configs()
Override configurations
48experiment.configs(conf,{
No fixed positional embeddings
50'transformer.src\_embed':'no\_pos',51'transformer.tgt\_embed':'no\_pos',
Encoder with RoPE
54'transformer.encoder\_attn':'rotary',
57'model':'rotary\_pe\_transformer',
Use character level tokenizer
60'tokenizer':'character',
Prompt separator is blank
62'prompt\_separator':'',
Starting prompt for sampling
64'prompt':'It is ',
Use Tiny Shakespeare dataset
66'text':'tiny\_shakespeare',
Use a context size of 256
69'seq\_len':512,
Train for 32 epochs
71'epochs':32,
Batch size 4
73'batch\_size':4,
Switch between training and validation for 10 times per epoch
76'inner\_iterations':10,
Model size
79'd\_model':128,80'transformer.ffn.d\_ff':512,81'transformer.n\_heads':16,82'transformer.dropout':0.0,
Use Noam optimizer
85'optimizer.optimizer':'Noam',86'optimizer.learning\_rate':1.,8788'dataloader\_shuffle\_with\_replacement':True89})
Set models for saving and loading
92experiment.add\_pytorch\_models({'model':conf.model})
Start the experiment
95withexperiment.start():
Run training
97conf.run()
101if\_\_name\_\_=='\_\_main\_\_':102main()