Back to Annotated Deep Learning Paper Implementations

Fast weights transformer

docs/transformers/fast_weights/index.html

latest10.0 KB
Original Source

hometransformersfast_weights

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/fast_weights/ init.py)

#

Fast weights transformer

The paper Linear Transformers Are Secretly Fast Weight Memory Systems in PyTorch finds similarities between linear self-attention and fast weight systems and makes modifications to self-attention update rule based on that. It also introduces a simpler, yet effective kernel function.

The authors have provided an official implementation of the paper including other variants they compare with in the paper.

Fast weights

Consider a sequence of inputs {x(i)}i=1L​ or length L and each step is a vector of size din​; i.e. x∈Rdin​. The fast weight model generates a weight matrix at each step to produce output {y(i)}i=1L​, y∈Rdout​

a(i),b(i)W(i)y(i)​=Wa​x(i),Wb​x(i)=σ(W(i−1)+a(i)⊗b(i))=W(i)x(i)​

⊗ is the outer product (a⊗b=ab⊤), where elements of the two vectors are multiplied with each other to give a matrix. σ is an activation function. Wa​ and Wb​ are trainable weights (parameters). W(i) are the fast weights that are generated at each step.

Linear self-attention

Original transformer self-attention is, (omitting dk​1​ for clarity)

y(i)​=[v(1),v(2),...,v(i)]softmax([k(1),k(2),...,k(i)]⊤q(i))=j=1∑i​∑j′=1i​κ(k(j′),q(i))v(j)κ(k(j),q(i))​​

where κ(k,q)=exp(k⋅q)

The idea behind linearizing self attention is to replace softmax kernel κ with a different kernel κ′ so that we can calculate the denominator of the self attention function faster:

κ′(k,q)=ϕ(k)⊤ϕ(q)

This gives

y(i)​=(∑j′=1i​ϕ(k(j′)))ϕ(q(i))(∑j=1i​v(j)⊗ϕ(k(j)))ϕ(q(i))​​

With W(i)=∑j=1i​v(j)⊗ϕ(k(j)) and z(i)=∑j=1i​ϕ(k(j)), we can calculate them efficiently:

W(i)z(i)y(i)​=W(i−1)+v(i)⊗ϕ(k(i))=z(i)+ϕ(k(i))=z(i)⋅ϕ(q(i))1​W(i)ϕ(q(i))​

This is quite similar to fast weights.

The paper introduces a new linear attention projection function ϕ a new update rule for W(i)=f(W(i−1)) and change the normalization z(i)⋅ϕ(q(i))1​

Here are the training code and a notebook for training a fast weights transformer on the Tiny Shakespeare dataset.

95importtorch96fromtorchimportnn9798fromlabml\_nn.transformers.feed\_forwardimportFeedForward99fromlabml\_nn.transformers.mhaimportPrepareForMultiHeadAttention100fromlabml\_nn.utilsimportclone\_module\_list

#

Deterministic Parameter Free Project (DPFP)

This is the new projection function ϕ introduced in the paper. DPFP projects k of dimensionality dkey​ to dimensionality ddot​=2dkey​ν, where ν∈1,2,...,2dkey​−1 is a hyper-parameter.

ϕ2dkey​(i−1)+j​(k)=ReLU([k,−k])j​ReLU([k,−k])i+j​

where [k,−k] is the concatenation of k and −k to give a vector of size 2dkey​, i∈1,2,...,ν, and j∈1,2,...,2dkey​. xi​ is the i-th element of vector x and is rolled around if i is larger than the number of elements in x.

Basically, it creates a new vector by multiplying elements of [k,−k] shifted by i.

This produces projections that are sparse (only a few elements of phi are non-zero) and orthogonal (ϕ(k(i))⋅ϕ(k(j))≈0 for most i,j unless k(i) and k(j) are very similar.

Normalization

Paper introduces a simple normalization for ϕ,

ϕ′(k)=∑j=1ddot​​ϕ(k)j​ϕ(k)​

Check the paper for derivation.

103classDPFP(nn.Module):

#

  • nu is the hyper-parameter ν.
  • eps is the small value used to make sure there is no division-by-zero when normalizing.
137def\_\_init\_\_(self,nu:int=1,eps:float=1e-6):

#

142super().\_\_init\_\_()143self.nu=nu144self.relu=nn.ReLU()145self.eps=eps

#

147defforward(self,k:torch.Tensor):

#

Get ϕ(k)

149k=self.dpfp(k)

#

Normalize by ∑j=1ddot​​ϕ(k)j​

151returnk/(torch.sum(k,dim=-1,keepdim=True)+self.eps)

#

ϕ(k)

153defdpfp(self,k:torch.Tensor):

#

x=ReLU([k,−k])

158x=self.relu(torch.cat([k,-k],dim=-1))

#

Shift and roll by i∈1,2,...,ν, to get xi,j′​=ReLU([k,−k])i+j​

161x\_rolled=[x.roll(shifts=i,dims=-1)foriinrange(1,self.nu+1)]

#

Concatenate to get x2dkey​(i−1)+j′​=ReLU([k,−k])i+j​

164x\_rolled=torch.cat(x\_rolled,dim=-1)

#

Concatenate copies of x

166x\_repeat=torch.cat([x]\*self.nu,dim=-1)

#

Multiply them, ϕ2dkey​(i−1)+j​(k)=ReLU([k,−k])j​ReLU([k,−k])i+j​

172returnx\_repeat\*x\_rolled

#

Fast Weights Attention

The paper introduces a new update rule for calculating W(i). The model first retrieves the current value vˉ(i) paired with the key k(i). Then stores a combination v(i)new​ of the retrieved value vˉ(i) and the input v(i).

k(i),v(i),q(i)vˉ(i)β(i)v(i)new​W(i)y(i)​=Wk​x(i),Wv​x(i),Wq​x(i)=W(i−1)ϕ′(k(i))=σ(Wβ​x(i))=β(i)v(i)+(1−β(i))vˉ(i)=W(i−1)+v(i)new​⊗ϕ′(k(i))=W(i−1)+β(i)(v(i)−vˉ(i))⊗ϕ′(k(i))=W(i)ϕ′(q(i))​

where Wβ​ is a trainable parameter and σ is the sigmoid function.

Note that we don't need the normalization term z because ϕ′ is normalized.

175classFastWeightsAttention(nn.Module):

#

203def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float,phi:DPFP):204super().\_\_init\_\_()

