50 lines
2.0 KiB
Python
50 lines
2.0 KiB
Python
import numpy as np
|
|
import os
|
|
|
|
class CoresetMethod(object):
|
|
def __init__(self, dst_train, args, fraction=0.5, random_seed=None,**kwargs):
|
|
if fraction <= 0.0 or fraction > 1.0:
|
|
raise ValueError("Illegal Coreset Size.")
|
|
|
|
self.dm = dst_train
|
|
self.dst_train = dst_train.dataset.train_x
|
|
self.num_classes = dst_train.dataset.num_classes
|
|
self.fraction = fraction
|
|
self.random_seed = random_seed
|
|
self.index = []
|
|
self.args = args
|
|
self.dst_train_label = self.get_train_label(self.dst_train)
|
|
self.n_train = len(self.dst_train)
|
|
self.coreset_size = round(self.n_train * fraction)
|
|
self.max_epoch = self.args.OPTIM_SELECTION.MAX_EPOCH
|
|
|
|
def select(self, **kwargs):
|
|
return
|
|
|
|
def get_train_label(self,dst_train):
|
|
####Readable
|
|
ind = []
|
|
for i,item in enumerate(dst_train):
|
|
ind.append(item.label)
|
|
return np.asarray(ind)
|
|
def pre_run(self):
|
|
self.train_indx = np.arange(self.n_train)
|
|
print(f'Start pre-funing CLIP with all datasets by {self.max_epoch} epoch')
|
|
file_save_name = self.args.DATASET.NAME + '_' + str(self.args.SEED) + '.pth'
|
|
output_checkpoint_dir = os.path.join('checkpoints', file_save_name)
|
|
if self.max_epoch > 0:
|
|
|
|
if os.path.exists(output_checkpoint_dir):
|
|
print(f'The checkpiont exists! Load that shit')
|
|
ckpt = torch.load(output_checkpoint_dir)
|
|
self.model.load_state_dict(ckpt)
|
|
else:
|
|
for epoch in range(self.epoch, self.max_epoch):
|
|
# list_of_train_idx = np.random.choice(np.arange(self.n_pretrain if self.if_dst_pretrain else self.n_train),
|
|
# self.n_pretrain_size, replace=False)
|
|
self.before_epoch() # PASS
|
|
self.train(epoch)
|
|
self.test(epoch)
|
|
self.after_epoch()
|
|
torch.save(self.model.state_dict(), output_checkpoint_dir)
|