Back to Annotated Deep Learning Paper Implementations

කටුසටහනක් RNN

docs/si/sketch_rnn/index.html

latest34.0 KB
Original Source

homesketch_rnn

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/sketch_rnn/ init.py)

#

කටුසටහනක් RNN

මෙයකඩදාසි පිළිබඳ විවරණ කරන ලද PyTorch ක්රියාත්මක කිරීම ස්කෙච් ඇඳීම්වල ස්නායුක නිරූපණයකි .

ස්කෙච්ආර්එන්එන් යනු අනුක්රම-සිට-අනුක්රමික විචල්යතා ස්වයංක්රීය ආකේතයකි. එන්කෝඩරය සහ විකේතකය යන දෙකම පුනරාවර්තන ස්නායුක ජාල ආකෘති වේ. ආ roke ාත මාලාවක් පුරෝකථනය කිරීමෙන් ආ roke ාතය පදනම් කරගත් සරල ඇඳීම් ප්රතිනිර්මාණය කිරීමට එය ඉගෙන ගනී. එක් එක් ආ roke ාතය ගවුසියානු මිශ්රණයක් ලෙස විකේතකය පුරෝකථනය කරයි.

දත්තලබා ගැනීම

ඉක්මන් වෙතින් දත්ත බාගන්න, අඳින්න! දත්ත කට්ටලය. Sketch-RNNහි npz ගොනු බාගත කිරීම සඳහා සබැඳියක් ඇත QuickDraw දත්ත සමුදාය කියවීමේ කොටස. බාගත කළ npz ගොනුව (ය) data/sketch ෆෝල්ඩරයේ තබන්න. මෙම කේතය bicycle දත්ත කට්ටලය භාවිතා කිරීමට වින්යාස කර ඇත. ඔබට මෙය වින්යාසයන් තුළ වෙනස් කළ හැකිය.

පිළිගැනීම්

ඇලෙක්සිස් ඩේවිඩ් ජැක් විසින් පයිටෝර්ච් ස්කෙච් ආර්එන්එන් ව්යාපෘතියෙන් උදව් ලබා ගත්තේය

32importmath33fromtypingimportOptional,Tuple,Any3435importnumpyasnp36importtorch37importtorch.nnasnn38frommatplotlibimportpyplotasplt39fromtorchimportoptim40fromtorch.utils.dataimportDataset,DataLoader4142importeinops43fromlabmlimportlab,experiment,tracker,monit44fromlabml\_helpers.deviceimportDeviceConfigs45fromlabml\_helpers.moduleimportModule46fromlabml\_helpers.optimizerimportOptimizerConfigs47fromlabml\_helpers.train\_validimportTrainValidConfigs,hook\_model\_outputs,BatchIndex

#

දත්තකට්ටලය

මෙමපන්තිය දත්ත පැටවීම සහ පූර්ව සැකසීම.

50classStrokesDataset(Dataset):

#

datasetSeq_len හැඩයේ හිරිවැටීම් අරා ලැයිස්තුවකි, 3. එය ආ ro ාත අනුපිළිවෙලක් වන අතර සෑම පහරක්ම නිඛිල 3 කින් නිරූපණය කෙරේ. පළමු දෙක x සහ y දිගේ විස්ථාපන වේ (Δx, Δy) සහ අවසාන නිඛිලය පෑනෙහි තත්වය නිරූපණය කරයි, එය කඩදාසි ස්පර්ශ කරන්නේ 1 නම් සහ 0 එසේ නොමැතිනම්.

57def\_\_init\_\_(self,dataset:np.array,max\_seq\_length:int,scale:Optional[float]=None):

#

67data=[]

#

අපිඑක් එක් අනුපිළිවෙල සහ පෙරීම හරහා නැවත නැවතත්

69forseqindataset:

#

ආro ාත අනුක්රමයේ දිග අපගේ පරාසය තුළ තිබේ නම් පෙරහන් කරන්න

71if10\<len(seq)\<=max\_seq\_length:

#

කලම්ප Δx, Δy කිරීමට [−1000,1000]

73seq=np.minimum(seq,1000)74seq=np.maximum(seq,-1000)

#

පාවෙනලක්ෂ්ය අරාවකට පරිවර්තනය කර එකතු කරන්න data

76seq=np.array(seq,dtype=np.float32)77data.append(seq)

