from .earlytrain import EarlyTrain import torch from .methods_utils import FacilityLocation, submodular_optimizer import numpy as np from .methods_utils.euclidean import euclidean_dist_pair_np from ..nets.nets_utils import MyDataParallel from tqdm import tqdm class Craig(EarlyTrain): def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200, specific_model=None, balance=True, greedy="LazyGreedy", **kwargs): super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model, **kwargs) if greedy not in submodular_optimizer.optimizer_choices: raise ModuleNotFoundError("Greedy optimizer not found.") self._greedy = greedy self.balance = balance def before_train(self): pass def after_loss(self, outputs, loss, targets, batch_inds, epoch): pass def before_epoch(self): pass def after_epoch(self): pass def before_run(self): pass def num_classes_mismatch(self): raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.") # def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size): # if batch_idx % self.args.print_freq == 0: # print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f' % ( # epoch, self.epochs, batch_idx + 1, (self.n_pretrain_size // batch_size) + 1, loss.item())) # def calc_gradient(self, index=None): # self.model.eval() # # batch_loader = torch.utils.data.DataLoader( # self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index), # batch_size=self.args.selection_batch, num_workers=self.args.workers) # sample_num = len(self.dst_val.targets) if index is None else len(index) # self.embedding_dim = self.model.get_last_layer().in_features # # gradients = [] # # for i, (input, targets) in enumerate(batch_loader): # self.model_optimizer.zero_grad() # outputs = self.model(input.to(self.args.device)) # loss = self.criterion(outputs.requires_grad_(True), # targets.to(self.args.device)).sum() # batch_num = targets.shape[0] # with torch.no_grad(): # bias_parameters_grads = torch.autograd.grad(loss, outputs)[0] # weight_parameters_grads = self.model.embedding_recorder.embedding.view(batch_num, 1, # self.embedding_dim).repeat(1, # self.args.num_classes, # 1) * bias_parameters_grads.view( # batch_num, self.args.num_classes, 1).repeat(1, 1, self.embedding_dim) # gradients.append( # torch.cat([bias_parameters_grads, weight_parameters_grads.flatten(1)], dim=1).cpu().numpy()) # # gradients = np.concatenate(gradients, axis=0) # # self.model.train() # return euclidean_dist_pair_np(gradients) def calc_weights(self, matrix, result): min_sample = np.argmax(matrix[result], axis=0) weights = np.ones(np.sum(result) if result.dtype == bool else len(result)) for i in min_sample: weights[i] = weights[i] + 1 return weights def finish_run(self): if isinstance(self.model, MyDataParallel): self.model = self.model.module self.model.no_grad = True grad = self.calc_gradient() grad_matrix = euclidean_dist_pair_np(grad) # with self.model.embedding_recorder: if self.balance: # Do selection by class selection_result = np.array([], dtype=np.int32) weights = np.array([]) for c in tqdm(range(self.num_classes)): class_index = np.arange(self.n_train)[self.dst_train_label == c] matrix = -1. * grad_matrix[class_index[:,None],class_index] # Change to column index # matrix = -1. * self.calc_gradient(class_index) matrix -= np.min(matrix) - 1e-3 #The least is zero submod_function = FacilityLocation(index=class_index, similarity_matrix=matrix) submod_optimizer = submodular_optimizer.__dict__[self._greedy](args=self.args, index=class_index, budget=round(self.fraction * len( class_index))) class_result = submod_optimizer.select(gain_function=submod_function.calc_gain, update_state=submod_function.update_state) selection_result = np.append(selection_result, class_result) weights = np.append(weights, self.calc_weights(matrix, np.isin(class_index, class_result))) else: matrix = np.zeros([self.n_train, self.n_train]) all_index = np.arange(self.n_train) for c in range(self.num_classes): # Sparse Matrix class_index = np.arange(self.n_train)[self.dst_train_label== c] matrix[np.ix_(class_index, class_index)] = -1. * self.calc_gradient(class_index) matrix[np.ix_(class_index, class_index)] -= np.min(matrix[np.ix_(class_index, class_index)]) - 1e-3 submod_function = FacilityLocation(index=all_index, similarity_matrix=matrix) submod_optimizer = submodular_optimizer.__dict__[self._greedy](args=self.args, index=all_index, budget=self.coreset_size) selection_result = submod_optimizer.select(gain_function=submod_function.calc_gain_batch, update_state=submod_function.update_state, batch=self.args.selection_batch) weights = self.calc_weights(matrix, selection_result) self.model.no_grad = False return {"indices": selection_result, "weights": weights} def select(self, **kwargs): selection_result = self.run() return selection_result