Files
DAPT/deepcore/methods/uniform.py
2025-10-07 22:42:55 +08:00

35 lines
1.3 KiB
Python

import numpy as np
from .coresetmethod import CoresetMethod
class Uniform(CoresetMethod):
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, balance=True, replace=False, **kwargs):
super().__init__(dst_train, args, fraction, random_seed)
self.balance = balance
self.replace = replace
self.n_train = len(self.dst_train)
def select_balance(self):
"""The same sampling proportions were used in each class separately."""
np.random.seed(self.random_seed)
self.index = np.array([], dtype=np.int64)
all_index = np.arange(self.n_train)
for c in range(self.num_classes):
c_index = (self.dst_train_label == c)
self.index = np.append(self.index,
np.random.choice(all_index[c_index], round(self.fraction * c_index.sum().item()),
replace=self.replace))
return self.index
def select_no_balance(self):
np.random.seed(self.random_seed)
self.index = np.random.choice(np.arange(self.n_train), round(self.n_train * self.fraction),
replace=self.replace)
return self.index
def select(self, **kwargs):
return {"indices": self.select_balance() if self.balance else self.select_no_balance()}