#

එවිටඅපි (Δx, Δy) ඒකාබද්ධ සම්මත අපගමනය වන පරිමාණ සාධකය ගණනය. මධ්යන්යය කෙසේ වෙතත් සමීප බැවින් සරල බව සඳහා මධ්යන්යය සකස් කර නොමැති බව කඩදාසි සටහන් 0කරයි.

83ifscaleisNone:84scale=np.std(np.concatenate([np.ravel(s[:,0:2])forsindata]))85self.scale=scale

#

සියලුමඅනුපිළිවෙලවල් අතර දිගම අනුක්රමික දිග ලබා ගන්න

88longest\_seq\_len=max([len(seq)forseqindata])

#

ආරම්භකඅනුක්රමය (sos) සහ අවසාන අනුක්රමය (eos) සඳහා අමතර පියවර දෙකක් සමඟ අපි PyTorch දත්ත අරාව ආරම්භ කරමු. සෑම පියවරක්ම දෛශිකයකි (Δx,Δy,p1​,p2​,p3​). එකක් පමණක් p1​,p2​,p3​ වන 1 අතර අනෙක් ඒවා වේ 0. ඒවා නියෝජනය කරන්නේ පෑන පහළට, _පෑන ඉහළට_සහ එම _අනුපිළිවෙලට අනුක්රමය අවසන්_කිරීමයි. p1​ ඊළඟ පියවරේදී පෑන කඩදාසි ස්පර්ශ 1 කරන්නේ නම්. p2​ ඊළඟ පියවරේදී පෑන කඩදාසි ස්පර්ශ නොකරන්නේ 1 නම්. p3​ එය චිත්රයේ අවසානය 1 නම් වේ.

98self.data=torch.zeros(len(data),longest\_seq\_len+2,5,dtype=torch.float)

#

වෙස්අරා අවශ්ය වන්නේ එක් අමතර පියවරක් පමණි, මන්ද එය විකේතනයේ ප්රතිදානයන් සඳහා වන data[:-1] අතර එය ඊළඟ පියවර ගනී.

101self.mask=torch.zeros(len(data),longest\_seq\_len+1)102103fori,seqinenumerate(data):104seq=torch.from\_numpy(seq)105len\_seq=len(seq)

#

පරිමාණයසහ කට්ටලය Δx,Δy

107self.data[i,1:len\_seq+1,:2]=seq[:,:2]/scale

#

p1​

109self.data[i,1:len\_seq+1,2]=1-seq[:,2]

#

p2​

111self.data[i,1:len\_seq+1,3]=seq[:,2]

#

p3​

113self.data[i,len\_seq+1:,4]=1

#

අනුක්රමයඅවසන් වන තෙක් මාස්ක් ක්රියාත්මක වේ

115self.mask[i,:len\_seq+1]=1

#

අනුපිළිවෙලආරම්භ කිරීම (0,0,1,0,0)

118self.data[:,0,2]=1

#

දත්තසමුදාය ප්රමාණය

120def\_\_len\_\_(self):

#

122returnlen(self.data)

#

නියැදියක්ලබා ගන්න

124def\_\_getitem\_\_(self,idx:int):

#

126returnself.data[idx],self.mask[idx]

#

ද්වි-විචල්යගවුසියානු මිශ්රණය

මිශ්රණයනියෝජනය වන්නේ Π සහ N(μx​,μy​,σx​,σy​,ρxy​). මෙම පන්තිය උෂ්ණත්වය සකස් කරන අතර පරාමිතීන්ගෙන් වර්ගීකරණ හා ගුසියානු බෙදාහැරීම් නිර්මාණය කරයි.

129classBivariateGaussianMixture:

#

139def\_\_init\_\_(self,pi\_logits:torch.Tensor,mu\_x:torch.Tensor,mu\_y:torch.Tensor,140sigma\_x:torch.Tensor,sigma\_y:torch.Tensor,rho\_xy:torch.Tensor):141self.pi\_logits=pi\_logits142self.mu\_x=mu\_x143self.mu\_y=mu\_y144self.sigma\_x=sigma\_x145self.sigma\_y=sigma\_y146self.rho\_xy=rho\_xy

#

මිශ්රණයේබෙදාහැරීම් ගණන, M

148@property149defn\_distributions(self):

