Back to Annotated Deep Learning Paper Implementations

රෙට්රෝආකෘතිය

docs/si/transformers/retro/model.html

latest30.9 KB
Original Source

hometransformersretro

View code on Github

#

රෙට්රෝආකෘතිය

RETROසඳහා ආදර්ශ අර්ථ දැක්වීම මෙයයි.

16importmath17fromtypingimportSet1819importtorch20fromtorchimportnn2122fromlabml.loggerimportinspect

#

කඹය කාවැද්දීම්

අපිස්වයං අවධානය ස්ථර භමණ තත්ත්වය කාවැද්දීම් භාවිතා කරන්න. ස්ථානීය තොරතුරු කාවැද්දීම් තුළට කාවැදී ඇති අතර එම නිසා ඒවා පොදු අවධානයට ලක් නොකරමු. හේතු නොවන ස්වයං අවධානයට පැහැදිලි ස්ථානීය තොරතුරු අවශ්ය වන්නේ එයට අනුමාන කළ නොහැකි බැවිනි.

25classRotaryPositionalEmbeddings(nn.Module):

#

  • d යනු විශේෂාංග ගණන d
  • base ගණනය කිරීම සඳහා භාවිතා කරන නියතය Θ
36def\_\_init\_\_(self,d:int,base:int=10\_000):

#

41super().\_\_init\_\_()

#

Θ=θi​=10000d2(i−1)​,i∈[1,2,...,2d​]

43self.theta=nn.Parameter(1./(base\*\*(torch.arange(0,d,2).float()/d)),requires\_grad=False)

#

  • x යනු යතුරක හිසෙහි ටෙන්සර් හෝ හැඩය සහිත විමසුමකි [batch_size, seq_len, n_heads, d]
45defforward(self,x:torch.Tensor):

#

හැඩයඋපුටා ගන්න

50batch\_size,seq\_len,n\_heads,d=x.shape

#

2d​

53d\_2=d//2

#

ස්ථානදර්ශක සාදන්න [0, 1, ..., seq_len - 1]

56seq\_idx=torch.arange(seq\_len,device=x.device).type\_as(self.theta)

#

ස්ථානදර්ශකයේ නිෂ්පාදිතය ගණනය කරන්න θi​

59idx\_theta=torch.einsum('n,d-\>nd',seq\_idx,self.theta)

#

පේළියසඳහා m අපට ඇති පරිදි සංයුක්ත කරන්න [mθ0​,mθ1​,...,mθ2d​​,mθ0,mθ1,...,mθ2d​​]

63idx\_theta2=torch.cat([idx\_theta,idx\_theta],dim=1)

#

ගණනයකරන්න [−x(2d​+1),−x(2d​+2),...,−x(d),x(1),x(2),...,−x(2d​)]

67neg\_half\_x=torch.cat([-x[:,:,:,d\_2:],x[:,:,:,:d\_2]],dim=-1)

#

ගණනයකරන්න

(xm(i)​cosmθi​−xm(i+2d​)​sinmθi​xm(i+2d​)​cosmθi​+xm(i)​sinmθi​​)​

සඳහා i∈1,2,...,2d​

79rx=(x\*idx\_theta2.cos()[None,:,None,:])+(neg\_half\_x\*idx\_theta2.sin()[None,:,None,:])

#

82returnrx

#

ස්වයංඅවධානය ස්ථරය ATTN

මෙයහේතු සහ හේතු නොවන බහු-හිස සහිත ස්වයං අවධානයඅදාළ වේ.

85classSelfAttention(nn.Module):

#

  • d_model ට්රාන්ස්ෆෝමර් කාවැද්දීම් වල විශේෂාංග ගණන වේ
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
  • d_k යනු හිසකට ඇති ලක්ෂණ ගණන
  • is_causal මෙය හේතුකාරක අවධානය (මැස්සෙඩ්) යන්න පෙන්නුම් කරයි
92def\_\_init\_\_(self,d\_model:int,n\_heads:int,d\_k:int,is\_causal:bool):

#

99super().\_\_init\_\_()100101self.is\_causal=is\_causal102self.n\_heads=n\_heads103self.d\_k=d\_k

#

සොෆ්ට්මැක්ස්වලට පෙර අවධානය පරිමාණය කිරීම dk​​1​

106self.scale=1/math.sqrt(self.d\_k)

#

විමසුම, යතුරු සහ අගය හිස් සඳහා රේඛීය ස්ථර.

