Back to Annotated Deep Learning Paper Implementations

ජීපීටී-නියෝක්ස්ආකෘතිය

docs/si/neox/model.html

latest32.5 KB
Original Source

homeneox

View code on Github

#

ජීපීටී-නියෝක්ස්ආකෘතිය

ජීපීටී-නියෝක්ස්ආකෘතියේ ස්ථර සඳහා කේතය සහ 20B මුරපොල පූරණය කිරීමේ කේතය මෙන්න.

ස්ථර load_state වල ඇති ක්රමය එම ස්ථරයේ මුරපොලවල් පූරණය කරයි. මුරපොලවල් පැටවීමේ සහායකයන් ක්රියාත්මක වේ checkpoint.py

16importcopy17importmath18fromtypingimportDict,Optional,Set,Callable,Any,Generator,Tuple1920importtorch21fromtorchimportnn22fromtorch.cuda.ampimportautocast2324fromlabmlimportmonit,logger25fromlabml.loggerimportText26fromlabml\_nn.neoximportcheckpoint27fromlabml\_nn.neox.utils.cacheimportget\_cache

#

30classNeoXModule(nn.Module):

#

31defload\_state(self,p1:Dict[str,torch.Tensor],p2:Dict[str,torch.Tensor]):32pass

#

කාවැද්දීමස්ථරය

මෙයමුරපොලට පැටවීම සඳහා කේතය සහිත සම්මත කාවැද්දීම් ස්ථරයකි.

35classEmbedding(NeoXModule):

#

  • n_vocab වචන මාලාවේ ප්රමාණය වේ
  • n_hidden මෙම කාවැද්දීම් ප්රමාණය
42def\_\_init\_\_(self,n\_vocab:int=50\_432,n\_hidden:int=6\_144):

#

47super().\_\_init\_\_()4849self.emb=nn.Embedding(n\_vocab,n\_hidden)

#

  • x හැඩයේ ටෝකන් හැඳුනුම් වේ [batch_size, seq_len]
51defforward(self,x:torch.Tensor):

#

55returnself.emb(x)

#

මුරපොලපූරණය කිරීමට කේතය

57defload\_state(self,p1:Dict[str,torch.Tensor],p2:Dict[str,torch.Tensor]):

#

61withmonit.section('Load embedding layer'):62checkpoint.merge\_params\_dim\_0(self.emb.weight,'word\_embeddings.weight',p1,p2)

#

රොටරිස්ථානීය කාවැද්දීම්

ජීපීටී-නියෝක්ස් භ්රමණ ස්ථානීය කාවැද්දීම් (කඹය)භාවිතා කරයි.

අපින්යාය වැඩි සටහන් සමඟ මෙහි කඹය ක්රියාත්මක කිරීම විස්තර කර ඇත.

65classRoPE(nn.Module):

#

  • d_rope කඹය කාවැද්දීම් සඳහා විශේෂාංග ගණන
  • base සඳහා පදනම වේ θi​=10000d2(i−1)​, එය පැහැර හරින 10000
75def\_\_init\_\_(self,d\_rope:int,base:float=10\_000.):

#

80super().\_\_init\_\_()

#

විශේෂාංග θi​ සඳහා ගබඩා කිරීමට

83self.theta=None

#

හැඹිලිය cosmθi​ සහ sinmθi​

85self.cos\_cached=None86self.sin\_cached=None

#

සඳහාමූලික θi​=10000d2(i−1)​

89self.base=base

#

කඹයසඳහා විශේෂාංග ගණන

91self.d\_rope=d\_rope

#

විශේෂාංගකරකවන්න

[−x(2d​+1),−x(2d​+2),...,−x(d),x(1),x(2),...,−x(2d​)]

93@staticmethod94defrotate\_half(x:torch.Tensor):

#

