docs/RWKV/experiment.html
1importinspect2importmath34importtorch5importtorch.nnasnn6fromlabml\_nn.rwkv.configsimportRWKVConfigs78fromlabml\_nn.rwkvimportRWKV9fromlabml\_nn.rwkvimportTimeMixing10fromlabmlimportexperiment11fromlabml.configsimportoption12fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs
This inherits from NLPAutoRegressionConfigs
15classConfigs(NLPAutoRegressionConfigs):
RWKV model
24model:RWKV2526rwkv:RWKVConfigs
number of warmup iterations
28warmup\_iters:int=2000
total number of training iterations
30max\_iters:int=600000
weight decay
32weight\_decay:float=1e-1
Custom optimizer
34beta1:float=0.935beta2:float=0.9536optimizer='rwkv\_optimizer'
39@option(Configs.rwkv,'RWKV')40def\_rwkv\_configs(c:Configs):
We use our configurable RWKV implementation
47conf=RWKVConfigs()
Set the vocabulary sizes for embeddings and generating logits
49conf.n\_src\_vocab=c.n\_tokens50conf.n\_tgt\_vocab=c.n\_tokens5152returnconf
55def\_init\_weights(module,rwkv:RWKVConfigs):
initialize Vector Parameters in TimeMixing
57ifisinstance(module,TimeMixing):58layer\_id=module.layer\_id59n\_layer=module.n\_layer60n\_embd=module.n\_embd61attn\_sz=n\_embd6263withtorch.no\_grad():64ratio\_0\_to\_1=layer\_id/(n\_layer-1)# 0 to 165ratio\_1\_to\_almost0=1.0-(layer\_id/n\_layer)# 1 to ~066ddd=torch.ones(1,1,n\_embd)67foriinrange(n\_embd):68ddd[0,0,i]=i/n\_embd6970decay\_speed=torch.ones(attn\_sz)71forhinrange(attn\_sz):72decay\_speed[h]=-5+8\*(h/(attn\_sz-1))\*\*(0.7+1.3\*ratio\_0\_to\_1)73module.time\_decay=nn.Parameter(decay\_speed)7475zigzag=torch.tensor([(i+1)%3-1foriinrange(attn\_sz)])\*0.576module.time\_first=nn.Parameter(torch.ones(attn\_sz)\*math.log(0.3)+zigzag)77module.time\_mix\_key=nn.Parameter(torch.pow(ddd,ratio\_1\_to\_almost0))78module.time\_mix\_value=nn.Parameter(torch.pow(ddd,ratio\_1\_to\_almost0)+0.3\*ratio\_0\_to\_1)79module.time\_mix\_receptance=nn.Parameter(torch.pow(ddd,0.5\*ratio\_1\_to\_almost0))
Create RWKV model and initialize weights
82@option(Configs.model)83def\_model(c:Configs):
87m=RWKV(c.rwkv).to(c.device)
Apply custom weight initialization
90m.apply(\_init\_weights,c.rwkv)9192returnm
95@option(NLPAutoRegressionConfigs.optimizer)96def\_configure\_optimizers(c:NLPAutoRegressionConfigs):
start with all of the candidate parameters
98param\_dict={pn:pforpn,pinc.model.named\_parameters()}
filter out those that do not require grad
100param\_dict={pn:pforpn,pinparam\_dict.items()ifp.requires\_grad}
create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
103decay\_params=[pforn,pinparam\_dict.items()ifp.dim()\>=2]104nodecay\_params=[pforn,pinparam\_dict.items()ifp.dim()\<2]105optim\_groups=[106{'params':decay\_params,'weight\_decay':c.weight\_decay},107{'params':nodecay\_params,'weight\_decay':0.0}108]109num\_decay\_params=sum(p.numel()forpindecay\_params)110num\_nodecay\_params=sum(p.numel()forpinnodecay\_params)111print(f"num decayed parameter tensors: {len(decay\_params)}, with {num\_decay\_params:,} parameters")112print(f"num non-decayed parameter tensors: {len(nodecay\_params)}, with {num\_nodecay\_params:,} parameters")
Create AdamW optimizer and use the fused version if it is available
114fused\_available='fused'ininspect.signature(torch.optim.AdamW).parameters115use\_fused=fused\_availableandc.device\_type=='cuda'116extra\_args=dict(fused=True)ifuse\_fusedelsedict()117optimizer=torch.optim.AdamW(optim\_groups,lr=c.learning\_rate,betas=c.betas,\*\*extra\_args)118print(f"using fused AdamW: {use\_fused}")119120returnoptimizer
123defmain():
Create experiment
125experiment.create(name="RWKV")
Create configs
127conf=Configs()128print(conf.model)
Override configurations
130experiment.configs(conf,{
Use character level tokenizer
132'tokenizer':'character',
Prompt separator is blank
134'prompt\_separator':'',
Starting prompt for sampling
136'prompt':'It is ',
Use Tiny Shakespeare dataset
138'text':'tiny\_shakespeare',
Use a context size of 128
141'seq\_len':128,
Train for 32 epochs
143'epochs':32,
Batch size 128
145'batch\_size':128,
Switch between training and validation for 10 times per epoch
148'inner\_iterations':10,149150'rwkv.block\_size':1024,
model
152'rwkv.n\_layer':12,153'rwkv.n\_heads':12,154'rwkv.n\_embd':768155})156157print(conf.model)
Set models for saving and loading
159experiment.add\_pytorch\_models({'model':conf.model})
Start the experiment
162withexperiment.start():
Run training
164conf.run()
168if\_\_name\_\_=='\_\_main\_\_':169main()