Upload to Main
This commit is contained in:
49
deepcore/methods/coresetmethod.py
Normal file
49
deepcore/methods/coresetmethod.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
class CoresetMethod(object):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None,**kwargs):
|
||||
if fraction <= 0.0 or fraction > 1.0:
|
||||
raise ValueError("Illegal Coreset Size.")
|
||||
|
||||
self.dm = dst_train
|
||||
self.dst_train = dst_train.dataset.train_x
|
||||
self.num_classes = dst_train.dataset.num_classes
|
||||
self.fraction = fraction
|
||||
self.random_seed = random_seed
|
||||
self.index = []
|
||||
self.args = args
|
||||
self.dst_train_label = self.get_train_label(self.dst_train)
|
||||
self.n_train = len(self.dst_train)
|
||||
self.coreset_size = round(self.n_train * fraction)
|
||||
self.max_epoch = self.args.OPTIM_SELECTION.MAX_EPOCH
|
||||
|
||||
def select(self, **kwargs):
|
||||
return
|
||||
|
||||
def get_train_label(self,dst_train):
|
||||
####Readable
|
||||
ind = []
|
||||
for i,item in enumerate(dst_train):
|
||||
ind.append(item.label)
|
||||
return np.asarray(ind)
|
||||
def pre_run(self):
|
||||
self.train_indx = np.arange(self.n_train)
|
||||
print(f'Start pre-funing CLIP with all datasets by {self.max_epoch} epoch')
|
||||
file_save_name = self.args.DATASET.NAME + '_' + str(self.args.SEED) + '.pth'
|
||||
output_checkpoint_dir = os.path.join('checkpoints', file_save_name)
|
||||
if self.max_epoch > 0:
|
||||
|
||||
if os.path.exists(output_checkpoint_dir):
|
||||
print(f'The checkpiont exists! Load that shit')
|
||||
ckpt = torch.load(output_checkpoint_dir)
|
||||
self.model.load_state_dict(ckpt)
|
||||
else:
|
||||
for epoch in range(self.epoch, self.max_epoch):
|
||||
# list_of_train_idx = np.random.choice(np.arange(self.n_pretrain if self.if_dst_pretrain else self.n_train),
|
||||
# self.n_pretrain_size, replace=False)
|
||||
self.before_epoch() # PASS
|
||||
self.train(epoch)
|
||||
self.test(epoch)
|
||||
self.after_epoch()
|
||||
torch.save(self.model.state_dict(), output_checkpoint_dir)
|
||||
Reference in New Issue
Block a user