Back to Annotated Deep Learning Paper Implementations

වේගයෙන්ට්රාන්ස්ෆෝමර් බර

docs/si/transformers/fast_weights/index.html

latest15.1 KB
Original Source

hometransformersfast_weights

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/fast_weights/ init.py)

#

වේගයෙන්ට්රාන්ස්ෆෝමර් බර

කඩදාසි රේඛීය ට්රාන්ස්ෆෝමර් පයිටෝර්ච් හි රහසින් වේගවත් බර මතක පද්ධති රේඛීය ස්වයං අවධානය සහ වේගවත් බර පද්ධති අතර සමානකම් සොයා ගන්නා අතර ඒ මත පදනම්ව ස්වයං අවධානය යාවත්කාලීන කිරීමේ රීතියට වෙනස් කිරීම් සිදු කරයි. එය සරල, නමුත් effective ලදායී කර්නල් ශ්රිතයක් ද හඳුන්වා දෙයි.

කතුවරුන්කඩදාසි සමඟ සංසන්දනය කරන වෙනත් ප්රභේද ඇතුළුව කඩදාසි නිල වශයෙන් ක්රියාත්මක කිරීමක් ලබා දී ඇත.

වේගවත්බර

යෙදවුම් {x(i)}i=1L​ හෝ දිග L අනුපිළිවෙලක් සලකා බලන්න. එක් එක් පියවර ප්රමාණයේ දෛශිකයකි din​; i.e x∈Rdin​. වේගවත් බර ආකෘතිය ප්රතිදානය නිෂ්පාදනය කිරීම සඳහා සෑම පියවරකදීම බර අනුකෘතියක් ජනනය කරයි {y(i)}i=1L​, y∈Rdout​

a(i),b(i)W(i)y(i)​=Wa​x(i),Wb​x(i)=σ(W(i−1)+a(i)⊗b(i))=W(i)x(i)​

⊗ යනු පිටත නිෂ්පාදිතය (a⊗b=ab⊤), අනුකෘතියක් ලබා දීම සඳහා දෛශික දෙකේ මූලද්රව්ය එකිනෙකා සමඟ ගුණ කරනු ලැබේ. σ යනු සක්රිය කිරීමේ කාර්යයකි. Wa​ පුහුණු කළ හැකි බර (පරාමිතීන්) වේ. Wb​ W(i) එක් එක් පියවරේදී ජනනය වන වේගවත් බර වේ.

රේඛීයස්වයං අවධානය

මුල්ට්රාන්ස්ෆෝමර් ස්වයං අවධානය යනු, (පැහැදිලි කිරීම dk​1​ සඳහා මඟ හැරීම)

y(i)​=[v(1),v(2),...,v(i)]softmax([k(1),k(2),...,k(i)]⊤q(i))=j=1∑i​∑j′=1i​κ(k(j′),q(i))v(j)κ(k(j),q(i))​​

කොහේද κ(k,q)=exp(k⋅q)

ස්වයංඅවධානය රේඛීයකරණය කිරීම පිටුපස ඇති අදහස නම්, සොෆ්ට්මැක්ස් කර්නලය වෙනත් කර්නලයක් κ සමඟ ප්රතිස්ථාපනය කිරීමයි, κ′ එවිට අපට ස්වයං අවධානය ක්රියාකාරිත්වයේ හරය වේගයෙන් ගණනය කළ හැකිය:

κ′(k,q)=ϕ(k)⊤ϕ(q)

මෙයලබා දෙයි

y(i)​=(∑j′=1i​ϕ(k(j′)))ϕ(q(i))(∑j=1i​v(j)⊗ϕ(k(j)))ϕ(q(i))​​

අපටඒවා කාර්යක්ෂමව ගණනය කළ හැකිය: W(i)=∑j=1i​v(j)⊗ϕ(k(j)) z(i)=∑j=1i​ϕ(k(j))

