docs/uncertainty/evidence/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/evidence/ init.py)
This is a PyTorch implementation of the paper Evidential Deep Learning to Quantify Classification Uncertainty.
Dampster-Shafer Theory of Evidence assigns belief masses a set of classes (unlike assigning a probability to a single class). Sum of the masses of all subsets is 1. Individual class probabilities (plausibilities) can be derived from these masses.
Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying "I don't know".
If there are K classes, we assign masses bk≥0 to each of the classes and an overall uncertainty mass u≥0 to all classes.
u+k=1∑Kbk=1
Belief masses bk and u can be computed from evidence ek≥0, as bk=Sek and u=SK where S=∑k=1K(ek+1). Paper uses term evidence as a measure of the amount of support collected from data in favor of a sample to be classified into a certain class.
This corresponds to a Dirichlet distribution with parameters αk=ek+1, and α0=S=∑k=1Kαk is known as the Dirichlet strength. Dirichlet distribution D(p∣α) is a distribution over categorical distribution; i.e. you can sample class probabilities from a Dirichlet distribution. The expected probability for class k is p^k=Sαk.
We get the model to output evidences e=α−1=f(x∣Θ) for a given input x. We use a function such as ReLU or a Softplus at the final layer to get f(x∣Θ)≥0.
The paper proposes a few loss functions to train the model, which we have implemented below.
Here is the training code experiment.py to train a model on MNIST dataset.
52importtorch53fromlabmlimporttracker54fromtorchimportnn
The distribution D(p∣α) is a prior on the likelihood Multi(y∣p), and the negative log marginal likelihood is calculated by integrating over class probabilities p.
If target probabilities (one-hot targets) are yk for a given sample the loss is,
L(Θ)=−log(∫k=1∏KpkykB(α)1k=1∏Kpkαk−1dp)=k=1∑Kyk(logS−logαk)
57classMaximumLikelihoodLoss(nn.Module):
evidence is e≥0 with shape [batch_size, n_classes]target is y with shape [batch_size, n_classes]83defforward(self,evidence:torch.Tensor,target:torch.Tensor):
αk=ek+1
89alpha=evidence+1.
S=∑k=1Kαk
91strength=alpha.sum(dim=-1)
Losses L(Θ)=∑k=1Kyk(logS−logαk)
94loss=(target\*(strength.log()[:,None]-alpha.log())).sum(dim=-1)
Mean loss over the batch
97returnloss.mean()
Bayes risk is the overall maximum cost of making incorrect estimates. It takes a cost function that gives the cost of making an incorrect estimate and sums it over all possible outcomes based on probability distribution.
Here the cost function is cross-entropy loss, for one-hot coded y k=1∑K−yklogpk
We integrate this cost over all p
L(Θ)=−log(∫[k=1∑K−yklogpk]B(α)1k=1∏Kpkαk−1dp)=k=1∑Kyk(ψ(S)−ψ(αk))
where ψ(⋅) is the digamma function.
100classCrossEntropyBayesRisk(nn.Module):
evidence is e≥0 with shape [batch_size, n_classes]target is y with shape [batch_size, n_classes]130defforward(self,evidence:torch.Tensor,target:torch.Tensor):
αk=ek+1
136alpha=evidence+1.
S=∑k=1Kαk
138strength=alpha.sum(dim=-1)
Losses L(Θ)=∑k=1Kyk(ψ(S)−ψ(αk))
141loss=(target\*(torch.digamma(strength)[:,None]-torch.digamma(alpha))).sum(dim=-1)
Mean loss over the batch
144returnloss.mean()
Here the cost function is squared error, k=1∑K(yk−pk)2=∥y−p∥22
We integrate this cost over all p
L(Θ)=−log(∫[k=1∑K(yk−pk)2]B(α)1k=1∏Kpkαk−1dp)=k=1∑KE[yk2−2ykpk+pk2]=k=1∑K(yk2−2ykE[pk]+E[pk2])
Where E[pk]=p^k=Sαk is the expected probability when sampled from the Dirichlet distribution and E[pk2]=E[pk]2+Var(pk) where Var(pk)=S2(S+1)αk(S−αk)=S+1p^k(1−p^k) is the variance.
This gives,
L(Θ)=k=1∑K(yk2−2ykE[pk]+E[pk2])=k=1∑K(yk2−2ykE[pk]+E[pk]2+Var(pk))=k=1∑K((yk−E[pk])2+Var(pk))=k=1∑K((yk−p^k)2+S+1p^k(1−p^k))
This first part of the equation (yk−E[pk])2 is the error term and the second part is the variance.
147classSquaredErrorBayesRisk(nn.Module):
evidence is e≥0 with shape [batch_size, n_classes]target is y with shape [batch_size, n_classes]193defforward(self,evidence:torch.Tensor,target:torch.Tensor):
αk=ek+1
199alpha=evidence+1.
S=∑k=1Kαk
201strength=alpha.sum(dim=-1)
p^k=Sαk
203p=alpha/strength[:,None]
Error (yk−p^k)2
206err=(target-p)\*\*2
Variance Var(pk)=S+1p^k(1−p^k)
208var=p\*(1-p)/(strength[:,None]+1)
Sum of them
211loss=(err+var).sum(dim=-1)
Mean loss over the batch
214returnloss.mean()
This tries to shrink the total evidence to zero if the sample cannot be correctly classified.
First we calculate α~k=yk+(1−yk)αk the Dirichlet parameters after remove the correct evidence.
KL[D(p∣α~)∥∥D(p∣<1,…,1>]=log(Γ(K)∏k=1KΓ(αk)Γ(∑k=1Kαk))+k=1∑K(αk−1)[ψ(αk)−ψ(S~)]
where Γ(⋅) is the gamma function, ψ(⋅) is the digamma function and S~=∑k=1Kα~k
217classKLDivergenceLoss(nn.Module):
evidence is e≥0 with shape [batch_size, n_classes]target is y with shape [batch_size, n_classes]242defforward(self,evidence:torch.Tensor,target:torch.Tensor):
αk=ek+1
248alpha=evidence+1.
Number of classes
250n\_classes=evidence.shape[-1]
Remove non-misleading evidence α~k=yk+(1−yk)αk
253alpha\_tilde=target+(1-target)\*alpha
S~=∑k=1Kα~k
255strength\_tilde=alpha\_tilde.sum(dim=-1)
The first term
log(Γ(K)∏k=1KΓ(αk)Γ(∑k=1Kαk))=logΓ(k=1∑Kαk)−logΓ(K)−k=1∑KlogΓ(αk)
266first=(torch.lgamma(alpha\_tilde.sum(dim=-1))267-torch.lgamma(alpha\_tilde.new\_tensor(float(n\_classes)))268-(torch.lgamma(alpha\_tilde)).sum(dim=-1))
The second term k=1∑K(αk−1)[ψ(αk)−ψ(S~)]
273second=(274(alpha\_tilde-1)\*275(torch.digamma(alpha\_tilde)-torch.digamma(strength\_tilde)[:,None])276).sum(dim=-1)
Sum of the terms
279loss=first+second
Mean loss over the batch
282returnloss.mean()
This module computes statistics and tracks them with labml tracker.
285classTrackStatistics(nn.Module):
294defforward(self,evidence:torch.Tensor,target:torch.Tensor):
Number of classes
296n\_classes=evidence.shape[-1]
Predictions that correctly match with the target (greedy sampling based on highest probability)
298match=evidence.argmax(dim=-1).eq(target.argmax(dim=-1))
Track accuracy
300tracker.add('accuracy.',match.sum()/match.shape[0])
αk=ek+1
303alpha=evidence+1.
S=∑k=1Kαk
305strength=alpha.sum(dim=-1)
p^k=Sαk
308expected\_probability=alpha/strength[:,None]
Expected probability of the selected (greedy highset probability) class
310expected\_probability,\_=expected\_probability.max(dim=-1)
Uncertainty mass u=SK
313uncertainty\_mass=n\_classes/strength
Track u for correctly predictions
316tracker.add('u.succ.',uncertainty\_mass.masked\_select(match))
Track u for incorrect predictions
318tracker.add('u.fail.',uncertainty\_mass.masked\_select(~match))
Track p^k for correctly predictions
320tracker.add('prob.succ.',expected\_probability.masked\_select(match))
Track p^k for incorrect predictions
322tracker.add('prob.fail.',expected\_probability.masked\_select(~match))