import torch import numpy as np from scipy.linalg import lstsq from scipy.optimize import nnls from .earlytrain import EarlyTrain from ..nets.nets_utils import MyDataParallel # https://github.com/krishnatejakk/GradMatch class GradMatch(EarlyTrain): def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200, specific_model=None, balance=True, dst_val=None, lam: float = 1., **kwargs): super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model, **kwargs) self.balance = balance self.dst_val = dst_val def num_classes_mismatch(self): raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.") def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size): if batch_idx % self.args.print_freq == 0: print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f' % ( epoch, self.epochs, batch_idx + 1, (self.n_pretrain_size // batch_size) + 1, loss.item())) def orthogonal_matching_pursuit(self, A, b, budget: int, lam: float = 1.): '''approximately solves min_x |x|_0 s.t. Ax=b using Orthogonal Matching Pursuit Acknowlegement to: https://github.com/krishnatejakk/GradMatch/blob/main/GradMatch/selectionstrategies/helpers/omp_solvers.py Args: A: design matrix of size (d, n) b: measurement vector of length d budget: selection budget lam: regularization coef. for the final output vector Returns: vector of length n ''' with torch.no_grad(): d, n = A.shape if budget <= 0: budget = 0 elif budget > n: budget = n x = np.zeros(n, dtype=np.float32) resid = b.clone() indices = [] boolean_mask = torch.ones(n, dtype=bool, device="cuda") all_idx = torch.arange(n, device='cuda') for i in range(budget): if i % self.args.print_freq == 0: print("| Selecting [%3d/%3d]" % (i + 1, budget)) projections = torch.matmul(A.T, resid) index = torch.argmax(projections[boolean_mask]) index = all_idx[boolean_mask][index] indices.append(index.item()) boolean_mask[index] = False if indices.__len__() == 1: A_i = A[:, index] x_i = projections[index] / torch.dot(A_i, A_i).view(-1) A_i = A[:, index].view(1, -1) else: A_i = torch.cat((A_i, A[:, index].view(1, -1)), dim=0) temp = torch.matmul(A_i, torch.transpose(A_i, 0, 1)) + lam * torch.eye(A_i.shape[0], device="cuda") x_i, _ = torch.lstsq(torch.matmul(A_i, b).view(-1, 1), temp) resid = b - torch.matmul(torch.transpose(A_i, 0, 1), x_i).view(-1) if budget > 1: x_i = nnls(temp.cpu().numpy(), torch.matmul(A_i, b).view(-1).cpu().numpy())[0] x[indices] = x_i elif budget == 1: x[indices[0]] = 1. return x def orthogonal_matching_pursuit_np(self, A, b, budget: int, lam: float = 1.): '''approximately solves min_x |x|_0 s.t. Ax=b using Orthogonal Matching Pursuit Acknowlegement to: https://github.com/krishnatejakk/GradMatch/blob/main/GradMatch/selectionstrategies/helpers/omp_solvers.py Args: A: design matrix of size (d, n) b: measurement vector of length d budget: selection budget lam: regularization coef. for the final output vector Returns: vector of length n ''' d, n = A.shape if budget <= 0: budget = 0 elif budget > n: budget = n x = np.zeros(n, dtype=np.float32) resid = np.copy(b) indices = [] boolean_mask = np.ones(n, dtype=bool) all_idx = np.arange(n) for i in range(budget): if i % self.args.print_freq == 0: print("| Selecting [%3d/%3d]" % (i + 1, budget)) projections = A.T.dot(resid) index = np.argmax(projections[boolean_mask]) index = all_idx[boolean_mask][index] indices.append(index.item()) boolean_mask[index] = False if indices.__len__() == 1: A_i = A[:, index] x_i = projections[index] / A_i.T.dot(A_i) else: A_i = np.vstack([A_i, A[:, index]]) x_i = lstsq(A_i.dot(A_i.T) + lam * np.identity(A_i.shape[0]), A_i.dot(b))[0] resid = b - A_i.T.dot(x_i) if budget > 1: x_i = nnls(A_i.dot(A_i.T) + lam * np.identity(A_i.shape[0]), A_i.dot(b))[0] x[indices] = x_i elif budget == 1: x[indices[0]] = 1. return x def calc_gradient(self, index=None, val=False): self.model.eval() if val: batch_loader = torch.utils.data.DataLoader( self.dst_val if index is None else torch.utils.data.Subset(self.dst_val, index), batch_size=self.args.selection_batch, num_workers=self.args.workers) sample_num = len(self.dst_val.targets) if index is None else len(index) else: batch_loader = torch.utils.data.DataLoader( self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index), batch_size=self.args.selection_batch, num_workers=self.args.workers) sample_num = self.n_train if index is None else len(index) self.embedding_dim = self.model.get_last_layer().in_features gradients = torch.zeros([sample_num, self.args.num_classes * (self.embedding_dim + 1)], requires_grad=False, device=self.args.device) for i, (input, targets) in enumerate(batch_loader): self.model_optimizer.zero_grad() outputs = self.model(input.to(self.args.device)).requires_grad_(True) loss = self.criterion(outputs, targets.to(self.args.device)).sum() batch_num = targets.shape[0] with torch.no_grad(): bias_parameters_grads = torch.autograd.grad(loss, outputs, retain_graph=True)[0].cpu() weight_parameters_grads = self.model.embedding_recorder.embedding.cpu().view(batch_num, 1, self.embedding_dim).repeat(1,self.args.num_classes,1) *\ bias_parameters_grads.view(batch_num, self.args.num_classes, 1).repeat(1, 1, self.embedding_dim) gradients[i * self.args.selection_batch:min((i + 1) * self.args.selection_batch, sample_num)] =\ torch.cat([bias_parameters_grads, weight_parameters_grads.flatten(1)], dim=1) return gradients def finish_run(self): if isinstance(self.model, MyDataParallel): self.model = self.model.module self.model.no_grad = True with self.model.embedding_recorder: if self.dst_val is not None: val_num = len(self.dst_val.targets) if self.balance: selection_result = np.array([], dtype=np.int64) weights = np.array([], dtype=np.float32) for c in range(self.args.num_classes): class_index = np.arange(self.n_train)[self.dst_train.targets == c] cur_gradients = self.calc_gradient(class_index) if self.dst_val is not None: # Also calculate gradients of the validation set. val_class_index = np.arange(val_num)[self.dst_val.targets == c] cur_val_gradients = torch.mean(self.calc_gradient(val_class_index, val=True), dim=0) else: cur_val_gradients = torch.mean(cur_gradients, dim=0) if self.args.device == "cpu": # Compute OMP on numpy cur_weights = self.orthogonal_matching_pursuit_np(cur_gradients.numpy().T, cur_val_gradients.numpy(), budget=round(len(class_index) * self.fraction)) else: cur_weights = self.orthogonal_matching_pursuit(cur_gradients.to(self.args.device).T, cur_val_gradients.to(self.args.device), budget=round(len(class_index) * self.fraction)) selection_result = np.append(selection_result, class_index[np.nonzero(cur_weights)[0]]) weights = np.append(weights, cur_weights[np.nonzero(cur_weights)[0]]) else: cur_gradients = self.calc_gradient() if self.dst_val is not None: # Also calculate gradients of the validation set. cur_val_gradients = torch.mean(self.calc_gradient(val=True), dim=0) else: cur_val_gradients = torch.mean(cur_gradients, dim=0) if self.args.device == "cpu": # Compute OMP on numpy cur_weights = self.orthogonal_matching_pursuit_np(cur_gradients.numpy().T, cur_val_gradients.numpy(), budget=self.coreset_size) else: cur_weights = self.orthogonal_matching_pursuit(cur_gradients.T, cur_val_gradients, budget=self.coreset_size) selection_result = np.nonzero(cur_weights)[0] weights = cur_weights[selection_result] self.model.no_grad = False return {"indices": selection_result, "weights": weights} def select(self, **kwargs): selection_result = self.run() return selection_result