12 lines
575 B
Python
12 lines
575 B
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
class SmoothCrossEntropy(nn.Module):
|
|
def __init__(self, epsilon: float = 0.):
|
|
super(SmoothCrossEntropy, self).__init__()
|
|
self.epsilon = float(epsilon)
|
|
|
|
def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
|
|
target_probs = torch.full_like(logits, self.epsilon / (logits.shape[1] - 1))
|
|
target_probs.scatter_(1, labels.unsqueeze(1), 1 - self.epsilon)
|
|
return F.kl_div(torch.log_softmax(logits, 1), target_probs, reduction='none').sum(1) |