Upload to Main
This commit is contained in:
120
deepcore/methods/uncertainty.py
Normal file
120
deepcore/methods/uncertainty.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from .earlytrain import EarlyTrain
|
||||
import torch
|
||||
import numpy as np
|
||||
from datasets.data_manager import select_dm_loader
|
||||
import time
|
||||
|
||||
class Uncertainty(EarlyTrain):
|
||||
def __init__(self, dst_train, args,fraction=0.5, random_seed=None, epochs=200, selection_method="Margin",
|
||||
specific_model=None, balance=False, **kwargs):
|
||||
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model, **kwargs)
|
||||
|
||||
selection_choices = ["LeastConfidence",
|
||||
"Entropy",
|
||||
"Margin"]
|
||||
if selection_method not in selection_choices:
|
||||
raise NotImplementedError("Selection algorithm unavailable.")
|
||||
self.selection_method = selection_method
|
||||
|
||||
self.epochs = epochs
|
||||
self.balance = balance
|
||||
|
||||
def before_train(self):
|
||||
pass
|
||||
|
||||
def after_loss(self, outputs, loss, targets, batch_inds, epoch):
|
||||
pass
|
||||
|
||||
|
||||
def after_epoch(self):
|
||||
pass
|
||||
|
||||
def before_run(self):
|
||||
pass
|
||||
|
||||
def num_classes_mismatch(self):
|
||||
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
|
||||
|
||||
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
||||
pass
|
||||
|
||||
def finish_run(self):
|
||||
if self.balance:
|
||||
selection_result = np.array([], dtype=np.int64)
|
||||
scores = []
|
||||
for c in range(self.num_classes):
|
||||
print(f"Balance Processing on the train set class {c}")
|
||||
class_index = np.arange(self.n_train)[self.dst_train_label == c]
|
||||
scores.append(self.rank_uncertainty_clip(class_index))
|
||||
selection_result = np.append(selection_result, class_index[np.argsort(scores[-1])[
|
||||
:round(len(class_index) * self.fraction)]])
|
||||
else:
|
||||
print(f"Imbalance Processing on the train set class")
|
||||
scores = self.rank_uncertainty_clip()
|
||||
selection_result = np.argsort(scores)[::-1][:self.coreset_size]
|
||||
return {"indices": selection_result, "scores": scores}
|
||||
|
||||
def rank_uncertainty(self,index=None):
|
||||
self.specific_model.eval()
|
||||
with torch.no_grad():
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index),
|
||||
batch_size=self.args.selection_batch,
|
||||
num_workers=self.args.workers)
|
||||
|
||||
scores = np.array([])
|
||||
batch_num = len(train_loader)
|
||||
|
||||
for i, (input, _) in enumerate(train_loader):
|
||||
if i % self.args.print_freq == 0:
|
||||
print("| Selecting for batch [%3d/%3d]" % (i + 1, batch_num))
|
||||
if self.selection_method == "LeastConfidence":
|
||||
scores = np.append(scores, self.model(input.to(self.args.device)).max(axis=1).values.cpu().numpy())
|
||||
elif self.selection_method == "Entropy":
|
||||
preds = torch.nn.functional.softmax(self.model(input.to(self.args.device)), dim=1).cpu().numpy()
|
||||
scores = np.append(scores, (np.log(preds + 1e-6) * preds).sum(axis=1))
|
||||
elif self.selection_method == 'Margin':
|
||||
preds = torch.nn.functional.softmax(self.model(input.to(self.args.device)), dim=1)
|
||||
preds_argmax = torch.argmax(preds, dim=1)
|
||||
max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone()
|
||||
preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0
|
||||
preds_sub_argmax = torch.argmax(preds, dim=1)
|
||||
scores = np.append(scores, (max_preds - preds[
|
||||
torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax]).cpu().numpy())
|
||||
return scores
|
||||
|
||||
|
||||
def rank_uncertainty_clip(self,index=None):
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
train_loader = select_dm_loader(self.args,self.dst_train,index)
|
||||
scores = np.array([])
|
||||
|
||||
for i, batch in enumerate(train_loader):
|
||||
# if i % self.args.print_freq == 0:
|
||||
# print("| Selecting for batch [%3d/%3d]" % (i + 1, batch_num))
|
||||
image, label = batch['img'].cuda(), batch['label'].cuda()
|
||||
logits = self.model(image,label) ##Eval mode
|
||||
if self.selection_method == "LeastConfidence":
|
||||
scores = np.append(scores, logits.max(axis=1).values.cpu().numpy())
|
||||
elif self.selection_method == "Entropy":
|
||||
preds = torch.softmax(logits, dim=1).cpu().numpy()
|
||||
scores = np.append(scores, (np.log(preds + 1e-6) * preds).sum(axis=1))
|
||||
elif self.selection_method == 'Margin':
|
||||
preds = torch.softmax(logits, dim=1)
|
||||
preds_argmax = torch.argmax(preds, dim=1)
|
||||
max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone()
|
||||
preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0
|
||||
preds_sub_argmax = torch.argmax(preds, dim=1)
|
||||
scores = np.append(scores, (max_preds - preds[torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax]).cpu().numpy())
|
||||
self.model.train()
|
||||
return scores
|
||||
|
||||
|
||||
def select(self, **kwargs):
|
||||
selection_result = self.run()
|
||||
return selection_result
|
||||
|
||||
def select_without_train(self):
|
||||
selection_result = self.finish_run()
|
||||
return selection_result
|
||||
Reference in New Issue
Block a user