109self.query=nn.Linear(d\_model,n\_heads\*d\_k)110self.key=nn.Linear(d\_model,n\_heads\*d\_k)111self.value=nn.Linear(d\_model,n\_heads\*d\_k)

#

පූර්වසම්මත ස්තරය. කඩදාසි වෙනුවට RMSNorm භාවිතා කරයි.

114self.norm=nn.LayerNorm(d\_model)

#

අවධානයසම්භාවිතාව සඳහා සොෆ්ට්මැක්ස්

117self.softmax=nn.Softmax(dim=-1)

#

රොටරිස්ථානීය කාවැද්දීම්

120self.rotary\_pe=RotaryPositionalEmbeddings(self.d\_k)

#

අවසානරේඛීය ස්ථරය

123self.output=nn.Linear(n\_heads\*d\_k,d\_model)

#

හේතුකාරකඅවධානය සඳහා අවධානය යොමු කිරීමේ ස්තරය ආවරණය කරන්න

  • attn හැඩයේ අවධානය යොමු කිරීමේ අනුකෘතියකි [batch_size, n_heads, seq_len, seq_len]
125defmask\_attention(self,attn:torch.Tensor):

#

හේතුනොවන අවධානය සඳහා ආවරණ නොමැත

133ifnotself.is\_causal:134returnattn

#

ත්රිකෝණාකාරවෙස් මුහුණක් සාදන්න

137mask=torch.tril(attn.new\_ones(attn.shape[-2:]))

#

වෙස්මුහුණමගින් පෙරහන් කරන්න

139returnattn.masked\_fill(mask==0,float('-inf'))

#

  • h යනු හැඩයේ ට්රාන්ස්ෆෝමර් කාවැද්දීම් වේ [batch_size, seq_len, d_model]
141defforward(self,h:torch.Tensor):

#

අවශේෂසම්බන්ධතාවය

147h\_res=h

#

පූර්වසාමාන්යකරණය

150h=self.norm(h)

#

විමසුම, යතුර සහ අගයන් ලබා ගෙන ඒවා හිස් වලට බෙදන්න. මේවාට හැඩයන් ඇත [batch_size, seq_len, n_heads, d_k]

154mh\_shape=(\*h.shape[:-1],self.n\_heads,self.d\_k)155q=self.query(h).view(mh\_shape)156k=self.key(h).view(mh\_shape)157v=self.value(h).view(mh\_shape)

#

භ්රමණස්ථානීය කාවැද්දීම් යොදන්න

160q=self.rotary\_pe(q)161k=self.rotary\_pe(k)

#

අවධානයගණනය කරන්න

164attn=torch.einsum('bihd,bjhd-\>bhij',q,k)

#

විසින්එය පරිමාණය dk​​1​

166attn=attn\*self.scale

#

එයහේතු අවධානයක් නම් වෙස් මුහුණු යොදන්න

169attn=self.mask\_attention(attn)

#

අවධානයසම්භාවිතාව ගණනය කරන්න

172attn=self.softmax(attn)

#

වටිනාකම්ලබා ගන්න

175h=torch.einsum("bhij,bjhd-\>bihd",attn,v)

#

හැඩයෙන්වෙනස් [batch_size, seq_len, n_heads, d_k] කරන්න [batch_size, seq_len, n_heads * d_k]

179h=h.reshape(\*h.shape[:-2],-1)

#

අවසානරේඛීය ස්ථරය යොදන්න. ප්රති result ලය හැඩය ඇත [batch_size, seq_len, d_model]

183h=self.output(h)

#

අවශේෂසම්බන්ධතාවය එක් කරන්න

186returnh+h\_res

#

හරස්අවධානය ස්ථරය CA

මෙයඉහත අර්ථ දක්වා ඇති ස්වයං අවධානය ස්ථරයට සමාන වේ, එය විමසුම් වලට වඩා වෙනස් කාවැද්දීම් කට්ටලයකින් යතුරු සහ අගයන් ලබා ගනී.

ආදානකුට්ටි මත පදනම්ව නැවත ලබා ගත් කුට්ටි කේතනය කිරීම සඳහා මෙය එන්කෝඩරයේ භාවිතා වේ.

අපිමෙහි කිසිදු පැහැදිලි ස්ථානීය කාවැද්දීමක් භාවිතා නොකරමු. ආකෘතියට කාවැද්දීම් වල ස්ථානීය තොරතුරු ව්යංගයෙන් නියෝජනය කළ හැකි යැයි අපි උපකල්පනය කරමු.

