Files
clip-symnets/models/CB_Loss.py
2024-05-21 19:41:56 +08:00

41 lines
2.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn.functional as F
import numpy as np
class CBLoss(torch.nn.Module):
def __init__(self, samples_per_cls, no_of_classes, loss_type, beta=0.9999, gamma=2.0):
super(CBLoss, self).__init__()
self.samples_per_cls = samples_per_cls#samples_per_cls: 一个列表,表示每个类别的样本数量
self.no_of_classes = no_of_classes #no_of_classes: 类别数量
self.loss_type = loss_type #loss_type: 表示损失函数的类型,可以是 softmax、sigmoid 或 focal
self.beta = beta #beta: 参考论文中定义的 beta 参数,默认值是 0.9999
self.gamma = gamma #gamma: 如果损失函数类型是 focal表示 gamma 参数,默认值是 2.0
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def forward(self, logits, labels):
#effective_num = 1.0 - np.power(self.beta, self.samples_per_cls)
#weights = (1.0 - self.beta) / effective_num
#weights = weights / torch.sum(weights) * self.no_of_classes
weights = np.array([1]) * self.no_of_classes
labels_one_hot = F.one_hot(labels, self.no_of_classes).float()
weights = torch.tensor(weights).float()
weights = weights.unsqueeze(0).to(self.device)
labels_one_hot = labels_one_hot.to(self.device)
weights = weights.repeat(labels_one_hot.shape[0], 1) * labels_one_hot
weights = weights.sum(1).unsqueeze(1)
weights = weights.repeat(1, self.no_of_classes)
if self.loss_type == "focal":
ce_loss = F.binary_cross_entropy_with_logits(input=logits, target=labels_one_hot, reduction="none")
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
loss = focal_loss
elif self.loss_type == "sigmoid":
loss = F.binary_cross_entropy_with_logits(input=logits, target=labels_one_hot, weight=weights)
elif self.loss_type == "softmax":
# loss = F.cross_entropy(input=logits, target=labels, weight=weights)
pred = logits.softmax(dim=1)
cb_loss = F.binary_cross_entropy(input=pred, target=labels_one_hot, weight=weights)
return cb_loss