Back to Annotated Deep Learning Paper Implementations

වේගයෙන් සිරුරේ බර පද්ධති

docs/si/transformers/fast_weights/token_wise.html

latest5.4 KB
Original Source

hometransformersfast_weights

View code on Github

#

9fromtypingimportOptional1011importtorch12fromtorchimportnn1314fromlabml\_helpers.moduleimportModule15fromlabml\_nn.transformers.fast\_weightsimportDPFP16fromlabml\_nn.transformers.feed\_forwardimportFeedForward17fromlabml\_nn.transformers.mhaimportPrepareForMultiHeadAttention18fromlabml\_nn.utilsimportclone\_module\_list

#

21classFastWeightsAttention(Module):

#

22def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float,phi:DPFP):23super().\_\_init\_\_()

#

හිසකටවිශේෂාංග ගණන

26self.d\_k=d\_model//heads

#

28self.heads=heads

#

මේවා query බහු-ශීර්ෂ අවධානය පරිවර්තනය කරයි.

31self.query=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)

#

මේවාබහු ශීර්ෂ අවධානය value සඳහා පරිවර්තනය කරයි. key

33self.key=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)34self.value=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)3536self.gate=nn.Sequential(PrepareForMultiHeadAttention(d\_model,heads,1,bias=False),37nn.Sigmoid())3839self.phi=phi

#

ප්රතිදානස්ථරය

42self.output=nn.Linear(d\_model,d\_model)

#

හැලීම

44self.dropout=nn.Dropout(dropout\_prob)

#

46defforward(self,x:torch.Tensor,weights:Optional[torch.Tensor]):47query=self.phi(self.query(x))48key=self.phi(self.key(x))49value=self.value(x)5051ifweightsisNone:52weights=key.new\_zeros((key.shape[0],key.shape[1],value.shape[2],key.shape[2]))5354value\_existing=torch.einsum('bhvk,bhk-\>bhv',weights,key)5556beta=self.gate(x)5758weights=weights+torch.einsum('bhv,bhk-\>bhvk',beta\*(value-value\_existing),key)5960x=torch.einsum('bhvk,bhk-\>bhv',weights,query)

#

බහුහිස් සංයුක්ත කරන්න

63x=x.reshape(x.shape[0],-1)

#

ප්රතිදානස්ථරය

66returnself.output(x),weights

#

69classFastWeightsAttentionTransformerLayer(Module):

#

70def\_\_init\_\_(self,\*,71d\_model:int,72attn:FastWeightsAttention,73feed\_forward:FeedForward,74dropout\_prob:float):75super().\_\_init\_\_()

#

ට්රාන්ස්ෆෝමර්ප්රමාණය dmodel​

77self.size=d\_model

#

79self.attn=attn80self.feed\_forward=feed\_forward81self.dropout=nn.Dropout(dropout\_prob)

#

සාමාන්යකරණයස්ථර

84self.norm\_self\_attn=nn.LayerNorm([d\_model])85self.norm\_ff=nn.LayerNorm([d\_model])

#

87defforward(self,x:torch.Tensor,weights:Optional[torch.Tensor]):88attn,weights=self.attn(x,weights)

#

ස්වයංඅවධානය ප්රතිඵල එකතු

90x=x+self.dropout(attn)

#

පෝෂණයසඳහා සාමාන්යකරණය කරන්න

93z=self.norm\_ff(x)

#

Feed-forwardජාලය හරහා ගමන් කරන්න

95ff=self.feed\_forward(z)

#

ප්රතිපෝෂණඉදිරි ප්රති results ල නැවත එක් කරන්න

97x=x+self.dropout(ff)

#

100returnx,weights

#

103classFastWeightsAttentionTransformer(Module):

#

104def\_\_init\_\_(self,layer:FastWeightsAttentionTransformerLayer,n\_layers:int):105super().\_\_init\_\_()

#

ට්රාන්ස්ෆෝමර්ස්ථරයේ පිටපත් සාදන්න

107self.layers=clone\_module\_list(layer,n\_layers)

#

අවසානසාමාන්යකරණ ස්තරය

109self.norm=nn.LayerNorm([layer.size])

#

111defforward(self,x\_seq:torch.Tensor):

#

අනුක්රමිකඅක්ෂය දිගේ ලැයිස්තුවකට ආදානය බෙදන්න

113x\_seq=torch.unbind(x\_seq,dim=0)

#

ප්රතිදානයන්ගබඩා කිරීම සඳහා ලැයිස්තුව

115res=[]

#

එක්එක් ආදාන පියවර සඳහා

117weights=[Nonefor\_inrange(len(self.layers))]118119forxinx\_seq:

#

එක්එක් ස්ථරය හරහා ධාවනය කරන්න

121fori,layerinenumerate(self.layers):

#

ස්ථරප්රතිදානය ලබා ගන්න

123x,weights[i]=layer(x,weights[i])124125res.append(x)

#

නිමැවුම්ආතතීන් ගොඩගසන්න

128res=torch.stack(res)

#

ප්රතිදානයසාමාන්යකරණය කරන්න

130returnself.norm(res)

Trending Research Paperslabml.ai