docs/helpers/metrics.html
1importdataclasses2fromabcimportABC34importtorch5fromlabmlimporttracker
8classStateModule:
9def\_\_init\_\_(self):10pass
def __call__(self): raise NotImplementedError
15defcreate\_state(self)-\>any:16raiseNotImplementedError
18defset\_state(self,data:any):19raiseNotImplementedError
21defon\_epoch\_start(self):22raiseNotImplementedError
24defon\_epoch\_end(self):25raiseNotImplementedError
28classMetric(StateModule,ABC):
29deftrack(self):30pass
[email protected]:35samples:int=036correct:int=03738defreset(self):39self.samples=040self.correct=0414243classAccuracy(Metric):44data:AccuracyState4546def\_\_init\_\_(self,ignore\_index:int=-1):47super().\_\_init\_\_()48self.ignore\_index=ignore\_index4950def\_\_call\_\_(self,output:torch.Tensor,target:torch.Tensor):51output=output.view(-1,output.shape[-1])52target=target.view(-1)53pred=output.argmax(dim=-1)54mask=target==self.ignore\_index55pred.masked\_fill\_(mask,self.ignore\_index)56n\_masked=mask.sum().item()57self.data.correct+=pred.eq(target).sum().item()-n\_masked58self.data.samples+=len(target)-n\_masked5960defcreate\_state(self):61returnAccuracyState()6263defset\_state(self,data:any):64self.data=data6566defon\_epoch\_start(self):67self.data.reset()6869defon\_epoch\_end(self):70self.track()7172deftrack(self):73ifself.data.samples==0:74return75tracker.add("accuracy.",self.data.correct/self.data.samples)767778classAccuracyDirect(Accuracy):79data:AccuracyState8081def\_\_call\_\_(self,output:torch.Tensor,target:torch.Tensor):82output=output.view(-1)83target=target.view(-1)84self.data.correct+=output.eq(target).sum().item()85self.data.samples+=len(target)