docs/si/capsule_networks/mnist.html
මෙයPyTorch සමඟ MNIST ඉලක්කම් වර්ගීකරණය කිරීම සඳහා විනීත පයිටෝච් කේතයකි.
මෙමලිපිය කඩදාසි විස්තර කර ඇති අත්හදා බැලීම ක්රියාත්මක කරයි ඩයිනමික් රවුටින් කැප්සියුල අතර.
14fromtypingimportAny1516importtorch.nnasnn17importtorch.nn.functionalasF18importtorch.utils.data1920fromlabmlimportexperiment,tracker21fromlabml.configsimportoption22fromlabml\_helpers.datasets.mnistimportMNISTConfigs23fromlabml\_helpers.metrics.accuracyimportAccuracyDirect24fromlabml\_helpers.moduleimportModule25fromlabml\_helpers.train\_validimportSimpleTrainValidConfigs,BatchIndex26fromlabml\_nn.capsule\_networksimportSquash,Router,MarginLoss
29classMNISTCapsuleNetworkModel(Module):
34def\_\_init\_\_(self):35super().\_\_init\_\_()
පළමුකැටි ගැසුණු ස්ථරය 256, 9×9 convolution කර්නල්
37self.conv1=nn.Conv2d(in\_channels=1,out\_channels=256,kernel\_size=9,stride=1)
දෙවනස්ථරය (ප්රාථමික කරල්) s convolutional කරල් 32 නාලිකා (කරලක් අනුව8 ලක්ෂණ) සමග convolutional 8D කරලක් ස්ථරය. එනම්, සෑම ප්රාථමික කැප්සියුලයකම 9 × 9 කර්නලයක් සහ 2 ක ඉරි සහිත සංයුක්ත ඒකක 8 ක් අඩංගු වේ. මෙය ක්රියාත්මක කිරීම සඳහා අපි 32×8 නාලිකා සහිත සංවහන තට්ටුවක් නිර්මාණය කර එක් එක් 8 විශේෂාංග කැප්සියුල ලබා ගැනීම සඳහා එහි ප්රතිදානය නැවත සකස් කර පරිපූර්ණ කරමු.
43self.conv2=nn.Conv2d(in\_channels=256,out\_channels=32\*8,kernel\_size=9,stride=2,padding=0)44self.squash=Squash()
රවුටින්ස්තරය 32×6×6 ප්රාථමික කැප්සියුල ලැබෙන අතර 10 කැප්සියුල නිෂ්පාදනය කරයි. සෑම ප්රාථමික කැප්සියුලයකම 8 විශේෂාංග ඇති අතර ප්රතිදාන කැප්සියුල (ඉලක්කම් කැප්සියුල) 16 විශේෂාංග ඇත. රවුටින් ඇල්ගොරිතම 3 වරක් පුනරාවර්තනය වේ.
50self.digit\_capsules=Router(32\*6\*6,10,8,16,3)
මෙමකඩදාසි සඳහන් විකේතකය වේ. 10 එය ඉලක්කම් කැප්සියුල වල ප්රතිදානයන් ගනී, එක් එක් රූපය ප්රතිනිෂ්පාදනය කිරීම සඳහා 16 විශේෂාංග ඇත. එය ප්රමාණවලින් රේඛීය ස්ථර හරහා 512 සහ ReLU සක්රිය කිරීම් 1024 සමඟ ගමන් කරයි.
55self.decoder=nn.Sequential(56nn.Linear(16\*10,512),57nn.ReLU(),58nn.Linear(512,1024),59nn.ReLU(),60nn.Linear(1024,784),61nn.Sigmoid()62)
data හැඩය සහිත MNIST රූප [batch_size, 1, 28, 28]
64defforward(self,data:torch.Tensor):
පළමුකැටි ගැසුණු ස්තරය හරහා ගමන් කරන්න. මෙම ස්ථරයේ ප්රතිදානය හැඩය ඇත [batch_size, 256, 20, 20]
70x=F.relu(self.conv1(data))
දෙවනකැටි ගැසුණු ස්තරය හරහා ගමන් කරන්න. මෙම ප්රතිදානය හැඩය ඇත [batch_size, 32 * 8, 6, 6] . මෙමස්ථරයට දිගු දිගක් ඇති බව සලකන්න 2.
74x=self.conv2(x)
කැප්සියුලලබා ගැනීම සඳහා ප්රමාණය වෙනස් කර permutate කරන්න
77caps=x.view(x.shape[0],8,32\*6\*6).permute(0,2,1)
කැප්සියුලස්කොෂ් කරන්න
79caps=self.squash(caps)
ඉලක්කම්කැප්සියුල ලබා ගැනීම සඳහා රවුටරය හරහා ඒවා රැගෙන යන්න. මෙය හැඩය ඇත [batch_size, 10, 16] .
82caps=self.digit\_capsules(caps)
ප්රතිනිර්මාණයසඳහා වෙස් මුහුණු ලබා ගන්න
85withtorch.no\_grad():
කැප්සියුලජාලය විසින් පුරෝකථනය කරනු ලබන්නේ දිගම දිග සහිත කැප්සියුලයයි
87pred=(caps\*\*2).sum(-1).argmax(-1)
අනෙක්සියලුම කැප්සියුල වෙස්මුහුණ දීමට වෙස්මුහුණක් සාදන්න
89mask=torch.eye(10,device=data.device)[pred]
අනාවැකියකළ කැප්සියුලය පමණක් ලබා ගැනීම සඳහා ඉලක්කම් කැප්සියුල Mask කර ප්රතිනිර්මාණය ලබා ගැනීම සඳහා විකේතකය හරහා එය රැගෙන යන්න
93reconstructions=self.decoder((caps\*mask[:,:,None]).view(x.shape[0],-1))
රූපමානයන් ගැලපෙන පරිදි ප්රතිනිර්මාණය නැවත සකස් කරන්න
95reconstructions=reconstructions.view(-1,1,28,28)9697returncaps,reconstructions,pred
MNISTදත්ත සහ දුම්රිය සහ වලංගු කිරීමේ සැකසුම සමඟ වින්යාස කිරීම්
100classConfigs(MNISTConfigs,SimpleTrainValidConfigs):
104epochs:int=10105model:nn.Module='capsule\_network\_model'106reconstruction\_loss=nn.MSELoss()107margin\_loss=MarginLoss(n\_labels=10)108accuracy=AccuracyDirect()
110definit(self):
පාඩුසහ නිරවද්යතාව තිරයට මුද්රණය කරන්න
112tracker.set\_scalar('loss.\*',True)113tracker.set\_scalar('accuracy.\*',True)
පුහුණුවසහ වලංගු කිරීම සඳහා එපෝච් සඳහා ඒවා ගණනය කිරීම සඳහා ප්රමිතික සකස් කළ යුතුය
116self.state\_modules=[self.accuracy]
මෙමක්රමය පුහුණුකරු විසින් කැඳවනු ලැබේ
118defstep(self,batch:Any,batch\_idx:BatchIndex):
ආදර්ශප්රකාරය සකසන්න
123self.model.train(self.mode.is\_train)
පින්තූරසහ ලේබල් ලබාගෙන ඒවා ආකෘතියේ උපාංගයට ගෙන යන්න
126data,target=batch[0].to(self.device),batch[1].to(self.device)
පුහුණුමාදිලියේ වර්ධක පියවර
129ifself.mode.is\_train:130tracker.add\_global\_step(len(data))
සක්රියකිරීම් ලොග් කළ යුතුද යන්න
133withself.mode.update(is\_log\_activations=batch\_idx.is\_last):
ආකෘතියධාවනය කරන්න
135caps,reconstructions,pred=self.model(data)
සම්පූර්ණඅලාභය ගණනය කරන්න
138loss=self.margin\_loss(caps,target)+0.0005\*self.reconstruction\_loss(reconstructions,data)139tracker.add("loss.",loss)
ඇමතුම්නිරවද්යතාව මෙට්රික්
142self.accuracy(pred,target)143144ifself.mode.is\_train:145loss.backward()146147self.optimizer.step()
ලොග්පරාමිතීන් සහ අනුක්රමික
149ifbatch\_idx.is\_last:150tracker.add('model',self.model)151self.optimizer.zero\_grad()152153tracker.save()
ආකෘතියසකසන්න
156@option(Configs.model)157defcapsule\_network\_model(c:Configs):
159returnMNISTCapsuleNetworkModel().to(c.device)
අත්හදාබැලීම ක්රියාත්මක කරන්න
162defmain():
166experiment.create(name='capsule\_network\_mnist')167conf=Configs()168experiment.configs(conf,{'optimizer.optimizer':'Adam',169'optimizer.learning\_rate':1e-3})170171experiment.add\_pytorch\_models({'model':conf.model})172173withexperiment.start():174conf.run()175176177if\_\_name\_\_=='\_\_main\_\_':178main()