Upload to Main
This commit is contained in:
99
deepcore/methods/forgetting.py
Normal file
99
deepcore/methods/forgetting.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from .earlytrain import EarlyTrain
|
||||
import torch, time
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
from datasets.data_manager import select_dm_loader
|
||||
|
||||
# Acknowledgement to
|
||||
# https://github.com/mtoneva/example_forgetting
|
||||
|
||||
class Forgetting(EarlyTrain):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200, specific_model=None, balance=True, #default True
|
||||
dst_test=None, **kwargs):
|
||||
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model=specific_model,
|
||||
dst_test=dst_test,**kwargs)
|
||||
|
||||
self.balance = balance
|
||||
|
||||
def get_hms(self, seconds):
|
||||
# Format time for printing purposes
|
||||
|
||||
m, s = divmod(seconds, 60)
|
||||
h, m = divmod(m, 60)
|
||||
|
||||
return h, m, s
|
||||
|
||||
def before_train(self):
|
||||
self.train_loss = 0.
|
||||
self.correct = 0.
|
||||
self.total = 0.
|
||||
|
||||
def after_loss(self, outputs, loss, targets, batch_inds, epoch):
|
||||
with torch.no_grad():
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
|
||||
cur_acc = (predicted == targets).clone().detach().requires_grad_(False).type(torch.float32)
|
||||
self.forgetting_events[batch_inds.clone().detach()[(self.last_acc[batch_inds]-cur_acc)>0.01]]+=1.
|
||||
self.last_acc[batch_inds] = cur_acc
|
||||
|
||||
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
||||
pass
|
||||
# self.train_loss += loss.item()
|
||||
# self.total += targets.size(0)
|
||||
# _, predicted = torch.max(outputs.data, 1)
|
||||
# self.correct += predicted.eq(targets.data).cpu().sum()
|
||||
#
|
||||
# if batch_idx % self.args.print_freq == 0:
|
||||
# print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%' % (
|
||||
# epoch, self.epochs, batch_idx + 1, (self.n_train // batch_size) + 1, loss.item(),
|
||||
# 100. * self.correct.item() / self.total))
|
||||
|
||||
|
||||
|
||||
def after_epoch(self):
|
||||
pass
|
||||
# epoch_time = time.time() - self.start_time
|
||||
# self.elapsed_time += epoch_time
|
||||
# print('| Elapsed time : %d:%02d:%02d' % (self.get_hms(self.elapsed_time)))
|
||||
|
||||
def before_run(self):
|
||||
self.elapsed_time = 0
|
||||
self.forgetting_events = torch.zeros(self.n_train, requires_grad=False).cuda()
|
||||
self.test_initial_acc()
|
||||
# self.last_acc = torch.zeros(self.n_train, requires_grad=False).cuda()
|
||||
|
||||
def test_initial_acc(self):
|
||||
self.model.no_grad = True
|
||||
self.model.eval()
|
||||
self.last_acc = torch.zeros(self.n_train, requires_grad=False).cuda()
|
||||
|
||||
print('\n=> Testing Initial acc for Forgetting')
|
||||
train_loader = select_dm_loader(self.args, self.dst_train)
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
image, target,batch_inds = batch['img'].cuda(), batch['label'].cuda(), batch['index'].cuda()
|
||||
output = self.model(image, target)
|
||||
predicted = torch.max(output.data, 1).indices
|
||||
|
||||
cur_acc = (predicted == target).clone().detach().requires_grad_(False).type(torch.float32)
|
||||
self.last_acc[batch_inds] = cur_acc
|
||||
|
||||
|
||||
self.model.no_grad = False
|
||||
|
||||
def finish_run(self):
|
||||
pass
|
||||
|
||||
def select(self, **kwargs):
|
||||
self.run()
|
||||
|
||||
if not self.balance:
|
||||
top_examples = self.train_indx[np.argsort(self.forgetting_events.cpu().numpy())][::-1][:self.coreset_size]
|
||||
else:
|
||||
top_examples = np.array([], dtype=np.int64)
|
||||
for c in range(self.num_classes):
|
||||
c_indx = self.train_indx[self.dst_train_label == c]
|
||||
budget = round(self.fraction * len(c_indx))
|
||||
top_examples = np.append(top_examples,
|
||||
c_indx[np.argsort(self.forgetting_events[c_indx].cpu().numpy())[::-1][:budget]])
|
||||
|
||||
return {"indices": top_examples, "scores": self.forgetting_events}
|
||||
Reference in New Issue
Block a user