189classCrossAttention(nn.Module):

#

  • d_model ට්රාන්ස්ෆෝමර් කාවැද්දීම් වල විශේෂාංග ගණන වේ
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
  • d_k යනු හිසකට ඇති ලක්ෂණ ගණන
203def\_\_init\_\_(self,d\_model:int,n\_heads:int,d\_k:int):

#

209super().\_\_init\_\_()210211self.n\_heads=n\_heads212self.d\_k=d\_k

#

සොෆ්ට්මැක්ස්වලට පෙර අවධානය පරිමාණය කිරීම dk​​1​

215self.scale=1/math.sqrt(self.d\_k)

#

විමසුම, යතුරු සහ අගය හිස් සඳහා රේඛීය ස්ථර.

218self.query=nn.Linear(d\_model,n\_heads\*d\_k)219self.key=nn.Linear(d\_model,n\_heads\*d\_k)220self.value=nn.Linear(d\_model,n\_heads\*d\_k)

#

විමසුම්කාවැද්දීම් සඳහා පූර්ව සම්මත ස්තරය. කඩදාසි වෙනුවට RMSNorm භාවිතා කරයි.

223self.norm=nn.LayerNorm(d\_model)

#

අවධානයසම්භාවිතාව සඳහා සොෆ්ට්මැක්ස්

226self.softmax=nn.Softmax(dim=-1)

#

අවසානරේඛීය ස්ථරය

229self.output=nn.Linear(n\_heads\*d\_k,d\_model)

#

  • e හැඩයෙන් යුත් ළඟම අසල්වැසියාගේ කුට්ටිය කාවැද්දීම් ලබා ගත හැකිය [batch_size, chunks, neighbors, neighbor_len, d_model]
  • h ආසන්නතම අසල්වැසියන් හැඩයෙන් ලබා ගන්නා ලද ආදාන කුට්ටි [batch_size, chunks, chunk_len, d_model] වේ. මෙය දැනටමත් සාමාන්යකරණය වී ඇත.
231defforward(self,e:torch.Tensor,h:torch.Tensor):

#

අවශේෂසම්බන්ධතාවය

240e\_res=e

#

ලබාගත් කුට්ටි සාමාන්යකරණය කරන්න

243e=self.norm(e)

#

ලබාගත් කුට්ටි වලින් විමසුම ලබා ගන්න

246q=self.query(e).view(\*e.shape[:-1],self.n\_heads,self.d\_k)

#

ආදානකුට්ටි වලින් යතුරු සහ අගයන් ලබා ගන්න

248k=self.key(h).view(\*h.shape[:-1],self.n\_heads,self.d\_k)249v=self.value(h).view(\*h.shape[:-1],self.n\_heads,self.d\_k)

#

සියලුමකුට්ටි සඳහා අවධානය ලකුණු ගණනය කරන්න. ලබා ගත් සෑම අසල්වැසියෙකුම එය නැවත ලබා ගත් මුල් කුට්ටිය කෙරෙහි අවධානය යොමු කරනු ඇත. මෙම හැඩය ඇත [batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]

254attn=torch.einsum('bcnihd,bcjhd-\>bcnhij',q,k)

#

පරිමාණඅවධානය ලකුණු

256attn=attn\*self.scale

#

අවසානමානය හරහා සොෆ්ට්මැක්ස් ගණනය කරන්න

259attn=self.softmax(attn)

#

වටිනාකම්එකතු කරන්න

262e=torch.einsum("bcnhij,bcjhd-\>bcnihd",attn,v)

#

හැඩයෙන්වෙනස් [batch_size, chunks, neighbors, neighbor_len, n_heads, d_k] කරන්න [batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]

266e=e.reshape(\*e.shape[:-2],-1)

#

අවසානරේඛීය ස්ථරය යොදන්න. ප්රති result ලය හැඩය ඇත [batch_size, chunks, neighbors, neighbor_len, d_model]

270e=self.output(e)

#

අවශේෂසම්බන්ධතාවය එක් කරන්න

273returne+e\_res

#

තැළුණුහරස් අවධානය ස්ථරය CCA

මෙයඉහත අර්ථ දක්වා ඇති හරස් අවධානය ස්ථරයට සමාන වේ.

