from .earlytrain import EarlyTrain import torch import numpy as np from .methods_utils import euclidean_dist from ..nets.nets_utils import MyDataParallel def k_center_greedy(matrix, budget: int, metric, device, random_seed=None, index=None, already_selected=None, print_freq: int = 20): if type(matrix) == torch.Tensor: assert matrix.dim() == 2 elif type(matrix) == np.ndarray: assert matrix.ndim == 2 matrix = torch.from_numpy(matrix).requires_grad_(False).to(device) sample_num = matrix.shape[0] assert sample_num >= 1 if budget < 0: raise ValueError("Illegal budget size.") elif budget > sample_num: budget = sample_num if index is not None: assert matrix.shape[0] == len(index) else: index = np.arange(sample_num) assert callable(metric) already_selected = np.array(already_selected) with torch.no_grad(): np.random.seed(random_seed) if already_selected.__len__() == 0: select_result = np.zeros(sample_num, dtype=bool) # Randomly select one initial point. already_selected = [np.random.randint(0, sample_num)] budget -= 1 select_result[already_selected] = True else: select_result = np.in1d(index, already_selected) num_of_already_selected = np.sum(select_result) # Initialize a (num_of_already_selected+budget-1)*sample_num matrix storing distances of pool points from # each clustering center. dis_matrix = -1 * torch.ones([num_of_already_selected + budget - 1, sample_num], requires_grad=False).to(device) dis_matrix[:num_of_already_selected, ~select_result] = metric(matrix[select_result], matrix[~select_result]) mins = torch.min(dis_matrix[:num_of_already_selected, :], dim=0).values for i in range(budget): if i % print_freq == 0: print("| Selecting [%3d/%3d]" % (i + 1, budget)) p = torch.argmax(mins).item() select_result[p] = True if i == budget - 1: break mins[p] = -1 dis_matrix[num_of_already_selected + i, ~select_result] = metric(matrix[[p]], matrix[~select_result]) mins = torch.min(mins, dis_matrix[num_of_already_selected + i]) return index[select_result] class kCenterGreedy(EarlyTrain): def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=0, specific_model="ResNet18", balance: bool = False, already_selected=[], metric="euclidean", torchvision_pretrain: bool = True, **kwargs): super().__init__(dst_train, args, fraction, random_seed, epochs=epochs, specific_model=specific_model, torchvision_pretrain=torchvision_pretrain, **kwargs) if already_selected.__len__() != 0: if min(already_selected) < 0 or max(already_selected) >= self.n_train: raise ValueError("List of already selected points out of the boundary.") self.already_selected = np.array(already_selected) self.min_distances = None if metric == "euclidean": self.metric = euclidean_dist elif callable(metric): self.metric = metric else: self.metric = euclidean_dist self.run = lambda : self.finish_run() def _construct_matrix(index=None): data_loader = torch.utils.data.DataLoader( self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index), batch_size=self.n_train if index is None else len(index), num_workers=self.args.workers) inputs, _ = next(iter(data_loader)) return inputs.flatten(1).requires_grad_(False).to(self.args.device) self.construct_matrix = _construct_matrix self.balance = balance def num_classes_mismatch(self): raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.") def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size): if batch_idx % self.args.print_freq == 0: print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f' % ( epoch, self.epochs, batch_idx + 1, (self.n_pretrain_size // batch_size) + 1, loss.item())) def old_construct_matrix(self, index=None): self.model.eval() self.model.no_grad = True with torch.no_grad(): with self.model.embedding_recorder: sample_num = self.n_train if index is None else len(index) matrix = torch.zeros([sample_num, self.emb_dim], requires_grad=False).to(self.args.device) data_loader = torch.utils.data.DataLoader(self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index), batch_size=self.args.selection_batch, num_workers=self.args.workers) for i, (inputs, _) in enumerate(data_loader): self.model(inputs.to(self.args.device)) matrix[i * self.args.selection_batch:min((i + 1) * self.args.selection_batch, sample_num)] = self.model.embedding_recorder.embedding self.model.no_grad = False return matrix def construct_matrix(self, index=None): self.model.eval() self.model.no_grad = True with torch.no_grad(): with self.model.embedding_recorder: sample_num = self.n_train if index is None else len(index) matrix = [] data_loader = torch.utils.data.DataLoader(self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index), batch_size=self.args.selection_batch, num_workers=self.args.workers) for i, (inputs, _) in enumerate(data_loader): self.model(inputs.to(self.args.device)) matrix.append(self.model.embedding_recorder.embedding) self.model.no_grad = False return torch.cat(matrix, dim=0) def before_run(self): self.emb_dim = self.model.get_last_layer().in_features def finish_run(self): if isinstance(self.model, MyDataParallel): self.model = self.model.module def select(self, **kwargs): self.run() if self.balance: selection_result = np.array([], dtype=np.int32) for c in range(self.args.num_classes): class_index = np.arange(self.n_train)[self.dst_train.targets == c] selection_result = np.append(selection_result, k_center_greedy(self.construct_matrix(class_index), budget=round( self.fraction * len(class_index)), metric=self.metric, device=self.args.device, random_seed=self.random_seed, index=class_index, already_selected=self.already_selected[ np.in1d(self.already_selected, class_index)], print_freq=self.args.print_freq)) else: matrix = self.construct_matrix() del self.model_optimizer del self.model selection_result = k_center_greedy(matrix, budget=self.coreset_size, metric=self.metric, device=self.args.device, random_seed=self.random_seed, already_selected=self.already_selected, print_freq=self.args.print_freq) return {"indices": selection_result}