Files
DAPT/deepcore/methods/forgetting.py
2025-10-07 22:42:55 +08:00

100 lines
3.8 KiB
Python

from .earlytrain import EarlyTrain
import torch, time
from torch import nn
import numpy as np
from datasets.data_manager import select_dm_loader
# Acknowledgement to
# https://github.com/mtoneva/example_forgetting
class Forgetting(EarlyTrain):
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200, specific_model=None, balance=True, #default True
dst_test=None, **kwargs):
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model=specific_model,
dst_test=dst_test,**kwargs)
self.balance = balance
def get_hms(self, seconds):
# Format time for printing purposes
m, s = divmod(seconds, 60)
h, m = divmod(m, 60)
return h, m, s
def before_train(self):
self.train_loss = 0.
self.correct = 0.
self.total = 0.
def after_loss(self, outputs, loss, targets, batch_inds, epoch):
with torch.no_grad():
_, predicted = torch.max(outputs.data, 1)
cur_acc = (predicted == targets).clone().detach().requires_grad_(False).type(torch.float32)
self.forgetting_events[batch_inds.clone().detach()[(self.last_acc[batch_inds]-cur_acc)>0.01]]+=1.
self.last_acc[batch_inds] = cur_acc
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
pass
# self.train_loss += loss.item()
# self.total += targets.size(0)
# _, predicted = torch.max(outputs.data, 1)
# self.correct += predicted.eq(targets.data).cpu().sum()
#
# if batch_idx % self.args.print_freq == 0:
# print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%' % (
# epoch, self.epochs, batch_idx + 1, (self.n_train // batch_size) + 1, loss.item(),
# 100. * self.correct.item() / self.total))
def after_epoch(self):
pass
# epoch_time = time.time() - self.start_time
# self.elapsed_time += epoch_time
# print('| Elapsed time : %d:%02d:%02d' % (self.get_hms(self.elapsed_time)))
def before_run(self):
self.elapsed_time = 0
self.forgetting_events = torch.zeros(self.n_train, requires_grad=False).cuda()
self.test_initial_acc()
# self.last_acc = torch.zeros(self.n_train, requires_grad=False).cuda()
def test_initial_acc(self):
self.model.no_grad = True
self.model.eval()
self.last_acc = torch.zeros(self.n_train, requires_grad=False).cuda()
print('\n=> Testing Initial acc for Forgetting')
train_loader = select_dm_loader(self.args, self.dst_train)
for batch_idx, batch in enumerate(train_loader):
image, target,batch_inds = batch['img'].cuda(), batch['label'].cuda(), batch['index'].cuda()
output = self.model(image, target)
predicted = torch.max(output.data, 1).indices
cur_acc = (predicted == target).clone().detach().requires_grad_(False).type(torch.float32)
self.last_acc[batch_inds] = cur_acc
self.model.no_grad = False
def finish_run(self):
pass
def select(self, **kwargs):
self.run()
if not self.balance:
top_examples = self.train_indx[np.argsort(self.forgetting_events.cpu().numpy())][::-1][:self.coreset_size]
else:
top_examples = np.array([], dtype=np.int64)
for c in range(self.num_classes):
c_indx = self.train_indx[self.dst_train_label == c]
budget = round(self.fraction * len(c_indx))
top_examples = np.append(top_examples,
c_indx[np.argsort(self.forgetting_events[c_indx].cpu().numpy())[::-1][:budget]])
return {"indices": top_examples, "scores": self.forgetting_events}