121 lines
5.3 KiB
Python
121 lines
5.3 KiB
Python
from .earlytrain import EarlyTrain
|
|
import torch
|
|
import numpy as np
|
|
|
|
|
|
class DeepFool(EarlyTrain):
|
|
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200,
|
|
specific_model=None, balance: bool = False, max_iter: int = 50, **kwargs):
|
|
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model, **kwargs)
|
|
|
|
self.balance = balance
|
|
self.max_iter = max_iter
|
|
|
|
def num_classes_mismatch(self):
|
|
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
|
|
|
|
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
|
if batch_idx % self.args.print_freq == 0:
|
|
print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f' % (
|
|
epoch, self.epochs, batch_idx + 1, (self.n_pretrain_size // batch_size) + 1, loss.item()))
|
|
|
|
def finish_run(self):
|
|
self.model.no_grad = False
|
|
|
|
# Create a data loader for self.dst_train with batch size self.args.selection_batch
|
|
batch_loader = torch.utils.data.DataLoader(self.dst_train, batch_size=self.args.selection_batch
|
|
, num_workers=self.args.workers)
|
|
|
|
r = np.zeros(self.n_train, dtype=np.float32)
|
|
batch_num = len(batch_loader)
|
|
for i, (inputs, targets) in enumerate(batch_loader):
|
|
if i % self.args.print_freq == 0:
|
|
print('| Selecting Batch [%3d/%3d]' % (i + 1, batch_num))
|
|
r[(i * self.args.selection_batch):(i * self.args.selection_batch + targets.shape[0])] = self.deep_fool(
|
|
inputs)
|
|
|
|
if self.balance:
|
|
selection_result = np.array([], dtype=np.int64)
|
|
for c in range(self.args.num_classes):
|
|
class_index = np.arange(self.n_train)[self.dst_train.targets == c]
|
|
selection_result = np.append(selection_result, class_index[
|
|
r[class_index].argsort()[:round(len(class_index) * self.fraction)]])
|
|
else:
|
|
selection_result = r.argsort()[:self.coreset_size]
|
|
return {"indices": selection_result, "scores": r}
|
|
|
|
def deep_fool(self, inputs):
|
|
# Here, start running DeepFool algorithm.
|
|
self.model.eval()
|
|
|
|
# Initialize a boolean mask indicating if selection has been stopped at corresponding positions.
|
|
sample_size = inputs.shape[0]
|
|
boolean_mask = np.ones(sample_size, dtype=bool)
|
|
all_idx = np.arange(sample_size)
|
|
|
|
# A matrix to store total pertubations.
|
|
r_tot = np.zeros([sample_size, inputs.shape[1] * inputs.shape[2] * inputs.shape[3]])
|
|
|
|
# Set requires_grad for inputs.
|
|
cur_inputs = inputs.requires_grad_(True).to(self.args.device)
|
|
|
|
original_shape = inputs.shape[1:]
|
|
|
|
# set requires_grad for all parametres in network as False to accelerate autograd
|
|
for p in self.model.parameters():
|
|
p.requires_grad_(False)
|
|
|
|
self.model.no_grad = True
|
|
first_preds = self.model(cur_inputs).argmax(dim=1)
|
|
self.model.no_grad = False
|
|
|
|
for i in range(self.max_iter):
|
|
f_all = self.model(cur_inputs)
|
|
|
|
w_k = []
|
|
for c in range(self.args.num_classes):
|
|
w_k.append(torch.autograd.grad(f_all[:, c].sum(), cur_inputs,
|
|
retain_graph=False if c + 1 == self.args.num_classes else True)[
|
|
0].flatten(1))
|
|
w_k = torch.stack(w_k, dim=0)
|
|
w_k = w_k - w_k[first_preds, boolean_mask[boolean_mask]].unsqueeze(0)
|
|
w_k_norm = w_k.norm(dim=2)
|
|
|
|
w_k_norm[first_preds, boolean_mask[
|
|
boolean_mask]] = 1. # Set w_k_norm for preds positions to 1. to avoid division by zero.
|
|
|
|
l_all = (f_all - f_all[boolean_mask[boolean_mask], first_preds].unsqueeze(1)).detach().abs() / w_k_norm.T
|
|
l_all[boolean_mask[
|
|
boolean_mask], first_preds] = np.inf # Set l_k for preds positions to inf, as the argmin for each
|
|
# row will be calculated soon.
|
|
|
|
l_hat = l_all.argmin(dim=1)
|
|
r_i = l_all[boolean_mask[boolean_mask], l_hat].unsqueeze(1) / w_k_norm[
|
|
l_hat, boolean_mask[boolean_mask]].T.unsqueeze(1) * w_k[l_hat, boolean_mask[boolean_mask]]
|
|
|
|
# Update r_tot values.
|
|
r_tot[boolean_mask] += r_i.cpu().numpy()
|
|
|
|
cur_inputs += r_i.reshape([r_i.shape[0]] + list(original_shape))
|
|
|
|
# Re-input the updated sample into the network and get new predictions.
|
|
self.model.no_grad = True
|
|
preds = self.model(cur_inputs).argmax(dim=1)
|
|
self.model.no_grad = False
|
|
|
|
# In DeepFool algorithm, the iteration stops when the updated sample produces a different prediction
|
|
# in the model.
|
|
index_unfinished = (preds == first_preds)
|
|
if torch.all(~index_unfinished):
|
|
break
|
|
|
|
cur_inputs = cur_inputs[index_unfinished]
|
|
first_preds = first_preds[index_unfinished]
|
|
boolean_mask[all_idx[boolean_mask][~index_unfinished.cpu().numpy()]] = False
|
|
|
|
return (r_tot * r_tot).sum(axis=1)
|
|
|
|
def select(self, **kwargs):
|
|
selection_result = self.run()
|
|
return selection_result
|