W(i)z(i)y(i)​=W(i−1)+v(i)⊗ϕ(k(i))=z(i)+ϕ(k(i))=z(i)⋅ϕ(q(i))1​W(i)ϕ(q(i))​

මෙයවේගවත් බරට බෙහෙවින් සමාන ය.

කඩදාසිනව රේඛීය අවධානය ප්රක්ෂේපණ ϕ ශ්රිතයක් හඳුන්වා දෙන W(i)=f(W(i−1)) අතර සාමාන්යකරණය වෙනස් කිරීම සඳහා නව යාවත්කාලීන රීතියක් z(i)⋅ϕ(q(i))1​

කුඩාෂේක්ස්පියර් දත්ත කට්ටලයෙහි වේගවත් බර ට්රාන්ස්ෆෝමරයක් පුහුණු කිරීම සඳහා පුහුණු කේතය සහ සටහන් පොතක් මෙන්න.

96importtorch97fromtorchimportnn9899fromlabml\_helpers.moduleimportModule100fromlabml\_nn.transformers.feed\_forwardimportFeedForward101fromlabml\_nn.transformers.mhaimportPrepareForMultiHeadAttention102fromlabml\_nn.utilsimportclone\_module\_list

#

නිර්නවාදීපරාමිති නිදහස් ව්යාපෘතිය (DPFP)

කඩදාසිතුළ ϕ හඳුන්වා දී ඇති නව ප්රක්ෂේපණ ශ්රිතය මෙයයි. ඩීපීඑෆ්පී ව්යාපෘති k dkey​ මානයන්හි මානයන්හි මානයන් ddot​=2dkey​ν, අධි-පරාමිතියක් කොහෙද? ν∈1,2,...,2dkey​−1

ϕ2dkey​(i−1)+j​(k)=ReLU([k,−k])j​ReLU([k,−k])i+j​

ප්රමාණයේදෛශිකයක් ලබා −k දීම k සහ [k,−k] කොතැනද 2dkey​i∈1,2,...,ν, සහ j∈1,2,...,2dkey​. xi​ යනු දෛශිකයේ i-th මූලද්රව්යය වන x අතර එය මූලද්රව්ය ගණනට වඩා විශාල නම් i වටා රෝල් කර ඇත x.

මූලිකවශයෙන්, එය [k,−k] මාරු කරන ලද මූලද්රව්ය ගුණ කිරීමෙන් නව දෛශිකයක් නිර්මාණය කරයි i.