#

151returnself.pi\_logits.shape[-1]

#

උෂ්ණත්වයඅනුව සකසන්න τ

153defset\_temperature(self,temperature:float):

#

Πk​^​←τΠk​^​​

158self.pi\_logits/=temperature

#

σx2​←σx2​τ

160self.sigma\_x\*=math.sqrt(temperature)

#

σy2​←σy2​τ

162self.sigma\_y\*=math.sqrt(temperature)

#

164defget\_distribution(self):

#

කලම්ප σx​, σy​ සහ NaN s ලබා වළක්වා ρxy​ ගැනීමට

166sigma\_x=torch.clamp\_min(self.sigma\_x,1e-5)167sigma\_y=torch.clamp\_min(self.sigma\_y,1e-5)168rho\_xy=torch.clamp(self.rho\_xy,-1+1e-5,1-1e-5)

#

මාධ්යයන්ලබා ගන්න

171mean=torch.stack([self.mu\_x,self.mu\_y],-1)

#

කෝවිචියන්ස්අනුකෘතිය ලබා ගන්න

173cov=torch.stack([174sigma\_x\*sigma\_x,rho\_xy\*sigma\_x\*sigma\_y,175rho\_xy\*sigma\_x\*sigma\_y,sigma\_y\*sigma\_y176],-1)177cov=cov.view(\*sigma\_y.shape,2,2)

#

ද්වි-විචල්යසාමාන්ය ව්යාප්තියක් සාදන්න.

📝එය [[a, 0], [b, c]] කොහෙද ලෙස scale_tril අනුකෘතිය කාර්යක්ෂම වනු a=σx​,b=ρxy​σy​,c=σy​1−ρxy2​​ඇත. නමුත් සරල බව සඳහා අපි සම-විචල්යතා අනුකෘතිය භාවිතා කරමු. ද්වි-විචල්ය බෙදාහැරීම්, ඒවායේ සම-විචල්යතා අනුකෘතිය සහසම්භාවිතා dens නත්ව ක්රියාකාරිත්වය ගැන වැඩිදුර කියවීමට ඔබට අවශ්ය නම් මෙය හොඳ සම්පතකි .

188multi\_dist=torch.distributions.MultivariateNormal(mean,covariance\_matrix=cov)

#

පිවිසුම් Π වලින් වර්ගීකරණ බෙදාහැරීමක් සාදන්න

191cat\_dist=torch.distributions.Categorical(logits=self.pi\_logits)

#

194returncat\_dist,multi\_dist

#

එන්කෝඩර්මොඩියුලය

මෙයද්විපාර්ශ්වික LSTM කින් සමන්විත වේ

197classEncoderRNN(Module):

#

204def\_\_init\_\_(self,d\_z:int,enc\_hidden\_size:int):205super().\_\_init\_\_()

#

ආදානය (Δx,Δy,p1​,p2​,p3​) ලෙස අනුක්රමයක් ගනිමින් ද්විපාර්ශ්වික LSTM සාදන්න.

208self.lstm=nn.LSTM(5,enc\_hidden\_size,bidirectional=True)

#

ලබාගැනීමට ප්රධානියා μ

210self.mu\_head=nn.Linear(2\*enc\_hidden\_size,d\_z)

#

ලබාගැනීමට ප්රධානියා σ^

212self.sigma\_head=nn.Linear(2\*enc\_hidden\_size,d\_z)

#

214defforward(self,inputs:torch.Tensor,state=None):

#

ද්විපාර්ශ්විකඑල්එස්ටීඑම් හි සැඟවුණු තත්වය යනු අවසාන ටෝකනයේ ප්රතිදානය ඉදිරි දිශාවට සහ ප්රතිලෝම දිශාවට පළමු ටෝකනය සංයුක්ත කිරීමයි, එය අපට අවශ්ය දෙයයි. h→​=encode→​(S),h←​=encode←←​(Sreverse​),h=[h→​;h←​]

221\_,(hidden,cell)=self.lstm(inputs.float(),state)

#

රාජ්යයටහැඩය ඇත [2, batch_size, hidden_size] , එහිදී පළමු මානය දිශාව වේ. ලබා ගැනීම සඳහා අපි එය නැවත සකස්

කරමු h=[h→​;h←​]

225hidden=einops.rearrange(hidden,'fb b h -\> b (fb h)')

