Upload to Main
This commit is contained in:
109
deepcore/methods/herding.py
Normal file
109
deepcore/methods/herding.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from .earlytrain import EarlyTrain
|
||||
import torch
|
||||
import numpy as np
|
||||
from .methods_utils import euclidean_dist
|
||||
from ..nets.nets_utils import MyDataParallel
|
||||
|
||||
|
||||
class Herding(EarlyTrain):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200,
|
||||
specific_model="ResNet18", balance: bool = False, metric="euclidean", **kwargs):
|
||||
super().__init__(dst_train, args, fraction, random_seed, epochs=epochs, specific_model=specific_model, **kwargs)
|
||||
|
||||
if metric == "euclidean":
|
||||
self.metric = euclidean_dist
|
||||
elif callable(metric):
|
||||
self.metric = metric
|
||||
else:
|
||||
self.metric = euclidean_dist
|
||||
self.run = lambda: self.finish_run()
|
||||
|
||||
def _construct_matrix(index=None):
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index),
|
||||
batch_size=self.n_train if index is None else len(index), num_workers=self.args.workers)
|
||||
inputs, _ = next(iter(data_loader))
|
||||
return inputs.flatten(1).requires_grad_(False).to(self.args.device)
|
||||
|
||||
self.construct_matrix = _construct_matrix
|
||||
|
||||
self.balance = balance
|
||||
self.select_bs = self.args.DATASET.SELECTION_BATCH_SIZE
|
||||
|
||||
def num_classes_mismatch(self):
|
||||
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
|
||||
|
||||
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
||||
pass
|
||||
|
||||
#Initial achievement, may not optimal
|
||||
def mixing_feature(self,img_fea,text_fea,lam=0.5):
|
||||
# return img_fea
|
||||
return lam*img_fea + (1-lam)*text_fea
|
||||
|
||||
def construct_matrix(self, index=None):
|
||||
self.model.eval()
|
||||
self.model.no_grad = True
|
||||
with torch.no_grad():
|
||||
# with self.model.embedding_recorder:
|
||||
sample_num = self.n_train if index is None else len(index)
|
||||
matrix = torch.zeros([sample_num, self.emb_dim], requires_grad=False).cuda()
|
||||
data_loader = self.select_dm(self.dst_train,index,is_train=False)
|
||||
for i, batch in enumerate(data_loader):
|
||||
image,label = batch['img'].cuda(),batch['label'].cuda()
|
||||
img_f,text_f,_ = self.model(image, label, record=True)
|
||||
final_embed = self.mixing_feature(img_f,text_f) #Using the mixed image_feature and text_feature
|
||||
matrix[i * self.select_bs:min((i + 1) * self.select_bs, sample_num)] = final_embed
|
||||
|
||||
self.model.no_grad = False
|
||||
self.model.train()
|
||||
return matrix
|
||||
|
||||
def before_run(self):
|
||||
self.emb_dim = self.model.image_encoder.output_dim
|
||||
|
||||
def herding(self, matrix, budget: int, index=None):
|
||||
|
||||
sample_num = matrix.shape[0]
|
||||
|
||||
if budget < 0:
|
||||
raise ValueError("Illegal budget size.")
|
||||
elif budget > sample_num:
|
||||
budget = sample_num
|
||||
|
||||
indices = np.arange(sample_num)
|
||||
with torch.no_grad():
|
||||
mu = torch.mean(matrix, dim=0)
|
||||
select_result = np.zeros(sample_num, dtype=bool)
|
||||
|
||||
for i in range(budget):
|
||||
if i % self.args.TRAIN.PRINT_FREQ == 0:
|
||||
print("| Selecting [%3d/%3d]" % (i + 1, budget))
|
||||
dist = self.metric(((i + 1) * mu - torch.sum(matrix[select_result], dim=0)).view(1, -1),
|
||||
matrix[~select_result])
|
||||
p = torch.argmax(dist).item()
|
||||
p = indices[~select_result][p]
|
||||
select_result[p] = True
|
||||
if index is None:
|
||||
index = indices
|
||||
return index[select_result]
|
||||
|
||||
def finish_run(self):
|
||||
if isinstance(self.model, MyDataParallel):
|
||||
self.model = self.model.module
|
||||
|
||||
if self.balance:
|
||||
selection_result = np.array([], dtype=np.int32)
|
||||
for c in range(self.num_classes):
|
||||
class_index = np.arange(self.n_train)[self.dst_train_label == c]
|
||||
selection_result = np.append(selection_result, self.herding(self.construct_matrix(class_index),
|
||||
budget=round(self.fraction * len(class_index)), index=class_index))
|
||||
else:
|
||||
selection_result = self.herding(self.construct_matrix(), budget=self.coreset_size)
|
||||
return {"indices": selection_result}
|
||||
|
||||
def select(self, **kwargs):
|
||||
selection_result = self.run()
|
||||
return selection_result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user