100 lines
3.8 KiB
Python
100 lines
3.8 KiB
Python
from .earlytrain import EarlyTrain
|
|
import torch, time
|
|
from torch import nn
|
|
import numpy as np
|
|
from datasets.data_manager import select_dm_loader
|
|
|
|
# Acknowledgement to
|
|
# https://github.com/mtoneva/example_forgetting
|
|
|
|
class Forgetting(EarlyTrain):
|
|
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200, specific_model=None, balance=True, #default True
|
|
dst_test=None, **kwargs):
|
|
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model=specific_model,
|
|
dst_test=dst_test,**kwargs)
|
|
|
|
self.balance = balance
|
|
|
|
def get_hms(self, seconds):
|
|
# Format time for printing purposes
|
|
|
|
m, s = divmod(seconds, 60)
|
|
h, m = divmod(m, 60)
|
|
|
|
return h, m, s
|
|
|
|
def before_train(self):
|
|
self.train_loss = 0.
|
|
self.correct = 0.
|
|
self.total = 0.
|
|
|
|
def after_loss(self, outputs, loss, targets, batch_inds, epoch):
|
|
with torch.no_grad():
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
|
|
cur_acc = (predicted == targets).clone().detach().requires_grad_(False).type(torch.float32)
|
|
self.forgetting_events[batch_inds.clone().detach()[(self.last_acc[batch_inds]-cur_acc)>0.01]]+=1.
|
|
self.last_acc[batch_inds] = cur_acc
|
|
|
|
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
|
pass
|
|
# self.train_loss += loss.item()
|
|
# self.total += targets.size(0)
|
|
# _, predicted = torch.max(outputs.data, 1)
|
|
# self.correct += predicted.eq(targets.data).cpu().sum()
|
|
#
|
|
# if batch_idx % self.args.print_freq == 0:
|
|
# print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%' % (
|
|
# epoch, self.epochs, batch_idx + 1, (self.n_train // batch_size) + 1, loss.item(),
|
|
# 100. * self.correct.item() / self.total))
|
|
|
|
|
|
|
|
def after_epoch(self):
|
|
pass
|
|
# epoch_time = time.time() - self.start_time
|
|
# self.elapsed_time += epoch_time
|
|
# print('| Elapsed time : %d:%02d:%02d' % (self.get_hms(self.elapsed_time)))
|
|
|
|
def before_run(self):
|
|
self.elapsed_time = 0
|
|
self.forgetting_events = torch.zeros(self.n_train, requires_grad=False).cuda()
|
|
self.test_initial_acc()
|
|
# self.last_acc = torch.zeros(self.n_train, requires_grad=False).cuda()
|
|
|
|
def test_initial_acc(self):
|
|
self.model.no_grad = True
|
|
self.model.eval()
|
|
self.last_acc = torch.zeros(self.n_train, requires_grad=False).cuda()
|
|
|
|
print('\n=> Testing Initial acc for Forgetting')
|
|
train_loader = select_dm_loader(self.args, self.dst_train)
|
|
for batch_idx, batch in enumerate(train_loader):
|
|
image, target,batch_inds = batch['img'].cuda(), batch['label'].cuda(), batch['index'].cuda()
|
|
output = self.model(image, target)
|
|
predicted = torch.max(output.data, 1).indices
|
|
|
|
cur_acc = (predicted == target).clone().detach().requires_grad_(False).type(torch.float32)
|
|
self.last_acc[batch_inds] = cur_acc
|
|
|
|
|
|
self.model.no_grad = False
|
|
|
|
def finish_run(self):
|
|
pass
|
|
|
|
def select(self, **kwargs):
|
|
self.run()
|
|
|
|
if not self.balance:
|
|
top_examples = self.train_indx[np.argsort(self.forgetting_events.cpu().numpy())][::-1][:self.coreset_size]
|
|
else:
|
|
top_examples = np.array([], dtype=np.int64)
|
|
for c in range(self.num_classes):
|
|
c_indx = self.train_indx[self.dst_train_label == c]
|
|
budget = round(self.fraction * len(c_indx))
|
|
top_examples = np.append(top_examples,
|
|
c_indx[np.argsort(self.forgetting_events[c_indx].cpu().numpy())[::-1][:budget]])
|
|
|
|
return {"indices": top_examples, "scores": self.forgetting_events}
|