#

μ

228mu=self.mu\_head(hidden)

#

σ^

230sigma\_hat=self.sigma\_head(hidden)

#

σ=exp(2σ^​)

232sigma=torch.exp(sigma\_hat/2.)

#

නියැදිය z=μ+σ⋅N(0,I)

235z=mu+sigma\*torch.normal(mu.new\_zeros(mu.shape),mu.new\_ones(mu.shape))

#

238returnz,mu,sigma\_hat

#

විකේතකමොඩියුලය

මෙයLSTM කින් සමන්විත වේ

241classDecoderRNN(Module):

#

248def\_\_init\_\_(self,d\_z:int,dec\_hidden\_size:int,n\_distributions:int):249super().\_\_init\_\_()

#

LSTMආදානය [(Δx,Δy,p1​,p2​,p3​);z] ලෙස ගනී

251self.lstm=nn.LSTM(d\_z+5,dec\_hidden\_size)

#

LSTMහි ආරම්භක තත්වය වේ [h0​;c0​]=tanh(Wz​z+bz​). init_state මේ සඳහා රේඛීය පරිවර්තනයයි

255self.init\_state=nn.Linear(d\_z,2\*dec\_hidden\_size)

#

මෙමස්තරය එක් එක් සඳහා ප්රතිදානයන් නිෂ්පාදනය කරයි n_distributions . සෑම ව්යාප්තියකටම පරාමිතීන් හයක් අවශ්ය

වේ (Πi​^​,μxi​​,μyi​​,σxi​​^​,σyi​​^​ρxyi​​^​)

260self.mixtures=nn.Linear(dec\_hidden\_size,6\*n\_distributions)

#

මෙමහිස පිවිසුම් සඳහා වේ (q1​^​,q2​^​,q3​^​)

263self.q\_head=nn.Linear(dec\_hidden\_size,3)

#

මෙය log(qk​) කොතැනද යන්න ගණනය කිරීමයි qk​=softmax(q^​)k​=∑j=13​exp(qj​^​)exp(qk​^​)​

266self.q\_log\_softmax=nn.LogSoftmax(-1)

#

මෙමපරාමිතීන් අනාගත යොමු කිරීම සඳහා ගබඩා කර ඇත

269self.n\_distributions=n\_distributions270self.dec\_hidden\_size=dec\_hidden\_size

#

272defforward(self,x:torch.Tensor,z:torch.Tensor,state:Optional[Tuple[torch.Tensor,torch.Tensor]]):

#

ආරම්භකතත්වය ගණනය කරන්න

274ifstateisNone:

#

[h0​;c0​]=tanh(Wz​z+bz​)

276h,c=torch.split(torch.tanh(self.init\_state(z)),self.dec\_hidden\_size,1)

#

h සහ c හැඩයන් [batch_size, lstm_size] ඇත. LSTM හි භාවිතා කරන හැඩය [1, batch_size, lstm_size] නිසා අපට ඒවා හැඩගස්වා ගැනීමට අවශ්යය.

279state=(h.unsqueeze(0).contiguous(),c.unsqueeze(0).contiguous())

#

LSTMධාවනය කරන්න

282outputs,state=self.lstm(x,state)

#

ලබාගන්න log(q)

285q\_logits=self.q\_log\_softmax(self.q\_head(outputs))

#

ලබාගන්න (Πi​^​,μx,i​,μy,i​,σx,i​^​,σy,i​^​ρxy,i​^​). torch.split ප්රතිදානය මානයක් self.n_distribution හරහා ප්රමාණය tensors 6 බවට splits 2 .

291pi\_logits,mu\_x,mu\_y,sigma\_x,sigma\_y,rho\_xy=\292torch.split(self.mixtures(outputs),self.n\_distributions,2)

#

ද්වි-විචල්යGaussian මිශ්රණයක් සාදන්න Π සහ N(μx​,μy​,σx​,σy​,ρxy​) කොහේද σx,i​=exp(σx,i​^​),σy,i​=exp(σy,i​^​),ρxy,i​=tanh(ρxy,i​^​) සහ Πi​=softmax(Π^)i​=∑j=13​exp(Πj​^​)exp(Πi​^​)​

