Back to Annotated Deep Learning Paper Implementations

Evidential Deep Learning to Quantify Classification Uncertainty

docs/uncertainty/evidence/index.html

latest9.5 KB
Original Source

homeuncertaintyevidence

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/evidence/ init.py)

#

Evidential Deep Learning to Quantify Classification Uncertainty

This is a PyTorch implementation of the paper Evidential Deep Learning to Quantify Classification Uncertainty.

Dampster-Shafer Theory of Evidence assigns belief masses a set of classes (unlike assigning a probability to a single class). Sum of the masses of all subsets is 1. Individual class probabilities (plausibilities) can be derived from these masses.

Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying "I don't know".

If there are K classes, we assign masses bk​≥0 to each of the classes and an overall uncertainty mass u≥0 to all classes.

u+k=1∑K​bk​=1

Belief masses bk​ and u can be computed from evidence ek​≥0, as bk​=Sek​​ and u=SK​ where S=∑k=1K​(ek​+1). Paper uses term evidence as a measure of the amount of support collected from data in favor of a sample to be classified into a certain class.

This corresponds to a Dirichlet distribution with parameters αk​=ek​+1, and α0​=S=∑k=1K​αk​ is known as the Dirichlet strength. Dirichlet distribution D(p∣α) is a distribution over categorical distribution; i.e. you can sample class probabilities from a Dirichlet distribution. The expected probability for class k is p^​k​=Sαk​​.

We get the model to output evidences e=α−1=f(x∣Θ) for a given input x. We use a function such as ReLU or a Softplus at the final layer to get f(x∣Θ)≥0.

The paper proposes a few loss functions to train the model, which we have implemented below.

Here is the training code experiment.py to train a model on MNIST dataset.

52importtorch53fromlabmlimporttracker54fromtorchimportnn

#

Type II Maximum Likelihood Loss

The distribution D(p∣α) is a prior on the likelihood Multi(y∣p), and the negative log marginal likelihood is calculated by integrating over class probabilities p.

If target probabilities (one-hot targets) are yk​ for a given sample the loss is,

L(Θ)​=−log(∫k=1∏K​pkyk​​B(α)1​k=1∏K​pkαk​−1​dp)=k=1∑K​yk​(logS−logαk​)​

57classMaximumLikelihoodLoss(nn.Module):

#

  • evidence is e≥0 with shape [batch_size, n_classes]
  • target is y with shape [batch_size, n_classes]
83defforward(self,evidence:torch.Tensor,target:torch.Tensor):

#

αk​=ek​+1

89alpha=evidence+1.

#

S=∑k=1K​αk​

91strength=alpha.sum(dim=-1)

#

Losses L(Θ)=∑k=1K​yk​(logS−logαk​)

94loss=(target\*(strength.log()[:,None]-alpha.log())).sum(dim=-1)

#

Mean loss over the batch

97returnloss.mean()

#

Bayes Risk with Cross Entropy Loss

Bayes risk is the overall maximum cost of making incorrect estimates. It takes a cost function that gives the cost of making an incorrect estimate and sums it over all possible outcomes based on probability distribution.

Here the cost function is cross-entropy loss, for one-hot coded y k=1∑K​−yk​logpk​

We integrate this cost over all p

L(Θ)​=−log(∫[k=1∑K​−yk​logpk​]B(α)1​k=1∏K​pkαk​−1​dp)=k=1∑K​yk​(ψ(S)−ψ(αk​))​

where ψ(⋅) is the digamma function.

100classCrossEntropyBayesRisk(nn.Module):

#

  • evidence is e≥0 with shape [batch_size, n_classes]
  • target is y with shape [batch_size, n_classes]
130defforward(self,evidence:torch.Tensor,target:torch.Tensor):

#

αk​=ek​+1

136alpha=evidence+1.

#

S=∑k=1K​αk​

138strength=alpha.sum(dim=-1)

#

Losses L(Θ)=∑k=1K​yk​(ψ(S)−ψ(αk​))

141loss=(target\*(torch.digamma(strength)[:,None]-torch.digamma(alpha))).sum(dim=-1)

#

Mean loss over the batch

144returnloss.mean()

#

Bayes Risk with Squared Error Loss

Here the cost function is squared error, k=1∑K​(yk​−pk​)2=∥y−p∥22​

We integrate this cost over all p

L(Θ)​=−log(∫[k=1∑K​(yk​−pk​)2]B(α)1​k=1∏K​pkαk​−1​dp)=k=1∑K​E[yk​2−2yk​pk​+pk2​]=k=1∑K​(yk​2−2yk​E[pk​]+E[pk2​])​

Where E[pk​]=p^​k​=Sαk​​ is the expected probability when sampled from the Dirichlet distribution and E[pk2​]=E[pk​]2+Var(pk​) where Var(pk​)=S2(S+1)αk​(S−αk​)​=S+1p^​k​(1−p^​k​)​ is the variance.

