110 lines
4.4 KiB
Python
110 lines
4.4 KiB
Python
from .earlytrain import EarlyTrain
|
|
import torch
|
|
import numpy as np
|
|
from .methods_utils import euclidean_dist
|
|
from ..nets.nets_utils import MyDataParallel
|
|
|
|
|
|
class Herding(EarlyTrain):
|
|
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200,
|
|
specific_model="ResNet18", balance: bool = False, metric="euclidean", **kwargs):
|
|
super().__init__(dst_train, args, fraction, random_seed, epochs=epochs, specific_model=specific_model, **kwargs)
|
|
|
|
if metric == "euclidean":
|
|
self.metric = euclidean_dist
|
|
elif callable(metric):
|
|
self.metric = metric
|
|
else:
|
|
self.metric = euclidean_dist
|
|
self.run = lambda: self.finish_run()
|
|
|
|
def _construct_matrix(index=None):
|
|
data_loader = torch.utils.data.DataLoader(
|
|
self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index),
|
|
batch_size=self.n_train if index is None else len(index), num_workers=self.args.workers)
|
|
inputs, _ = next(iter(data_loader))
|
|
return inputs.flatten(1).requires_grad_(False).to(self.args.device)
|
|
|
|
self.construct_matrix = _construct_matrix
|
|
|
|
self.balance = balance
|
|
self.select_bs = self.args.DATASET.SELECTION_BATCH_SIZE
|
|
|
|
def num_classes_mismatch(self):
|
|
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
|
|
|
|
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
|
pass
|
|
|
|
#Initial achievement, may not optimal
|
|
def mixing_feature(self,img_fea,text_fea,lam=0.5):
|
|
# return img_fea
|
|
return lam*img_fea + (1-lam)*text_fea
|
|
|
|
def construct_matrix(self, index=None):
|
|
self.model.eval()
|
|
self.model.no_grad = True
|
|
with torch.no_grad():
|
|
# with self.model.embedding_recorder:
|
|
sample_num = self.n_train if index is None else len(index)
|
|
matrix = torch.zeros([sample_num, self.emb_dim], requires_grad=False).cuda()
|
|
data_loader = self.select_dm(self.dst_train,index,is_train=False)
|
|
for i, batch in enumerate(data_loader):
|
|
image,label = batch['img'].cuda(),batch['label'].cuda()
|
|
img_f,text_f,_ = self.model(image, label, record=True)
|
|
final_embed = self.mixing_feature(img_f,text_f) #Using the mixed image_feature and text_feature
|
|
matrix[i * self.select_bs:min((i + 1) * self.select_bs, sample_num)] = final_embed
|
|
|
|
self.model.no_grad = False
|
|
self.model.train()
|
|
return matrix
|
|
|
|
def before_run(self):
|
|
self.emb_dim = self.model.image_encoder.output_dim
|
|
|
|
def herding(self, matrix, budget: int, index=None):
|
|
|
|
sample_num = matrix.shape[0]
|
|
|
|
if budget < 0:
|
|
raise ValueError("Illegal budget size.")
|
|
elif budget > sample_num:
|
|
budget = sample_num
|
|
|
|
indices = np.arange(sample_num)
|
|
with torch.no_grad():
|
|
mu = torch.mean(matrix, dim=0)
|
|
select_result = np.zeros(sample_num, dtype=bool)
|
|
|
|
for i in range(budget):
|
|
if i % self.args.TRAIN.PRINT_FREQ == 0:
|
|
print("| Selecting [%3d/%3d]" % (i + 1, budget))
|
|
dist = self.metric(((i + 1) * mu - torch.sum(matrix[select_result], dim=0)).view(1, -1),
|
|
matrix[~select_result])
|
|
p = torch.argmax(dist).item()
|
|
p = indices[~select_result][p]
|
|
select_result[p] = True
|
|
if index is None:
|
|
index = indices
|
|
return index[select_result]
|
|
|
|
def finish_run(self):
|
|
if isinstance(self.model, MyDataParallel):
|
|
self.model = self.model.module
|
|
|
|
if self.balance:
|
|
selection_result = np.array([], dtype=np.int32)
|
|
for c in range(self.num_classes):
|
|
class_index = np.arange(self.n_train)[self.dst_train_label == c]
|
|
selection_result = np.append(selection_result, self.herding(self.construct_matrix(class_index),
|
|
budget=round(self.fraction * len(class_index)), index=class_index))
|
|
else:
|
|
selection_result = self.herding(self.construct_matrix(), budget=self.coreset_size)
|
|
return {"indices": selection_result}
|
|
|
|
def select(self, **kwargs):
|
|
selection_result = self.run()
|
|
return selection_result
|
|
|
|
|