Π මිශ්රණයෙන් බෙදා හැරීම තෝරා ගැනීමේ වර්ගීකරණ සම්භාවිතාවන් N(μx​,μy​,σx​,σy​,ρxy​)වේ.

305dist=BivariateGaussianMixture(pi\_logits,mu\_x,mu\_y,306torch.exp(sigma\_x),torch.exp(sigma\_y),torch.tanh(rho\_xy))

#

309returndist,q\_logits,state

#

ප්රතිසංස්කරණඅලාභය

312classReconstructionLoss(Module):

#

317defforward(self,mask:torch.Tensor,target:torch.Tensor,318dist:'BivariateGaussianMixture',q\_logits:torch.Tensor):

#

ලබා Π ගන්න N(μx​,μy​,σx​,σy​,ρxy​)

320pi,mix=dist.get\_distribution()

#

target අවසාන මානය ලක්ෂණ [seq_len, batch_size, 5] වන හැඩය ඇත (Δx,Δy,p1​,p2​,p3​). අපට අවශ්ය වන්නේ Δx,Δ y ලබා ගැනීමට සහ මිශ්රණයේ ඇති එක් එක් බෙදාහැරීම් වලින් සම්භාවිතාවන් ලබා ගැනීමයි N(μx​,μy​,σx​,σy​,ρxy​).

xy හැඩය ඇත [seq_len, batch_size, n_distributions, 2]

327xy=target[:,:,0:2].unsqueeze(-2).expand(-1,-1,dist.n\_distributions,-1)

#

සම්භාවිතාවගණනය කරන්න p(Δx,Δy)=j=1∑M​Πj​N(Δx,Δy∣μx,j​,μy,j​,σx,j​,σy,j​,ρxy,j​)

333probs=torch.sum(pi.probs\*torch.exp(mix.log\_prob(xy)),2)

#

Ls​=−Nmax​1​i=1∑Ns​​log(p(Δx,Δy)) Nmax​ (longest_seq_len ) මූලද්රව්ය probs තිබුණද, එකතුව ගනු ලබන්නේ ඉතිරිය වන Ns​ බැවිනි වෙස් වලාගෙන.

අපඑකතුව ගෙන බෙදිය යුතු යැයි හැඟෙනු ඇත Nmax​, නමුත් මෙය කෙටි අනුපිළිවෙලින් තනි අනාවැකි සඳහා වැඩි බරක් ලබා දෙනු ඇත. Ns​ අපි බෙදෙන p(Δx,Δy) විට අපි එක් එක් අනාවැකිය සමාන බර

දෙන්න Nmax​

342loss\_stroke=-torch.mean(mask\*torch.log(1e-5+probs))

#

Lp​=−Nmax​1​i=1∑Nmax​​k=1∑3​pk,i​log(qk,i​)

345loss\_pen=-torch.mean(target[:,:,2:]\*q\_logits)

#

LR​=Ls​+Lp​

348returnloss\_stroke+loss\_pen

#

එල්. එල්-අපසරනය අහිමි

මෙයලබා දී ඇති සාමාන්ය බෙදාහැරීමක් අතර KL අපසරනය ගණනය කරයි N(0,1)

351classKLDivLoss(Module):

#

358defforward(self,sigma\_hat:torch.Tensor,mu:torch.Tensor):

#

LKL​=−2Nz​1​(1+σ^−μ2−exp(σ^))

360return-0.5\*torch.mean(1+sigma\_hat-mu\*\*2-torch.exp(sigma\_hat))

#

නියැදිකරු

මෙයවිකේතකයෙන් රූප සටහනක් සාම්පල කර එය ගොඩබෑමට ලක් කරයි

363classSampler:

#

370def\_\_init\_\_(self,encoder:EncoderRNN,decoder:DecoderRNN):371self.decoder=decoder372self.encoder=encoder

#

374defsample(self,data:torch.Tensor,temperature:float):

#

Nmax​

376longest\_seq\_len=len(data)

#

z එන්කෝඩරයෙන් ලබා ගන්න

379z,\_,\_=self.encoder(data)

#

ආරම්භකඅනුක්රමය ආ roke ාතය වේ (0,0,1,0,0)

382s=data.new\_tensor([0,0,1,0,0])383seq=[s]

#

ආරම්භකවිකේතකය වේ None . විකේතකය එය ආරම්භ කරනු ඇත [h0​;c0​]=tanh(Wz​z+bz​)

