from .coresetmethod import CoresetMethod import torch, time from torch import nn import numpy as np from copy import deepcopy from .. import nets from torchvision import transforms from datasets.data_manager import select_dm_loader from dassl.utils import MetricMeter, AverageMeter from torch.cuda.amp import GradScaler, autocast import datetime from tqdm import tqdm import os class EarlyTrain(CoresetMethod): ''' Core code for training related to coreset selection methods when pre-training is required. ''' def __init__(self, dst_train, args,fraction=0.5, random_seed=None, epochs=200, specific_model=None, torchvision_pretrain: bool = False, dst_pretrain_dict: dict = {}, fraction_pretrain=1., dst_test=None, **kwargs): super().__init__(dst_train, args, fraction, random_seed) self.epochs = epochs self.n_train = len(self.dst_train) self.coreset_size = round(self.n_train * fraction) self.model = specific_model self.train_loader = self.dm.train_loader_x self.test_loader = self.dm.test_loader if kwargs: # self.text_feature = kwargs['text_feature'] self.optim = kwargs['optim'] self.sche = kwargs['schedule'] self.scar = kwargs['scar'] self.start_epoch = self.epoch = 0 self.max_epoch = self.args.OPTIM_SELECTION.MAX_EPOCH if fraction_pretrain <= 0. or fraction_pretrain > 1.: raise ValueError("Illegal pretrain fraction value.") self.fraction_pretrain = fraction_pretrain if dst_pretrain_dict.__len__() != 0: dict_keys = dst_pretrain_dict.keys() if 'im_size' not in dict_keys or 'channel' not in dict_keys or 'dst_train' not in dict_keys or \ 'num_classes' not in dict_keys: raise AttributeError( 'Argument dst_pretrain_dict must contain imszie, channel, dst_train and num_classes.') if dst_pretrain_dict['im_size'][0] != args.im_size[0] or dst_pretrain_dict['im_size'][0] != args.im_size[0]: raise ValueError("im_size of pretrain dataset does not match that of the training dataset.") if dst_pretrain_dict['channel'] != args.channel: raise ValueError("channel of pretrain dataset does not match that of the training dataset.") if dst_pretrain_dict['num_classes'] != args.num_classes: self.num_classes_mismatch() self.dst_pretrain_dict = dst_pretrain_dict self.torchvision_pretrain = torchvision_pretrain self.if_dst_pretrain = (len(self.dst_pretrain_dict) != 0) if torchvision_pretrain: # Pretrained models in torchvision only accept 224*224 inputs, therefore we resize current # datasets to 224*224. if args.im_size[0] != 224 or args.im_size[1] != 224: self.dst_train = deepcopy(dst_train) self.dst_train.transform = transforms.Compose([self.dst_train.transform, transforms.Resize(224)]) if self.if_dst_pretrain: self.dst_pretrain_dict['dst_train'] = deepcopy(dst_pretrain_dict['dst_train']) self.dst_pretrain_dict['dst_train'].transform = transforms.Compose( [self.dst_pretrain_dict['dst_train'].transform, transforms.Resize(224)]) if self.if_dst_pretrain: self.n_pretrain = len(self.dst_pretrain_dict['dst_train']) self.n_pretrain_size = round( self.fraction_pretrain * (self.n_pretrain if self.if_dst_pretrain else self.n_train)) self.dst_test = dst_test def train(self, epoch, list_of_train_idx=None, **kwargs): """ Train model for one epoch """ self.before_train() self.model.train() losses = MetricMeter() batch_time = AverageMeter() data_time = AverageMeter() end = time.time() print('\n=> Training Pre-tuning Epoch #%d' % epoch) train_loader = select_dm_loader(self.args,self.dst_train,is_train=True) self.num_batches = len(train_loader) # trainset_permutation_inds = np.random.permutation(list_of_train_idx) # batch_sampler = torch.utils.data.BatchSampler(trainset_permutation_inds, batch_size=self.args.selection_batch, # drop_last=False) # trainset_permutation_inds = list(batch_sampler) # # train_loader = torch.utils.data.DataLoader(self.dst_pretrain_dict['dst_train'] if self.if_dst_pretrain # else self.dst_train, shuffle=False, batch_sampler=batch_sampler, # # # num_workers=self.args.workers, pin_memory=True) for i, batch in enumerate(train_loader): data_time.update(time.time() - end) image, label,real_ind = batch['img'].cuda(),batch['label'].cuda(),batch['index'].cuda() model = self.model optim = self.optim scaler = self.scar prec = self.args.TRAINER.MAPLE.PREC if prec == "amp": with autocast(): loss,outputs = model(image, label) optim.zero_grad() scaler.scale(loss).backward() scaler.step(optim) scaler.update() else: loss,outputs = model(image, label) optim.zero_grad() loss.backward() optim.step() self.after_loss(outputs, loss, label, real_ind, epoch) self.while_update(outputs, loss, label, epoch, i, self.args.DATALOADER.TRAIN_X.BATCH_SIZE) loss_summary = {"loss": loss.item()} if (i + 1) == self.num_batches: self.sche.step() batch_time.update(time.time() - end) losses.update(loss_summary) meet_freq = (i + 1) % self.args.TRAIN.PRINT_FREQ == 0 only_few_batches = self.num_batches < self.args.TRAIN.PRINT_FREQ if meet_freq or only_few_batches: nb_remain = 0 nb_remain += self.num_batches - i - 1 nb_remain += (self.max_epoch - self.epoch - 1) * self.num_batches eta_seconds = batch_time.avg * nb_remain eta = str(datetime.timedelta(seconds=int(eta_seconds))) info = [] info += [f"epoch [{self.epoch + 1}/{self.max_epoch}]"] info += [f"batch [{i + 1}/{self.num_batches}]"] info += [f"time {batch_time.val:.3f} ({batch_time.avg:.3f})"] info += [f"data {data_time.val:.3f} ({data_time.avg:.3f})"] info += [f"{losses}"] info += [f"lr {optim.param_groups[0]['lr']:.4e}"] info += [f"eta {eta}"] print(" ".join(info)) # n_iter = self.epoch * self.num_batches + i # for name, meter in losses.meters.items(): # self.write_scalar("train/" + name, meter.avg, n_iter) # self.write_scalar("train/lr", self.get_current_lr(), n_iter) end = time.time() return self.finish_train() def run(self): self.train_indx = np.arange(self.n_train) self.before_run() print(f'Start pre-funing CLIP with all datasets by {self.max_epoch} epoch') file_save_name = self.args.DATASET.NAME + '_' + str(self.args.SEED) + '.pth' output_checkpoint_dir = os.path.join('checkpoints', file_save_name) if self.max_epoch > 0: if os.path.exists(output_checkpoint_dir): print(f'The checkpiont exists! Load that shit') ckpt = torch.load(output_checkpoint_dir) self.model.load_state_dict(ckpt) else: for epoch in range(self.epoch,self.max_epoch): # list_of_train_idx = np.random.choice(np.arange(self.n_pretrain if self.if_dst_pretrain else self.n_train), # self.n_pretrain_size, replace=False) self.before_epoch() #PASS self.train(epoch) self.test(epoch) self.after_epoch() torch.save(self.model.state_dict(),output_checkpoint_dir) return self.finish_run() def test(self, epoch): self.model.no_grad = True self.model.eval() correct = 0. total = 0. print('\n=> Testing Tuning Epoch #%d' % epoch) for batch_idx, batch in enumerate(self.test_loader): image, target = batch['img'].cuda(), batch['label'] output = self.model(image, target.cuda()) predicted = torch.max(output.data, 1).indices.cpu() correct += predicted.eq(target).sum().item() total += target.size(0) # if batch_idx % self.args.print_freq == 0: # print('| Test Epoch [%3d/%3d] Iter[%3d/%3d]\t\t Test Acc: %.3f%%' % ( # epoch, self.epochs, batch_idx + 1, (round(len(self.dst_test) * self.args.selection_test_fraction) // # self.args.selection_batch) + 1, loss.item(), # 100. * correct / total)) print(f'| Test Epoch {epoch} Test Acc: {100. * correct / total:.3f}%') self.model.no_grad = False def num_classes_mismatch(self): pass def before_train(self): pass def after_loss(self, outputs, loss, targets, batch_inds, epoch): pass def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size): pass def finish_train(self): pass def before_epoch(self): pass def after_epoch(self): pass def before_run(self): pass def finish_run(self): pass def select(self, **kwargs): selection_result = self.run() return selection_result def select_without_train(self, **kwargs): return self.finish_run() @torch.no_grad() def calcluate_clip_probability(self,batch): input = batch["img"].cuda() self.specific_model = self.specific_model.cuda() image_features = self.specific_model.encode_image(input) image_features = image_features / image_features.norm(dim=-1, keepdim=True) logit_scale = self.specific_model.logit_scale.exp() return logit_scale * image_features @ self.text_feature.t() # using the defined select_dm def select_dm(self,data,ind=None,is_train=None): return select_dm_loader(self.args,data,ind,is_train) def parse_batch_test(self, batch): input = batch["img"] label = batch["label"] input = input.cuda() label = label.cuda() return input, label def parse_batch_train(self, batch): input = batch["img"].cuda() label = batch["label"].cuda() domain = batch["index"].cuda() return input, label, domain def calc_gradient(self, index=None): ''' Calculate gradients matrix on current network for specified training dataset. ''' self.model.eval() data_loader = self.select_dm(self.dst_train, index, is_train=False) # Initialize a matrix to save gradients. # (on cpu) gradients = [] lam = 0.5 for i, batch in enumerate(tqdm(data_loader)): self.optim.zero_grad() image, label = batch['img'].cuda(), batch['label'].cuda() bs_size = image.shape[0] loss, visual_embedding, logit= self.model(image, label, cal_gradient=True) embed_dim = visual_embedding.shape[-1] with torch.no_grad(): bias_parameters_grads = torch.autograd.grad(loss, logit)[0] weight_parameters_grads = visual_embedding.view(bs_size, 1, -1).repeat(1, self.num_classes, 1) * \ bias_parameters_grads.view(bs_size, self.num_classes, 1).repeat(1, 1, embed_dim) # weight_parameters_grads_t = text_embedding.view(bs_size, 1, # -1).repeat(1, self.num_classes, 1) * \ # bias_parameters_grads.view(bs_size, self.num_classes, # 1).repeat(1, 1, embed_dim) # final_weight = torch.abs(weight_parameters_grads-weight_parameters_grads_t) gradients.append(torch.cat([bias_parameters_grads, weight_parameters_grads.flatten(1)], dim=1).cpu().numpy()) gradients = np.concatenate(gradients, axis=0, dtype=np.float32) print('Finish Gradient Calculation') self.model.train() return gradients