init
This commit is contained in:
41
models/CB_Loss.py
Normal file
41
models/CB_Loss.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
class CBLoss(torch.nn.Module):
|
||||
def __init__(self, samples_per_cls, no_of_classes, loss_type, beta=0.9999, gamma=2.0):
|
||||
super(CBLoss, self).__init__()
|
||||
self.samples_per_cls = samples_per_cls#samples_per_cls: 一个列表,表示每个类别的样本数量
|
||||
self.no_of_classes = no_of_classes #no_of_classes: 类别数量
|
||||
self.loss_type = loss_type #loss_type: 表示损失函数的类型,可以是 softmax、sigmoid 或 focal
|
||||
self.beta = beta #beta: 参考论文中定义的 beta 参数,默认值是 0.9999
|
||||
self.gamma = gamma #gamma: 如果损失函数类型是 focal,表示 gamma 参数,默认值是 2.0
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def forward(self, logits, labels):
|
||||
#effective_num = 1.0 - np.power(self.beta, self.samples_per_cls)
|
||||
#weights = (1.0 - self.beta) / effective_num
|
||||
#weights = weights / torch.sum(weights) * self.no_of_classes
|
||||
weights = np.array([1]) * self.no_of_classes
|
||||
labels_one_hot = F.one_hot(labels, self.no_of_classes).float()
|
||||
weights = torch.tensor(weights).float()
|
||||
|
||||
weights = weights.unsqueeze(0).to(self.device)
|
||||
labels_one_hot = labels_one_hot.to(self.device)
|
||||
|
||||
weights = weights.repeat(labels_one_hot.shape[0], 1) * labels_one_hot
|
||||
weights = weights.sum(1).unsqueeze(1)
|
||||
weights = weights.repeat(1, self.no_of_classes)
|
||||
|
||||
if self.loss_type == "focal":
|
||||
ce_loss = F.binary_cross_entropy_with_logits(input=logits, target=labels_one_hot, reduction="none")
|
||||
pt = torch.exp(-ce_loss)
|
||||
focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
|
||||
loss = focal_loss
|
||||
elif self.loss_type == "sigmoid":
|
||||
loss = F.binary_cross_entropy_with_logits(input=logits, target=labels_one_hot, weight=weights)
|
||||
elif self.loss_type == "softmax":
|
||||
# loss = F.cross_entropy(input=logits, target=labels, weight=weights)
|
||||
pred = logits.softmax(dim=1)
|
||||
cb_loss = F.binary_cross_entropy(input=pred, target=labels_one_hot, weight=weights)
|
||||
|
||||
return cb_loss
|
||||
Reference in New Issue
Block a user