386state=None

#

අපටඅනුක්රමික අවශ්ය නොවේ

389withtorch.no\_grad():

#

නියැදි Nmax​ පහරවල්

391foriinrange(longest\_seq\_len):

#

[(Δx,Δy,p1​,p2​,p3​);z] යනු විකේතකයට ආදානය කිරීමයි

393data=torch.cat([s.view(1,1,-1),z.unsqueeze(0)],2)

#

N(μx​,μy​,σx​,σy​,ρxy​)q විකේතකයෙන් ඊළඟ තත්වය ලබා ගන්න Π

396dist,q\_logits,state=self.decoder(data,z,state)

#

ආඝාතයසාම්පලයක්

398s=self.\_sample\_step(dist,q\_logits,temperature)

#

ආro ාත අනුපිළිවෙලට නව ආ roke ාතය එක් කරන්න

400seq.append(s)

#

නියැදීමනවත්වන්න නම් p3​=1. මෙයින් ඇඟවෙන්නේ ස්කීච් කිරීම නතර වී ඇති

බවයි

402ifs[4]==1:403break

#

ආro ාත අනුපිළිවෙලෙහි පයිටෝච් ටෙන්සරයක් සාදන්න

406seq=torch.stack(seq)

#

ආro ාත අනුපිළිවෙල සැලසුම් කරන්න

409self.plot(seq)

#

411@staticmethod412def\_sample\_step(dist:'BivariateGaussianMixture',q\_logits:torch.Tensor,temperature:float):

#

නියැදීම් τ සඳහා උෂ්ණත්වය සකසන්න. මෙය පන්තියේ ක්රියාත්මක වේ BivariateGaussianMixture .

414dist.set\_temperature(temperature)

#

උෂ්ණත්වයසකස් කර Π ලබා ගන්න N(μx​,μy​,σx​,σy​,ρxy​)

416pi,mix=dist.get\_distribution()

#

මිශ්රණයෙන්භාවිතා කිරීම සඳහා බෙදා හැරීමේ දර්ශකයෙන් නියැදිය Π

418idx=pi.sample()[0,0]

#

ලොග්-සම්භාවිතාවන් q q_logits සමඟ වර්ගීකරණ බෙදාහැරීමක් සාදන්න q^​

421q=torch.distributions.Categorical(logits=q\_logits/temperature)

#

වෙතින්නියැදිය q

423q\_idx=q.sample()[0,0]

#

මිශ්රණයේසාමාන්ය බෙදාහැරීම් වලින් නියැදිය සහ සුචිගත කරන ලද එක තෝරන්න idx

426xy=mix.sample()[0,0,idx]

#

හිස්ආඝාතය සාදන්න (Δx,Δy,q1​,q2​,q3​)

429stroke=q\_logits.new\_zeros(5)

#

සකසන්න Δx,Δy

431stroke[:2]=xy

#

සකසන්න q1​,q2​,q3​

433stroke[q\_idx+2]=1

#

435returnstroke

#

437@staticmethod438defplot(seq:torch.Tensor):

#

ලබාගැනීම සඳහා සමුච්චිත (Δx,Δy) සාරාංශ ගන්න (x,y)

440seq[:,0:2]=torch.cumsum(seq[:,0:2],dim=0)

#

පෝරමයේනව අංකුර අරා සාදන්න (x,y,q2​)

442seq[:,2]=seq[:,3]443seq=seq[:,0:3].detach().cpu().numpy()

#

කොහෙද q2​ ලකුණු දී අරාව බෙදන්න 1. i.e. පෑන කඩදාසි සිට ඔසවා එහිදී ලකුණු දී පිලි අරාව බෙදී. මෙය ආ ro ාත අනුපිළිවෙල ලැයිස්තුවක් ලබා දෙයි.

448strokes=np.split(seq,np.where(seq[:,2]\>0)[0]+1)

#

පහරවල්එක් එක් අනුපිළිවෙල සැලසුම් කරන්න

450forsinstrokes:451plt.plot(s[:,0],-s[:,1])

#

අක්ෂපෙන්වන්න එපා

453plt.axis('off')

#

කුමන්ත්රණයපෙන්වන්න

455plt.show()

#

වින්යාසකිරීම්

