Files
clip-symnets/models/SmoothCrossEntropy.py
2024-05-21 19:41:56 +08:00

12 lines
575 B
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
class SmoothCrossEntropy(nn.Module):
def __init__(self, epsilon: float = 0.):
super(SmoothCrossEntropy, self).__init__()
self.epsilon = float(epsilon)
def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
target_probs = torch.full_like(logits, self.epsilon / (logits.shape[1] - 1))
target_probs.scatter_(1, labels.unsqueeze(1), 1 - self.epsilon)
return F.kl_div(torch.log_softmax(logits, 1), target_probs, reduction='none').sum(1)