Back to Annotated Deep Learning Paper Implementations

Receptance Weighted Key Value (RWKV)

docs/RWKV/index.html

latest9.9 KB
Original Source

homerwkv

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/rwkv/ init.py)

#

Receptance Weighted Key Value (RWKV)

This is a tutorial/implementation of RWKV from paper RWKV: Reinventing RNNs for the Transformer Era in PyTorch.

Full definition of a RWKV Language Model, all of it in this single file. References: 1) the official RWKV PyTorch implementation released by Bo Peng 2) huggingface/transformers PyTorch implementation

22importtorch23importtorch.nnasnn24fromtorch.nnimportfunctionalasF252627PREV\_X\_TIME=028NUM\_STATE=129DEN\_STATE=230MAX\_STATE=331PREV\_X\_CHANNEL=4

#

Layer normalization with bias

34classLayerNorm(nn.Module):

#

39def\_\_init\_\_(self,ndim,bias):40super().\_\_init\_\_()41self.weight=nn.Parameter(torch.ones(ndim))42self.bias=nn.Parameter(torch.zeros(ndim))ifbiaselseNone

#

44defforward(self,input):45returnF.layer\_norm(input,self.weight.shape,self.weight,self.bias,1e-5)

#

L2 loss wrapper

ref

48classL2Wrap(torch.autograd.Function):

#

55@staticmethod56defforward(ctx,loss,y):57ctx.save\_for\_backward(y)58returnloss5960@staticmethod61defbackward(ctx,grad\_output):62y=ctx.saved\_tensors[0]

#

to encourage the logits to be close to 0

64factor=1e-4/(y.shape[0]\*y.shape[1])65maxx,ids=torch.max(y,-1,keepdim=True)66gy=torch.zeros\_like(y)67gy.scatter\_(-1,ids,maxx\*factor)68returngrad\_output,gy

#

Channel Mixing

71classChannelMixing(nn.Module):

#

76def\_\_init\_\_(self,config,layer\_id):77super().\_\_init\_\_()78self.time\_shift=nn.ZeroPad2d((0,0,1,-1))

#

token shifting

80self.layer\_id=layer\_id8182n\_embd=config.n\_embd83intermediate\_size=(84config.intermediate\_sizeifconfig.intermediate\_sizeisnotNoneelse4\*n\_embd85)

#

Learnable Matrix

88self.key\_proj=nn.Linear(n\_embd,intermediate\_size,bias=False)89self.value\_proj=nn.Linear(intermediate\_size,n\_embd,bias=False)90self.receptance\_proj=nn.Linear(n\_embd,n\_embd,bias=False)

#

Learnable Vector

93self.time\_mix\_key=nn.Parameter(torch.empty(1,1,n\_embd))94self.time\_mix\_receptance=nn.Parameter(torch.empty(1,1,n\_embd))

#

x = (Batch,Time,Channel)

96defforward(self,x,state=None):

#

100ifstateisnotNone:101prev\_x=state[self.layer\_id,:,[PREV\_X\_CHANNEL],:]102state[self.layer\_id,:,[PREV\_X\_CHANNEL],:]=x103else:104prev\_x=self.time\_shift(x)

#

rt​=Wr​⋅(μr​xt​+(1−μr​)xt−1​)

107receptance=x\*self.time\_mix\_receptance+prev\_x\*(1-self.time\_mix\_receptance)108receptance=self.receptance\_proj(receptance)

#

kt​=Wk​⋅(μk​xt​+(1−μk​)xt−1​)

111key=x\*self.time\_mix\_key+prev\_x\*(1-self.time\_mix\_key)112key=self.key\_proj(key)

#

Vt​=Wv​⋅max(kt​,0)2

115value=self.value\_proj(torch.square(torch.relu(key)))

#

ot​=σ(rt​)⊙vt​

118out=F.sigmoid(receptance)\*value119returnout,state

#

Time Mixing

122classTimeMixing(nn.Module):

#

127def\_\_init\_\_(self,config,layer\_id):128super().\_\_init\_\_()129self.config=config130self.time\_shift=nn.ZeroPad2d((0,0,1,-1))131self.layer\_id=layer\_id132133n\_embd=config.n\_embd134attn\_sz=n\_embd

#

learnable matrix

137self.key\_proj=nn.Linear(n\_embd,attn\_sz,bias=False)138self.value\_proj=nn.Linear(n\_embd,attn\_sz,bias=False)139self.receptance\_proj=nn.Linear(n\_embd,attn\_sz,bias=False)140self.output\_proj=nn.Linear(attn\_sz,n\_embd,bias=False)

#

learnable vector

143self.time\_decay=nn.Parameter(torch.empty(attn\_sz))144self.time\_first=nn.Parameter(torch.empty(attn\_sz))145self.time\_mix\_key=nn.Parameter(torch.empty(1,1,n\_embd))146self.time\_mix\_value=nn.Parameter(torch.empty(1,1,n\_embd))147self.time\_mix\_receptance=nn.Parameter(torch.empty(1,1,n\_embd))

#

x = (Batch,Time,Channel)

149defforward(self,x,state=None):

#

153ifstateisnotNone:154prev\_x=state[self.layer\_id,:,[PREV\_X\_TIME],:]155state[self.layer\_id,:,[PREV\_X\_TIME],:]=x156else:157prev\_x=self.time\_shift(x)

#

rt​=Wr​⋅(μr​xt​+(1−μr​)xt−1​)

160receptance=x\*self.time\_mix\_receptance+prev\_x\*(1-self.time\_mix\_receptance)161receptance=self.receptance\_proj(receptance)

#

