import numpy as np from .coresetmethod import CoresetMethod class Uniform(CoresetMethod): def __init__(self, dst_train, args, fraction=0.5, random_seed=None, balance=True, replace=False, **kwargs): super().__init__(dst_train, args, fraction, random_seed) self.balance = balance self.replace = replace self.n_train = len(self.dst_train) def select_balance(self): """The same sampling proportions were used in each class separately.""" np.random.seed(self.random_seed) self.index = np.array([], dtype=np.int64) all_index = np.arange(self.n_train) for c in range(self.num_classes): c_index = (self.dst_train_label == c) self.index = np.append(self.index, np.random.choice(all_index[c_index], round(self.fraction * c_index.sum().item()), replace=self.replace)) return self.index def select_no_balance(self): np.random.seed(self.random_seed) self.index = np.random.choice(np.arange(self.n_train), round(self.n_train * self.fraction), replace=self.replace) return self.index def select(self, **kwargs): return {"indices": self.select_balance() if self.balance else self.select_no_balance()}