docs/si/transformers/aft/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/aft/ init.py)
මෙය PyTorch කඩදාසි ක්රියාත්මක කිරීමකි අවධානය නිදහස් ට්රාන්ස්ෆෝමර් .
මෙමලිපිය ස්වයං අවධානය ස්තරය නව කාර්යක්ෂම මෙහෙයුමකින් ප්රතිස්ථාපනය කරයි, T එය මතක සංකීර්ණතාවයක් ඇත O(Td), අනුක්රමයේ දිග කොතැනද d ? කාවැද්දීම් වල මානය.
කඩදාසිAFT සහ AFT හඳුන්වා දෙයි AFT සහ AFT-conv. මෙන්න අපි ස්වයංක්රීය ප්රතිගාමී ආකෘතියක් තුළ සමීප ටෝකන කෙරෙහි අවධානය යොමු කරන AFT- දේශීය ක්රියාත්මක කර ඇත්තෙමු.
AFT( MHAසමාන) පළමු විමසුම X බවට කාවැද්දීම් පරිවර්තනය Q=XWQ, සමග ප්රධාන K=XWK හා අගය V=XWV tensors ඉගෙන ගත් බර. එක් එක් ස්ථානය සඳහා ප්රතිදානය පහත සඳහන් මෙහෙයුම සමඟ ගණනය කරනු t∈[1,T] ලැබේ.
Yt=σ(Qt)⊙∑t′=1Texp(Kt′+wt,t′)∑t′=1Texp(Kt′+wt,t′)⊙Vt′
,මූලද්රව්ය-wise ානවන්ත නිෂ්පාදනයක් කොහෙද ⊙ , σ -ෙර්ඛීය නොවන w∈RT×T වේ (සිග්මෝයිඩ්) හා යුගල-නැණවත් තත්ත්වය අගතීන් උගත් න්යාසය වේ.
මෙයින්අදහස් කරන්නේ අපි අගයන් බරිත සාමාන්යය ගෙන විමසුම මගින් ඒවා ගුණ කරන බවයි. මෙමඟින් MHA අවශ්ය T×T අවධානය යොමු කිරීමේ අනුකෘතිය ගණනය කිරීමේ අවශ්යතාවය ඉවත් කරන අතර එම නිසා මතක අවශ්යතාවය අඩු කරයි.
AFTදේශීය වශයෙන් උගත් යුගල-නැණවත් ස්ථාන අගතීන් පමණක් දේශීයව අදාළ වේ:
wt,t′′={wt,t′,0,for ∣t−t′∣<sotherwise
,දේශීය කවුළු s≤T ප්රමාණය කොහේද?
දේශීයකවුළුවෙන් 0 පිටත wt,t′′ වුවද AFT මෙහෙයුම තවමත් වෙනත් ප්රදේශවලින් යතුරු වටිනාකම් යුගල භාවිතා කරයි. දේශීය කවුළුවෙන් පිටත කාවැද්දීම් සම්පූර්ණයෙන්ම නොපෙනෙන දේශීය ට්රාන්ස්ෆෝමර් වලට වඩා මෙය වෙනස් වේ.
AFTදේශීය ආකෘතියක් සඳහා පුහුණු කේතය මෙන්න.
61fromtypingimportOptional6263importtorch64fromtorchimportnn6566fromlabml\_helpers.moduleimportModule
Yt=σ(Qt)⊙∑t′=1Texp(Kt′+wt,t′)∑t′=1Texp(Kt′+wt,t′)⊙Vt′
කොහෙද,
wt,t′′={wt,t′,0,for ∣t−t′∣<sotherwise
69classAFTLocal(Module):
d_model යනු query , key සහ value දෛශිකවල ඇති ලක්ෂණ ගණන වේ.seq_len වේ Tlocal_window_size දේශීය කවුළු ප්රමාණයයි sbias සඳහා පරිවර්තනයන් සඳහා නැඹුරුව පරාමිතිය තිබිය යුතුද යන්න Q, K සහ V.88def\_\_init\_\_(self,d\_model:int,seq\_len:int,local\_window\_size:int,bias:bool=True):
96super().\_\_init\_\_()
දේශීයකවුළු ප්රමාණය s
99self.local\_window\_size=local\_window\_size
මේවාපරිණාමනය කරයි query , key සහ value දෛශික.
101self.query=nn.Linear(d\_model,d\_model,bias=bias)102self.key=nn.Linear(d\_model,d\_model,bias=bias)103self.value=nn.Linear(d\_model,d\_model,bias=bias)
යුගල-නැණවත්ස්ථානීය අගතීන් w∈RT×T
105self.pos\_bias=nn.Parameter(torch.zeros(seq\_len,seq\_len),requires\_grad=True)
සඳහාමාස්ක් wt,t′
107self.local\_mask=nn.Parameter(self.create\_local\_mask(seq\_len,local\_window\_size),requires\_grad=False)
සක්රීයකිරීම σ
109self.activation=nn.Sigmoid()
ප්රතිදානස්ථරය
111self.output=nn.Linear(d\_model,d\_model)
මෙයවෙස් මුහුණක් නිර්මාණය කරයි
mt,t′={1,0,for ∣t−t′∣<sotherwise
113@staticmethod114defcreate\_local\_mask(seq\_len,local\_window\_size):
ඒවාටමුල පුරන්න
130local\_mask=torch.ones(seq\_len,seq\_len,dtype=torch.bool)
t′−t≥s ශුන්ය කරන්න
132local\_mask=torch.tril(local\_mask,local\_window\_size-1)
t−t′≥s ශුන්ය කරන්න
134local\_mask=torch.triu(local\_mask,-(local\_window\_size-1))
137returnlocal\_mask
query , key සහ valueවිමසුම, _යතුර_සහ _වටිනාකම_සඳහා ටෝකන් කාවැද්දීම් එකතු කිරීම ගබඩා කරන ආතතීන් වේ. ඒවායේ හැඩය ඇත [seq_len, batch_size, d_model] .
mask හැඩය ඇති [seq_len, seq_len, batch_size] අතර කණ්ඩායම සඳහා b , ස්ථානයේ විමසුමට ප්රවේශය i තිබේද යන්න mask[i, j, b] දක්වයි ස්ථානයේ ප්රධාන-අගය j .
139defforward(self,\*,140query:torch.Tensor,141key:torch.Tensor,142value:torch.Tensor,143mask:Optional[torch.Tensor]=None):
query , key``value සහ හැඩය [seq_len, batch_size, d_model]
155seq\_len,\_,\_=query.shape156157ifmaskisnotNone:
mask හැඩය ඇත [seq_len_q, seq_len_k, batch_size] , එහිදී පළමු මානය විමසුම් මානයක් වේ. විමසුම මානයක් සමාන වේ නම් 1 එය විකාශනය කරනු ඇත.
161assertmask.shape[0]==1ormask.shape[0]==query.shape[0]162assertmask.shape[1]==key.shape[0]163assertmask.shape[2]==1ormask.shape[2]==query.shape[1]
විමසුම, යතුර සහ අගය කාවැද්දීම් පරිවර්තනය කරන්න
166query=self.query(query)167key=self.key(key)168value=self.value(value)
ලබාගන්න
wt,t′′={wt,t′,0,for ∣t−t′∣<sotherwise
වෙස්මුහුණභාවිතා කිරීම
181pos\_bias=self.pos\_bias[:seq\_len,:seq\_len]\*self.local\_mask[:seq\_len,:seq\_len]182pos\_bias=pos\_bias.unsqueeze(-1)183pos\_bias.masked\_fill\_(~mask,float('-inf'))
# Yt=σ(Qt)⊙∑t′=1Texp(Kt′+wt,t′)∑t′=1Texp(Kt′+wt,t′)⊙Vt′=σ(Qt)⊙∑t′=1Texp(wt,t′)⊙exp(Kt′)∑t′=1Texp(wt,t′)⊙exp(Kt′)⊙Vt′
අපිගණනය exp(Kt′)⊙Vt′ කර exp(wt,t′), exp(Kt′) වෙන වෙනම සහ අනුකෘති ගුණ කිරීමක් කරන්නෙමු. අපි පැහැදිලි කිරීම සඳහා einsum භාවිතා කරමු.
සොෆ්ට්මැක්ස්ගණනය ස්ථාවර කිරීම සඳහා ඝාතකයන් ගණනය කිරීමට maxt′(wt,t′) පෙර අපි අඩු maxt′(Kt′) කරන්නෙමු.
විශාල xi නම් විශාල exp(xi) වන අතර ගණනය කිරීම අස්ථායී ∑exp(xi)∑exp(xi)yiවේ. numerator සහ නිකායකය සිට ඝාතීය ගණනය කිරීමට පෙර නියතයක් අඩු කිරීම අවලංගු වනු ඇත. හා ගණනය ස්ථාවර උදව් විය හැක. එබැවින් අපි ගණනය කිරීම ස්ථාවර max(xi) කිරීමට අඩු කරමු.
205max\_key=key.max(dim=0,keepdims=True)[0]206max\_pos\_bias=pos\_bias.max(dim=1,keepdims=True)[0]
exp(Kt′−maxt′(Kt′))
209exp\_key=torch.exp(key-max\_key)
exp(wt,t′−maxt′(wt,t′))
211exp\_pos\_bias=torch.exp(pos\_bias-max\_pos\_bias)
සංඛ්යාකොටස ∑t′=1Texp(wt,t′)⊙exp(Kt′)⊙Vt′
214num=torch.einsum('ijb,jbd-\>ibd',exp\_pos\_bias,exp\_key\*value)
මෙමහරය කොටසක් ∑t′=1Texp(wt,t′)⊙exp(Kt′)
216den=torch.einsum('ijb,jbd-\>ibd',exp\_pos\_bias,exp\_key)
ප්රතිදාන Yt=σ(Qt)⊙∑t′=1Texp(wt,t′)⊙exp(Kt′)∑t′=1Texp(wt,t′)⊙exp(Kt′)⊙Vt′
221y=self.activation(query)\*num/den
ප්රතිදානස්ථරය
224returnself.output(y)
දේශීයවෙස් මුහුණ පරීක්ෂා කරන්න
227def\_test\_local\_mask():
231fromlabml.loggerimportinspect232inspect(AFTLocal.create\_local\_mask(10,4))
236if\_\_name\_\_=='\_\_main\_\_':237\_test\_local\_mask()