Back to Annotated Deep Learning Paper Implementations

වර්ගීකරණ අවිනිශ්චිතතා අත්හදා බැලීම ගණනය කිරීම සඳහා ගැඹුරු ඉගෙනීම

docs/si/uncertainty/evidence/experiment.html

latest10.8 KB
Original Source

homeuncertaintyevidence

View code on Github

#

වර්ගීකරණ අවිනිශ්චිතතා අත්හදා බැලීම ගණනය කිරීම සඳහා ගැඹුරු ඉගෙනීම

MNISTදත්ත කට්ටලයේ වර්ගීකරණ අවිනිශ්චිතතාව ගණනය කිරීම සඳහා සාක්ෂි ගැඹුරු ඉගෙනීම මත පදනම් වූ ආකෘතියක් මෙය පුහුණු කරයි.

14fromtypingimportAny1516importtorch.nnasnn17importtorch.utils.data1819fromlabmlimporttracker,experiment20fromlabml.configsimportoption,calculate21fromlabml\_helpers.moduleimportModule22fromlabml\_helpers.scheduleimportSchedule,RelativePiecewise23fromlabml\_helpers.train\_validimportBatchIndex24fromlabml\_nn.experiments.mnistimportMNISTConfigs25fromlabml\_nn.uncertainty.evidenceimportKLDivergenceLoss,TrackStatistics,MaximumLikelihoodLoss,\26CrossEntropyBayesRisk,SquaredErrorBayesRisk

#

Lenetපදනම් කරගත් ආකෘතිය සිට MNIST වර්ගීකරණය

29classModel(Module):

#

34def\_\_init\_\_(self,dropout:float):35super().\_\_init\_\_()

#

පළමු 5x5 කැටි ගැසුණු ස්ථරය

37self.conv1=nn.Conv2d(1,20,kernel\_size=5)

#

Reluසක්රිය

39self.act1=nn.ReLU()

#

2x2 උපරිම තටාක

41self.max\_pool1=nn.MaxPool2d(2,2)

#

දෙවන 5x5 කැටි ගැසුණු ස්ථරය

43self.conv2=nn.Conv2d(20,50,kernel\_size=5)

#

Reluසක්රිය

45self.act2=nn.ReLU()

#

2x2 උපරිම තටාක

47self.max\_pool2=nn.MaxPool2d(2,2)

#

500 විශේෂාංග සිතියම් ගත කරන පළමු පූර්ණ සම්බන්ධිත ස්ථරය

49self.fc1=nn.Linear(50\*4\*4,500)

#

Reluසක්රිය

51self.act3=nn.ReLU()

#

10 පන්ති සඳහා නිමැවුම් සාක්ෂි සඳහා අවසාන පූර්ණ සම්බන්ධිත ස්ථරය. Negative ණාත්මක නොවන සාක්ෂි ලබා ගැනීම සඳහා ආකෘතියෙන් පිටත RelU හෝ Softplus සක්රිය කිරීම මේ සඳහා යොදනු ලැබේ

55self.fc2=nn.Linear(500,10)

#

සැඟවුණුස්තරය සඳහා අතහැර දැමීම

57self.dropout=nn.Dropout(p=dropout)

#

  • x හැඩයේ MNIST රූප කාණ්ඩයයි [batch_size, 1, 28, 28]
59def\_\_call\_\_(self,x:torch.Tensor):

#

පළමුකැටි ගැසීම සහ උපරිම තටාක යොදන්න. ප්රති result ලය හැඩය ඇත [batch_size, 20, 12, 12]

65x=self.max\_pool1(self.act1(self.conv1(x)))

#

දෙවනකැටි ගැසීම සහ උපරිම තටාක යොදන්න. ප්රති result ලය හැඩය ඇත [batch_size, 50, 4, 4]

68x=self.max\_pool2(self.act2(self.conv2(x)))

#

ටෙන්සරයහැඩයට සමතලා කරන්න [batch_size, 50 * 4 * 4]

70x=x.view(x.shape[0],-1)

#

සැඟවුණුස්ථරය යොදන්න

72x=self.act3(self.fc1(x))

#

අතහැරදැමීම යොදන්න

74x=self.dropout(x)

#

අවසානස්ථරය යොදන්න සහ ආපසු යන්න

76returnself.fc2(x)

#

වින්යාසකිරීම්

අපි MNISTConfigs වින්යාසයන් භාවිතා කරමු.

79classConfigs(MNISTConfigs):

#

KL අපසරනය විධිමත්

87kl\_div\_loss=KLDivergenceLoss()

#

KLඅපසරනය විධිමත් කිරීමේ සංගුණකය කාලසටහන

89kl\_div\_coef:Schedule

#

KLඅපසරනය විධිමත් කිරීමේ සංගුණකය කාලසටහන

91kl\_div\_coef\_schedule=[(0,0.),(0.2,0.01),(1,1.)]

#

ලුහුබැඳීමසඳහාසංඛ්යාන මොඩියුලය

93stats=TrackStatistics()

#

හැලීම

95dropout:float=0.5

#

ආදර්ශප්රතිදානය ශුන්ය නොවන සාක්ෂි බවට පරිවර්තනය කිරීමේ මොඩියුලය