This gives,

L(Θ)​=k=1∑K​(yk​2−2yk​E[pk​]+E[pk2​])=k=1∑K​(yk​2−2yk​E[pk​]+E[pk​]2+Var(pk​))=k=1∑K​((yk​−E[pk​])2+Var(pk​))=k=1∑K​((yk​−p^​k​)2+S+1p^​k​(1−p^​k​)​)​

This first part of the equation (yk​−E[pk​])2 is the error term and the second part is the variance.

147classSquaredErrorBayesRisk(nn.Module):

#

  • evidence is e≥0 with shape [batch_size, n_classes]
  • target is y with shape [batch_size, n_classes]
193defforward(self,evidence:torch.Tensor,target:torch.Tensor):

#

αk​=ek​+1

199alpha=evidence+1.

#

S=∑k=1K​αk​

201strength=alpha.sum(dim=-1)

#

p^​k​=Sαk​​

203p=alpha/strength[:,None]

#

Error (yk​−p^​k​)2

206err=(target-p)\*\*2

#

Variance Var(pk​)=S+1p^​k​(1−p^​k​)​

208var=p\*(1-p)/(strength[:,None]+1)

#

Sum of them

211loss=(err+var).sum(dim=-1)

#

Mean loss over the batch

214returnloss.mean()

#

KL Divergence Regularization Loss

This tries to shrink the total evidence to zero if the sample cannot be correctly classified.

First we calculate α~k​=yk​+(1−yk​)αk​ the Dirichlet parameters after remove the correct evidence.

​KL[D(p∣α~)∥∥​D(p∣<1,…,1>]=log(Γ(K)∏k=1K​Γ(αk​)Γ(∑k=1K​αk​)​)+k=1∑K​(αk​−1)[ψ(αk​)−ψ(S~)]​

where Γ(⋅) is the gamma function, ψ(⋅) is the digamma function and S~=∑k=1K​α~k​

217classKLDivergenceLoss(nn.Module):

#

  • evidence is e≥0 with shape [batch_size, n_classes]
  • target is y with shape [batch_size, n_classes]
242defforward(self,evidence:torch.Tensor,target:torch.Tensor):

#

αk​=ek​+1

248alpha=evidence+1.

#

Number of classes

250n\_classes=evidence.shape[-1]

#

Remove non-misleading evidence α~k​=yk​+(1−yk​)αk​

253alpha\_tilde=target+(1-target)\*alpha

#

S~=∑k=1K​α~k​

255strength\_tilde=alpha\_tilde.sum(dim=-1)

#

The first term

​log(Γ(K)∏k=1K​Γ(αk​)Γ(∑k=1K​αk​)​)=logΓ(k=1∑K​αk​)−logΓ(K)−k=1∑K​logΓ(αk​)​

266first=(torch.lgamma(alpha\_tilde.sum(dim=-1))267-torch.lgamma(alpha\_tilde.new\_tensor(float(n\_classes)))268-(torch.lgamma(alpha\_tilde)).sum(dim=-1))

#

The second term k=1∑K​(αk​−1)[ψ(αk​)−ψ(S~)]

273second=(274(alpha\_tilde-1)\*275(torch.digamma(alpha\_tilde)-torch.digamma(strength\_tilde)[:,None])276).sum(dim=-1)

#

Sum of the terms

279loss=first+second

#

Mean loss over the batch

282returnloss.mean()

#

Track statistics

This module computes statistics and tracks them with labml tracker.

285classTrackStatistics(nn.Module):

#

294defforward(self,evidence:torch.Tensor,target:torch.Tensor):

#

Number of classes

296n\_classes=evidence.shape[-1]

#

Predictions that correctly match with the target (greedy sampling based on highest probability)

298match=evidence.argmax(dim=-1).eq(target.argmax(dim=-1))

#

Track accuracy

300tracker.add('accuracy.',match.sum()/match.shape[0])

#

αk​=ek​+1

303alpha=evidence+1.

#

S=∑k=1K​αk​

305strength=alpha.sum(dim=-1)

#

p^​k​=Sαk​​

308expected\_probability=alpha/strength[:,None]

#

Expected probability of the selected (greedy highset probability) class

310expected\_probability,\_=expected\_probability.max(dim=-1)

#

Uncertainty mass u=SK​

313uncertainty\_mass=n\_classes/strength

#

Track u for correctly predictions

316tracker.add('u.succ.',uncertainty\_mass.masked\_select(match))

#

Track u for incorrect predictions

318tracker.add('u.fail.',uncertainty\_mass.masked\_select(~match))

#

Track p^​k​ for correctly predictions

320tracker.add('prob.succ.',expected\_probability.masked\_select(match))

#

Track p^​k​ for incorrect predictions

322tracker.add('prob.fail.',expected\_probability.masked\_select(~match))

labml.ai