docs/transformers/label_smoothing_loss.html
11importmatplotlib.pyplotasplt12importnumpyasnp1314importtorch15fromtorchimportnn
18classLabelSmoothingLoss(nn.Module):
19def\_\_init\_\_(self,size:int,padding\_idx:int,smoothing:float=0.0):20super().\_\_init\_\_()21self.loss=nn.KLDivLoss(reduction='sum')22self.padding\_idx=padding\_idx23self.confidence=1.0-smoothing24self.smoothing=smoothing25self.size=size26self.true\_dist=None
28defforward(self,x:torch.Tensor,target:torch.Tensor):29assertx.shape[1]==self.size30true\_dist=x.clone()31true\_dist.fill\_(self.smoothing/(self.size-2))32true\_dist.scatter\_(1,target.unsqueeze(1),self.confidence)33true\_dist[:,self.padding\_idx]=034mask=torch.nonzero(target==self.padding\_idx,as\_tuple=False)35ifmask.dim()\>0:36true\_dist.index\_fill\_(0,mask.squeeze(),0.0)37self.true\_dist=true\_dist38returnself.loss(x,true\_dist.detach())
41def\_test\_label\_smoothing():42smooth\_loss=LabelSmoothingLoss(5,0,0.4)43predict=torch.tensor([[0,0.2,0.7,0.1,0],44[0,0.2,0.7,0.1,0],45[0,0.2,0.7,0.1,0]],dtype=torch.float)46\_=smooth\_loss(predict.log(),47torch.tensor([2,1,0],dtype=torch.long))
Show the target distributions expected by the system.
50plt.imshow(smooth\_loss.true\_dist)51plt.show()5253smooth\_loss=LabelSmoothingLoss(5,0,0.1)
55defloss\_sample(x):56d=x+3\*157predict2=torch.tensor([[0,x/d,1/d,1/d,1/d],58],dtype=torch.float)
print(predict)
60returnsmooth\_loss(predict2.log(),61torch.tensor([1],dtype=torch.long)).item()6263plt.plot(np.arange(1,100),[loss\_sample(x)forxinrange(1,100)])64plt.show()656667if\_\_name\_\_=='\_\_main\_\_':68\_test\_label\_smoothing()