100x1,x2=x[...,:x.shape[-1]//2],x[...,x.shape[-1]//2:]101returntorch.cat((-x2,x1),dim=-1)

#

  • x හැඩය ඇත [..., seq, n_heads, d_k]
  • offset යනු ආරම්භක ස්ථානයයි x . පෙර තනතුරු වල යතුරු සහ විමසුම් අප හැඹිලි කළ >0 විට මෙය සිදු වේ
103defforward(self,x:torch.Tensor,offset:int=0):

#

සත්යඅනුක්රමයේ දිග ලබා ගන්න

111seq\_len=x.shape[-3]+offset

#

ආරම්භකරන්න θ

114ifself.thetaisNone:

#

θi​=10000d2(i−1)​

116theta=1.0/(self.base\*\*(torch.arange(0,self.d\_rope,2).float()/self.d\_rope))117self.theta=theta.to(x.device).to(x.dtype)

#

ආරම්භකරන්න cosmθi​ සහ sinmθi​ හැඹිලිය

120if(121self.cos\_cachedisNoneor122seq\_len\>self.cos\_cached.shape[1]or123self.cos\_cached.device!=x.deviceor124self.cos\_cached.dtype!=x.dtype125):

#

ස්ථානදර්ශක ලබා ගන්න m

127seq\_idx=torch.arange(seq\_len,device=x.device).type\_as(self.theta)

#

mθi​

129idx\_theta=torch.einsum("s,d-\>sd",seq\_idx,self.theta)

#

පේළියසඳහා m අපට ඇති ��රිදි සංයුක්ත කරන්න

[mθ0​,mθ1​,...,mθ2d​​,mθ0​,mθ1​,...,mθ2d​​]

133idx\_theta2=torch.cat((idx\_theta,idx\_theta),dim=-1).to(x.device)

#

ගණනයකරන්න cosmθi​ සහ sinmθi​ fp32

136withautocast(enabled=False):137idx\_theta2=idx\_theta2.float()

#

හිසමානයක් එක් කරන්න

139self.cos\_cached=idx\_theta2.cos()[:,None,:]140self.sin\_cached=idx\_theta2.sin()[:,None,:]

#

ඒවාහැඹිලිය

143self.cos\_cached=self.cos\_cached.to(x.dtype)144self.sin\_cached=self.sin\_cached.to(x.dtype)

#

විශේෂාංගබෙදන්න. d_rope විශේෂාංග සඳහා පමණක් අපි කඹය යොදන්නෙමු

147x\_rope,x\_pass=x[...,:self.d\_rope],x[...,self.d\_rope:]

#

හැඹිලියසිට පාපය සහ කෝස් අගයන් ලබා ගන්න

150cos,sin=self.cos\_cached[offset:seq\_len],self.sin\_cached[offset:seq\_len]

#

කඹයකාවැද්දීම්

(xm(i)​cosmθi​−xm(i+2d​)​sinmθi​xm(i+2d​)​cosmθi​+xm(i)​sinmθi​​)​

සඳහා i∈1,2,...,2d​

162x\_rope=(x\_rope\*cos)+(self.rotate\_half(x\_rope)\*sin)

#

කඹයකාවැද්දීම් ලබා නොගත් විශේෂාංග සමඟ සංයුක්ත වන්න

165returntorch.cat((x\_rope,x\_pass),dim=-1)

#

අවධානයස්ථරය

168classAttentionLayer(nn.Module):

#

  • n_hidden කාවැද්දීම් වල විශේෂාංග ගණන
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව
  • rope_percentage කඹය කාවැද්දීම් එකතු කිරීම සඳහා විශේෂාංග ප්රතිශතය
  • mask_fill අවධානය යොමු න්යාසය සඳහා ආවරණ පිරවුම් අගය
  • is_flash_attentionFlashAttention භාවිතා කළ යුතුද යන්න නියම කරයි
173def\_\_init\_\_(self,n\_hidden:int=6\_144,n\_heads:int=64,rope\_percentage:float=0.25,174mask\_fill:float=-10\_000.0,\*,is\_flash\_attention:bool=False):

#

183super().\_\_init\_\_()184185self.n\_heads=n\_heads186self.mask\_fill=mask\_fill

#

විමසුම, යතුර සහ වටිනාකම සඳහා රේඛීය ස්ථරය

189self.qkv\_lin=nn.Linear(n\_hidden,n\_hidden\*3)

#

අවසානරේඛීය ස්ථරය

191self.output=nn.Linear(n\_hidden,n\_hidden)

#

හිසකටවිශේෂාංග ගණන

194d\_k=n\_hidden//n\_heads

#

කඹයකාවැද්දීම මොඩියුලය

196self.rope=RoPE(int(d\_k\*rope\_percentage))

#

අවධානයපරිමාණ සාධකය

199self.scale=1/math.sqrt(d\_k)

#

හේතුආවරණ හැඹිලි කිරීමට

202self.causal\_mask=None

#

අවධානයයොමු කරන්න සොෆ්ට්මැක්ස් මොඩියුලය

205self.softmax=nn.Softmax(dim=-2)

#

ෆ්ලෑෂ් අවධානය

208ifis\_flash\_attention:209try:210fromflash\_attn.flash\_attentionimportFlashAttention211self.flash\_attention=FlashAttention()212exceptImportError:213logger.log('Install flash attention github.com/HazyResearch/flash-attention. '214'Falling back to normal attention',Text.warning)215self.flash\_attention=None216else:217self.flash\_attention=None

#

හේතුආවරණ ගණනය කරන්න

219def\_get\_mask(self,attn:torch.Tensor):

#

විමසුමසහ යතුරු දිග

227nq,nk=attn.shape[1:3]

#

වෙස්මුහුණ සාදන්න

230if(231self.causal\_maskisNoneor232self.causal\_mask.shape[0]!=nqor233self.causal\_mask.shape[1]!=nkor234self.causal\_mask.device!=attn.device235):236self.causal\_mask=torch.triu(attn.new\_ones([nq,nk],dtype=torch.bool),1+nk-nq)

#

හැඹිලියසිට ආපසු

239returnself.causal\_mask[None,:,:,None]

#

  • x හැඩය ඇත [batch_size, seq_len, n_hidden]
241defforward(self,x:torch.Tensor):

#

විමසුම, යතුර සහ වටිනාකම් කාවැද්දීම් ලබා ගන්න (සියල්ල සංයුක්ත කර ඇත). පසුගිය මානයක් ප්රමාණය n_hidden -> සිට වෙනස් වනු ඇත 3 x n_hidden

247qkv=self.qkv\_lin(x)

#

හැඩයවෙනස් කිරීමෙන් හිස් වලට බෙදන්න [batch_size, seq_len, n_heads, 3 * d_k]

250qkv=qkv.view(\*qkv.shape[:-1],self.n\_heads,-1)

#

විමසුමටබෙදන්න, යතුර සහ හැඩය එක් එක් අගය කරන්න [batch_size, seq_len, n_heads, 3 * d_k]

252q,k,v=torch.split(qkv,qkv.shape[-1]//3,dim=-1)

#

අපිපෙර ටෝකන වල තත්වයන් හැඹිලි කරන්නේ නම්

255ifget\_cache().get('use\_cache',False):

#

රාජ්යid ගේ ලබා ගන්න. අපි පෙර රාජ්යයන් ලබා ගැනීමට හා ඉදිරි රාජ්යයන් ගබඩා කිරීම සඳහා භාවිතා

257prev\_state\_id,next\_state\_id=get\_cache().get('state\_ids')

#

හැඹිලියතිබේ නම්

259ifprev\_state\_idisnotNone:

#

අතීතයතුරු සහ අගයන් ලබා ගන්න. මේවාට හැඩය ඇත [batch_size, prev_seq_len, n_heads, d_k]

261k\_past,v\_past=get\_cache().pop(f'attn\_kv\_{prev\_state\_id}')

#

වත්මන්කාවැද්දීම් වල ඕෆ්සෙට්

263offset=k\_past.shape[1]

#

කඹයකාවැද්දීම් එකතු කරන්න

266q=self.rope(q,offset=offset)267k=self.rope(k,offset=offset)

#

අතීතයසංයුක්ත කරන්න

270k=torch.cat([k\_past,k],dim=1)271v=torch.cat([v\_past,v],dim=1)272else:

#

කඹයකාවැද්දීම් එකතු කරන්න

274q=self.rope(q)275k=self.rope(k)

#

වත්මන්තත්වය සුරකින්න

278get\_cache().push(f'attn\_kv\_{next\_state\_id}',(k,v))279else:

#

හැඹිලියක්නැත - කඹය කාවැද්දීම් එකතු කරන්න

281q=self.rope(q)282k=self.rope(k)

#

ෆ්ලෑෂ් අවධානය භාවිතා කරන්න

285ifself.flash\_attentionisnotNoneandq.shape[1]==k.shape[1]andq.shape[-1]\<=128:286output=self.compute\_flash\_attention(q,k,v)

#

එසේ නොමැති නම්, සාමාන්ය අවධානය භාවිතා කරන්න

288else:289output=self.compute\_attention(q,k,v)

#

[batch_size, seq_len, n_heads, d_k] toBatch_size, seq_len, n_hidden`වෙතින් නැවත සකස් කරන්න

292output=output.reshape(\*x.shape)

#

අවසානරේඛීය ස්ථරය

295returnself.output(output)

#

297defcompute\_flash\_attention(self,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor):

#

ඒවා හැඩයට ගොඩගසන්න[batch_size, seq_len, 3, n_heads, d_k]

299qkv=torch.stack((q,k,v),dim=2)300d\_k=qkv.shape[-1]301ifd\_k\<=32:302pad=32-d\_k303elifd\_k\<=64:304pad=64-d\_k305elifd\_k\<=128:306pad=128-d\_k307else:308raiseValueError(f'Head size {d\_k} too large for flash attention')309310ifpad\>0:311qkv=torch.cat((qkv,qkv.new\_zeros(\*qkv.shape[:-1],pad)),dim=-1)312313output,\_=self.flash\_attention(qkv,causal=True)

#

ප්රතිදානය හැඩයෙන් යුක්ත වේ[batch_size, seq_len, n_heads, d_k + padding]

315output=output[:,:,:,:d\_k]316317returnoutput

#

319defcompute\_attention(self,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor):

#

අවධානයගණනය කිරීම සඳහා fp16 කිරීමට ස්වයංක්රීය-වාත්තු අක්රීය

321withautocast(enabled=False):322ifq.dtype==torch.float16:

#

වත්මන්dtype fp16 නම් fp32 බවට පරිවර්තනය කරන්න

324attn=torch.einsum('bihk,bjhk-\>bijh',q.float(),k.float())325else:

#

bfloatසඳහා වාත්තු නොකරන්න

327attn=torch.einsum('bihk,bjhk-\>bijh',q,k)

#

පරිමාණඅවධානය

330attn=attn\*self.scale

#

හේතුවෙස්මුහුණ ලබා ගන්න

333mask=self.\_get\_mask(attn)

#

වෙස්යොදන්න

335attn.masked\_fill\_(mask,self.mask\_fill)

#

අවධානයසොෆ්ට්මැක්ස්

338attn=self.softmax(attn)

#

අවධානයබර තැබූ අගයන් ලබා ගන්න

341output=torch.einsum('bijh,bjhk-\>bihk',attn.to(v.dtype),v)342343returnoutput

#

ප්රතිපෝෂණජාලය

346classFFNLayer(nn.Module):

#

  • n_hidden කාවැද්දීම ප්රමාණය වේ
351def\_\_init\_\_(self,n\_hidden:int=6\_144,d\_ff:int=0):

#

355super().\_\_init\_\_()356357ifnotd\_ff:358d\_ff=n\_hidden\*4

#

පුළුල්රේඛීය ස්ථරය

361self.dense\_h\_h4=nn.Linear(n\_hidden,d\_ff)

#

GELUසක්රිය කිරීම

363self.activation=nn.GELU()

#

සංකෝචනයරේඛීය ස්ථරය

365self.dense\_h4\_h=nn.Linear(d\_ff,n\_hidden)

#

  • x හැඩය ඇත [batch_size, seq_len, n_hidden]
367defforward(self,x:torch.Tensor):

#

371x=self.dense\_h\_h4(x)372x=self.activation(x)373x=self.dense\_h4\_h(x)374375returnx

#

ට්රාන්ස්ෆෝමර්ස්ථරය

378classTransformerLayer(NeoXModule):

#

  • n_hidden කාවැද්දීම ප්රමාණය වේ
  • n_heads හිස් සංඛ්යාව වේ
  • is_flash_attentionFlashAttention භාවිතා කළ යුතුද යන්න නියම කරයි

පිටත ක්රියාත්මක කිරීම අතහැර දැමීම ඇතුළත් නොවේ.

383def\_\_init\_\_(self,n\_hidden:int=6\_144,n\_heads:int=64,\*,is\_flash\_attention:bool=False):

#

392super().\_\_init\_\_()

#

අවධානයටපෙර ස්ථර සාමාන්යකරණය

395self.pre\_ln\_attn=nn.LayerNorm(n\_hidden)

#

FFNට පෙර ස්ථර සාමාන්යකරණය

397self.pre\_ln\_ffn=nn.LayerNorm(n\_hidden)

#

අවධානයස්ථරය

400self.attention=AttentionLayer(n\_hidden,n\_heads,is\_flash\_attention=is\_flash\_attention)

#

FFNස්ථරය

402self.ffn=FFNLayer(n\_hidden)

#

  • x හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, n_hidden]
404defforward(self,x:torch.Tensor):

#

අවශේෂසම්බන්ධතාවය

410residual=x

#

නියෝක්ස්සමාන්තරව අවධානය සහ ප්රතිපෝෂණ ජාලය ක්රියාත්මක කරයි

412attn=self.attention(self.pre\_ln\_attn(x))413ffn=self.ffn(self.pre\_ln\_ffn(x))

#

ඒවාසහ අවශේෂ සම්බන්ධතාවය එකතු කරන්න

415returnattn+ffn+residual

#

මුරපොලපූරණය කිරීමට කේතය

417defload\_state(self,p1:Dict[str,torch.Tensor],p2:Dict[str,torch.Tensor]):

#

421withmonit.section('Load transformer layer'):

#

අවධානයයොමු ප්රතිදානය පරිණාමනය

423checkpoint.merge\_params\_sum(self.attention.output.bias,'attention.dense.bias',p1,p2)424checkpoint.merge\_params\_dim\_1(self.attention.output.weight,'attention.dense.weight',p1,p2)

#

අවධානයවිමසුම, යතුර සහ අගය පරිවර්තනය කිරීම

427checkpoint.merge\_params\_dim\_0(self.attention.qkv\_lin.bias,'attention.query\_key\_value.bias',p1,p2)428checkpoint.merge\_params\_dim\_0(self.attention.qkv\_lin.weight,'attention.query\_key\_value.weight',p1,p2)

#

අවධානයටපෙර ස්ථර සම්මතය

431checkpoint.merge\_params\_duplicate(self.pre\_ln\_attn.bias,'input\_layernorm.bias',p1,p2)432checkpoint.merge\_params\_duplicate(self.pre\_ln\_attn.weight,'input\_layernorm.weight',p1,p2)

#

FFNදෙවන පරිණාමනය

435checkpoint.merge\_params\_dim\_0(self.ffn.dense\_h\_h4.bias,'mlp.dense\_h\_to\_4h.bias',p1,p2)436checkpoint.merge\_params\_dim\_0(self.ffn.dense\_h\_h4.weight,'mlp.dense\_h\_to\_4h.weight',p1,p2)

#

FFNපළමු පරිවර්තනය

439checkpoint.merge\_params\_sum(self.ffn.dense\_h4\_h.bias,'mlp.dense\_4h\_to\_h.bias',p1,p2)440checkpoint.merge\_params\_dim\_1(self.ffn.dense\_h4\_h.weight,'mlp.dense\_4h\_to\_h.weight',p1,p2)

#

FFNට පෙර ස්ථර සම්මතය

443checkpoint.merge\_params\_duplicate(self.pre\_ln\_ffn.bias,'post\_attention\_layernorm.bias',p1,p2)444checkpoint.merge\_params\_duplicate(self.pre\_ln\_ffn.weight,'post\_attention\_layernorm.weight',p1,p2)

#

අවසානසාමාන්යකරණ ස්තරය

447classFinalNorm(NeoXModule):

#

  • n_hidden කාවැද්දීම ප්රමාණය වේ
452def\_\_init\_\_(self,n\_hidden:int=6\_144):

#

456super().\_\_init\_\_()457458self.ln=nn.LayerNorm(n\_hidden)

#

  • x හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, n_hidden]
460defforward(self,x:torch.Tensor):

#

464returnself.ln(x)

#

මුරපොලපූරණය කිරීමට කේතය

466defload\_state(self,p1:Dict[str,torch.Tensor],p2:Dict[str,torch.Tensor]):

#

470withmonit.section('Load final normalization layer'):471checkpoint.merge\_params\_duplicate(self.ln.bias,'norm.bias',p1,p2)472checkpoint.merge\_params\_duplicate(self.ln.weight,'norm.weight',p1,p2)

#

කියවීමේස්ථරය

475classReadoutLayer(NeoXModule):

#

  • n_hidden කාවැද්දීම ප්රමාණය වේ
  • n_vocab වචන මාලාවේ ප්රමාණය වේ
480def\_\_init\_\_(self,n\_hidden:int=6\_144,n\_vocab:int=50\_432):

#

485super().\_\_init\_\_()486487self.linear=nn.Linear(n\_hidden,n\_vocab,bias=False)

#

  • x හැඩයේ කාවැද්දීම් වේ [batch_size, seq_len, n_hidden]
489defforward(self,x:torch.Tensor):

#

493returnself.linear(x)

#

මුරපොලපූරණය කිරීමට කේතය

495defload\_state(self,p1:Dict[str,torch.Tensor],p2:Dict[str,torch.Tensor]):

#

499withmonit.section('Load final linear layer'):500checkpoint.merge\_params\_dim\_0(self.linear.weight,'final\_linear.weight',p1,p2)

#

503classLayerGenerator:504pre\_created\_layers:Dict[Any,Optional[NeoXModule]]

#

ස්ථර නිර්මාණය කිරීමට උත්පාදක යන්ත්රය

ස්ථර ජනනය කරනු ලබන්නේ මුරපොලවල් මෙන් එකම අනුපිළිවෙලකි.

ස්ථරයක් නොමැතිNone විට එය ලබා දෙයි; අපි ස්ථර දර්ශක නියෝක්ස් ලෙස භාවිතා කරන අතර අපගේ ක්රියාත්මක කිරීමේදී අපට අවශ්ය නොවන පරිවර්තන ස්ථර දෙකක් තිබේ.

  • n_vocab යනු වචන මාලාවේ ටෝකන ගණන
  • n_hidden මෙම කාවැද්දීම් දී විශේෂාංග සංඛ්යාව
  • n_layers ට්රාන්ස්ෆෝමර් ස්ථර ගණන වේ
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
  • filter_layers භාවිතා කළ යුතු ස්ථර සමූහයයි. කිසිවක් නොමැති නම් සියලුම ස්ථර භාවිතා කරනු ඇත. අඩු ස්ථර සහිත ආකෘතියේ කුඩා අනුවාදයන් පරීක්ෂා කිරීමට මෙය භාවිතා කරයි- is_clone_layers ට්රාන්ස්ෆෝමර් ස්ථර ක්ලෝන කළ යුතුද යන්න නියම කරයි (ටිකක් වේගවත්)
  • dtype ආකෘතියේ දත්ත වර්ගයයි
  • device ආකෘතියේ උපාංගය වේ
  • is_llm_int8 INT8 ප්රමාණකරණය භාවිතා කළ යුතුද යන්න නියම කරයි
  • llm_int8_threshold පිටත විශේෂාංග වෙන් කිරීමα සඳහා භාවිතා කරන එළිපත්ත වේ
  • is_flash_attentionFlashAttention භාවිතා කළ යුතුද යන්න නියම කරයි
506def\_\_init\_\_(self,\*,n\_vocab:int=50\_432,n\_hidden:int=6\_144,507n\_layers:int=44,n\_heads:int=64,508filter\_layers:Optional[Set]=None,509is\_clone\_layers:bool=True,510dtype:torch.dtype=torch.float,511device:torch.device=torch.device('cpu'),512is\_llm\_int8:bool=False,513llm\_int8\_threshold:float=6.0,514is\_flash\_attention:bool=False515):

#

538iffilter\_layersisNone:539filter\_layers=set(range(n\_layers+3))540541self.n\_vocab=n\_vocab542self.n\_hidden=n\_hidden543self.n\_layers=n\_layers544self.n\_heads=n\_heads545self.filter\_layers=filter\_layers546self.is\_clone\_layers=is\_clone\_layers547self.dtype=dtype548self.device=device549self.is\_llm\_int8=is\_llm\_int8550self.llm\_int8\_threshold=llm\_int8\_threshold551self.is\_flash\_attention=is\_flash\_attention552553self.pre\_created\_layers=dict(554transformer\_layer=None,555)

#

භාවිතයසඳහා ස්තරය සකස් කරයි

අපිස්තරය උපාංගය වෙත ගෙන ගොස් නිවැරදි දත්ත වර්ගයට පරිවර්තනය කරමු

  • layer සකස් කළ යුතු ස්ථරයයි

සකස්කළ ස්තරය_ආපසු ලබා දෙයි_

557def\_prepare\_layer(self,layer:NeoXModule):

#

566returnlayer.to(self.device,self.dtype)

#

පිරික්සුම්ස්ථානය පැටවීමෙන් පසු ස්ථර පරිවර්තනයන්

මෙමශ්රිතය පිරික්සුම් ස්ථානය පැටවීමෙන් පසු ස්ථර පරිවර්තනයන් ක්රියාත්මක කරයි.

දැනටඑය අදාළ වන්නේ INT8 ප්රමාණකරණය පමණි.

  • layer සකස් කළ යුතු ස්ථරයයි
  • is_llm_int8 INT8 ප්රමාණකරණය භාවිතා කළ යුතුද යන්න නියම කරයි
  • device ආකෘතියේ උපාංගය වේ
  • llm_int8_threshold යනු පිටත විශේෂාංග වෙන් කිරීම α සඳහා භාවිතා කරන එළිපත්ත

සකස්කළ ස්තරය_ආපසු ලබා දෙයි_

[email protected]\_grad()569defpost\_load\_prepare(self,layer:NeoXModule,\*,570is\_llm\_int8:bool=None,571device:torch.device=None,572llm\_int8\_threshold:float=None,573):

#

නියමකර නොමැති නම් පෙරනිමි අගයන් ලබා ගන්න

591ifis\_llm\_int8isNone:592is\_llm\_int8=self.is\_llm\_int8593ifdeviceisNone:594device=self.device595ifllm\_int8\_thresholdisNone:596llm\_int8\_threshold=self.llm\_int8\_threshold

#

INT8ප්රමාණකරණය භාවිතා නොකරන්නේ නම් මඟ හරින්න

599ifnotis\_llm\_int8:600returnlayer

#

ට්රාන්ස්ෆෝමර්ස්ථර වල රේඛීය ස්ථර පමණක් පරිවර්තනය කරන්න

603ifnotisinstance(layer,TransformerLayer):604returnlayer

#

උපයෝගිතාවල make_llm_int8_linear අර්ථ දක්වා ඇති භාවිතා කරන්න.

607fromlabml\_nn.neox.utils.llm\_int8importmake\_llm\_int8\_linear

#

රේඛීයස්ථර පරිවර්තනය කරන්න

610withmonit.section('Convert to int8'):611layer.attention.output=make\_llm\_int8\_linear(layer.attention.output,612device=device,613threshold=llm\_int8\_threshold)614layer.attention.qkv\_lin=make\_llm\_int8\_linear(layer.attention.qkv\_lin,615device=device,616threshold=llm\_int8\_threshold)617layer.ffn.dense\_h\_h4=make\_llm\_int8\_linear(layer.ffn.dense\_h\_h4,618device=device,619threshold=llm\_int8\_threshold)620layer.ffn.dense\_h4\_h=make\_llm\_int8\_linear(layer.ffn.dense\_h4\_h,621device=device,622threshold=llm\_int8\_threshold)

#

624returnlayer

#

ස්තරයක්නිර්මාණය කර හැච් කරයි

පරාමිතීන්ආරම්භ කිරීමට කාලය ගතවන නිසා හැඹිලි ස්ථර පිටපත් කිරීම නව ස්ථර ආරම්භ කිරීමට වඩා වේගවත් වේ.

  • name යනු ස්තරයේ නමයි
  • creator ස්තරය නිර්මාණය කිරීමේ කාර්යයයි

සාදනලද ස්තරය හෝ කැච් ස්ථරයේ පිටපතක්_ආපසු ලබා දෙයි_

626def\_create\_and\_cache\_layer(self,name:str,creator:Callable[[],NeoXModule]):

#

638ifnotself.is\_clone\_layers:639returnself.\_prepare\_layer(creator())640641ifself.pre\_created\_layers[name]isNone:642self.pre\_created\_layers[name]=self.\_prepare\_layer(creator())643644layer=copy.deepcopy(self.pre\_created\_layers[name])645returnlayer

#

647def\_create\_transformer\_layer(self):648returnself.\_create\_and\_cache\_layer(649'transformer\_layer',650lambda:TransformerLayer(self.n\_hidden,self.n\_heads,is\_flash\_attention=self.is\_flash\_attention)651)

#

653def\_create\_embedding\_layer(self):654returnEmbedding(self.n\_vocab,self.n\_hidden)

#

656def\_create\_final\_norm\_layer(self):657returnFinalNorm(self.n\_hidden)

#

659def\_create\_readout\_layer(self):660returnReadoutLayer(self.n\_hidden,self.n\_vocab)

#

ස්ථරලබා ගැනීම සඳහා උත්පාදක යන්ත්රය

[email protected]\_grad()663defget\_layers(self)-\>Generator[Tuple[NeoXModule,Tuple[str,str]],None,None]:

#

කාවැද්දීමස්ථරය

668if0inself.filter\_layers:669withmonit.section('Embedding layer'):670layer=self.\_prepare\_layer(self.\_create\_embedding\_layer())671yieldlayer,('layer\_00-model\_00-model\_states.pt','layer\_00-model\_01-model\_states.pt')

#

ට්රාන්ස්ෆෝමර්ස්ථර

674foriinrange(self.n\_layers):

#

ට්රාන්ස්ෆෝමර්ස්ථරය

676ifi+1inself.filter\_layers:677withmonit.section(f'Transformer Layer {i}'):678yieldself.\_create\_transformer\_layer(),\679(f'layer\_{i+2:02d}-model\_00-model\_states.pt',680f'layer\_{i+2:02d}-model\_01-model\_states.pt')

#

අවසානසාමාන්යකරණ ස්තරය

683ifself.n\_layers+1inself.filter\_layers:684withmonit.section('Final norm layer'):685layer=self.\_prepare\_layer(self.\_create\_final\_norm\_layer())686yieldlayer,('layer\_47-model\_00-model\_states.pt','layer\_47-model\_01-model\_states.pt')

#

කියවීමේස්ථරය

689ifself.n\_layers+2inself.filter\_layers:690withmonit.section('Readout layer'):691layer=self.\_prepare\_layer(self.\_create\_readout\_layer())692yieldlayer,('layer\_48-model\_00-model\_states.pt','layer\_48-model\_01-model\_states.pt')693694forkinself.pre\_created\_layers.keys():695self.pre\_created\_layers[k]=None

#

මුළුස්ථර ගණන නැවත ලබා දෙයි

697@property698deftotal\_layers(self):

#

702returnself.n\_layers+3

#

ස්ථරපැටවීමට උත්පාදක යන්ත්රය

[email protected]\_grad()705defload(self)-\>Generator[NeoXModule,None,None]:

#

709withmonit.section("Layers"):710fori,(layer,files)inenumerate(self.get\_layers()):711iffilesisnotNone:712layer.load\_state(\*checkpoint.load\_checkpoint\_files(files))713714layer=self.post\_load\_prepare(layer)715716monit.progress(min(0.99,(i+1)/self.total\_layers))717yieldlayer

Trending Research Paperslabml.ai