මේවාපෙරනිමි වින්යාසයන් වන අතර ඒවා පසුව a පසුකර යාමෙන් සකස් කළ හැකිය dict .

458classConfigs(TrainValidConfigs):

#

අත්හදාබැලීම ක්රියාත්මක කිරීම සඳහා උපාංගය තෝරා ගැනීමට උපාංග වින්යාසයන්

466device:torch.device=DeviceConfigs()

#

468encoder:EncoderRNN469decoder:DecoderRNN470optimizer:optim.Adam471sampler:Sampler472473dataset\_name:str474train\_loader:DataLoader475valid\_loader:DataLoader476train\_dataset:StrokesDataset477valid\_dataset:StrokesDataset

#

එන්කෝඩර්සහ විකේතක ප්රමාණ

480enc\_hidden\_size=256481dec\_hidden\_size=512

#

කණ්ඩායම්ප්රමාණය

484batch\_size=100

#

විශේෂාංගගණන z

487d\_z=128

#

මිශ්රණයේබෙදාහැරීම් ගණන, M

489n\_distributions=20

#

KLඅපසරනය අඞු කිරීමට බර, wKL​

492kl\_div\_loss\_weight=0.5

#

ශ්රේණියේක්ලිපින්

494grad\_clip=1.

#

නියැදීම τ සඳහා උෂ්ණත්වය

496temperature=0.4

#

වඩාදිගු ආ roke ාත අනුපිළිවෙල පෙරහන් කරන්න 200

499max\_seq\_length=200500501epochs=100502503kl\_div\_loss=KLDivLoss()504reconstruction\_loss=ReconstructionLoss()

#

506definit(self):

#

එන්කෝඩරයසහ විකේතකය ආරම්භ කරන්න

508self.encoder=EncoderRNN(self.d\_z,self.enc\_hidden\_size).to(self.device)509self.decoder=DecoderRNN(self.d\_z,self.dec\_hidden\_size,self.n\_distributions).to(self.device)

#

ප්රශස්තකරණයසකසන්න. ප්රශස්තිකරණ වර්ගය සහ ඉගෙනුම් අනුපාතය වැනි දේවල් වින්යාසගත කළ හැකිය

512optimizer=OptimizerConfigs()513optimizer.parameters=list(self.encoder.parameters())+list(self.decoder.parameters())514self.optimizer=optimizer

#

නියැදියසාදන්න

517self.sampler=Sampler(self.encoder,self.decoder)

#

npz ගොනු මාර්ගය වේ data/sketch/[DATASET NAME].npz

520path=lab.get\_data\_path()/'sketch'/f'{self.dataset\_name}.npz'

#

අංකිතගොනුව පූරණය කරන්න

522dataset=np.load(str(path),encoding='latin1',allow\_pickle=True)

#

පුහුණුදත්ත සමුදාය සාදන්න

525self.train\_dataset=StrokesDataset(dataset['train'],self.max\_seq\_length)

#

වලංගුදත්ත සමුදාය සාදන්න

527self.valid\_dataset=StrokesDataset(dataset['valid'],self.max\_seq\_length,self.train\_dataset.scale)

#

පුහුණුදත්ත පැටවුම සාදන්න

530self.train\_loader=DataLoader(self.train\_dataset,self.batch\_size,shuffle=True)

#

වලංගුදත්ත පැටවුම සාදන්න

532self.valid\_loader=DataLoader(self.valid\_dataset,self.batch\_size)

#

ටෙන්සෝර්බෝඩ්හි ස්ථර ප්රතිදානයන් නිරීක්ෂණය කිරීම සඳහා කොකු එක් කරන්න

535hook\_model\_outputs(self.mode,self.encoder,'encoder')536hook\_model\_outputs(self.mode,self.decoder,'decoder')

#

සම්පූර්ණදුම්රිය/වලංගු කිරීමේ අලාභය මුද්රණය කිරීම සඳහා ට්රැකර් සකසන්න

539tracker.set\_scalar("loss.total.\*",True)540541self.state\_modules=[]

#

543defstep(self,batch:Any,batch\_idx:BatchIndex):544self.encoder.train(self.mode.is\_train)545self.decoder.train(self.mode.is\_train)

#