97outputs\_to\_evidence:Module

#

ආරම්භකකරණය

99definit(self):

#

ට්රැකර්වින්යාසයන් සකසන්න

104tracker.set\_scalar("loss.\*",True)105tracker.set\_scalar("accuracy.\*",True)106tracker.set\_histogram('u.\*',True)107tracker.set\_histogram('prob.\*',False)108tracker.set\_scalar('annealing\_coef.\*',False)109tracker.set\_scalar('kl\_div\_loss.\*',False)

#

112self.state\_modules=[]

#

පුහුණුවහෝ වලංගු කිරීමේ පියවර

114defstep(self,batch:Any,batch\_idx:BatchIndex):

#

පුහුණුව/ඇගයීම්මාදිලිය

120self.model.train(self.mode.is\_train)

#

උපාංගයවෙත දත්ත ගෙනයන්න

123data,target=batch[0].to(self.device),batch[1].to(self.device)

#

එක්-උණුසුම්කේත කරන ලද ඉලක්ක

126eye=torch.eye(10).to(torch.float).to(self.device)127target=eye[target]

#

පුහුණුප්රකාරයේදී ගෝලීය පියවර (සැකසූ සාම්පල ගණන) යාවත්කාලීන කරන්න

130ifself.mode.is\_train:131tracker.add\_global\_step(len(data))

#

ආදර්ශප්රතිදානයන් ලබා ගන්න

134outputs=self.model(data)

#

සාක්ෂිලබා ගන්න ek​≥0

136evidence=self.outputs\_to\_evidence(outputs)

#

අලාභයගණනය කරන්න

139loss=self.loss\_func(evidence,target)

#

KLඅපසරනය විධිමත් කිරීමේ අලාභය ගණනය කරන්න

141kl\_div\_loss=self.kl\_div\_loss(evidence,target)142tracker.add("loss.",loss)143tracker.add("kl\_div\_loss.",kl\_div\_loss)

#

KLඅපසරනය පාඩු සංගුණකය λt​

146annealing\_coef=min(1.,self.kl\_div\_coef(tracker.get\_global\_step()))147tracker.add("annealing\_coef.",annealing\_coef)

#

මුළුඅලාභය

150loss=loss+annealing\_coef\*kl\_div\_loss

#

සංඛ්යාලේඛනනිරීක්ෂණය කරන්න

153self.stats(evidence,target)

#

ආකෘතියපුහුණු කරන්න

156ifself.mode.is\_train:

#

අනුක්රමිකගණනය කරන්න

158loss.backward()

#

ප්රශස්තිකරණපියවර ගන්න

160self.optimizer.step()

#

අනුක්රමිකඉවත්

162self.optimizer.zero\_grad()

#

ලුහුබැඳඇති ප්රමිතික සුරකින්න

165tracker.save()

#

ආකෘතියසාදන්න

168@option(Configs.model)169defmnist\_model(c:Configs):

#

173returnModel(c.dropout).to(c.device)

#

KLඅපසරනය පාඩු සංගුණක උපලේඛනය

176@option(Configs.kl\_div\_coef)177defkl\_div\_coef(c:Configs):

#

සාපේක්ෂ කෑලි කාලසටහනක් සාදන්න

183returnRelativePiecewise(c.kl\_div\_coef\_schedule,c.epochs\*len(c.train\_dataset))

#

උපරිම සම්භාවිතාව නැතිවීම

187calculate(Configs.loss\_func,'max\_likelihood\_loss',lambda:MaximumLikelihoodLoss())

#

කුරුස එන්ට්රොපි බේස් අවදානම්

189calculate(Configs.loss\_func,'cross\_entropy\_bayes\_risk',lambda:CrossEntropyBayesRisk())

#

වර්ග දෝෂ අවදානම් බේස්

191calculate(Configs.loss\_func,'squared\_error\_bayes\_risk',lambda:SquaredErrorBayesRisk())

#

සාක්ෂිගණනය කිරීමට RELU

194calculate(Configs.outputs\_to\_evidence,'relu',lambda:nn.ReLU())

#

සාක්ෂිගණනය කිරීමට සොෆ්ට්ප්ලස්

196calculate(Configs.outputs\_to\_evidence,'softplus',lambda:nn.Softplus())

#

199defmain():

#

අත්හදාබැලීම සාදන්න

201experiment.create(name='evidence\_mnist')

#

වින්යාසයන්සාදන්න

203conf=Configs()

#

වින්යාසයන්පූරණය කරන්න

205experiment.configs(conf,{206'optimizer.optimizer':'Adam',207'optimizer.learning\_rate':0.001,208'optimizer.weight\_decay':0.005,

#

'loss_func':' max_likelihood_loss ',' අහිමි_func ':' cross_entropy_bayes_risk ',

212'loss\_func':'squared\_error\_bayes\_risk',213214'outputs\_to\_evidence':'softplus',215216'dropout':0.5,217})

#

අත්හදාබැලීම ආරම්භ කර පුහුණු ලූපය ක්රියාත්මක කරන්න

219withexperiment.start():220conf.run()

#

224if\_\_name\_\_=='\_\_main\_\_':225main()

Trending Research Paperslabml.ai