Files
DAPT/deepcore/methods/herding.py
2025-10-07 22:42:55 +08:00

110 lines
4.4 KiB
Python

from .earlytrain import EarlyTrain
import torch
import numpy as np
from .methods_utils import euclidean_dist
from ..nets.nets_utils import MyDataParallel
class Herding(EarlyTrain):
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200,
specific_model="ResNet18", balance: bool = False, metric="euclidean", **kwargs):
super().__init__(dst_train, args, fraction, random_seed, epochs=epochs, specific_model=specific_model, **kwargs)
if metric == "euclidean":
self.metric = euclidean_dist
elif callable(metric):
self.metric = metric
else:
self.metric = euclidean_dist
self.run = lambda: self.finish_run()
def _construct_matrix(index=None):
data_loader = torch.utils.data.DataLoader(
self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index),
batch_size=self.n_train if index is None else len(index), num_workers=self.args.workers)
inputs, _ = next(iter(data_loader))
return inputs.flatten(1).requires_grad_(False).to(self.args.device)
self.construct_matrix = _construct_matrix
self.balance = balance
self.select_bs = self.args.DATASET.SELECTION_BATCH_SIZE
def num_classes_mismatch(self):
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
pass
#Initial achievement, may not optimal
def mixing_feature(self,img_fea,text_fea,lam=0.5):
# return img_fea
return lam*img_fea + (1-lam)*text_fea
def construct_matrix(self, index=None):
self.model.eval()
self.model.no_grad = True
with torch.no_grad():
# with self.model.embedding_recorder:
sample_num = self.n_train if index is None else len(index)
matrix = torch.zeros([sample_num, self.emb_dim], requires_grad=False).cuda()
data_loader = self.select_dm(self.dst_train,index,is_train=False)
for i, batch in enumerate(data_loader):
image,label = batch['img'].cuda(),batch['label'].cuda()
img_f,text_f,_ = self.model(image, label, record=True)
final_embed = self.mixing_feature(img_f,text_f) #Using the mixed image_feature and text_feature
matrix[i * self.select_bs:min((i + 1) * self.select_bs, sample_num)] = final_embed
self.model.no_grad = False
self.model.train()
return matrix
def before_run(self):
self.emb_dim = self.model.image_encoder.output_dim
def herding(self, matrix, budget: int, index=None):
sample_num = matrix.shape[0]
if budget < 0:
raise ValueError("Illegal budget size.")
elif budget > sample_num:
budget = sample_num
indices = np.arange(sample_num)
with torch.no_grad():
mu = torch.mean(matrix, dim=0)
select_result = np.zeros(sample_num, dtype=bool)
for i in range(budget):
if i % self.args.TRAIN.PRINT_FREQ == 0:
print("| Selecting [%3d/%3d]" % (i + 1, budget))
dist = self.metric(((i + 1) * mu - torch.sum(matrix[select_result], dim=0)).view(1, -1),
matrix[~select_result])
p = torch.argmax(dist).item()
p = indices[~select_result][p]
select_result[p] = True
if index is None:
index = indices
return index[select_result]
def finish_run(self):
if isinstance(self.model, MyDataParallel):
self.model = self.model.module
if self.balance:
selection_result = np.array([], dtype=np.int32)
for c in range(self.num_classes):
class_index = np.arange(self.n_train)[self.dst_train_label == c]
selection_result = np.append(selection_result, self.herding(self.construct_matrix(class_index),
budget=round(self.fraction * len(class_index)), index=class_index))
else:
selection_result = self.herding(self.construct_matrix(), budget=self.coreset_size)
return {"indices": selection_result}
def select(self, **kwargs):
selection_result = self.run()
return selection_result