Back to Annotated Deep Learning Paper Implementations

experiment.py

docs/RWKV/experiment.html

latest5.8 KB
Original Source

homerwkv

View code on Github

#

1importinspect2importmath34importtorch5importtorch.nnasnn6fromlabml\_nn.rwkv.configsimportRWKVConfigs78fromlabml\_nn.rwkvimportRWKV9fromlabml\_nn.rwkvimportTimeMixing10fromlabmlimportexperiment11fromlabml.configsimportoption12fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs

#

Configurations

This inherits from NLPAutoRegressionConfigs

15classConfigs(NLPAutoRegressionConfigs):

#

RWKV model

24model:RWKV2526rwkv:RWKVConfigs

#

number of warmup iterations

28warmup\_iters:int=2000

#

total number of training iterations

30max\_iters:int=600000

#

weight decay

32weight\_decay:float=1e-1

#

Custom optimizer

34beta1:float=0.935beta2:float=0.9536optimizer='rwkv\_optimizer'

#

RWKV configurations

39@option(Configs.rwkv,'RWKV')40def\_rwkv\_configs(c:Configs):

#

We use our configurable RWKV implementation

47conf=RWKVConfigs()

#

Set the vocabulary sizes for embeddings and generating logits

49conf.n\_src\_vocab=c.n\_tokens50conf.n\_tgt\_vocab=c.n\_tokens5152returnconf

#

55def\_init\_weights(module,rwkv:RWKVConfigs):

#

initialize Vector Parameters in TimeMixing

57ifisinstance(module,TimeMixing):58layer\_id=module.layer\_id59n\_layer=module.n\_layer60n\_embd=module.n\_embd61attn\_sz=n\_embd6263withtorch.no\_grad():64ratio\_0\_to\_1=layer\_id/(n\_layer-1)# 0 to 165ratio\_1\_to\_almost0=1.0-(layer\_id/n\_layer)# 1 to ~066ddd=torch.ones(1,1,n\_embd)67foriinrange(n\_embd):68ddd[0,0,i]=i/n\_embd6970decay\_speed=torch.ones(attn\_sz)71forhinrange(attn\_sz):72decay\_speed[h]=-5+8\*(h/(attn\_sz-1))\*\*(0.7+1.3\*ratio\_0\_to\_1)73module.time\_decay=nn.Parameter(decay\_speed)7475zigzag=torch.tensor([(i+1)%3-1foriinrange(attn\_sz)])\*0.576module.time\_first=nn.Parameter(torch.ones(attn\_sz)\*math.log(0.3)+zigzag)77module.time\_mix\_key=nn.Parameter(torch.pow(ddd,ratio\_1\_to\_almost0))78module.time\_mix\_value=nn.Parameter(torch.pow(ddd,ratio\_1\_to\_almost0)+0.3\*ratio\_0\_to\_1)79module.time\_mix\_receptance=nn.Parameter(torch.pow(ddd,0.5\*ratio\_1\_to\_almost0))

#

Create RWKV model and initialize weights

82@option(Configs.model)83def\_model(c:Configs):

#

87m=RWKV(c.rwkv).to(c.device)

#

Apply custom weight initialization

90m.apply(\_init\_weights,c.rwkv)9192returnm

#

95@option(NLPAutoRegressionConfigs.optimizer)96def\_configure\_optimizers(c:NLPAutoRegressionConfigs):

#

start with all of the candidate parameters

98param\_dict={pn:pforpn,pinc.model.named\_parameters()}

#

filter out those that do not require grad

100param\_dict={pn:pforpn,pinparam\_dict.items()ifp.requires\_grad}

#

create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.

103decay\_params=[pforn,pinparam\_dict.items()ifp.dim()\>=2]104nodecay\_params=[pforn,pinparam\_dict.items()ifp.dim()\<2]105optim\_groups=[106{'params':decay\_params,'weight\_decay':c.weight\_decay},107{'params':nodecay\_params,'weight\_decay':0.0}108]109num\_decay\_params=sum(p.numel()forpindecay\_params)110num\_nodecay\_params=sum(p.numel()forpinnodecay\_params)111print(f"num decayed parameter tensors: {len(decay\_params)}, with {num\_decay\_params:,} parameters")112print(f"num non-decayed parameter tensors: {len(nodecay\_params)}, with {num\_nodecay\_params:,} parameters")

#

Create AdamW optimizer and use the fused version if it is available

114fused\_available='fused'ininspect.signature(torch.optim.AdamW).parameters115use\_fused=fused\_availableandc.device\_type=='cuda'116extra\_args=dict(fused=True)ifuse\_fusedelsedict()117optimizer=torch.optim.AdamW(optim\_groups,lr=c.learning\_rate,betas=c.betas,\*\*extra\_args)118print(f"using fused AdamW: {use\_fused}")119120returnoptimizer

#

123defmain():

#

Create experiment

125experiment.create(name="RWKV")

#

Create configs

127conf=Configs()128print(conf.model)

#

Override configurations

130experiment.configs(conf,{

#

Use character level tokenizer

132'tokenizer':'character',

#

Prompt separator is blank

134'prompt\_separator':'',

#

Starting prompt for sampling

136'prompt':'It is ',

#

Use Tiny Shakespeare dataset

138'text':'tiny\_shakespeare',

#

Use a context size of 128

141'seq\_len':128,

#

Train for 32 epochs

143'epochs':32,

#

Batch size 128

145'batch\_size':128,

#

Switch between training and validation for 10 times per epoch

148'inner\_iterations':10,149150'rwkv.block\_size':1024,

#

model

152'rwkv.n\_layer':12,153'rwkv.n\_heads':12,154'rwkv.n\_embd':768155})156157print(conf.model)

#

Set models for saving and loading

159experiment.add\_pytorch\_models({'model':conf.model})

#

Start the experiment

162withexperiment.start():

#

Run training

164conf.run()

#

168if\_\_name\_\_=='\_\_main\_\_':169main()

labml.ai