import torch import torch.nn.functional as F import numpy as np class CBLoss(torch.nn.Module): def __init__(self, samples_per_cls, no_of_classes, loss_type, beta=0.9999, gamma=2.0): super(CBLoss, self).__init__() self.samples_per_cls = samples_per_cls#samples_per_cls: 一个列表,表示每个类别的样本数量 self.no_of_classes = no_of_classes #no_of_classes: 类别数量 self.loss_type = loss_type #loss_type: 表示损失函数的类型,可以是 softmax、sigmoid 或 focal self.beta = beta #beta: 参考论文中定义的 beta 参数,默认值是 0.9999 self.gamma = gamma #gamma: 如果损失函数类型是 focal,表示 gamma 参数,默认值是 2.0 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def forward(self, logits, labels): #effective_num = 1.0 - np.power(self.beta, self.samples_per_cls) #weights = (1.0 - self.beta) / effective_num #weights = weights / torch.sum(weights) * self.no_of_classes weights = np.array([1]) * self.no_of_classes labels_one_hot = F.one_hot(labels, self.no_of_classes).float() weights = torch.tensor(weights).float() weights = weights.unsqueeze(0).to(self.device) labels_one_hot = labels_one_hot.to(self.device) weights = weights.repeat(labels_one_hot.shape[0], 1) * labels_one_hot weights = weights.sum(1).unsqueeze(1) weights = weights.repeat(1, self.no_of_classes) if self.loss_type == "focal": ce_loss = F.binary_cross_entropy_with_logits(input=logits, target=labels_one_hot, reduction="none") pt = torch.exp(-ce_loss) focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean() loss = focal_loss elif self.loss_type == "sigmoid": loss = F.binary_cross_entropy_with_logits(input=logits, target=labels_one_hot, weight=weights) elif self.loss_type == "softmax": # loss = F.cross_entropy(input=logits, target=labels, weight=weights) pred = logits.softmax(dim=1) cb_loss = F.binary_cross_entropy(input=pred, target=labels_one_hot, weight=weights) return cb_loss