init
This commit is contained in:
12
models/SmoothCrossEntropy.py
Normal file
12
models/SmoothCrossEntropy.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
class SmoothCrossEntropy(nn.Module):
|
||||
def __init__(self, epsilon: float = 0.):
|
||||
super(SmoothCrossEntropy, self).__init__()
|
||||
self.epsilon = float(epsilon)
|
||||
|
||||
def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
|
||||
target_probs = torch.full_like(logits, self.epsilon / (logits.shape[1] - 1))
|
||||
target_probs.scatter_(1, labels.unsqueeze(1), 1 - self.epsilon)
|
||||
return F.kl_div(torch.log_softmax(logits, 1), target_probs, reduction='none').sum(1)
|
||||
Reference in New Issue
Block a user