diff --git a/configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml b/configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml index fbe381d..063f94a 100644 --- a/configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml +++ b/configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml @@ -39,5 +39,4 @@ TRAINER: PROMPT_DEPTH_TEXT: 9 TEXT_LOSS_WEIGHT: 25 IMAGE_LOSS_WEIGHT: 10 - GPA_MEAN: 15 - GPA_STD: 1 + LAST_K: 5 diff --git a/configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets.yaml b/configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets.yaml index 895912a..d620a60 100644 --- a/configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets.yaml +++ b/configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets.yaml @@ -40,5 +40,4 @@ TRAINER: PROMPT_DEPTH_TEXT: 3 TEXT_LOSS_WEIGHT: 25 IMAGE_LOSS_WEIGHT: 10 - GPA_MEAN: 6 - GPA_STD: 10 + LAST_K: 5 diff --git a/configs/trainers/PromptSRC/vit_b16_c2_ep5_batch4_4+4ctx_cross_datasets.yaml b/configs/trainers/PromptSRC/vit_b16_c2_ep5_batch4_4+4ctx_cross_datasets.yaml index 2dbbacd..520f66b 100644 --- a/configs/trainers/PromptSRC/vit_b16_c2_ep5_batch4_4+4ctx_cross_datasets.yaml +++ b/configs/trainers/PromptSRC/vit_b16_c2_ep5_batch4_4+4ctx_cross_datasets.yaml @@ -39,5 +39,4 @@ TRAINER: PROMPT_DEPTH_TEXT: 3 TEXT_LOSS_WEIGHT: 25 IMAGE_LOSS_WEIGHT: 10 - GPA_MEAN: 6 - GPA_STD: 10 + LAST_K: 5 diff --git a/docs/TRAIN.md b/docs/TRAIN.md index d55fb2c..47c3e05 100644 --- a/docs/TRAIN.md +++ b/docs/TRAIN.md @@ -11,7 +11,7 @@ Training PromptSRC on ImageNet for 20 epochs takes around 6 hours for a single s ## PromptSRC #### (1) Base-to-Novel class generalization setting -The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as GPA STD, GPA Mean, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file. +The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as LAST_K, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file. Run the commands below to train PromptSRC on ImageNet. diff --git a/scripts/promptsrc/base2new_all.sh b/scripts/promptsrc/base2new_all.sh index 65fdd42..6429188 100644 --- a/scripts/promptsrc/base2new_all.sh +++ b/scripts/promptsrc/base2new_all.sh @@ -1,15 +1,15 @@ seeds=(1 2 3) datasets=( - # "ucf101" - # "eurosat" - # "oxford_pets" - # "food101" - # "oxford_flowers" - # "dtd" - # "caltech101" - # "fgvc_aircraft" - # "stanford_cars" - # "sun397" + "ucf101" + "eurosat" + "oxford_pets" + "food101" + "oxford_flowers" + "dtd" + "caltech101" + "fgvc_aircraft" + "stanford_cars" + "sun397" "imagenet" ) diff --git a/train.py b/train.py index 09fab30..bd19a94 100644 --- a/train.py +++ b/train.py @@ -122,8 +122,7 @@ def extend_cfg(cfg): cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_WEAK = 2.5 # lambda3: weak text constraint weight cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT = 10 - cfg.TRAINER.PROMPTSRC.GPA_MEAN = 15 - cfg.TRAINER.PROMPTSRC.GPA_STD = 1 + cfg.TRAINER.PROMPTSRC.LAST_K = 5 cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new # Config for independent Vision Language prompting (independent-vlp) diff --git a/trainers/promptsrc.py b/trainers/promptsrc.py index f12e9cd..db817c0 100644 --- a/trainers/promptsrc.py +++ b/trainers/promptsrc.py @@ -311,12 +311,8 @@ class PromptSRC(TrainerX): # Cosine scheduler self.total_epochs = cfg.OPTIM.MAX_EPOCH self.step_counter = 1 - N = cfg.OPTIM.MAX_EPOCH - mean = cfg.TRAINER.PROMPTSRC.GPA_MEAN - stdev = cfg.TRAINER.PROMPTSRC.GPA_STD - gauss = self.get_gauss(mean, stdev) - self.gauss = np.array([gauss(a) for a in range(1, N + 1)]) - self.gauss = self.gauss / sum(self.gauss) + self.max_k = cfg.TRAINER.PROMPTSRC.LAST_K + self.last_k_models = [] self.scaler = GradScaler() if cfg.TRAINER.PROMPTSRC.PREC == "amp" else None # Note that multi-gpu training could be slow because CLIP's size is # big, which slows down the copy operation in DataParallel @@ -324,8 +320,6 @@ class PromptSRC(TrainerX): if device_count > 1: print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") self.model = nn.DataParallel(self.model) - # Keep model with GPA - self.previous_model_gpa = None def forward_backward(self, batch): image, label = self.parse_batch_train(batch) @@ -371,45 +365,32 @@ class PromptSRC(TrainerX): if (self.batch_idx + 1) == self.num_batches: self.update_lr() - # Means one epoch is completed, perform GPA self.step_counter = self.step_counter + 1 - current_epoch_weight = self.gauss[self.step_counter - 2] current_model_weights = copy.deepcopy(model.state_dict()) - weighted_state_dict = self.state_dict_weighting(current_model_weights, current_epoch_weight) - if self.previous_model_gpa is None: - self.previous_model_gpa = weighted_state_dict - else: - self.previous_model_gpa = self.state_dict_add(weighted_state_dict, self.previous_model_gpa) + for key in current_model_weights: + current_model_weights[key] = current_model_weights[key].cpu() + self.last_k_models.append(current_model_weights) + if len(self.last_k_models) > self.max_k: + self.last_k_models.pop(0) + torch.cuda.empty_cache() if self.step_counter == self.model.total_epochs + 1: - print("Using GPA model for final inference...") - model.load_state_dict(self.previous_model_gpa) - self.model.load_state_dict(self.previous_model_gpa) + print(f"Using Last-K Averaging (K={len(self.last_k_models)}) model for final inference...") + averaged_state_dict = self._average_last_k_models() + for key in averaged_state_dict: + averaged_state_dict[key] = averaged_state_dict[key].cuda() + model.load_state_dict(averaged_state_dict) + self.model.load_state_dict(averaged_state_dict) return loss_summary - def state_dict_weighting(self, main_dict, weightage, prompt_only=False): - # Average all parameters - updated_dict = copy.deepcopy(main_dict) - if not prompt_only: - for key in main_dict: - updated_dict[key] = main_dict[key] * weightage - return updated_dict - else: - return main_dict * weightage - - def state_dict_add(self, dict1, dict2, prompt_only=False): - # Average all parameters - if not prompt_only: - modified_dict = dict2 - for key in dict1: - modified_dict[key] = (modified_dict[key] + dict1[key]) - return modified_dict - else: - return dict1 + dict2 - - def get_gauss(self, mu, sigma): - gauss = lambda x: (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2) - return gauss + def _average_last_k_models(self): + if not self.last_k_models: + return {} + averaged_dict = {} + for key in self.last_k_models[0]: + stacked = torch.stack([model_state[key] for model_state in self.last_k_models]) + averaged_dict[key] = torch.mean(stacked, dim=0) + return averaged_dict def parse_batch_train(self, batch): input = batch["img"]