උපාංගය mask වෙත ගෙන data ගොස් අනුක්රමය සහ කණ්ඩායම් මානයන් මාරු කරන්න. data හැඩය ඇති [seq_len, batch_size, 5] අතර හැඩය mask ඇත [seq_len, batch_size] .

550data=batch[0].to(self.device).transpose(0,1)551mask=batch[1].to(self.device).transpose(0,1)

#

පුහුණුමාදිලියේ වර්ධක පියවර

554ifself.mode.is\_train:555tracker.add\_global\_step(len(data))

#

ආro ාත අනුපිළිවෙල කේතනය කරන්න

558withmonit.section("encoder"):

#

ලබාගන්න zμ, සහ σ^

560z,mu,sigma\_hat=self.encoder(data)

#

බෙදාහැරීම්මිශ්රණය විකේතනය කිරීම සහ q^​

563withmonit.section("decoder"):

#

කොන්කැටෙනේට් [(Δx,Δy,p1​,p2​,p3​);z]

565z\_stack=z.unsqueeze(0).expand(data.shape[0]-1,-1,-1)566inputs=torch.cat([data[:-1],z\_stack],2)

#

බෙදාහැරීම්මිශ්රණය ලබා ගන්න q^​

568dist,q\_logits,\_=self.decoder(inputs,z,None)

#

අලාභයගණනය කරන්න

571withmonit.section('loss'):

#

LKL​

573kl\_loss=self.kl\_div\_loss(sigma\_hat,mu)

#

LR​

575reconstruction\_loss=self.reconstruction\_loss(mask,data[1:],dist,q\_logits)

#

Loss=LR​+wKL​LKL​

577loss=reconstruction\_loss+self.kl\_div\_loss\_weight\*kl\_loss

#

පාඩුලුහුබඳින්න

580tracker.add("loss.kl.",kl\_loss)581tracker.add("loss.reconstruction.",reconstruction\_loss)582tracker.add("loss.total.",loss)

#

අපපුහුණු තත්වයේ සිටී නම් පමණි

585ifself.mode.is\_train:

#

ධාවනයප්රශස්තකරණය

587withmonit.section('optimize'):

#

grad ශුන්යයට සකසන්න

589self.optimizer.zero\_grad()

#

අනුක්රමිකගණනය

591loss.backward()

#

ලොග්ආදර්ශ පරාමිතීන් සහ අනුක්රමික

593ifbatch\_idx.is\_last:594tracker.add(encoder=self.encoder,decoder=self.decoder)

#

ක්ලිප්අනුක්රමික

596nn.utils.clip\_grad\_norm\_(self.encoder.parameters(),self.grad\_clip)597nn.utils.clip\_grad\_norm\_(self.decoder.parameters(),self.grad\_clip)

#

ප්රශස්තකරන්න

599self.optimizer.step()600601tracker.save()

#

603defsample(self):

#

අහඹුලෙස වලංගු දත්ත කට්ටලයේ සිට ආකේතකය දක්වා නියැදියක් තෝරන්න

605data,\*\_=self.valid\_dataset[np.random.choice(len(self.valid\_dataset))]

#

කණ්ඩායම්මානයන් එක් කර එය උපාංගයට ගෙන යන්න

607data=data.unsqueeze(1).to(self.device)

#

නියැදිය

609self.sampler.sample(data,self.temperature)

#

612defmain():613configs=Configs()614experiment.create(name="sketch\_rnn")

#

වින්යාසකිරීමේ ශබ්දකෝෂයක් සම්මත කරන්න

617experiment.configs(configs,{618'optimizer.optimizer':'Adam',

#

ප්රතිresults ල වේගයෙන් දැකිය හැකි 1e-3 නිසා අපි ඉගෙනුම් අනුපාතයක් භාවිතා කරමු. කඩදාසි යෝජනා කර 1e-4 ඇත.

621'optimizer.learning\_rate':1e-3,

#

දත්තසමුදාය නම

623'dataset\_name':'bicycle',

#

පුහුණුව, වලංගු කිරීම සහ නියැදීම අතර මාරුවීම සඳහා එපෝච් තුළ අභ්යන්තර පුනරාවර්තන ගණන.

625'inner\_iterations':10626})627628withexperiment.start():

#

අත්හදාබැලීම ක්රියාත්මක කරන්න

630configs.run()631632633if\_\_name\_\_=="\_\_main\_\_":634main()

Trending Research Paperslabml.ai