මෙයනැවත ලබා ගත් අසල්වැසියා කුට්ටි වෙත අවධානය යොමු කිරීම සඳහා විකේතකය තුළ භාවිතා වේ.

අපිමෙහි කිසිදු පැහැදිලි ස්ථානීය කාවැද්දීමක් භාවිතා නොකරමු. ආකෘතියට කාවැද්දීම් වල ස්ථානීය තොරතුරු ව්යංගයෙන් නියෝජනය කළ හැකි යැයි අපි උපකල්පනය කරමු.

276classChunkedCrossAttention(nn.Module):

#

  • d_model ට්රාන්ස්ෆෝමර් කාවැද්දීම් වල විශේෂාංග ගණන වේ
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
  • d_k යනු හිසකට ඇති ලක්ෂණ ගණන
  • chunk_len යනු කුට්ටියක දිග
288def\_\_init\_\_(self,d\_model:int,n\_heads:int,d\_k:int,chunk\_len:int):

#

296super().\_\_init\_\_()297298self.chunk\_len=chunk\_len299self.n\_heads=n\_heads300self.d\_k=d\_k

#

සොෆ්ට්මැක්ස්වලට පෙර අවධානය පරිමාණය කිරීම dk​​1​

303self.scale=1/math.sqrt(self.d\_k)

#

විමසුම, යතුරු සහ අගය හිස් සඳහා රේඛීය ස්ථර.

306self.query=nn.Linear(d\_model,n\_heads\*d\_k)307self.key=nn.Linear(d\_model,n\_heads\*d\_k)308self.value=nn.Linear(d\_model,n\_heads\*d\_k)

#

විමසුම්කාවැද්දීම් සඳහා පූර්ව සම්මත ස්තරය. කඩදාසි වෙනුවට RMSNorm භාවිතා කරයි.

311self.norm=nn.LayerNorm(d\_model)

#

අවධානයසම්භාවිතාව සඳහා සොෆ්ට්මැක්ස්

314self.softmax=nn.Softmax(dim=-1)

#

අවසානරේඛීය ස්ථරය

317self.output=nn.Linear(n\_heads\*d\_k,d\_model)

#

h හැඩයේ ආදාන [batch_size, seq_len, d_model]``e කාවැද්දීම් යනු හැඩයේ ආසන්නතම අසල්වැසියන් ලබා ගත හැකිය [batch_size, chunks, neighbors, neighbor_len, d_model]

319defforward(self,h:torch.Tensor,e:torch.Tensor):

#

හැඩයලබා ගන්න

326batch\_size,chunks,neighbors,neighbor\_len,d\_model=e.shape

#

කුට්ටිනොමැති නම් අවධානය යොමු නොකරයි (නියැදීමේදී කෙටි යෙදවුම් සඳහා)

329ifchunks==0:330returnh

#

අවශේෂසම්බන්ධතාවය

333h\_res=h

#

පළමු chunk_len - 1 කාවැද්දීම් ඉවත් කරන්න. අතීත ටෝකන භාවිතයෙන් පමණක් ලබා ගත් සහ කේතනය කර ඇති අසල්වැසියන් කෙරෙහි ආදානය අවධානය යොමු කරයි; එවිට තොරතුරු කාන්දු වීමක් සිදු නොවේ. පළමු කුට්ටියේ සිට ලබාගත් අසල්වැසියන් පළමු කුට්ටියෙන් තොරතුරු ලැබෙනු ඇත. එබැවින් අනුක්රමය වමට මාරු කිරීමෙන් තොරතුරු පමණක් දකුණට ගලා යන බවට chunk_len - 1 අපි වග බලා ගනිමු.

341h=h[:,self.chunk\_len-1:]

#

පූර්වසම්මතය

343h=self.norm(h)

#

ආදානයකුට්ටි වලට බෙදීමට හැකිවන පරිදි හිස් කාවැද්දීම් අවසානය දක්වා එක් කරන්න

345ifh.shape[1]\<chunks\*self.chunk\_len:346h=torch.cat((h,h.new\_zeros(batch\_size,chunks\*self.chunk\_len-h.shape[1],d\_model)),dim=1)

#

ආදානයකුට්ටි බවට නැවත සකස් කරන්න.

348h=h.reshape(batch\_size,chunks,self.chunk\_len,d\_model)

#

ආදානයෙන්විමසුම ලබා ගන්න

