121 lines
5.5 KiB
Python
121 lines
5.5 KiB
Python
from .earlytrain import EarlyTrain
|
|
import torch
|
|
import numpy as np
|
|
from datasets.data_manager import select_dm_loader
|
|
import time
|
|
|
|
class Uncertainty(EarlyTrain):
|
|
def __init__(self, dst_train, args,fraction=0.5, random_seed=None, epochs=200, selection_method="Margin",
|
|
specific_model=None, balance=False, **kwargs):
|
|
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model, **kwargs)
|
|
|
|
selection_choices = ["LeastConfidence",
|
|
"Entropy",
|
|
"Margin"]
|
|
if selection_method not in selection_choices:
|
|
raise NotImplementedError("Selection algorithm unavailable.")
|
|
self.selection_method = selection_method
|
|
|
|
self.epochs = epochs
|
|
self.balance = balance
|
|
|
|
def before_train(self):
|
|
pass
|
|
|
|
def after_loss(self, outputs, loss, targets, batch_inds, epoch):
|
|
pass
|
|
|
|
|
|
def after_epoch(self):
|
|
pass
|
|
|
|
def before_run(self):
|
|
pass
|
|
|
|
def num_classes_mismatch(self):
|
|
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
|
|
|
|
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
|
pass
|
|
|
|
def finish_run(self):
|
|
if self.balance:
|
|
selection_result = np.array([], dtype=np.int64)
|
|
scores = []
|
|
for c in range(self.num_classes):
|
|
print(f"Balance Processing on the train set class {c}")
|
|
class_index = np.arange(self.n_train)[self.dst_train_label == c]
|
|
scores.append(self.rank_uncertainty_clip(class_index))
|
|
selection_result = np.append(selection_result, class_index[np.argsort(scores[-1])[
|
|
:round(len(class_index) * self.fraction)]])
|
|
else:
|
|
print(f"Imbalance Processing on the train set class")
|
|
scores = self.rank_uncertainty_clip()
|
|
selection_result = np.argsort(scores)[::-1][:self.coreset_size]
|
|
return {"indices": selection_result, "scores": scores}
|
|
|
|
def rank_uncertainty(self,index=None):
|
|
self.specific_model.eval()
|
|
with torch.no_grad():
|
|
train_loader = torch.utils.data.DataLoader(
|
|
self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index),
|
|
batch_size=self.args.selection_batch,
|
|
num_workers=self.args.workers)
|
|
|
|
scores = np.array([])
|
|
batch_num = len(train_loader)
|
|
|
|
for i, (input, _) in enumerate(train_loader):
|
|
if i % self.args.print_freq == 0:
|
|
print("| Selecting for batch [%3d/%3d]" % (i + 1, batch_num))
|
|
if self.selection_method == "LeastConfidence":
|
|
scores = np.append(scores, self.model(input.to(self.args.device)).max(axis=1).values.cpu().numpy())
|
|
elif self.selection_method == "Entropy":
|
|
preds = torch.nn.functional.softmax(self.model(input.to(self.args.device)), dim=1).cpu().numpy()
|
|
scores = np.append(scores, (np.log(preds + 1e-6) * preds).sum(axis=1))
|
|
elif self.selection_method == 'Margin':
|
|
preds = torch.nn.functional.softmax(self.model(input.to(self.args.device)), dim=1)
|
|
preds_argmax = torch.argmax(preds, dim=1)
|
|
max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone()
|
|
preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0
|
|
preds_sub_argmax = torch.argmax(preds, dim=1)
|
|
scores = np.append(scores, (max_preds - preds[
|
|
torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax]).cpu().numpy())
|
|
return scores
|
|
|
|
|
|
def rank_uncertainty_clip(self,index=None):
|
|
self.model.eval()
|
|
with torch.no_grad():
|
|
train_loader = select_dm_loader(self.args,self.dst_train,index)
|
|
scores = np.array([])
|
|
|
|
for i, batch in enumerate(train_loader):
|
|
# if i % self.args.print_freq == 0:
|
|
# print("| Selecting for batch [%3d/%3d]" % (i + 1, batch_num))
|
|
image, label = batch['img'].cuda(), batch['label'].cuda()
|
|
logits = self.model(image,label) ##Eval mode
|
|
if self.selection_method == "LeastConfidence":
|
|
scores = np.append(scores, logits.max(axis=1).values.cpu().numpy())
|
|
elif self.selection_method == "Entropy":
|
|
preds = torch.softmax(logits, dim=1).cpu().numpy()
|
|
scores = np.append(scores, (np.log(preds + 1e-6) * preds).sum(axis=1))
|
|
elif self.selection_method == 'Margin':
|
|
preds = torch.softmax(logits, dim=1)
|
|
preds_argmax = torch.argmax(preds, dim=1)
|
|
max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone()
|
|
preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0
|
|
preds_sub_argmax = torch.argmax(preds, dim=1)
|
|
scores = np.append(scores, (max_preds - preds[torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax]).cpu().numpy())
|
|
self.model.train()
|
|
return scores
|
|
|
|
|
|
def select(self, **kwargs):
|
|
selection_result = self.run()
|
|
return selection_result
|
|
|
|
def select_without_train(self):
|
|
selection_result = self.finish_run()
|
|
return selection_result
|