docs/transformers/rope/value_pe/arithmetic_experiment.html
11fromlabmlimportexperiment12fromlabml.configsimportcalculate13fromlabml\_nn.experiments.arithmetic\_datasetimportArithmeticAutoregression14fromlabml\_nn.transformersimportTransformerConfigs15fromlabml\_nn.transformers.rope.experimentimportConfigsasRoPEConfigs
We inherit RoPE experiment and use it for arithmetic addition task.
We add the option to change attention to use Rotary Positional Embeddings with Relative distance (RoPER) below.
18classConfigs(RoPEConfigs,ArithmeticAutoregression):
26pass
Use Rotary Positional Embeddings with Relative distance (RoPER) in attention.
29def\_rotary\_value\_pe\_mha(c:TransformerConfigs):
33fromlabml\_nn.transformers.rope.value\_peimportRotaryValuePEMultiHeadAttention34returnRotaryValuePEMultiHeadAttention(c.n\_heads,c.d\_model,1.,1.)
Configuration options
38calculate(TransformerConfigs.encoder\_attn,'rotary\_value',\_rotary\_value\_pe\_mha)39calculate(TransformerConfigs.decoder\_attn,'rotary\_value',\_rotary\_value\_pe\_mha)40calculate(TransformerConfigs.decoder\_mem\_attn,'rotary\_value',\_rotary\_value\_pe\_mha)
43defmain():
Create experiment
45experiment.create(name="roper\_addition",comment="rotary value 7",writers={'screen','labml'})
Create configs
47conf=Configs()
Override configurations
49experiment.configs(conf,{50'max\_digits':7,
No fixed positional embeddings
53'transformer.src\_embed':'no\_pos',54'transformer.tgt\_embed':'no\_pos',
Encoder with RoPER attention
57'transformer.encoder\_attn':'rotary\_value',
Encoder with RoPE attention 'transformer.encoder_attn': 'rotary',
62'model':'rotary\_pe\_transformer',
Use a context size of 256
65'seq\_len':512,
Train for 32 epochs
67'epochs':20,
Batch size 4
69'batch\_size':16,
Model size
72'd\_model':128,73'transformer.ffn.d\_ff':512,74'transformer.n\_heads':4,75'transformer.dropout':0.0,
Use Adam optimizer
78'optimizer.optimizer':'Adam',79'optimizer.learning\_rate':2.5e-4,80})
Set models for saving and loading
83experiment.add\_pytorch\_models({'model':conf.model})
Start the experiment
86withexperiment.start():
Run training
88conf.run()
92if\_\_name\_\_=='\_\_main\_\_':93main()