351q=self.query(h).view(\*h.shape[:-1],self.n\_heads,self.d\_k)

#

ලබාගත් අසල්වැසියන්ගෙන් යතුරු සහ වටිනාකම් ලබා ගන්න

353k=self.key(e).view(\*e.shape[:-1],self.n\_heads,self.d\_k)354v=self.value(e).view(\*e.shape[:-1],self.n\_heads,self.d\_k)

#

ආදානකුට්ටි සඳහා අවධානය ලකුණු ගණනය කරන්න. සෑම කුට්ටියක්ම කලින් කුට්ටිය විසින් ලබා ගන්නා ලද අසල්වැසියන් කෙරෙහි අවධානය යොමු කරනු ඇත. මෙම හැඩය ඇත [batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]

359attn=torch.einsum('bcihd,bcnjhd-\>bchinj',q,k)

#

පරිමාණඅවධානය ලකුණු

361attn=attn\*self.scale

#

අවසානමානයන් දෙකට වඩා සොෆ්ට්මැක්ස් යොදන්න neighbors, neighbor_len

364attn=self.softmax(attn.view(\*attn.shape[:-2],-1)).view(attn.shape)

#

වටිනාකම්එකතු කරන්න

367h=torch.einsum("bchinj,bcnjhd-\>bcihd",attn,v)

#

හැඩයෙන්වෙනස් [batch_size, chunks, chunk_len, n_heads, d_k] කරන්න [batch_size, chunks * chunk_len, n_heads * d_k]

371h=h.reshape(batch\_size,chunks\*self.chunk\_len,-1)

#

අවසානරේඛීය ස්ථරය යොදන්න. ප්රති result ලය හැඩය ඇත [batch_size, chunks * chunk_len, d_model]

375h=self.output(h)

#

වමට chunk_len - 1 ශුන්ය කාවැද්දීම එක් කරන්න; එනම් දකුණු එය ආපසු මාරු කරන්න

378h=torch.cat((h.new\_zeros(batch\_size,self.chunk\_len-1,d\_model),h),dim=1)

#

අවශේෂසම්බන්ධතාවය ඉවත් කර එකතු කරන්න

381returnh[:,:h\_res.shape[1]]+h\_res

#

ස්ථාන-නැණවත්පෝෂණය ඉදිරි ස්ථරය FFW

මෙයරේඛීය ස්ථර දෙකක් සහ මැද සක්රිය කිරීමකින් සමන්විත වේ.

384classFeedForward(nn.Module):

#

  • d_model ට්රාන්ස්ෆෝමර් කාවැද්දීම් වල විශේෂාංග ගණන වේ

  • d_ff සැඟවුණු ස්ථරයේ අංක ලක්ෂණ වේ

391def\_\_init\_\_(self,d\_model:int,d\_ff:int):

#

397super().\_\_init\_\_()

#

රේඛීයස්ථර දෙක

400self.lin1=nn.Linear(d\_model,d\_ff)401self.lin2=nn.Linear(d\_ff,d\_model)

#

Reluසක්රිය කිරීම

404self.act=nn.ReLU()

#

පෙර-සම්මතස්තරය

407self.norm=nn.LayerNorm(d\_model)

#

h හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, d_model]

409defforward(self,h:torch.Tensor):

#

අවශේෂ

415h\_res=h

#

පූර්වසම්මතය

417h=self.norm(h)

#

පළමුරේඛීය ස්ථරය

419h=self.lin1(h)

#

සක්‍රීයකිරීම

421h=self.act(h)

#

දෙවනරේඛීය ස්ථරය

423h=self.lin2(h)

#

අවශේෂසම්බන්ධතාවය එක් කරන්න

426returnh+h\_res

#

ළඟමඅසල්වැසි ආකේතකය ENCODER(RET(Cu​)1≤u≤l​,H)

මෙමමොඩියුලය ලබා ගත් ආසන්නතම අසල්වැසියන් සංකේතවත් කරයි

429classNearestNeighborEncoder(nn.Module):

#

  • chunk_len යනු කුට්ටියක දිග
  • n_layer එන්කෝඩරයේ ස්ථර ගණන Lenc​
  • ca_layers හරස් අවධානය ඇති ස්ථර වේ Penc​
  • d_model යනු කාවැද්දීම් වල විශේෂාංග ගණන
  • n_heads අවධානය යොමු ස්ථර වල හිස් ගණන
  • d_k අවධානය යොමු ප්රධානීන් ප්රමාණය
  • d_ff යනු පෝෂක ඉදිරි ජාලයේ සැඟවුණු ස්ථර වල ප්රමාණයයි