මෙයවිරල (ශුන්ය නොවන මූලද්රව්ය කිහිපයක් පමණි) සහ විකලාංග (ϕ(k(i))⋅ϕ(k(j))≈0 බොහෝ විට i,j මිස k(i) සහ phi k(j) ඉතා සමාන ය.

සාමාන්‍යකරණය

කඩදාසිසඳහා සරල සාමාන්යකරණයක් හඳුන්වා දෙයි ϕ,

ϕ′(k)=∑j=1ddot​​ϕ(k)j​ϕ(k)​

ව්යුත්පන්නකිරීම සඳහා කඩදාසි පරීක්ෂා කරන්න.

105classDPFP(Module):

#

  • nu අධි-පරාමිතිය νවේ.
  • eps සාමාන්යකරණය කිරීමේදී බෙදීම් ශුන්ය නොවන බවට වග බලා ගැනීම සඳහා භාවිතා කරන කුඩා අගයයි.
139def\_\_init\_\_(self,nu:int=1,eps:float=1e-6):

#

144super().\_\_init\_\_()145self.nu=nu146self.relu=nn.ReLU()147self.eps=eps

#

149defforward(self,k:torch.Tensor):

#

ලබාගන්න ϕ(k)

151k=self.dpfp(k)

#

විසින්සාමාන්යකරණය ∑j=1ddot​​ϕ(k)j​

153returnk/(torch.sum(k,dim=-1,keepdim=True)+self.eps)

#

ϕ(k)

155defdpfp(self,k:torch.Tensor):

#

x=ReLU([k,−k])

160x=self.relu(torch.cat([k,-k],dim=-1))

#

ලබාගැනීම සඳහා මාරුවීම සහ රෝල් කරන්න i∈1,2,...,νxi,j′​=ReLU([k,−k])i+j​

163x\_rolled=[x.roll(shifts=i,dims=-1)foriinrange(1,self.nu+1)]

#

ලබාගැනීමට එකඟ වන්න x2dkey​(i−1)+j′​=ReLU([k,−k])i+j​

166x\_rolled=torch.cat(x\_rolled,dim=-1)

#

පිටපත්සංයුක්ත කරන්න x

168x\_repeat=torch.cat([x]\*self.nu,dim=-1)

#

ඒවාගුණ කරන්න, ϕ2dkey​(i−1)+j​(k)=ReLU([k,−k])j​ReLU([k,−k])i+j​

174returnx\_repeat\*x\_rolled

#

වේගවත්අවධානය බර

කඩදාසිගණනය කිරීම සඳහා නව යාවත්කාලීන රීතියක් හඳුන්වා දෙයි W(i). ආකෘතිය මුලින්ම යතුර සමඟ vˉ(i) යුගලනය කරන ලද වත්මන් අගය ලබා k(i)ගනී. ඉන්පසු නැවත ලබා ගත් අගය vˉ(i) සහ v(i)new​ ආදානයේ සංයෝජනයක් ගබඩා v(i)කරයි.

k(i),v(i),q(i)vˉ(i)β(i)v(i)new​W(i)y(i)​=Wk​x(i),Wv​x(i),Wq​x(i)=W(i−1)ϕ′(k(i))=σ(Wβ​x(i))=β(i)v(i)+(1−β(i))vˉ(i)=W(i−1)+v(i)new​⊗ϕ′(k(i))=W(i−1)+β(i)(v(i)−vˉ(i))⊗ϕ′(k(i))=W(i)ϕ′(q(i))​

පුහුණුකළ හැකි පරාමිතියක් σ වන අතර සිග්මෝයිඩ් ශ්රිතය වේ. Wβ​

සාමාන්යකරණයවී ඇති z නිසා අපට සාමාන්යකරණ පදය අවශ්ය ϕ′ නොවන බව සලකන්න.

177classFastWeightsAttention(Module):

#

205def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float,phi:DPFP):206super().\_\_init\_\_()

#

හිසකටවිශේෂාංග ගණන dk​

209self.d\_k=d\_model//heads

#

හිස්ගණන

211self.heads=heads

#

මේවාපරිවර්තනය කරයි query , key සහ value බහු-හිස අවධානය.

214self.query=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)215self.key=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)216self.value=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)

#

එක්එක් හිස σ(Wβ​x(i)) සඳහා අන්තර්නිවේෂණය බර කාර්යය

219self.interpolation\_weight=nn.Sequential(220PrepareForMultiHeadAttention(d\_model,heads,1,bias=False),221nn.Sigmoid()222)

#

ϕ′

225self.phi=phi

#

ප්රතිදානස්ථරය

228self.output=nn.Linear(d\_model,d\_model)

#

හැලීම

230self.dropout=nn.Dropout(dropout\_prob)

#

232defforward(self,x:torch.Tensor):

#

පියවරගණන ලබා ගන්න L

234seq\_len=x.shape[0]

#

ϕ′(q(i)) සියලු පියවර සහ හිස් සඳහා

236query=self.phi(self.query(x))

#

ϕ′(k(i)) සියලු පියවර සහ හිස් සඳහා

238key=self.phi(self.key(x))

#

v(i) සියලු පියවර සහ හිස් සඳහා

240value=self.value(x)

#

β(i) සියලු පියවර සහ හිස් සඳහා

242beta=self.interpolation\_weight(x)

#

W(0)

