Back to Annotated Deep Learning Paper Implementations

අවධානයරහිත ට්රාන්ස්ෆෝමරයක්

docs/si/transformers/aft/index.html

latest11.0 KB
Original Source

hometransformersaft

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/aft/ init.py)

#

අවධානයරහිත ට්රාන්ස්ෆෝමරයක්

මෙය PyTorch කඩදාසි ක්රියාත්මක කිරීමකි අවධානය නිදහස් ට්රාන්ස්ෆෝමර් .

මෙමලිපිය ස්වයං අවධානය ස්තරය නව කාර්යක්ෂම මෙහෙයුමකින් ප්රතිස්ථාපනය කරයි, T එය මතක සංකීර්ණතාවයක් ඇත O(Td), අනුක්රමයේ දිග කොතැනද d ? කාවැද්දීම් වල මානය.

කඩදාසිAFT සහ AFT හඳුන්වා දෙයි AFT සහ AFT-conv. මෙන්න අපි ස්වයංක්රීය ප්රතිගාමී ආකෘතියක් තුළ සමීප ටෝකන කෙරෙහි අවධානය යොමු කරන AFT- දේශීය ක්රියාත්මක කර ඇත්තෙමු.

අවධානයනිදහස් ට්රාන්ස්ෆෝමර්

AFT( MHAසමාන) පළමු විමසුම X බවට කාවැද්දීම් පරිවර්තනය Q=XWQ, සමග ප්රධාන K=XWK හා අගය V=XWV tensors ඉගෙන ගත් බර. එක් එක් ස්ථානය සඳහා ප්රතිදානය පහත සඳහන් මෙහෙයුම සමඟ ගණනය කරනු t∈[1,T] ලැබේ.

Yt​=σ(Qt​)⊙∑t′=1T​exp(Kt′​+wt,t′​)∑t′=1T​exp(Kt′​+wt,t′​)⊙Vt′​​

,මූලද්රව්ය-wise ානවන්ත නිෂ්පාදනයක් කොහෙද ⊙ , σ -ෙර්ඛීය නොවන w∈RT×T වේ (සිග්මෝයිඩ්) හා යුගල-නැණවත් තත්ත්වය අගතීන් උගත් න්යාසය වේ.

මෙයින්අදහස් කරන්නේ අපි අගයන් බරිත සාමාන්යය ගෙන විමසුම මගින් ඒවා ගුණ කරන බවයි. මෙමඟින් MHA අවශ්ය T×T අවධානය යොමු කිරීමේ අනුකෘතිය ගණනය කිරීමේ අවශ්යතාවය ඉවත් කරන අතර එම නිසා මතක අවශ්යතාවය අඩු කරයි.

AFTදේශීය

AFTදේශීය වශයෙන් උගත් යුගල-නැණවත් ස්ථාන අගතීන් පමණක් දේශීයව අදාළ වේ:

wt,t′′​={wt,t′​,0,​for ∣t−t′∣<sotherwise​​

,දේශීය කවුළු s≤T ප්රමාණය කොහේද?

දේශීයකවුළුවෙන් 0 පිටත wt,t′′​ වුවද AFT මෙහෙයුම තවමත් වෙනත් ප්රදේශවලින් යතුරු වටිනාකම් යුගල භාවිතා කරයි. දේශීය කවුළුවෙන් පිටත කාවැද්දීම් සම්පූර්ණයෙන්ම නොපෙනෙන දේශීය ට්රාන්ස්ෆෝමර් වලට වඩා මෙය වෙනස් වේ.

AFTදේශීය ආකෘතියක් සඳහා පුහුණු කේතය මෙන්න.

61fromtypingimportOptional6263importtorch64fromtorchimportnn6566fromlabml\_helpers.moduleimportModule

#

AFTදේශීය මෙහෙයුම

Yt​=σ(Qt​)⊙∑t′=1T​exp(Kt′​+wt,t′​)∑t′=1T​exp(Kt′​+wt,t′​)⊙Vt′​​

කොහෙද,

wt,t′′​={wt,t′​,0,​for ∣t−t′∣<sotherwise​​

69classAFTLocal(Module):

#

  • d_model යනු query , key සහ value දෛශිකවල ඇති ලක්ෂණ ගණන වේ.
  • seq_len වේ T
  • local_window_size දේශීය කවුළු ප්රමාණයයි s
  • bias සඳහා පරිවර්තනයන් සඳහා නැඹුරුව පරාමිතිය තිබිය යුතුද යන්න Q, K සහ V.
88def\_\_init\_\_(self,d\_model:int,seq\_len:int,local\_window\_size:int,bias:bool=True):

#

96super().\_\_init\_\_()

#

දේශීයකවුළු ප්රමාණය s

99self.local\_window\_size=local\_window\_size

#

මේවාපරිණාමනය කරයි query , key සහ value දෛශික.

101self.query=nn.Linear(d\_model,d\_model,bias=bias)102self.key=nn.Linear(d\_model,d\_model,bias=bias)103self.value=nn.Linear(d\_model,d\_model,bias=bias)

#

යුගල-නැණවත්ස්ථානීය අගතීන් w∈RT×T

105self.pos\_bias=nn.Parameter(torch.zeros(seq\_len,seq\_len),requires\_grad=True)

#

සඳහාමාස්ක් wt,t′​

107self.local\_mask=nn.Parameter(self.create\_local\_mask(seq\_len,local\_window\_size),requires\_grad=False)

#

සක්‍රීයකිරීම σ

109self.activation=nn.Sigmoid()

#

ප්රතිදානස්ථරය

111self.output=nn.Linear(d\_model,d\_model)

#

දේශීයවෙස් මුහුණ සාදන්න

මෙයවෙස් මුහුණක් නිර්මාණය කරයි

mt,t′​={1,0,​for ∣t−t′∣<sotherwise​​

113@staticmethod114defcreate\_local\_mask(seq\_len,local\_window\_size):

#

ඒවාටමුල පුරන්න

130local\_mask=torch.ones(seq\_len,seq\_len,dtype=torch.bool)

#

t′−t≥s ශුන්ය කරන්න

132local\_mask=torch.tril(local\_mask,local\_window\_size-1)

#

t−t′≥s ශුන්ය කරන්න

134local\_mask=torch.triu(local\_mask,-(local\_window\_size-1))

#

137returnlocal\_mask

#

query , key සහ valueවිමසුම, _යතුර_සහ _වටිනාකම_සඳහා ටෝකන් කාවැද්දීම් එකතු කිරීම ගබඩා කරන ආතතීන් වේ. ඒවායේ හැඩය ඇත [seq_len, batch_size, d_model] .

mask හැඩය ඇති [seq_len, seq_len, batch_size] අතර කණ්ඩායම සඳහා b , ස්ථානයේ විමසුමට ප්රවේශය i තිබේද යන්න mask[i, j, b] දක්වයි ස්ථානයේ ප්රධාන-අගය j .

139defforward(self,\*,140query:torch.Tensor,141key:torch.Tensor,142value:torch.Tensor,143mask:Optional[torch.Tensor]=None):

#

query , key``value සහ හැඩය [seq_len, batch_size, d_model]

155seq\_len,\_,\_=query.shape156157ifmaskisnotNone:

#

mask හැඩය ඇත [seq_len_q, seq_len_k, batch_size] , එහිදී පළමු මානය විමසුම් මානයක් වේ. විමසුම මානයක් සමාන වේ නම් 1 එය විකාශනය කරනු ඇත.

161assertmask.shape[0]==1ormask.shape[0]==query.shape[0]162assertmask.shape[1]==key.shape[0]163assertmask.shape[2]==1ormask.shape[2]==query.shape[1]

#

විමසුම, යතුර සහ අගය කාවැද්දීම් පරිවර්තනය කරන්න

166query=self.query(query)167key=self.key(key)168value=self.value(value)

#

ලබාගන්න

wt,t′′​={wt,t′​,0,​for ∣t−t′∣<sotherwise​​

වෙස්මුහුණභාවිතා කිරීම

181pos\_bias=self.pos\_bias[:seq\_len,:seq\_len]\*self.local\_mask[:seq\_len,:seq\_len]182pos\_bias=pos\_bias.unsqueeze(-1)183pos\_bias.masked\_fill\_(~mask,float('-inf'))

# Yt​​=σ(Qt​)⊙∑t′=1T​exp(Kt′​+wt,t′​)∑t′=1T​exp(Kt′​+wt,t′​)⊙Vt′​​=σ(Qt​)⊙∑t′=1T​exp(wt,t′​)⊙exp(Kt′​)∑t′=1T​exp(wt,t′​)⊙exp(Kt′​)⊙Vt′​​​

අපිගණනය exp(Kt′​)⊙Vt′​ කර exp(wt,t′​), exp(Kt′​) වෙන වෙනම සහ අනුකෘති ගුණ කිරීමක් කරන්නෙමු. අපි පැහැදිලි කිරීම සඳහා einsum භාවිතා කරමු.

#

සොෆ්ට්මැක්ස්ගණනය ස්ථාවර කිරීම සඳහා ඝාතකයන් ගණනය කිරීමට maxt′​(wt,t′​) පෙර අපි අඩු maxt′​(Kt′​) කරන්නෙමු.

විශාල xi​ නම් විශාල exp(xi​) වන අතර ගණනය කිරීම අස්ථායී ∑exp(xi​)∑exp(xi​)yi​​වේ. numerator සහ නිකායකය සිට ඝාතීය ගණනය කිරීමට පෙර නියතයක් අඩු කිරීම අවලංගු වනු ඇත. හා ගණනය ස්ථාවර උදව් විය හැක. එබැවින් අපි ගණනය කිරීම ස්ථාවර max(xi​) කිරීමට අඩු කරමු.

205max\_key=key.max(dim=0,keepdims=True)[0]206max\_pos\_bias=pos\_bias.max(dim=1,keepdims=True)[0]

#

exp(Kt′​−maxt′​(Kt′​))

209exp\_key=torch.exp(key-max\_key)

#

exp(wt,t′​−maxt′​(wt,t′​))

211exp\_pos\_bias=torch.exp(pos\_bias-max\_pos\_bias)

#

සංඛ්යාකොටස ∑t′=1T​exp(wt,t′​)⊙exp(Kt′​)⊙Vt′​

214num=torch.einsum('ijb,jbd-\>ibd',exp\_pos\_bias,exp\_key\*value)

#

මෙමහරය කොටසක් ∑t′=1T​exp(wt,t′​)⊙exp(Kt′​)

216den=torch.einsum('ijb,jbd-\>ibd',exp\_pos\_bias,exp\_key)

#

ප්රතිදාන Yt​=σ(Qt​)⊙∑t′=1T​exp(wt,t′​)⊙exp(Kt′​)∑t′=1T​exp(wt,t′​)⊙exp(Kt′​)⊙Vt′​​

221y=self.activation(query)\*num/den

#

ප්රතිදානස්ථරය

224returnself.output(y)

#

දේශීයවෙස් මුහුණ පරීක්ෂා කරන්න

227def\_test\_local\_mask():

#

231fromlabml.loggerimportinspect232inspect(AFTLocal.create\_local\_mask(10,4))

#

236if\_\_name\_\_=='\_\_main\_\_':237\_test\_local\_mask()

Trending Research Paperslabml.ai