41 lines
2.2 KiB
Python
41 lines
2.2 KiB
Python
import torch
|
||
import torch.nn.functional as F
|
||
import numpy as np
|
||
class CBLoss(torch.nn.Module):
|
||
def __init__(self, samples_per_cls, no_of_classes, loss_type, beta=0.9999, gamma=2.0):
|
||
super(CBLoss, self).__init__()
|
||
self.samples_per_cls = samples_per_cls#samples_per_cls: 一个列表,表示每个类别的样本数量
|
||
self.no_of_classes = no_of_classes #no_of_classes: 类别数量
|
||
self.loss_type = loss_type #loss_type: 表示损失函数的类型,可以是 softmax、sigmoid 或 focal
|
||
self.beta = beta #beta: 参考论文中定义的 beta 参数,默认值是 0.9999
|
||
self.gamma = gamma #gamma: 如果损失函数类型是 focal,表示 gamma 参数,默认值是 2.0
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
def forward(self, logits, labels):
|
||
#effective_num = 1.0 - np.power(self.beta, self.samples_per_cls)
|
||
#weights = (1.0 - self.beta) / effective_num
|
||
#weights = weights / torch.sum(weights) * self.no_of_classes
|
||
weights = np.array([1]) * self.no_of_classes
|
||
labels_one_hot = F.one_hot(labels, self.no_of_classes).float()
|
||
weights = torch.tensor(weights).float()
|
||
|
||
weights = weights.unsqueeze(0).to(self.device)
|
||
labels_one_hot = labels_one_hot.to(self.device)
|
||
|
||
weights = weights.repeat(labels_one_hot.shape[0], 1) * labels_one_hot
|
||
weights = weights.sum(1).unsqueeze(1)
|
||
weights = weights.repeat(1, self.no_of_classes)
|
||
|
||
if self.loss_type == "focal":
|
||
ce_loss = F.binary_cross_entropy_with_logits(input=logits, target=labels_one_hot, reduction="none")
|
||
pt = torch.exp(-ce_loss)
|
||
focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
|
||
loss = focal_loss
|
||
elif self.loss_type == "sigmoid":
|
||
loss = F.binary_cross_entropy_with_logits(input=logits, target=labels_one_hot, weight=weights)
|
||
elif self.loss_type == "softmax":
|
||
# loss = F.cross_entropy(input=logits, target=labels, weight=weights)
|
||
pred = logits.softmax(dim=1)
|
||
cb_loss = F.binary_cross_entropy(input=pred, target=labels_one_hot, weight=weights)
|
||
|
||
return cb_loss |