436def\_\_init\_\_(self,chunk\_len:int,n\_layers:int,ca\_layers:Set[int],437d\_model:int,n\_heads:int,d\_k:int,d\_ff:int):

#

448super().\_\_init\_\_()449self.ca\_layers=ca\_layers450self.chunk\_len=chunk\_len

#

හරස්අවධානය ස්ථර

452self.ca=nn.ModuleList([CrossAttention(d\_model,n\_heads,d\_k)for\_inrange(len(ca\_layers))])

#

ද්වි-දිශානුගතස්වයං අවධානය ස්ථර

454self.attn=nn.ModuleList([SelfAttention(d\_model,n\_heads,d\_k,is\_causal=False)for\_inrange(n\_layers)])

#

ඉදිරිස්ථර පෝෂණය කරන්න

456self.ffw=nn.ModuleList([FeedForward(d\_model,d\_ff)for\_inrange(n\_layers)])

#

සඳහාපූර්ව සාමාන්යකරණ ස්තරය H

459self.norm\_h=nn.LayerNorm(d\_model)

#

  • e ලබා ගත් ළඟම අසල්වැසියන්ගේ ටෝකන් කාවැද්දීම්, EMB(RET(Cu​)1≤u≤l​) හැඩයෙන් [batch_size, chunks, neighbors, neighbor_len, d_model]

  • h යනු ආදාන ටෝකන කාවැද්දීම්, H හැඩයෙන් [batch_size, seq_len, d_model]

කුට්ටි u∈[1,l] සහ අසල්වාසීන් සමාන්තරව සකස් j∈[1,k] කරනු ලැබේ.

461defforward(self,e:torch.Tensor,h:torch.Tensor):

#

හැඩයලබා ගන්න

474batch\_size,chunks,neighbors,neighbor\_len,d\_model=e.shape

#

(Hu​)u∈[1,l]​←SPLIT(H)

477h\_split=h[:,:self.chunk\_len\*chunks,:].reshape(batch\_size,chunks,self.chunk\_len,d\_model)

#

පූර්වසම්මතය

480h\_split=self.norm\_h(h\_split)

#

හරස්අවධානය ස්ථරයේ දර්ශකය තබා ගන්න

483p\_ca=0

#

සියලුමස්ථර සඳහා p′∈[1,Lenc​]

485forpinrange(len(self.attn)):

#

ද්වි-දිශානුගතස්වයං අවධානය Euj​←ATTNenc​(Euj​)

488e=self.attn[p](e.view(-1,neighbor\_len,d\_model)).view(e.shape)

#

හරස්අවධානය යොමු කරන්නේ නම් p′∈Penc​

491ifpinself.ca\_layers:

#

Euj​←CAenc​(Euj​,Hu​)

493e=self.ca[p\_ca](e,h\_split)

#

හරස්අවධානය දර්ශකය වැඩි කරන්න

495p\_ca+=1

#

ඉදිරිස්ථරය පෝෂණය කරන්න Euj​←FFWenc​(Euj​)

498e=self.ffw[p](e)

#

ආපසු E

501returne

#

රෙට්රෝආකෘතිය

මෙයරෙට්රෝ විකේතකය

504classRetroModel(nn.Module):

#

  • v_vocab යනු වචන මාලාවේ ටෝකන ගණන
  • d_model යනු කාවැද්දීම් වල විශේෂාංග ගණන
  • n_layers යනු විකේතකයේ ස්ථර ගණන L
  • ca_layers හරස් අවධානය ඇති ස්ථර වේ P
  • chunk_len යනු කුට්ටියක දිග
  • n_heads අවධානය යොමු ස්ථර වල හිස් ගණන
  • d_k අවධානය යොමු ප්රධානීන් ප්රමාණය
  • d_ff යනු පෝෂක ඉදිරි ජාලයේ සැඟවුණු ස්ථර වල ප්රමාණයයි
  • encoder ළඟම අසල්වැසියා එන්කෝඩරයයි
511def\_\_init\_\_(self,n\_vocab:int,d\_model:int,n\_layers:int,ca\_layers:Set[int],chunk\_len:int,512n\_heads:int,d\_k:int,d\_ff:int,encoder:NearestNeighborEncoder):