245weights=key.new\_zeros((key.shape[1],key.shape[2],value.shape[3],key.shape[3]))

#

ප්රතිදානයන්ගබඩා කිරීමට ලැයිස්තුව y(i)

247outputs=[]

#

පියවරහරහා නැවත ක්රියාත්මක කරන්න

250foriinrange(seq\_len):

#

vˉ(i)=W(i−1)ϕ′(k(i))

252value\_existing=torch.einsum('bhvk,bhk-\>bhv',weights,key[i])

#

W(i)=W(i−1)+β(i)(v(i)−vˉ(i))⊗ϕ′(k(i))

257weights=weights+torch.einsum('bhv,bhk-\>bhvk',beta[i]\*(value[i]-value\_existing),key[i])

#

y(i)=W(i)ϕ′(q(i))

260y=torch.einsum('bhvk,bhk-\>bhv',weights,query[i])

#

බහුහිස් ඒකාබද්ධ කර ඊට සම්බන්ධ කරන්න outputs

263outputs.append(y.reshape(y.shape[0],-1))

#

එක්එක් පියවරේදී ප්රතිදානයන් තනි ටෙන්සරයකට ගොඩගසන්න

266x=torch.stack(outputs)

#

ප්රතිදානස්ථරය

269returnself.output(x)

#

මෙයස්වයං අවධානය සහ පෝෂක ජාලය ඒකාබද්ධ කරන සාමාන්ය ට්රාන්ස්ෆෝමර් ස්ථරයකි.

272classFastWeightsAttentionTransformerLayer(Module):

#

276def\_\_init\_\_(self,\*,277d\_model:int,278attn:FastWeightsAttention,279feed\_forward:FeedForward,280dropout\_prob:float):281super().\_\_init\_\_()

#

ට්රාන්ස්ෆෝමර්ප්රමාණය dmodel​

283self.size=d\_model

#

වේගවත්බර අවධානය මොඩියුලය

285self.attn=attn

#

Feed-ඉදිරිජාලය

287self.feed\_forward=feed\_forward

#

හැලෙනස්ථරය

289self.dropout=nn.Dropout(dropout\_prob)

#

සාමාන්යකරණයස්ථර

292self.norm\_self\_attn=nn.LayerNorm([d\_model])293self.norm\_ff=nn.LayerNorm([d\_model])

#

295defforward(self,x:torch.Tensor):

#

වේගවත්බර ගණනය කරන්න ස්වයං අවධානය

297attn=self.attn(x)

#

ස්වයංඅවධානය ප්රතිඵල එකතු

299x=x+self.dropout(attn)

#

පෝෂණයසඳහා සාමාන්යකරණය කරන්න

302z=self.norm\_ff(x)

#

Feed-forwardජාලය හරහා ගමන් කරන්න

304ff=self.feed\_forward(z)

#

ප්රතිපෝෂණඉදිරි ප්රති results ල නැවත එක් කරන්න

306x=x+self.dropout(ff)

#

309returnx

#

මෙයබහු ට්රාන්ස්ෆෝමර් ස්ථර සහිත සාමාන්ය ට්රාන්ස්ෆෝමර් මොඩියුලයකි

312classFastWeightsAttentionTransformer(Module):

#

316def\_\_init\_\_(self,layer:FastWeightsAttentionTransformerLayer,n\_layers:int):317super().\_\_init\_\_()

#

ට්රාන්ස්ෆෝමර්ස්ථරයේ පිටපත් සාදන්න

319self.layers=clone\_module\_list(layer,n\_layers)

#

අවසානසාමාන්යකරණ ස්තරය

321self.norm=nn.LayerNorm([layer.size])

#

323defforward(self,x:torch.Tensor):324fori,layerinenumerate(self.layers):

#

ස්ථරප්රතිදානය ලබා ගන්න

326x=layer(x)

#

ප්රතිදානයසාමාන්යකරණය කරන්න

329returnself.norm(x)

Trending Research Paperslabml.ai