Back to Annotated Deep Learning Paper Implementations

metrics.py

docs/helpers/metrics.html

latest2.1 KB
Original Source

homehelpers

View code on Github

#

1importdataclasses2fromabcimportABC34importtorch5fromlabmlimporttracker

#

8classStateModule:

#

9def\_\_init\_\_(self):10pass

#

def __call__(self): raise NotImplementedError

#

15defcreate\_state(self)-\>any:16raiseNotImplementedError

#

18defset\_state(self,data:any):19raiseNotImplementedError

#

21defon\_epoch\_start(self):22raiseNotImplementedError

#

24defon\_epoch\_end(self):25raiseNotImplementedError

#

28classMetric(StateModule,ABC):

#

29deftrack(self):30pass

#

[email protected]:35samples:int=036correct:int=03738defreset(self):39self.samples=040self.correct=0414243classAccuracy(Metric):44data:AccuracyState4546def\_\_init\_\_(self,ignore\_index:int=-1):47super().\_\_init\_\_()48self.ignore\_index=ignore\_index4950def\_\_call\_\_(self,output:torch.Tensor,target:torch.Tensor):51output=output.view(-1,output.shape[-1])52target=target.view(-1)53pred=output.argmax(dim=-1)54mask=target==self.ignore\_index55pred.masked\_fill\_(mask,self.ignore\_index)56n\_masked=mask.sum().item()57self.data.correct+=pred.eq(target).sum().item()-n\_masked58self.data.samples+=len(target)-n\_masked5960defcreate\_state(self):61returnAccuracyState()6263defset\_state(self,data:any):64self.data=data6566defon\_epoch\_start(self):67self.data.reset()6869defon\_epoch\_end(self):70self.track()7172deftrack(self):73ifself.data.samples==0:74return75tracker.add("accuracy.",self.data.correct/self.data.samples)767778classAccuracyDirect(Accuracy):79data:AccuracyState8081def\_\_call\_\_(self,output:torch.Tensor,target:torch.Tensor):82output=output.view(-1)83target=target.view(-1)84self.data.correct+=output.eq(target).sum().item()85self.data.samples+=len(target)

labml.ai