#

524super().\_\_init\_\_()525526self.ca\_layers=ca\_layers527self.encoder=encoder

#

ටෝකන්කාවැද්දීම ස්ථරය

530self.emb=nn.Embedding(n\_vocab,d\_model)

#

කපනලද හරස් අවධානය ස්ථර CCA

532self.cca=nn.ModuleList(533[ChunkedCrossAttention(d\_model,n\_heads,d\_k,chunk\_len)for\_inrange(len(ca\_layers))])

#

අවධානයස්ථර ATTN

535self.attn=nn.ModuleList([SelfAttention(d\_model,n\_heads,d\_k,is\_causal=True)for\_inrange(n\_layers)])

#

ඉදිරිස්ථර පෝෂණය කරන්න FFW

537self.ffw=nn.ModuleList([FeedForward(d\_model,d\_ff)for\_inrange(n\_layers)])

#

කියවීමේස්ථරය READ

539self.read=nn.Linear(d\_model,n\_vocab)

#

ආසන්නතමඅසල්වැසියාගේ කාවැද්දීම් සඳහා පූර්ව සාමාන්යකරණ ස්තරය ENCODER(RET(Cu​)1≤u≤l​,H)

543self.norm\_e=nn.LayerNorm(d\_model)

#

  • x ආදාන අනුක්රමය, X හැඩයෙන් [batch_size, seq_len]
  • ret හැඩයෙන් ලබා ගත් අසල්වැසියන් RET(Cu​)1≤u≤l​ වේ [batch_size, chunks, neighbors, neighbor_len]
545defforward(self,x:torch.Tensor,ret:torch.Tensor):

#

ආදානකාවැද්දීම් ලබා ගන්න H←EMB(X)

554h=self.emb(x)

#

ලබාගත් අසල්වාසීන්ගේ කාවැද්දීම් Euj​=EMBenc​(RET(Cu​)j).

ආදානසහ අසල්වැසියන් සඳහා අපි එකම කාවැද්දීම් භාවිතා කරමු

560ret\_emb=self.emb(ret)

#

කපනලද හරස් අවධානය ස්ථරයේ දර්ශකය තබා ගන්න

563p\_ca=0

#

සියලුමස්ථර සඳහා p∈[1,L]

565forpinrange(len(self.attn)):

#

හේතුකාරකස්වයං අවධානය H←ATTN(H)

567h=self.attn[p](h)

#

පළමු CCA ස්ථරයට පෙර එන්කෝඩර් කාවැද්දීම් ලබා ගන්න p=min(P)

571ifself.ca\_layersandp==min(self.ca\_layers):

#

E=ENCODER(RET(Cu​)1≤u≤l​,H)

අපි RET(Cu​)1≤u≤l​ එන්කෝඩරයට කාවැද්දීම් සම්මත කළෙමු.

575e=self.encoder(ret\_emb,h)

#

එන්කෝඩර්කාවැද්දීම් සාමාන්යකරණය කරන්න

577e=self.norm\_e(e)

#

කුරුස-හරස්අවධානය නම් p∈P

580ifpinself.ca\_layers:

#

H←CCA(H,E)

582h=self.cca[p\_ca](h,e)

#

වර්ධකකපන ලද හරස් අවධානය දර්ශකය

584p\_ca+=1

#

H←FFW(H)

587h=self.ffw[p](h)

#

O←READ(H)

590returnself.read(h)

#

ව්යාජදත්ත සමඟ ආකෘතිය පරීක්ෂා කරන්න

593def\_test():

#

597chunk\_len=4598d\_model=8599d\_ff=32600n\_heads=2601d\_k=4602603device=torch.device('cuda:0')604605m=RetroModel(5,d\_model,6,{2,5},chunk\_len,n\_heads,d\_k,d\_ff,606encoder=NearestNeighborEncoder(chunk\_len,2,{1},d\_model,n\_heads,d\_k,d\_ff))607608m.to(device)609x=[1,2,4,4,0,1,2,3,4,3]610ret=[611[[0,0,0,0,0,0],[1,1,1,1,1,1]],612[[0,0,0,0,0,0],[1,1,1,1,1,1]],613]614res=m(torch.tensor([x]\*10).to(device),torch.tensor([ret]\*10).to(device))615616inspect(res)

#

620if\_\_name\_\_=='\_\_main\_\_':621\_test()

Trending Research Paperslabml.ai