docs/RWKV/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/rwkv/ init.py)
This is a tutorial/implementation of RWKV from paper RWKV: Reinventing RNNs for the Transformer Era in PyTorch.
Full definition of a RWKV Language Model, all of it in this single file. References: 1) the official RWKV PyTorch implementation released by Bo Peng 2) huggingface/transformers PyTorch implementation
22importtorch23importtorch.nnasnn24fromtorch.nnimportfunctionalasF252627PREV\_X\_TIME=028NUM\_STATE=129DEN\_STATE=230MAX\_STATE=331PREV\_X\_CHANNEL=4
34classLayerNorm(nn.Module):
39def\_\_init\_\_(self,ndim,bias):40super().\_\_init\_\_()41self.weight=nn.Parameter(torch.ones(ndim))42self.bias=nn.Parameter(torch.zeros(ndim))ifbiaselseNone
44defforward(self,input):45returnF.layer\_norm(input,self.weight.shape,self.weight,self.bias,1e-5)
48classL2Wrap(torch.autograd.Function):
55@staticmethod56defforward(ctx,loss,y):57ctx.save\_for\_backward(y)58returnloss5960@staticmethod61defbackward(ctx,grad\_output):62y=ctx.saved\_tensors[0]
to encourage the logits to be close to 0
64factor=1e-4/(y.shape[0]\*y.shape[1])65maxx,ids=torch.max(y,-1,keepdim=True)66gy=torch.zeros\_like(y)67gy.scatter\_(-1,ids,maxx\*factor)68returngrad\_output,gy
71classChannelMixing(nn.Module):
76def\_\_init\_\_(self,config,layer\_id):77super().\_\_init\_\_()78self.time\_shift=nn.ZeroPad2d((0,0,1,-1))
token shifting
80self.layer\_id=layer\_id8182n\_embd=config.n\_embd83intermediate\_size=(84config.intermediate\_sizeifconfig.intermediate\_sizeisnotNoneelse4\*n\_embd85)
Learnable Matrix
88self.key\_proj=nn.Linear(n\_embd,intermediate\_size,bias=False)89self.value\_proj=nn.Linear(intermediate\_size,n\_embd,bias=False)90self.receptance\_proj=nn.Linear(n\_embd,n\_embd,bias=False)
Learnable Vector
93self.time\_mix\_key=nn.Parameter(torch.empty(1,1,n\_embd))94self.time\_mix\_receptance=nn.Parameter(torch.empty(1,1,n\_embd))
96defforward(self,x,state=None):
100ifstateisnotNone:101prev\_x=state[self.layer\_id,:,[PREV\_X\_CHANNEL],:]102state[self.layer\_id,:,[PREV\_X\_CHANNEL],:]=x103else:104prev\_x=self.time\_shift(x)
rt=Wr⋅(μrxt+(1−μr)xt−1)
107receptance=x\*self.time\_mix\_receptance+prev\_x\*(1-self.time\_mix\_receptance)108receptance=self.receptance\_proj(receptance)
kt=Wk⋅(μkxt+(1−μk)xt−1)
111key=x\*self.time\_mix\_key+prev\_x\*(1-self.time\_mix\_key)112key=self.key\_proj(key)
Vt=Wv⋅max(kt,0)2
115value=self.value\_proj(torch.square(torch.relu(key)))
ot=σ(rt)⊙vt
118out=F.sigmoid(receptance)\*value119returnout,state
122classTimeMixing(nn.Module):
127def\_\_init\_\_(self,config,layer\_id):128super().\_\_init\_\_()129self.config=config130self.time\_shift=nn.ZeroPad2d((0,0,1,-1))131self.layer\_id=layer\_id132133n\_embd=config.n\_embd134attn\_sz=n\_embd
learnable matrix
137self.key\_proj=nn.Linear(n\_embd,attn\_sz,bias=False)138self.value\_proj=nn.Linear(n\_embd,attn\_sz,bias=False)139self.receptance\_proj=nn.Linear(n\_embd,attn\_sz,bias=False)140self.output\_proj=nn.Linear(attn\_sz,n\_embd,bias=False)
learnable vector
143self.time\_decay=nn.Parameter(torch.empty(attn\_sz))144self.time\_first=nn.Parameter(torch.empty(attn\_sz))145self.time\_mix\_key=nn.Parameter(torch.empty(1,1,n\_embd))146self.time\_mix\_value=nn.Parameter(torch.empty(1,1,n\_embd))147self.time\_mix\_receptance=nn.Parameter(torch.empty(1,1,n\_embd))
x = (Batch,Time,Channel)
149defforward(self,x,state=None):
153ifstateisnotNone:154prev\_x=state[self.layer\_id,:,[PREV\_X\_TIME],:]155state[self.layer\_id,:,[PREV\_X\_TIME],:]=x156else:157prev\_x=self.time\_shift(x)
rt=Wr⋅(μrxt+(1−μr)xt−1)
160receptance=x\*self.time\_mix\_receptance+prev\_x\*(1-self.time\_mix\_receptance)161receptance=self.receptance\_proj(receptance)
kt=Wk⋅(μkxt+(1−μk)xt−1)
164key=x\*self.time\_mix\_key+prev\_x\*(1-self.time\_mix\_key)165key=self.key\_proj(key)
vt=Wv⋅(μvxt+(1−μv)xt−1)
168value=x\*self.time\_mix\_value+prev\_x\*(1-self.time\_mix\_value)169value=self.value\_proj(value)
WKV calculation
172\_,seq\_length,\_=key.size()173output=torch.zeros\_like(key)174175ifstateisNone:176num\_state=torch.zeros\_like(key[:,0],dtype=torch.float32)177den\_state=torch.zeros\_like(key[:,0],dtype=torch.float32)178max\_state=torch.zeros\_like(key[:,0],dtype=torch.float32)-1e38179else:180num\_state=state[self.layer\_id,:,NUM\_STATE,:]181den\_state=state[self.layer\_id,:,DEN\_STATE,:]182max\_state=state[self.layer\_id,:,MAX\_STATE,:]183184time\_decay=-torch.exp(self.time\_decay)185186forcurrent\_indexinrange(seq\_length):187current\_key=key[:,current\_index].float()188current\_value=value[:,current\_index]
wkvt=∑i=1t−1e−(t−1−i)w+ki+eu+kt∑i=1t−1d−(t−1−i)w+kivi+eu+ktvt
191max\_for\_output=torch.maximum(max\_state,current\_key+self.time\_first)192e1=torch.exp(max\_state-max\_for\_output)193e2=torch.exp(current\_key+self.time\_first-max\_for\_output)194numerator=e1\*num\_state+e2\*current\_value195denominator=e1\*den\_state+e2196output[:,current\_index]=(numerator/denominator).to(output.dtype)
Update state for next iteration
199max\_for\_state=torch.maximum(max\_state+time\_decay,current\_key)200e1=torch.exp(max\_state+time\_decay-max\_for\_state)201e2=torch.exp(current\_key-max\_for\_state)202num\_state=e1\*num\_state+e2\*current\_value203den\_state=e1\*den\_state+e2204max\_state=max\_for\_state
update states
207state[self.layer\_id,:,NUM\_STATE,:]=num\_state208state[self.layer\_id,:,DEN\_STATE,:]=den\_state209state[self.layer\_id,:,MAX\_STATE,:]=max\_state210wkv,state=self.wkv\_function(key,value,use\_customized\_cuda\_kernel=self.config.use\_customized\_cuda\_kernel,211state=state)
ot=Wo⋅(σ(rt)⊙wkvt)
214rwkv=F.sigmoid(receptance)\*wkv215rwkv=self.output\_proj(rwkv)216217returnrwkv,state
220classBlock(nn.Module):
225def\_\_init\_\_(self,config,layer\_id):226super().\_\_init\_\_()227self.ln\_1=LayerNorm(config.n\_embd,bias=config.bias)228self.attn=TimeMixing(config,layer\_id)229self.ln\_2=LayerNorm(config.n\_embd,bias=config.bias)230self.ffn=ChannelMixing(config,layer\_id)
232defforward(self,x,state=None):
state: batch_size, 5 , n_embd
time mixing
236residual=x237x,state=self.attn(self.ln\_1(x),state=state)238x=x+residual
channel mixing
241residual=x242x,state=self.ffn(self.ln\_2(x),state=state)243x=x+residual244returnx,state
247classRWKV(nn.Module):
251def\_\_init\_\_(self,config,lr\_init=0.0008):252super().\_\_init\_\_()253assertconfig.vocab\_sizeisnotNone254assertconfig.block\_sizeisnotNone255self.config=config256self.lr\_init=lr\_init## used to initialize embedding parameters257self.n\_layer=config.n\_layer258self.n\_embd=config.n\_embd
Initiate model layers
261self.rwkv=nn.ModuleDict(dict(262wte=nn.Embedding(config.vocab\_size,config.n\_embd),263ln\_p=LayerNorm(config.n\_embd,bias=config.bias),264h=nn.ModuleList([Block(config,layer\_id)forlayer\_idinrange(config.n\_layer)]),265ln\_f=LayerNorm(config.n\_embd,bias=config.bias),266))
Output linear layer
269self.lm\_head=nn.Linear(config.n\_embd,config.vocab\_size,bias=False)
271defforward(self,idx,targets=None,state=None,return\_state=False):272b,t=idx.size()273assertt\<=self.config.block\_size,f"Cannot forward sequence of length {t}, block size is only {self.config.block\_size}"
Embedding Layer
276x=self.rwkv.wte(idx)
Layer Norm
279x=self.rwkv.ln\_p(x)
RWKV Blocks
282forblock\_idx,blockinenumerate(self.rwkv.h):283x,state=block(x,state)284x=self.rwkv.ln\_f(x)
Logit Layer and loss Function (for training)
287iftargetsisnotNone:
if we are given some desired targets also calculate the loss
289logits=self.lm\_head(x)290loss=F.cross\_entropy(logits.view(-1,logits.size(-1)),targets.view(-1),ignore\_index=-1)291ifself.training:292loss=L2Wrap.apply(loss,logits)293else:
inference-time mini-optimization: only forward the lm_head on the very last position
295logits=self.lm\_head(x[:,[-1],:])# note: using list [-1] to preserve the time dim296loss=None
Return Logits and loss
299ifreturn\_state:300returnlogits,loss,state301else:302returnlogits,loss