Back to Annotated Deep Learning Paper Implementations

Label Smoothing Loss

docs/transformers/label_smoothing_loss.html

latest2.0 KB
Original Source

hometransformers

View code on Github

#

Label Smoothing Loss

11importmatplotlib.pyplotasplt12importnumpyasnp1314importtorch15fromtorchimportnn

#

18classLabelSmoothingLoss(nn.Module):

#

19def\_\_init\_\_(self,size:int,padding\_idx:int,smoothing:float=0.0):20super().\_\_init\_\_()21self.loss=nn.KLDivLoss(reduction='sum')22self.padding\_idx=padding\_idx23self.confidence=1.0-smoothing24self.smoothing=smoothing25self.size=size26self.true\_dist=None

#

28defforward(self,x:torch.Tensor,target:torch.Tensor):29assertx.shape[1]==self.size30true\_dist=x.clone()31true\_dist.fill\_(self.smoothing/(self.size-2))32true\_dist.scatter\_(1,target.unsqueeze(1),self.confidence)33true\_dist[:,self.padding\_idx]=034mask=torch.nonzero(target==self.padding\_idx,as\_tuple=False)35ifmask.dim()\>0:36true\_dist.index\_fill\_(0,mask.squeeze(),0.0)37self.true\_dist=true\_dist38returnself.loss(x,true\_dist.detach())

#

41def\_test\_label\_smoothing():42smooth\_loss=LabelSmoothingLoss(5,0,0.4)43predict=torch.tensor([[0,0.2,0.7,0.1,0],44[0,0.2,0.7,0.1,0],45[0,0.2,0.7,0.1,0]],dtype=torch.float)46\_=smooth\_loss(predict.log(),47torch.tensor([2,1,0],dtype=torch.long))

#

Show the target distributions expected by the system.

50plt.imshow(smooth\_loss.true\_dist)51plt.show()5253smooth\_loss=LabelSmoothingLoss(5,0,0.1)

#

55defloss\_sample(x):56d=x+3\*157predict2=torch.tensor([[0,x/d,1/d,1/d,1/d],58],dtype=torch.float)

#

print(predict)

60returnsmooth\_loss(predict2.log(),61torch.tensor([1],dtype=torch.long)).item()6263plt.plot(np.arange(1,100),[loss\_sample(x)forxinrange(1,100)])64plt.show()656667if\_\_name\_\_=='\_\_main\_\_':68\_test\_label\_smoothing()

labml.ai