This commit is contained in:
2024-05-21 19:41:56 +08:00
commit ca67205608
217 changed files with 201004 additions and 0 deletions

View File

@@ -0,0 +1,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class SmoothCrossEntropy(nn.Module):
def __init__(self, epsilon: float = 0.):
super(SmoothCrossEntropy, self).__init__()
self.epsilon = float(epsilon)
def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
target_probs = torch.full_like(logits, self.epsilon / (logits.shape[1] - 1))
target_probs.scatter_(1, labels.unsqueeze(1), 1 - self.epsilon)
return F.kl_div(torch.log_softmax(logits, 1), target_probs, reduction='none').sum(1)