This commit is contained in:
2024-05-21 19:41:56 +08:00
commit ca67205608
217 changed files with 201004 additions and 0 deletions

41
models/CB_Loss.py Normal file
View File

@@ -0,0 +1,41 @@
import torch
import torch.nn.functional as F
import numpy as np
class CBLoss(torch.nn.Module):
def __init__(self, samples_per_cls, no_of_classes, loss_type, beta=0.9999, gamma=2.0):
super(CBLoss, self).__init__()
self.samples_per_cls = samples_per_cls#samples_per_cls: 一个列表,表示每个类别的样本数量
self.no_of_classes = no_of_classes #no_of_classes: 类别数量
self.loss_type = loss_type #loss_type: 表示损失函数的类型,可以是 softmax、sigmoid 或 focal
self.beta = beta #beta: 参考论文中定义的 beta 参数,默认值是 0.9999
self.gamma = gamma #gamma: 如果损失函数类型是 focal表示 gamma 参数,默认值是 2.0
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def forward(self, logits, labels):
#effective_num = 1.0 - np.power(self.beta, self.samples_per_cls)
#weights = (1.0 - self.beta) / effective_num
#weights = weights / torch.sum(weights) * self.no_of_classes
weights = np.array([1]) * self.no_of_classes
labels_one_hot = F.one_hot(labels, self.no_of_classes).float()
weights = torch.tensor(weights).float()
weights = weights.unsqueeze(0).to(self.device)
labels_one_hot = labels_one_hot.to(self.device)
weights = weights.repeat(labels_one_hot.shape[0], 1) * labels_one_hot
weights = weights.sum(1).unsqueeze(1)
weights = weights.repeat(1, self.no_of_classes)
if self.loss_type == "focal":
ce_loss = F.binary_cross_entropy_with_logits(input=logits, target=labels_one_hot, reduction="none")
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
loss = focal_loss
elif self.loss_type == "sigmoid":
loss = F.binary_cross_entropy_with_logits(input=logits, target=labels_one_hot, weight=weights)
elif self.loss_type == "softmax":
# loss = F.cross_entropy(input=logits, target=labels, weight=weights)
pred = logits.softmax(dim=1)
cb_loss = F.binary_cross_entropy(input=pred, target=labels_one_hot, weight=weights)
return cb_loss