Compare commits

1 Commits

Author SHA1 Message Date
1d7d93ede5 Last-k Average 2026-02-07 15:58:51 +08:00
7 changed files with 37 additions and 60 deletions

View File

@@ -39,5 +39,4 @@ TRAINER:
PROMPT_DEPTH_TEXT: 9 PROMPT_DEPTH_TEXT: 9
TEXT_LOSS_WEIGHT: 25 TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10 IMAGE_LOSS_WEIGHT: 10
GPA_MEAN: 15 LAST_K: 5
GPA_STD: 1

View File

@@ -40,5 +40,4 @@ TRAINER:
PROMPT_DEPTH_TEXT: 3 PROMPT_DEPTH_TEXT: 3
TEXT_LOSS_WEIGHT: 25 TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10 IMAGE_LOSS_WEIGHT: 10
GPA_MEAN: 6 LAST_K: 5
GPA_STD: 10

View File

@@ -39,5 +39,4 @@ TRAINER:
PROMPT_DEPTH_TEXT: 3 PROMPT_DEPTH_TEXT: 3
TEXT_LOSS_WEIGHT: 25 TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10 IMAGE_LOSS_WEIGHT: 10
GPA_MEAN: 6 LAST_K: 5
GPA_STD: 10

View File

@@ -11,7 +11,7 @@ Training PromptSRC on ImageNet for 20 epochs takes around 6 hours for a single s
## PromptSRC ## PromptSRC
#### (1) Base-to-Novel class generalization setting #### (1) Base-to-Novel class generalization setting
The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as GPA STD, GPA Mean, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file. The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as LAST_K, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file.
Run the commands below to train PromptSRC on ImageNet. Run the commands below to train PromptSRC on ImageNet.

View File

@@ -1,15 +1,15 @@
seeds=(1 2 3) seeds=(1 2 3)
datasets=( datasets=(
# "ucf101" "ucf101"
# "eurosat" "eurosat"
# "oxford_pets" "oxford_pets"
# "food101" "food101"
# "oxford_flowers" "oxford_flowers"
# "dtd" "dtd"
# "caltech101" "caltech101"
# "fgvc_aircraft" "fgvc_aircraft"
# "stanford_cars" "stanford_cars"
# "sun397" "sun397"
"imagenet" "imagenet"
) )

View File

@@ -122,8 +122,7 @@ def extend_cfg(cfg):
cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight
cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_WEAK = 2.5 # lambda3: weak text constraint weight cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_WEAK = 2.5 # lambda3: weak text constraint weight
cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT = 10 cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT = 10
cfg.TRAINER.PROMPTSRC.GPA_MEAN = 15 cfg.TRAINER.PROMPTSRC.LAST_K = 5
cfg.TRAINER.PROMPTSRC.GPA_STD = 1
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
# Config for independent Vision Language prompting (independent-vlp) # Config for independent Vision Language prompting (independent-vlp)

View File

@@ -311,12 +311,8 @@ class PromptSRC(TrainerX):
# Cosine scheduler # Cosine scheduler
self.total_epochs = cfg.OPTIM.MAX_EPOCH self.total_epochs = cfg.OPTIM.MAX_EPOCH
self.step_counter = 1 self.step_counter = 1
N = cfg.OPTIM.MAX_EPOCH self.max_k = cfg.TRAINER.PROMPTSRC.LAST_K
mean = cfg.TRAINER.PROMPTSRC.GPA_MEAN self.last_k_models = []
stdev = cfg.TRAINER.PROMPTSRC.GPA_STD
gauss = self.get_gauss(mean, stdev)
self.gauss = np.array([gauss(a) for a in range(1, N + 1)])
self.gauss = self.gauss / sum(self.gauss)
self.scaler = GradScaler() if cfg.TRAINER.PROMPTSRC.PREC == "amp" else None self.scaler = GradScaler() if cfg.TRAINER.PROMPTSRC.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is # Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel # big, which slows down the copy operation in DataParallel
@@ -324,8 +320,6 @@ class PromptSRC(TrainerX):
if device_count > 1: if device_count > 1:
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
self.model = nn.DataParallel(self.model) self.model = nn.DataParallel(self.model)
# Keep model with GPA
self.previous_model_gpa = None
def forward_backward(self, batch): def forward_backward(self, batch):
image, label = self.parse_batch_train(batch) image, label = self.parse_batch_train(batch)
@@ -371,45 +365,32 @@ class PromptSRC(TrainerX):
if (self.batch_idx + 1) == self.num_batches: if (self.batch_idx + 1) == self.num_batches:
self.update_lr() self.update_lr()
# Means one epoch is completed, perform GPA
self.step_counter = self.step_counter + 1 self.step_counter = self.step_counter + 1
current_epoch_weight = self.gauss[self.step_counter - 2]
current_model_weights = copy.deepcopy(model.state_dict()) current_model_weights = copy.deepcopy(model.state_dict())
weighted_state_dict = self.state_dict_weighting(current_model_weights, current_epoch_weight) for key in current_model_weights:
if self.previous_model_gpa is None: current_model_weights[key] = current_model_weights[key].cpu()
self.previous_model_gpa = weighted_state_dict self.last_k_models.append(current_model_weights)
else: if len(self.last_k_models) > self.max_k:
self.previous_model_gpa = self.state_dict_add(weighted_state_dict, self.previous_model_gpa) self.last_k_models.pop(0)
torch.cuda.empty_cache()
if self.step_counter == self.model.total_epochs + 1: if self.step_counter == self.model.total_epochs + 1:
print("Using GPA model for final inference...") print(f"Using Last-K Averaging (K={len(self.last_k_models)}) model for final inference...")
model.load_state_dict(self.previous_model_gpa) averaged_state_dict = self._average_last_k_models()
self.model.load_state_dict(self.previous_model_gpa) for key in averaged_state_dict:
averaged_state_dict[key] = averaged_state_dict[key].cuda()
model.load_state_dict(averaged_state_dict)
self.model.load_state_dict(averaged_state_dict)
return loss_summary return loss_summary
def state_dict_weighting(self, main_dict, weightage, prompt_only=False): def _average_last_k_models(self):
# Average all parameters if not self.last_k_models:
updated_dict = copy.deepcopy(main_dict) return {}
if not prompt_only: averaged_dict = {}
for key in main_dict: for key in self.last_k_models[0]:
updated_dict[key] = main_dict[key] * weightage stacked = torch.stack([model_state[key] for model_state in self.last_k_models])
return updated_dict averaged_dict[key] = torch.mean(stacked, dim=0)
else: return averaged_dict
return main_dict * weightage
def state_dict_add(self, dict1, dict2, prompt_only=False):
# Average all parameters
if not prompt_only:
modified_dict = dict2
for key in dict1:
modified_dict[key] = (modified_dict[key] + dict1[key])
return modified_dict
else:
return dict1 + dict2
def get_gauss(self, mu, sigma):
gauss = lambda x: (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
return gauss
def parse_batch_train(self, batch): def parse_batch_train(self, batch):
input = batch["img"] input = batch["img"]