docs/transformers/fast_weights/token_wise.html
9fromtypingimportOptional1011importtorch12fromtorchimportnn1314fromlabml\_nn.transformers.fast\_weightsimportDPFP15fromlabml\_nn.transformers.feed\_forwardimportFeedForward16fromlabml\_nn.transformers.mhaimportPrepareForMultiHeadAttention17fromlabml\_nn.utilsimportclone\_module\_list
20classFastWeightsAttention(nn.Module):
21def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float,phi:DPFP):22super().\_\_init\_\_()
Number of features per head
25self.d\_k=d\_model//heads
27self.heads=heads
These transform the query multi-headed attention.
30self.query=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)
These transform the key and value for multi-headed attention.
32self.key=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)33self.value=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)3435self.gate=nn.Sequential(PrepareForMultiHeadAttention(d\_model,heads,1,bias=False),36nn.Sigmoid())3738self.phi=phi
Output layer
41self.output=nn.Linear(d\_model,d\_model)
Dropout
43self.dropout=nn.Dropout(dropout\_prob)
45defforward(self,x:torch.Tensor,weights:Optional[torch.Tensor]):46query=self.phi(self.query(x))47key=self.phi(self.key(x))48value=self.value(x)4950ifweightsisNone:51weights=key.new\_zeros((key.shape[0],key.shape[1],value.shape[2],key.shape[2]))5253value\_existing=torch.einsum('bhvk,bhk-\>bhv',weights,key)5455beta=self.gate(x)5657weights=weights+torch.einsum('bhv,bhk-\>bhvk',beta\*(value-value\_existing),key)5859x=torch.einsum('bhvk,bhk-\>bhv',weights,query)
Concatenate multiple heads
62x=x.reshape(x.shape[0],-1)
Output layer
65returnself.output(x),weights
68classFastWeightsAttentionTransformerLayer(nn.Module):
69def\_\_init\_\_(self,\*,70d\_model:int,71attn:FastWeightsAttention,72feed\_forward:FeedForward,73dropout\_prob:float):74super().\_\_init\_\_()
Transformer size dmodel
76self.size=d\_model
78self.attn=attn79self.feed\_forward=feed\_forward80self.dropout=nn.Dropout(dropout\_prob)
Normalization layers
83self.norm\_self\_attn=nn.LayerNorm([d\_model])84self.norm\_ff=nn.LayerNorm([d\_model])
86defforward(self,x:torch.Tensor,weights:Optional[torch.Tensor]):87attn,weights=self.attn(x,weights)
Add the self attention results
89x=x+self.dropout(attn)
Normalize for feed-forward
92z=self.norm\_ff(x)
Pass through the feed-forward network
94ff=self.feed\_forward(z)
Add the feed-forward results back
96x=x+self.dropout(ff)
99returnx,weights
102classFastWeightsAttentionTransformer(nn.Module):
103def\_\_init\_\_(self,layer:FastWeightsAttentionTransformerLayer,n\_layers:int):104super().\_\_init\_\_()
Make copies of the transformer layer
106self.layers=clone\_module\_list(layer,n\_layers)
Final normalization layer
108self.norm=nn.LayerNorm([layer.size])
110defforward(self,x\_seq:torch.Tensor):
Split the input to a list along the sequence axis
112x\_seq=torch.unbind(x\_seq,dim=0)
List to store the outputs
114res=[]
For each input step
116weights=[Nonefor\_inrange(len(self.layers))]117118forxinx\_seq:
Run through each layer
120fori,layerinenumerate(self.layers):
Get layer output
122x,weights[i]=layer(x,weights[i])123124res.append(x)
Stack the output tensors
127res=torch.stack(res)
Normalize the output
129returnself.norm(res)