import torch import torch.nn as nn import torch.nn.functional as F class SmoothCrossEntropy(nn.Module): def __init__(self, epsilon: float = 0.): super(SmoothCrossEntropy, self).__init__() self.epsilon = float(epsilon) def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: target_probs = torch.full_like(logits, self.epsilon / (logits.shape[1] - 1)) target_probs.scatter_(1, labels.unsqueeze(1), 1 - self.epsilon) return F.kl_div(torch.log_softmax(logits, 1), target_probs, reduction='none').sum(1)