Back to Annotated Deep Learning Paper Implementations

Evidential Deep Learning to Quantify Classification Uncertainty Experiment

docs/uncertainty/evidence/experiment.html

latest7.8 KB
Original Source

homeuncertaintyevidence

View code on Github

#

Evidential Deep Learning to Quantify Classification Uncertainty Experiment

This trains a model based on Evidential Deep Learning to Quantify Classification Uncertainty on MNIST dataset.

14fromtypingimportAny1516importtorch.nnasnn17importtorch.utils.data1819fromlabmlimporttracker,experiment20fromlabml.configsimportoption,calculate21fromlabml\_nn.helpers.scheduleimportSchedule,RelativePiecewise22fromlabml\_nn.helpers.trainerimportBatchIndex23fromlabml\_nn.experiments.mnistimportMNISTConfigs24fromlabml\_nn.uncertainty.evidenceimportKLDivergenceLoss,TrackStatistics,MaximumLikelihoodLoss,\25CrossEntropyBayesRisk,SquaredErrorBayesRisk

#

LeNet based model fro MNIST classification

28classModel(nn.Module):

#

33def\_\_init\_\_(self,dropout:float):34super().\_\_init\_\_()

#

First 5x5 convolution layer

36self.conv1=nn.Conv2d(1,20,kernel\_size=5)

#

ReLU activation

38self.act1=nn.ReLU()

#

2x2 max-pooling

40self.max\_pool1=nn.MaxPool2d(2,2)

#

Second 5x5 convolution layer

42self.conv2=nn.Conv2d(20,50,kernel\_size=5)

#

ReLU activation

44self.act2=nn.ReLU()

#

2x2 max-pooling

46self.max\_pool2=nn.MaxPool2d(2,2)

#

First fully-connected layer that maps to 500 features

48self.fc1=nn.Linear(50\*4\*4,500)

#

ReLU activation

50self.act3=nn.ReLU()

#

Final fully connected layer to output evidence for 10 classes. The ReLU or Softplus activation is applied to this outside the model to get the non-negative evidence

54self.fc2=nn.Linear(500,10)

#

Dropout for the hidden layer

56self.dropout=nn.Dropout(p=dropout)

#

  • x is the batch of MNIST images of shape [batch_size, 1, 28, 28]
58def\_\_call\_\_(self,x:torch.Tensor):

#

Apply first convolution and max pooling. The result has shape [batch_size, 20, 12, 12]

64x=self.max\_pool1(self.act1(self.conv1(x)))

#

Apply second convolution and max pooling. The result has shape [batch_size, 50, 4, 4]

67x=self.max\_pool2(self.act2(self.conv2(x)))

#

Flatten the tensor to shape [batch_size, 50 * 4 * 4]

69x=x.view(x.shape[0],-1)

#

Apply hidden layer

71x=self.act3(self.fc1(x))

#

Apply dropout

73x=self.dropout(x)

#

Apply final layer and return

75returnself.fc2(x)

#

Configurations

We use MNISTConfigs configurations.

78classConfigs(MNISTConfigs):

#

KL Divergence regularization

86kl\_div\_loss=KLDivergenceLoss()

#

KL Divergence regularization coefficient schedule

88kl\_div\_coef:Schedule

#

KL Divergence regularization coefficient schedule

90kl\_div\_coef\_schedule=[(0,0.),(0.2,0.01),(1,1.)]

#

Stats module for tracking

92stats=TrackStatistics()

#

Dropout

94dropout:float=0.5

#

Module to convert the model output to non-zero evidences

96outputs\_to\_evidence:nn.Module

#

Initialization

98definit(self):

#

Set tracker configurations

103tracker.set\_scalar("loss.\*",True)104tracker.set\_scalar("accuracy.\*",True)105tracker.set\_histogram('u.\*',True)106tracker.set\_histogram('prob.\*',False)107tracker.set\_scalar('annealing\_coef.\*',False)108tracker.set\_scalar('kl\_div\_loss.\*',False)

#

111self.state\_modules=[]

#

Training or validation step

113defstep(self,batch:Any,batch\_idx:BatchIndex):

#

Training/Evaluation mode

119self.model.train(self.mode.is\_train)

#

Move data to the device

122data,target=batch[0].to(self.device),batch[1].to(self.device)

#

One-hot coded targets

125eye=torch.eye(10).to(torch.float).to(self.device)126target=eye[target]

#

Update global step (number of samples processed) when in training mode

129ifself.mode.is\_train:130tracker.add\_global\_step(len(data))

#

Get model outputs

133outputs=self.model(data)

#

Get evidences ek​≥0

135evidence=self.outputs\_to\_evidence(outputs)

#

Calculate loss

138loss=self.loss\_func(evidence,target)

#

Calculate KL Divergence regularization loss

140kl\_div\_loss=self.kl\_div\_loss(evidence,target)141tracker.add("loss.",loss)142tracker.add("kl\_div\_loss.",kl\_div\_loss)

#

KL Divergence loss coefficient λt​

145annealing\_coef=min(1.,self.kl\_div\_coef(tracker.get\_global\_step()))146tracker.add("annealing\_coef.",annealing\_coef)

#

Total loss

149loss=loss+annealing\_coef\*kl\_div\_loss

#

Track statistics

152self.stats(evidence,target)

#

Train the model

155ifself.mode.is\_train:

#

Calculate gradients

157loss.backward()

#

Take optimizer step

159self.optimizer.step()

#

Clear the gradients

161self.optimizer.zero\_grad()

#

Save the tracked metrics

164tracker.save()

#

Create model

167@option(Configs.model)168defmnist\_model(c:Configs):

#

172returnModel(c.dropout).to(c.device)

#

KL Divergence Loss Coefficient Schedule

175@option(Configs.kl\_div\_coef)176defkl\_div\_coef(c:Configs):

#

Create a relative piecewise schedule

182returnRelativePiecewise(c.kl\_div\_coef\_schedule,c.epochs\*len(c.train\_dataset))

#

Maximum Likelihood Loss

186calculate(Configs.loss\_func,'max\_likelihood\_loss',lambda:MaximumLikelihoodLoss())

#

Cross Entropy Bayes Risk

188calculate(Configs.loss\_func,'cross\_entropy\_bayes\_risk',lambda:CrossEntropyBayesRisk())

#

Squared Error Bayes Risk

190calculate(Configs.loss\_func,'squared\_error\_bayes\_risk',lambda:SquaredErrorBayesRisk())

#

ReLU to calculate evidence

193calculate(Configs.outputs\_to\_evidence,'relu',lambda:nn.ReLU())

#

Softplus to calculate evidence

195calculate(Configs.outputs\_to\_evidence,'softplus',lambda:nn.Softplus())

#

198defmain():

#

Create experiment

200experiment.create(name='evidence\_mnist')

#

Create configurations

202conf=Configs()

#

Load configurations

204experiment.configs(conf,{205'optimizer.optimizer':'Adam',206'optimizer.learning\_rate':0.001,207'optimizer.weight\_decay':0.005,

#

'loss_func': 'max_likelihood_loss', 'loss_func': 'cross_entropy_bayes_risk',

211'loss\_func':'squared\_error\_bayes\_risk',212213'outputs\_to\_evidence':'softplus',214215'dropout':0.5,216})

#

Start the experiment and run the training loop

218withexperiment.start():219conf.run()

#

223if\_\_name\_\_=='\_\_main\_\_':224main()

labml.ai