Files
DAPT/deepcore/methods/deepfool.py
2025-10-07 22:42:55 +08:00

121 lines
5.3 KiB
Python

from .earlytrain import EarlyTrain
import torch
import numpy as np
class DeepFool(EarlyTrain):
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200,
specific_model=None, balance: bool = False, max_iter: int = 50, **kwargs):
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model, **kwargs)
self.balance = balance
self.max_iter = max_iter
def num_classes_mismatch(self):
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
if batch_idx % self.args.print_freq == 0:
print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f' % (
epoch, self.epochs, batch_idx + 1, (self.n_pretrain_size // batch_size) + 1, loss.item()))
def finish_run(self):
self.model.no_grad = False
# Create a data loader for self.dst_train with batch size self.args.selection_batch
batch_loader = torch.utils.data.DataLoader(self.dst_train, batch_size=self.args.selection_batch
, num_workers=self.args.workers)
r = np.zeros(self.n_train, dtype=np.float32)
batch_num = len(batch_loader)
for i, (inputs, targets) in enumerate(batch_loader):
if i % self.args.print_freq == 0:
print('| Selecting Batch [%3d/%3d]' % (i + 1, batch_num))
r[(i * self.args.selection_batch):(i * self.args.selection_batch + targets.shape[0])] = self.deep_fool(
inputs)
if self.balance:
selection_result = np.array([], dtype=np.int64)
for c in range(self.args.num_classes):
class_index = np.arange(self.n_train)[self.dst_train.targets == c]
selection_result = np.append(selection_result, class_index[
r[class_index].argsort()[:round(len(class_index) * self.fraction)]])
else:
selection_result = r.argsort()[:self.coreset_size]
return {"indices": selection_result, "scores": r}
def deep_fool(self, inputs):
# Here, start running DeepFool algorithm.
self.model.eval()
# Initialize a boolean mask indicating if selection has been stopped at corresponding positions.
sample_size = inputs.shape[0]
boolean_mask = np.ones(sample_size, dtype=bool)
all_idx = np.arange(sample_size)
# A matrix to store total pertubations.
r_tot = np.zeros([sample_size, inputs.shape[1] * inputs.shape[2] * inputs.shape[3]])
# Set requires_grad for inputs.
cur_inputs = inputs.requires_grad_(True).to(self.args.device)
original_shape = inputs.shape[1:]
# set requires_grad for all parametres in network as False to accelerate autograd
for p in self.model.parameters():
p.requires_grad_(False)
self.model.no_grad = True
first_preds = self.model(cur_inputs).argmax(dim=1)
self.model.no_grad = False
for i in range(self.max_iter):
f_all = self.model(cur_inputs)
w_k = []
for c in range(self.args.num_classes):
w_k.append(torch.autograd.grad(f_all[:, c].sum(), cur_inputs,
retain_graph=False if c + 1 == self.args.num_classes else True)[
0].flatten(1))
w_k = torch.stack(w_k, dim=0)
w_k = w_k - w_k[first_preds, boolean_mask[boolean_mask]].unsqueeze(0)
w_k_norm = w_k.norm(dim=2)
w_k_norm[first_preds, boolean_mask[
boolean_mask]] = 1. # Set w_k_norm for preds positions to 1. to avoid division by zero.
l_all = (f_all - f_all[boolean_mask[boolean_mask], first_preds].unsqueeze(1)).detach().abs() / w_k_norm.T
l_all[boolean_mask[
boolean_mask], first_preds] = np.inf # Set l_k for preds positions to inf, as the argmin for each
# row will be calculated soon.
l_hat = l_all.argmin(dim=1)
r_i = l_all[boolean_mask[boolean_mask], l_hat].unsqueeze(1) / w_k_norm[
l_hat, boolean_mask[boolean_mask]].T.unsqueeze(1) * w_k[l_hat, boolean_mask[boolean_mask]]
# Update r_tot values.
r_tot[boolean_mask] += r_i.cpu().numpy()
cur_inputs += r_i.reshape([r_i.shape[0]] + list(original_shape))
# Re-input the updated sample into the network and get new predictions.
self.model.no_grad = True
preds = self.model(cur_inputs).argmax(dim=1)
self.model.no_grad = False
# In DeepFool algorithm, the iteration stops when the updated sample produces a different prediction
# in the model.
index_unfinished = (preds == first_preds)
if torch.all(~index_unfinished):
break
cur_inputs = cur_inputs[index_unfinished]
first_preds = first_preds[index_unfinished]
boolean_mask[all_idx[boolean_mask][~index_unfinished.cpu().numpy()]] = False
return (r_tot * r_tot).sum(axis=1)
def select(self, **kwargs):
selection_result = self.run()
return selection_result