kt​=Wk​⋅(μk​xt​+(1−μk​)xt−1​)

164key=x\*self.time\_mix\_key+prev\_x\*(1-self.time\_mix\_key)165key=self.key\_proj(key)

#

vt​=Wv​⋅(μv​xt​+(1−μv​)xt−1​)

168value=x\*self.time\_mix\_value+prev\_x\*(1-self.time\_mix\_value)169value=self.value\_proj(value)

#

WKV calculation

172\_,seq\_length,\_=key.size()173output=torch.zeros\_like(key)174175ifstateisNone:176num\_state=torch.zeros\_like(key[:,0],dtype=torch.float32)177den\_state=torch.zeros\_like(key[:,0],dtype=torch.float32)178max\_state=torch.zeros\_like(key[:,0],dtype=torch.float32)-1e38179else:180num\_state=state[self.layer\_id,:,NUM\_STATE,:]181den\_state=state[self.layer\_id,:,DEN\_STATE,:]182max\_state=state[self.layer\_id,:,MAX\_STATE,:]183184time\_decay=-torch.exp(self.time\_decay)185186forcurrent\_indexinrange(seq\_length):187current\_key=key[:,current\_index].float()188current\_value=value[:,current\_index]

#

wkvt​=∑i=1t−1​e−(t−1−i)w+ki​+eu+kt​∑i=1t−1​d−(t−1−i)w+ki​vi​+eu+kt​vt​​

191max\_for\_output=torch.maximum(max\_state,current\_key+self.time\_first)192e1=torch.exp(max\_state-max\_for\_output)193e2=torch.exp(current\_key+self.time\_first-max\_for\_output)194numerator=e1\*num\_state+e2\*current\_value195denominator=e1\*den\_state+e2196output[:,current\_index]=(numerator/denominator).to(output.dtype)

#

Update state for next iteration

199max\_for\_state=torch.maximum(max\_state+time\_decay,current\_key)200e1=torch.exp(max\_state+time\_decay-max\_for\_state)201e2=torch.exp(current\_key-max\_for\_state)202num\_state=e1\*num\_state+e2\*current\_value203den\_state=e1\*den\_state+e2204max\_state=max\_for\_state

#

update states

207state[self.layer\_id,:,NUM\_STATE,:]=num\_state208state[self.layer\_id,:,DEN\_STATE,:]=den\_state209state[self.layer\_id,:,MAX\_STATE,:]=max\_state210wkv,state=self.wkv\_function(key,value,use\_customized\_cuda\_kernel=self.config.use\_customized\_cuda\_kernel,211state=state)

#

ot​=Wo​⋅(σ(rt​)⊙wkvt​)

214rwkv=F.sigmoid(receptance)\*wkv215rwkv=self.output\_proj(rwkv)216217returnrwkv,state

#

RWKV block element

220classBlock(nn.Module):

#

225def\_\_init\_\_(self,config,layer\_id):226super().\_\_init\_\_()227self.ln\_1=LayerNorm(config.n\_embd,bias=config.bias)228self.attn=TimeMixing(config,layer\_id)229self.ln\_2=LayerNorm(config.n\_embd,bias=config.bias)230self.ffn=ChannelMixing(config,layer\_id)

#

232defforward(self,x,state=None):

#

state: batch_size, 5 , n_embd

#

time mixing

236residual=x237x,state=self.attn(self.ln\_1(x),state=state)238x=x+residual

#

channel mixing

241residual=x242x,state=self.ffn(self.ln\_2(x),state=state)243x=x+residual244returnx,state

#

RWKV

247classRWKV(nn.Module):

#

251def\_\_init\_\_(self,config,lr\_init=0.0008):252super().\_\_init\_\_()253assertconfig.vocab\_sizeisnotNone254assertconfig.block\_sizeisnotNone255self.config=config256self.lr\_init=lr\_init## used to initialize embedding parameters257self.n\_layer=config.n\_layer258self.n\_embd=config.n\_embd

#

Initiate model layers

261self.rwkv=nn.ModuleDict(dict(262wte=nn.Embedding(config.vocab\_size,config.n\_embd),263ln\_p=LayerNorm(config.n\_embd,bias=config.bias),264h=nn.ModuleList([Block(config,layer\_id)forlayer\_idinrange(config.n\_layer)]),265ln\_f=LayerNorm(config.n\_embd,bias=config.bias),266))

#

Output linear layer

269self.lm\_head=nn.Linear(config.n\_embd,config.vocab\_size,bias=False)

#

271defforward(self,idx,targets=None,state=None,return\_state=False):272b,t=idx.size()273assertt\<=self.config.block\_size,f"Cannot forward sequence of length {t}, block size is only {self.config.block\_size}"

#

Embedding Layer

276x=self.rwkv.wte(idx)

#

Layer Norm

279x=self.rwkv.ln\_p(x)

#

RWKV Blocks

282forblock\_idx,blockinenumerate(self.rwkv.h):283x,state=block(x,state)284x=self.rwkv.ln\_f(x)

#

Logit Layer and loss Function (for training)

287iftargetsisnotNone:

#

if we are given some desired targets also calculate the loss

289logits=self.lm\_head(x)290loss=F.cross\_entropy(logits.view(-1,logits.size(-1)),targets.view(-1),ignore\_index=-1)291ifself.training:292loss=L2Wrap.apply(loss,logits)293else:

#

inference-time mini-optimization: only forward the lm_head on the very last position

295logits=self.lm\_head(x[:,[-1],:])# note: using list [-1] to preserve the time dim296loss=None

#

Return Logits and loss

299ifreturn\_state:300returnlogits,loss,state301else:302returnlogits,loss

labml.ai