docs/uncertainty/evidence/experiment.html
This trains a model based on Evidential Deep Learning to Quantify Classification Uncertainty on MNIST dataset.
14fromtypingimportAny1516importtorch.nnasnn17importtorch.utils.data1819fromlabmlimporttracker,experiment20fromlabml.configsimportoption,calculate21fromlabml\_nn.helpers.scheduleimportSchedule,RelativePiecewise22fromlabml\_nn.helpers.trainerimportBatchIndex23fromlabml\_nn.experiments.mnistimportMNISTConfigs24fromlabml\_nn.uncertainty.evidenceimportKLDivergenceLoss,TrackStatistics,MaximumLikelihoodLoss,\25CrossEntropyBayesRisk,SquaredErrorBayesRisk
28classModel(nn.Module):
33def\_\_init\_\_(self,dropout:float):34super().\_\_init\_\_()
First 5x5 convolution layer
36self.conv1=nn.Conv2d(1,20,kernel\_size=5)
ReLU activation
38self.act1=nn.ReLU()
2x2 max-pooling
40self.max\_pool1=nn.MaxPool2d(2,2)
Second 5x5 convolution layer
42self.conv2=nn.Conv2d(20,50,kernel\_size=5)
ReLU activation
44self.act2=nn.ReLU()
2x2 max-pooling
46self.max\_pool2=nn.MaxPool2d(2,2)
First fully-connected layer that maps to 500 features
48self.fc1=nn.Linear(50\*4\*4,500)
ReLU activation
50self.act3=nn.ReLU()
Final fully connected layer to output evidence for 10 classes. The ReLU or Softplus activation is applied to this outside the model to get the non-negative evidence
54self.fc2=nn.Linear(500,10)
Dropout for the hidden layer
56self.dropout=nn.Dropout(p=dropout)
x is the batch of MNIST images of shape [batch_size, 1, 28, 28]58def\_\_call\_\_(self,x:torch.Tensor):
Apply first convolution and max pooling. The result has shape [batch_size, 20, 12, 12]
64x=self.max\_pool1(self.act1(self.conv1(x)))
Apply second convolution and max pooling. The result has shape [batch_size, 50, 4, 4]
67x=self.max\_pool2(self.act2(self.conv2(x)))
Flatten the tensor to shape [batch_size, 50 * 4 * 4]
69x=x.view(x.shape[0],-1)
Apply hidden layer
71x=self.act3(self.fc1(x))
Apply dropout
73x=self.dropout(x)
Apply final layer and return
75returnself.fc2(x)
We use MNISTConfigs configurations.
78classConfigs(MNISTConfigs):
86kl\_div\_loss=KLDivergenceLoss()
KL Divergence regularization coefficient schedule
88kl\_div\_coef:Schedule
KL Divergence regularization coefficient schedule
90kl\_div\_coef\_schedule=[(0,0.),(0.2,0.01),(1,1.)]
Stats module for tracking
92stats=TrackStatistics()
Dropout
94dropout:float=0.5
Module to convert the model output to non-zero evidences
96outputs\_to\_evidence:nn.Module
98definit(self):
Set tracker configurations
103tracker.set\_scalar("loss.\*",True)104tracker.set\_scalar("accuracy.\*",True)105tracker.set\_histogram('u.\*',True)106tracker.set\_histogram('prob.\*',False)107tracker.set\_scalar('annealing\_coef.\*',False)108tracker.set\_scalar('kl\_div\_loss.\*',False)
111self.state\_modules=[]
113defstep(self,batch:Any,batch\_idx:BatchIndex):
Training/Evaluation mode
119self.model.train(self.mode.is\_train)
Move data to the device
122data,target=batch[0].to(self.device),batch[1].to(self.device)
One-hot coded targets
125eye=torch.eye(10).to(torch.float).to(self.device)126target=eye[target]
Update global step (number of samples processed) when in training mode
129ifself.mode.is\_train:130tracker.add\_global\_step(len(data))
Get model outputs
133outputs=self.model(data)
Get evidences ek≥0
135evidence=self.outputs\_to\_evidence(outputs)
Calculate loss
138loss=self.loss\_func(evidence,target)
Calculate KL Divergence regularization loss
140kl\_div\_loss=self.kl\_div\_loss(evidence,target)141tracker.add("loss.",loss)142tracker.add("kl\_div\_loss.",kl\_div\_loss)
KL Divergence loss coefficient λt
145annealing\_coef=min(1.,self.kl\_div\_coef(tracker.get\_global\_step()))146tracker.add("annealing\_coef.",annealing\_coef)
Total loss
149loss=loss+annealing\_coef\*kl\_div\_loss
Track statistics
152self.stats(evidence,target)
Train the model
155ifself.mode.is\_train:
Calculate gradients
157loss.backward()
Take optimizer step
159self.optimizer.step()
Clear the gradients
161self.optimizer.zero\_grad()
Save the tracked metrics
164tracker.save()
167@option(Configs.model)168defmnist\_model(c:Configs):
172returnModel(c.dropout).to(c.device)
175@option(Configs.kl\_div\_coef)176defkl\_div\_coef(c:Configs):
Create a relative piecewise schedule
182returnRelativePiecewise(c.kl\_div\_coef\_schedule,c.epochs\*len(c.train\_dataset))
186calculate(Configs.loss\_func,'max\_likelihood\_loss',lambda:MaximumLikelihoodLoss())
188calculate(Configs.loss\_func,'cross\_entropy\_bayes\_risk',lambda:CrossEntropyBayesRisk())
190calculate(Configs.loss\_func,'squared\_error\_bayes\_risk',lambda:SquaredErrorBayesRisk())
ReLU to calculate evidence
193calculate(Configs.outputs\_to\_evidence,'relu',lambda:nn.ReLU())
Softplus to calculate evidence
195calculate(Configs.outputs\_to\_evidence,'softplus',lambda:nn.Softplus())
198defmain():
Create experiment
200experiment.create(name='evidence\_mnist')
Create configurations
202conf=Configs()
Load configurations
204experiment.configs(conf,{205'optimizer.optimizer':'Adam',206'optimizer.learning\_rate':0.001,207'optimizer.weight\_decay':0.005,
'loss_func': 'max_likelihood_loss', 'loss_func': 'cross_entropy_bayes_risk',
211'loss\_func':'squared\_error\_bayes\_risk',212213'outputs\_to\_evidence':'softplus',214215'dropout':0.5,216})
Start the experiment and run the training loop
218withexperiment.start():219conf.run()
223if\_\_name\_\_=='\_\_main\_\_':224main()