docs/si/transformers/fast_weights/token_wise.html
9fromtypingimportOptional1011importtorch12fromtorchimportnn1314fromlabml\_helpers.moduleimportModule15fromlabml\_nn.transformers.fast\_weightsimportDPFP16fromlabml\_nn.transformers.feed\_forwardimportFeedForward17fromlabml\_nn.transformers.mhaimportPrepareForMultiHeadAttention18fromlabml\_nn.utilsimportclone\_module\_list
21classFastWeightsAttention(Module):
22def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float,phi:DPFP):23super().\_\_init\_\_()
හිසකටවිශේෂාංග ගණන
26self.d\_k=d\_model//heads
28self.heads=heads
මේවා query බහු-ශීර්ෂ අවධානය පරිවර්තනය කරයි.
31self.query=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)
මේවාබහු ශීර්ෂ අවධානය value සඳහා පරිවර්තනය කරයි. key
33self.key=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)34self.value=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)3536self.gate=nn.Sequential(PrepareForMultiHeadAttention(d\_model,heads,1,bias=False),37nn.Sigmoid())3839self.phi=phi
ප්රතිදානස්ථරය
42self.output=nn.Linear(d\_model,d\_model)
හැලීම
44self.dropout=nn.Dropout(dropout\_prob)
46defforward(self,x:torch.Tensor,weights:Optional[torch.Tensor]):47query=self.phi(self.query(x))48key=self.phi(self.key(x))49value=self.value(x)5051ifweightsisNone:52weights=key.new\_zeros((key.shape[0],key.shape[1],value.shape[2],key.shape[2]))5354value\_existing=torch.einsum('bhvk,bhk-\>bhv',weights,key)5556beta=self.gate(x)5758weights=weights+torch.einsum('bhv,bhk-\>bhvk',beta\*(value-value\_existing),key)5960x=torch.einsum('bhvk,bhk-\>bhv',weights,query)
බහුහිස් සංයුක්ත කරන්න
63x=x.reshape(x.shape[0],-1)
ප්රතිදානස්ථරය
66returnself.output(x),weights
69classFastWeightsAttentionTransformerLayer(Module):
70def\_\_init\_\_(self,\*,71d\_model:int,72attn:FastWeightsAttention,73feed\_forward:FeedForward,74dropout\_prob:float):75super().\_\_init\_\_()
ට්රාන්ස්ෆෝමර්ප්රමාණය dmodel
77self.size=d\_model
79self.attn=attn80self.feed\_forward=feed\_forward81self.dropout=nn.Dropout(dropout\_prob)
සාමාන්යකරණයස්ථර
84self.norm\_self\_attn=nn.LayerNorm([d\_model])85self.norm\_ff=nn.LayerNorm([d\_model])
87defforward(self,x:torch.Tensor,weights:Optional[torch.Tensor]):88attn,weights=self.attn(x,weights)
ස්වයංඅවධානය ප්රතිඵල එකතු
90x=x+self.dropout(attn)
පෝෂණයසඳහා සාමාන්යකරණය කරන්න
93z=self.norm\_ff(x)
Feed-forwardජාලය හරහා ගමන් කරන්න
95ff=self.feed\_forward(z)
ප්රතිපෝෂණඉදිරි ප්රති results ල නැවත එක් කරන්න
97x=x+self.dropout(ff)
100returnx,weights
103classFastWeightsAttentionTransformer(Module):
104def\_\_init\_\_(self,layer:FastWeightsAttentionTransformerLayer,n\_layers:int):105super().\_\_init\_\_()
ට්රාන්ස්ෆෝමර්ස්ථරයේ පිටපත් සාදන්න
107self.layers=clone\_module\_list(layer,n\_layers)
අවසානසාමාන්යකරණ ස්තරය
109self.norm=nn.LayerNorm([layer.size])
111defforward(self,x\_seq:torch.Tensor):
අනුක්රමිකඅක්ෂය දිගේ ලැයිස්තුවකට ආදානය බෙදන්න
113x\_seq=torch.unbind(x\_seq,dim=0)
ප්රතිදානයන්ගබඩා කිරීම සඳහා ලැයිස්තුව
115res=[]
එක්එක් ආදාන පියවර සඳහා
117weights=[Nonefor\_inrange(len(self.layers))]118119forxinx\_seq:
එක්එක් ස්ථරය හරහා ධාවනය කරන්න
121fori,layerinenumerate(self.layers):
ස්ථරප්රතිදානය ලබා ගන්න
123x,weights[i]=layer(x,weights[i])124125res.append(x)
නිමැවුම්ආතතීන් ගොඩගසන්න
128res=torch.stack(res)
ප්රතිදානයසාමාන්යකරණය කරන්න
130returnself.norm(res)