docs/transformers/fast_weights/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/fast_weights/ init.py)
The paper Linear Transformers Are Secretly Fast Weight Memory Systems in PyTorch finds similarities between linear self-attention and fast weight systems and makes modifications to self-attention update rule based on that. It also introduces a simpler, yet effective kernel function.
The authors have provided an official implementation of the paper including other variants they compare with in the paper.
Consider a sequence of inputs {x(i)}i=1L or length L and each step is a vector of size din; i.e. x∈Rdin. The fast weight model generates a weight matrix at each step to produce output {y(i)}i=1L, y∈Rdout
a(i),b(i)W(i)y(i)=Wax(i),Wbx(i)=σ(W(i−1)+a(i)⊗b(i))=W(i)x(i)
⊗ is the outer product (a⊗b=ab⊤), where elements of the two vectors are multiplied with each other to give a matrix. σ is an activation function. Wa and Wb are trainable weights (parameters). W(i) are the fast weights that are generated at each step.
Original transformer self-attention is, (omitting dk1 for clarity)
y(i)=[v(1),v(2),...,v(i)]softmax([k(1),k(2),...,k(i)]⊤q(i))=j=1∑i∑j′=1iκ(k(j′),q(i))v(j)κ(k(j),q(i))
where κ(k,q)=exp(k⋅q)
The idea behind linearizing self attention is to replace softmax kernel κ with a different kernel κ′ so that we can calculate the denominator of the self attention function faster:
κ′(k,q)=ϕ(k)⊤ϕ(q)
This gives
y(i)=(∑j′=1iϕ(k(j′)))ϕ(q(i))(∑j=1iv(j)⊗ϕ(k(j)))ϕ(q(i))
With W(i)=∑j=1iv(j)⊗ϕ(k(j)) and z(i)=∑j=1iϕ(k(j)), we can calculate them efficiently:
W(i)z(i)y(i)=W(i−1)+v(i)⊗ϕ(k(i))=z(i)+ϕ(k(i))=z(i)⋅ϕ(q(i))1W(i)ϕ(q(i))
This is quite similar to fast weights.
The paper introduces a new linear attention projection function ϕ a new update rule for W(i)=f(W(i−1)) and change the normalization z(i)⋅ϕ(q(i))1
Here are the training code and a notebook for training a fast weights transformer on the Tiny Shakespeare dataset.
95importtorch96fromtorchimportnn9798fromlabml\_nn.transformers.feed\_forwardimportFeedForward99fromlabml\_nn.transformers.mhaimportPrepareForMultiHeadAttention100fromlabml\_nn.utilsimportclone\_module\_list
This is the new projection function ϕ introduced in the paper. DPFP projects k of dimensionality dkey to dimensionality ddot=2dkeyν, where ν∈1,2,...,2dkey−1 is a hyper-parameter.
ϕ2dkey(i−1)+j(k)=ReLU([k,−k])jReLU([k,−k])i+j
where [k,−k] is the concatenation of k and −k to give a vector of size 2dkey, i∈1,2,...,ν, and j∈1,2,...,2dkey. xi is the i-th element of vector x and is rolled around if i is larger than the number of elements in x.
Basically, it creates a new vector by multiplying elements of [k,−k] shifted by i.
This produces projections that are sparse (only a few elements of phi are non-zero) and orthogonal (ϕ(k(i))⋅ϕ(k(j))≈0 for most i,j unless k(i) and k(j) are very similar.
Paper introduces a simple normalization for ϕ,
ϕ′(k)=∑j=1ddotϕ(k)jϕ(k)
Check the paper for derivation.
103classDPFP(nn.Module):
nu is the hyper-parameter ν.eps is the small value used to make sure there is no division-by-zero when normalizing.137def\_\_init\_\_(self,nu:int=1,eps:float=1e-6):
142super().\_\_init\_\_()143self.nu=nu144self.relu=nn.ReLU()145self.eps=eps
147defforward(self,k:torch.Tensor):
Get ϕ(k)
149k=self.dpfp(k)
Normalize by ∑j=1ddotϕ(k)j
151returnk/(torch.sum(k,dim=-1,keepdim=True)+self.eps)
ϕ(k)
153defdpfp(self,k:torch.Tensor):
x=ReLU([k,−k])
158x=self.relu(torch.cat([k,-k],dim=-1))
Shift and roll by i∈1,2,...,ν, to get xi,j′=ReLU([k,−k])i+j
161x\_rolled=[x.roll(shifts=i,dims=-1)foriinrange(1,self.nu+1)]
Concatenate to get x2dkey(i−1)+j′=ReLU([k,−k])i+j
164x\_rolled=torch.cat(x\_rolled,dim=-1)
Concatenate copies of x
166x\_repeat=torch.cat([x]\*self.nu,dim=-1)
Multiply them, ϕ2dkey(i−1)+j(k)=ReLU([k,−k])jReLU([k,−k])i+j
172returnx\_repeat\*x\_rolled
The paper introduces a new update rule for calculating W(i). The model first retrieves the current value vˉ(i) paired with the key k(i). Then stores a combination v(i)new of the retrieved value vˉ(i) and the input v(i).
k(i),v(i),q(i)vˉ(i)β(i)v(i)newW(i)y(i)=Wkx(i),Wvx(i),Wqx(i)=W(i−1)ϕ′(k(i))=σ(Wβx(i))=β(i)v(i)+(1−β(i))vˉ(i)=W(i−1)+v(i)new⊗ϕ′(k(i))=W(i−1)+β(i)(v(i)−vˉ(i))⊗ϕ′(k(i))=W(i)ϕ′(q(i))
where Wβ is a trainable parameter and σ is the sigmoid function.
Note that we don't need the normalization term z because ϕ′ is normalized.
175classFastWeightsAttention(nn.Module):
203def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float,phi:DPFP):204super().\_\_init\_\_()
Number of features per head dk
207self.d\_k=d\_model//heads
Number of heads
209self.heads=heads
These transform the query , key and value multi-headed attention.
212self.query=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)213self.key=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)214self.value=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)
Interpolation weight function σ(Wβx(i)) for each head
217self.interpolation\_weight=nn.Sequential(218PrepareForMultiHeadAttention(d\_model,heads,1,bias=False),219nn.Sigmoid()220)
ϕ′
223self.phi=phi
Output layer
226self.output=nn.Linear(d\_model,d\_model)
Dropout
228self.dropout=nn.Dropout(dropout\_prob)
230defforward(self,x:torch.Tensor):
Get the number of steps L
232seq\_len=x.shape[0]
ϕ′(q(i)) for all steps and heads
234query=self.phi(self.query(x))
ϕ′(k(i)) for all steps and heads
236key=self.phi(self.key(x))
v(i) for all steps and heads
238value=self.value(x)
β(i) for all steps and heads
240beta=self.interpolation\_weight(x)
W(0)
243weights=key.new\_zeros((key.shape[1],key.shape[2],value.shape[3],key.shape[3]))
List to store outputs y(i)
245outputs=[]
Iterate through steps
248foriinrange(seq\_len):
vˉ(i)=W(i−1)ϕ′(k(i))
250value\_existing=torch.einsum('bhvk,bhk-\>bhv',weights,key[i])
W(i)=W(i−1)+β(i)(v(i)−vˉ(i))⊗ϕ′(k(i))
255weights=weights+torch.einsum('bhv,bhk-\>bhvk',beta[i]\*(value[i]-value\_existing),key[i])
y(i)=W(i)ϕ′(q(i))
258y=torch.einsum('bhvk,bhk-\>bhv',weights,query[i])
Merge multiple heads and append to outputs
261outputs.append(y.reshape(y.shape[0],-1))
Stack outputs at each step into a single tensor
264x=torch.stack(outputs)
Output layer
267returnself.output(x)
This is a general transformer layer that combines self attention and feedforward network.
270classFastWeightsAttentionTransformerLayer(nn.Module):
274def\_\_init\_\_(self,\*,275d\_model:int,276attn:FastWeightsAttention,277feed\_forward:FeedForward,278dropout\_prob:float):279super().\_\_init\_\_()
Transformer size dmodel
281self.size=d\_model
Fast weights attention module
283self.attn=attn
Feed-forward network
285self.feed\_forward=feed\_forward
Dropout layer
287self.dropout=nn.Dropout(dropout\_prob)
Normalization layers
290self.norm\_self\_attn=nn.LayerNorm([d\_model])291self.norm\_ff=nn.LayerNorm([d\_model])
293defforward(self,x:torch.Tensor):
Calculate fast weights self attention
295attn=self.attn(x)
Add the self attention results
297x=x+self.dropout(attn)
Normalize for feed-forward
300z=self.norm\_ff(x)
Pass through the feed-forward network
302ff=self.feed\_forward(z)
Add the feed-forward results back
304x=x+self.dropout(ff)
307returnx
This is a general transformer module with multiple transformer layers
310classFastWeightsAttentionTransformer(nn.Module):
314def\_\_init\_\_(self,layer:FastWeightsAttentionTransformerLayer,n\_layers:int):315super().\_\_init\_\_()
Make copies of the transformer layer
317self.layers=clone\_module\_list(layer,n\_layers)
Final normalization layer
319self.norm=nn.LayerNorm([layer.size])
321defforward(self,x:torch.Tensor):322fori,layerinenumerate(self.layers):
Get layer output
324x=layer(x)
Normalize the output
327returnself.norm(x)