Back to Annotated Deep Learning Paper Implementations

Fast Weight Systems

docs/transformers/fast_weights/token_wise.html

latest4.3 KB
Original Source

hometransformersfast_weights

View code on Github

#

9fromtypingimportOptional1011importtorch12fromtorchimportnn1314fromlabml\_nn.transformers.fast\_weightsimportDPFP15fromlabml\_nn.transformers.feed\_forwardimportFeedForward16fromlabml\_nn.transformers.mhaimportPrepareForMultiHeadAttention17fromlabml\_nn.utilsimportclone\_module\_list

#

20classFastWeightsAttention(nn.Module):

#

21def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float,phi:DPFP):22super().\_\_init\_\_()

#

Number of features per head

25self.d\_k=d\_model//heads

#

27self.heads=heads

#

These transform the query multi-headed attention.

30self.query=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)

#

These transform the key and value for multi-headed attention.

32self.key=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)33self.value=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)3435self.gate=nn.Sequential(PrepareForMultiHeadAttention(d\_model,heads,1,bias=False),36nn.Sigmoid())3738self.phi=phi

#

Output layer

41self.output=nn.Linear(d\_model,d\_model)

#

Dropout

43self.dropout=nn.Dropout(dropout\_prob)

#

45defforward(self,x:torch.Tensor,weights:Optional[torch.Tensor]):46query=self.phi(self.query(x))47key=self.phi(self.key(x))48value=self.value(x)4950ifweightsisNone:51weights=key.new\_zeros((key.shape[0],key.shape[1],value.shape[2],key.shape[2]))5253value\_existing=torch.einsum('bhvk,bhk-\>bhv',weights,key)5455beta=self.gate(x)5657weights=weights+torch.einsum('bhv,bhk-\>bhvk',beta\*(value-value\_existing),key)5859x=torch.einsum('bhvk,bhk-\>bhv',weights,query)

#

Concatenate multiple heads

62x=x.reshape(x.shape[0],-1)

#

Output layer

65returnself.output(x),weights

#

68classFastWeightsAttentionTransformerLayer(nn.Module):

#

69def\_\_init\_\_(self,\*,70d\_model:int,71attn:FastWeightsAttention,72feed\_forward:FeedForward,73dropout\_prob:float):74super().\_\_init\_\_()

#

Transformer size dmodel​

76self.size=d\_model

#

78self.attn=attn79self.feed\_forward=feed\_forward80self.dropout=nn.Dropout(dropout\_prob)

#

Normalization layers

83self.norm\_self\_attn=nn.LayerNorm([d\_model])84self.norm\_ff=nn.LayerNorm([d\_model])

#

86defforward(self,x:torch.Tensor,weights:Optional[torch.Tensor]):87attn,weights=self.attn(x,weights)

#

Add the self attention results

89x=x+self.dropout(attn)

#

Normalize for feed-forward

92z=self.norm\_ff(x)

#

Pass through the feed-forward network

94ff=self.feed\_forward(z)

#

Add the feed-forward results back

96x=x+self.dropout(ff)

#

99returnx,weights

#

102classFastWeightsAttentionTransformer(nn.Module):

#

103def\_\_init\_\_(self,layer:FastWeightsAttentionTransformerLayer,n\_layers:int):104super().\_\_init\_\_()

#

Make copies of the transformer layer

106self.layers=clone\_module\_list(layer,n\_layers)

#

Final normalization layer

108self.norm=nn.LayerNorm([layer.size])

#

110defforward(self,x\_seq:torch.Tensor):

#

Split the input to a list along the sequence axis

112x\_seq=torch.unbind(x\_seq,dim=0)

#

List to store the outputs

114res=[]

#

For each input step

116weights=[Nonefor\_inrange(len(self.layers))]117118forxinx\_seq:

#

Run through each layer

120fori,layerinenumerate(self.layers):

#

Get layer output

122x,weights[i]=layer(x,weights[i])123124res.append(x)

#

Stack the output tensors

127res=torch.stack(res)

#

Normalize the output

129returnself.norm(res)

labml.ai