Files
DAPT/deepcore/methods/coresetmethod.py
2025-10-07 22:42:55 +08:00

50 lines
2.0 KiB
Python

import numpy as np
import os
class CoresetMethod(object):
def __init__(self, dst_train, args, fraction=0.5, random_seed=None,**kwargs):
if fraction <= 0.0 or fraction > 1.0:
raise ValueError("Illegal Coreset Size.")
self.dm = dst_train
self.dst_train = dst_train.dataset.train_x
self.num_classes = dst_train.dataset.num_classes
self.fraction = fraction
self.random_seed = random_seed
self.index = []
self.args = args
self.dst_train_label = self.get_train_label(self.dst_train)
self.n_train = len(self.dst_train)
self.coreset_size = round(self.n_train * fraction)
self.max_epoch = self.args.OPTIM_SELECTION.MAX_EPOCH
def select(self, **kwargs):
return
def get_train_label(self,dst_train):
####Readable
ind = []
for i,item in enumerate(dst_train):
ind.append(item.label)
return np.asarray(ind)
def pre_run(self):
self.train_indx = np.arange(self.n_train)
print(f'Start pre-funing CLIP with all datasets by {self.max_epoch} epoch')
file_save_name = self.args.DATASET.NAME + '_' + str(self.args.SEED) + '.pth'
output_checkpoint_dir = os.path.join('checkpoints', file_save_name)
if self.max_epoch > 0:
if os.path.exists(output_checkpoint_dir):
print(f'The checkpiont exists! Load that shit')
ckpt = torch.load(output_checkpoint_dir)
self.model.load_state_dict(ckpt)
else:
for epoch in range(self.epoch, self.max_epoch):
# list_of_train_idx = np.random.choice(np.arange(self.n_pretrain if self.if_dst_pretrain else self.n_train),
# self.n_pretrain_size, replace=False)
self.before_epoch() # PASS
self.train(epoch)
self.test(epoch)
self.after_epoch()
torch.save(self.model.state_dict(), output_checkpoint_dir)