#

Number of features per head dk​

207self.d\_k=d\_model//heads

#

Number of heads

209self.heads=heads

#

These transform the query , key and value multi-headed attention.

212self.query=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)213self.key=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)214self.value=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)

#

Interpolation weight function σ(Wβ​x(i)) for each head

217self.interpolation\_weight=nn.Sequential(218PrepareForMultiHeadAttention(d\_model,heads,1,bias=False),219nn.Sigmoid()220)

#

ϕ′

223self.phi=phi

#

Output layer

226self.output=nn.Linear(d\_model,d\_model)

#

Dropout

228self.dropout=nn.Dropout(dropout\_prob)

#

230defforward(self,x:torch.Tensor):

#

Get the number of steps L

232seq\_len=x.shape[0]

#

ϕ′(q(i)) for all steps and heads

234query=self.phi(self.query(x))

#

ϕ′(k(i)) for all steps and heads

236key=self.phi(self.key(x))

#

v(i) for all steps and heads

238value=self.value(x)

#

β(i) for all steps and heads

240beta=self.interpolation\_weight(x)

#

W(0)

243weights=key.new\_zeros((key.shape[1],key.shape[2],value.shape[3],key.shape[3]))

#

List to store outputs y(i)

245outputs=[]

#

Iterate through steps

248foriinrange(seq\_len):

#

vˉ(i)=W(i−1)ϕ′(k(i))

250value\_existing=torch.einsum('bhvk,bhk-\>bhv',weights,key[i])

#

W(i)=W(i−1)+β(i)(v(i)−vˉ(i))⊗ϕ′(k(i))

255weights=weights+torch.einsum('bhv,bhk-\>bhvk',beta[i]\*(value[i]-value\_existing),key[i])

#

y(i)=W(i)ϕ′(q(i))

258y=torch.einsum('bhvk,bhk-\>bhv',weights,query[i])

#

Merge multiple heads and append to outputs

261outputs.append(y.reshape(y.shape[0],-1))

#

Stack outputs at each step into a single tensor

264x=torch.stack(outputs)

#

Output layer

267returnself.output(x)

#

This is a general transformer layer that combines self attention and feedforward network.

270classFastWeightsAttentionTransformerLayer(nn.Module):

#

274def\_\_init\_\_(self,\*,275d\_model:int,276attn:FastWeightsAttention,277feed\_forward:FeedForward,278dropout\_prob:float):279super().\_\_init\_\_()

#

Transformer size dmodel​

281self.size=d\_model

#

Fast weights attention module

283self.attn=attn

#

Feed-forward network

285self.feed\_forward=feed\_forward

#

Dropout layer

287self.dropout=nn.Dropout(dropout\_prob)

#

Normalization layers

290self.norm\_self\_attn=nn.LayerNorm([d\_model])291self.norm\_ff=nn.LayerNorm([d\_model])

#

293defforward(self,x:torch.Tensor):

#

Calculate fast weights self attention

295attn=self.attn(x)

#

Add the self attention results

297x=x+self.dropout(attn)

#

Normalize for feed-forward

300z=self.norm\_ff(x)

#

Pass through the feed-forward network

302ff=self.feed\_forward(z)

#

Add the feed-forward results back

304x=x+self.dropout(ff)

#

307returnx

#

This is a general transformer module with multiple transformer layers

310classFastWeightsAttentionTransformer(nn.Module):

#

314def\_\_init\_\_(self,layer:FastWeightsAttentionTransformerLayer,n\_layers:int):315super().\_\_init\_\_()

#

Make copies of the transformer layer

317self.layers=clone\_module\_list(layer,n\_layers)

#

Final normalization layer

319self.norm=nn.LayerNorm([layer.size])

#

321defforward(self,x:torch.Tensor):322fori,layerinenumerate(self.layers):

#

Get layer output

324x=layer(x)

#

Normalize the output

327returnself.norm(x)

labml.ai