146 lines
7.2 KiB
Python
146 lines
7.2 KiB
Python
from .earlytrain import EarlyTrain
|
|
from .methods_utils.euclidean import euclidean_dist_pair_np
|
|
from .methods_utils.cossim import cossim_pair_np
|
|
import numpy as np
|
|
import torch
|
|
from tqdm import tqdm
|
|
from .. import nets
|
|
from copy import deepcopy
|
|
from torchvision import transforms
|
|
|
|
|
|
class Cal(EarlyTrain):
|
|
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200, specific_model=None,
|
|
balance=False, metric="euclidean", neighbors: int = 10, pretrain_model: str = "ResNet18", **kwargs):
|
|
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model, **kwargs)
|
|
|
|
self.balance = balance
|
|
|
|
assert neighbors > 0 and neighbors < 100
|
|
self.neighbors = neighbors
|
|
|
|
if metric == "euclidean":
|
|
self.metric = euclidean_dist_pair_np
|
|
elif metric == "cossim":
|
|
self.metric = lambda a, b: -1. * cossim_pair_np(a, b)
|
|
elif callable(metric):
|
|
self.metric = metric
|
|
else:
|
|
self.metric = euclidean_dist_pair_np
|
|
|
|
self.pretrain_model = pretrain_model
|
|
|
|
def num_classes_mismatch(self):
|
|
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
|
|
|
|
#Initial achievement, may not optimal
|
|
def mixing_feature(self,img_fea,text_fea,lam=0.5):
|
|
# return img_fea
|
|
return lam*img_fea + (1-lam)*text_fea
|
|
|
|
def find_knn(self):
|
|
"""
|
|
Find k-nearest-neighbor data points with the pretrained embedding model
|
|
:return: knn matrix
|
|
"""
|
|
|
|
# Initialize pretrained model
|
|
# model = nets.__dict__[self.pretrain_model](channel=self.args.channel, num_classes=self.args.num_classes,
|
|
# im_size=(224, 224), record_embedding=True, no_grad=True,
|
|
# pretrained=True).to(self.args.device)
|
|
self.model.eval()
|
|
probs = []
|
|
# # Resize dst_train to 224*224
|
|
# if self.args.im_size[0] != 224 or self.args.im_size[1] != 224:
|
|
# dst_train = deepcopy(self.dst_train)
|
|
# dst_train.transform = transforms.Compose([dst_train.transform, transforms.Resize(224)])
|
|
# else:
|
|
# dst_train = self.dst_train
|
|
|
|
# Calculate the distance matrix and return knn results
|
|
if self.balance:
|
|
knn = []
|
|
for c in tqdm(range(self.num_classes)):
|
|
print(f'Start processing class {c}/{self.num_classes}')
|
|
class_index = np.arange(self.n_train)[self.dst_train_label == c]
|
|
|
|
# Start recording embedding vectors
|
|
# batch_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(dst_train, class_index),
|
|
# batch_size=self.args.selection_batch,
|
|
# num_workers=self.args.workers)
|
|
embdeddings = []
|
|
c_probs = np.zeros([len(class_index), self.num_classes])
|
|
data_loader = self.select_dm(self.dst_train, class_index, is_train=False)
|
|
for i, batch in enumerate(data_loader):
|
|
image, label = batch['img'].cuda(), batch['label'].cuda()
|
|
img_f, text_f,logit = self.model(image, label, record=True)
|
|
final_feature = self.mixing_feature(img_f,text_f)
|
|
embdeddings.append(final_feature.cpu().numpy())
|
|
c_probs[i * self.args.DATASET.SELECTION_BATCH_SIZE:(i + 1) * self.args.DATASET.SELECTION_BATCH_SIZE] = \
|
|
torch.softmax(logit, dim=1).detach().cpu()
|
|
|
|
embdeddings = np.concatenate(embdeddings, axis=0)
|
|
probs.append(c_probs)
|
|
knn.append(np.argsort(self.metric(embdeddings), axis=1)[:, 1:(self.neighbors + 1)])
|
|
self.probs = np.concatenate(probs,axis=0)
|
|
return knn
|
|
else:
|
|
# Start recording embedding vectors
|
|
embdeddings = []
|
|
batch_loader = self.select_dm(self.dst_train, None, is_train=False)
|
|
print(f'Start processing all class')
|
|
for i, batch in enumerate(tqdm(batch_loader)):
|
|
image, label = batch['img'].cuda(), batch['label'].cuda()
|
|
img_f, text_f,logit = self.model(image, label, record=True)
|
|
final_feature = self.mixing_feature(img_f, text_f)
|
|
embdeddings.append(final_feature.cpu().numpy())
|
|
probs[i * self.args.DATASET.SELECTION_BATCH_SIZE:(i + 1) * self.args.DATASET.SELECTION_BATCH_SIZE] = \
|
|
torch.softmax(logit, dim=1).detach().cpu()
|
|
embdeddings = np.concatenate(embdeddings, axis=0)
|
|
self.probs = np.concatenate(probs, axis=0)
|
|
return np.argsort(self.metric(embdeddings), axis=1)[:, 1:(self.neighbors + 1)]
|
|
|
|
def calc_kl(self, knn, index=None):
|
|
self.model.eval()
|
|
self.model.no_grad = True
|
|
sample_num = self.n_train if index is None else len(index)
|
|
# probs = np.zeros([sample_num, self.num_classes])
|
|
#
|
|
# batch_loader = torch.utils.data.DataLoader(
|
|
# self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index),
|
|
# batch_size=self.args.selection_batch, num_workers=self.args.workers)
|
|
# batch_num = len(batch_loader)
|
|
#
|
|
# for i, (inputs, _) in enumerate(batch_loader):
|
|
# probs[i * self.args.selection_batch:(i + 1) * self.args.selection_batch] = torch.nn.functional.softmax(
|
|
# self.model(inputs.to(self.args.device)), dim=1).detach().cpu()
|
|
probs = self.probs[index]
|
|
s = np.zeros(sample_num)
|
|
for i in range(0, sample_num, self.args.DATASET.SELECTION_BATCH_SIZE):
|
|
|
|
print("| Caculating KL-divergence for batch [%3d/%3d] with batchsize [%3d]" % (i, sample_num, self.args.DATASET.SELECTION_BATCH_SIZE))
|
|
aa = np.expand_dims(probs[i:(i + self.args.DATASET.SELECTION_BATCH_SIZE)], 1).repeat(self.neighbors, 1)
|
|
bb = probs[knn[i:(i + self.args.DATASET.SELECTION_BATCH_SIZE)], :]
|
|
s[i:(i + self.args.DATASET.SELECTION_BATCH_SIZE)] = np.mean(
|
|
np.sum(0.5 * aa * np.log(aa / bb) + 0.5 * bb * np.log(bb / aa), axis=2), axis=1)
|
|
self.model.no_grad = False
|
|
return s
|
|
|
|
def finish_run(self):
|
|
scores=[]
|
|
if self.balance:
|
|
selection_result = np.array([], dtype=np.int32)
|
|
for c, knn in zip(range(self.num_classes), self.knn):
|
|
class_index = np.arange(self.n_train)[self.dst_train_label == c]
|
|
scores.append(self.calc_kl(knn, class_index))
|
|
selection_result = np.append(selection_result, class_index[np.argsort(
|
|
#self.calc_kl(knn, class_index))[::1][:round(self.fraction * len(class_index))]])
|
|
scores[-1])[::1][:round(self.fraction * len(class_index))]])
|
|
else:
|
|
selection_result = np.argsort(self.calc_kl(self.knn))[::1][:self.coreset_size]
|
|
return {"indices": selection_result, "scores":scores}
|
|
|
|
def select(self, **kwargs):
|
|
self.knn = self.find_knn()
|
|
selection_result = self.run()
|
|
return selection_result |