rename ewa

This commit is contained in:
2026-02-25 17:36:27 +08:00
parent 61864e192a
commit f26f793937
7 changed files with 41 additions and 39 deletions

View File

@@ -37,8 +37,8 @@ TRAINER:
PREC: "fp16" PREC: "fp16"
PROMPT_DEPTH_VISION: 9 PROMPT_DEPTH_VISION: 9
PROMPT_DEPTH_TEXT: 9 PROMPT_DEPTH_TEXT: 9
IMAGE_LOSS_WEIGHT: 10 IMAGE_LOSS_WEIGHT: 8
TEXT_LOSS_WEIGHT_STRONG: 10 TEXT_LOSS_WEIGHT_STRONG: 8
TEXT_LOSS_WEIGHT_WEAK: 25 TEXT_LOSS_WEIGHT_WEAK: 24
GPA_MEAN: 15 EWA_MEAN: 15
GPA_STD: 1 EWA_STD: 1

View File

@@ -40,5 +40,5 @@ TRAINER:
PROMPT_DEPTH_TEXT: 3 PROMPT_DEPTH_TEXT: 3
TEXT_LOSS_WEIGHT: 25 TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10 IMAGE_LOSS_WEIGHT: 10
GPA_MEAN: 6 EWA_MEAN: 6
GPA_STD: 10 EWA_STD: 10

View File

@@ -39,5 +39,5 @@ TRAINER:
PROMPT_DEPTH_TEXT: 3 PROMPT_DEPTH_TEXT: 3
TEXT_LOSS_WEIGHT: 25 TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10 IMAGE_LOSS_WEIGHT: 10
GPA_MEAN: 6 EWA_MEAN: 6
GPA_STD: 10 EWA_STD: 10

View File

@@ -11,7 +11,7 @@ Training PromptSRC on ImageNet for 20 epochs takes around 6 hours for a single s
## PromptSRC ## PromptSRC
#### (1) Base-to-Novel class generalization setting #### (1) Base-to-Novel class generalization setting
The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as GPA STD, GPA Mean, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file. The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as EWA STD, EWA Mean, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file.
Run the commands below to train PromptSRC on ImageNet. Run the commands below to train PromptSRC on ImageNet.

View File

@@ -9,7 +9,7 @@ datasets=(
"caltech101" "caltech101"
"fgvc_aircraft" "fgvc_aircraft"
"stanford_cars" "stanford_cars"
# "sun397" "sun397"
# "imagenet" # "imagenet"
) )

View File

@@ -121,8 +121,8 @@ def extend_cfg(cfg):
cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight
cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK = 10 # lambda3: weak text constraint weight cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK = 10 # lambda3: weak text constraint weight
cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT = 10 cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT = 10
cfg.TRAINER.DZGCOOP.GPA_MEAN = 15 cfg.TRAINER.DZGCOOP.EWA_MEAN = 15
cfg.TRAINER.DZGCOOP.GPA_STD = 1 cfg.TRAINER.DZGCOOP.EWA_STD = 1
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
# Config for independent Vision Language prompting (independent-vlp) # Config for independent Vision Language prompting (independent-vlp)

View File

@@ -149,7 +149,7 @@ class VLPromptLearner(nn.Module):
template = CUSTOM_TEMPLATES[cfg.DATASET.NAME] template = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
for cls in classnames: for cls in classnames:
cls_descs = [template.format(cls)[:-1] + f", {desc}" for desc in all_desc[cls]] cls_descs = [template.format(cls)[:-1] + f", features with {desc}" for desc in all_desc[cls]]
cls_token = torch.cat([clip.tokenize(cls_desc) for cls_desc in cls_descs]).cuda() cls_token = torch.cat([clip.tokenize(cls_desc) for cls_desc in cls_descs]).cuda()
with torch.no_grad(): with torch.no_grad():
cls_feature = clip_model_temp.encode_text(cls_token) cls_feature = clip_model_temp.encode_text(cls_token)
@@ -312,11 +312,11 @@ class DZGCoOp(TrainerX):
self.total_epochs = cfg.OPTIM.MAX_EPOCH self.total_epochs = cfg.OPTIM.MAX_EPOCH
self.step_counter = 1 self.step_counter = 1
N = cfg.OPTIM.MAX_EPOCH N = cfg.OPTIM.MAX_EPOCH
mean = cfg.TRAINER.DZGCOOP.GPA_MEAN mean = cfg.TRAINER.DZGCOOP.EWA_MEAN
stdev = cfg.TRAINER.DZGCOOP.GPA_STD stdev = cfg.TRAINER.DZGCOOP.EWA_STD
gauss = self.get_gauss(mean, stdev) normal = self.get_normal(mean, stdev)
self.gauss = np.array([gauss(a) for a in range(1, N + 1)]) self.normal_weights = np.array([normal(a) for a in range(1, N + 1)])
self.gauss = self.gauss / sum(self.gauss) self.normal_weights = self.normal_weights / sum(self.normal_weights)
self.scaler = GradScaler() if cfg.TRAINER.DZGCOOP.PREC == "amp" else None self.scaler = GradScaler() if cfg.TRAINER.DZGCOOP.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is # Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel # big, which slows down the copy operation in DataParallel
@@ -324,8 +324,8 @@ class DZGCoOp(TrainerX):
if device_count > 1: if device_count > 1:
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
self.model = nn.DataParallel(self.model) self.model = nn.DataParallel(self.model)
# Keep model with GPA # Keep model with EWA
self.previous_model_gpa = None self.previous_model_ewa = None
def forward_backward(self, batch): def forward_backward(self, batch):
image, label = self.parse_batch_train(batch) image, label = self.parse_batch_train(batch)
@@ -354,14 +354,14 @@ class DZGCoOp(TrainerX):
L_sg_strong = F.l1_loss(text_features_strong, fixed_embeddings.cuda(), reduction='mean') * lambda2 L_sg_strong = F.l1_loss(text_features_strong, fixed_embeddings.cuda(), reduction='mean') * lambda2
L_sg_weak = F.l1_loss(text_features_weak, fixed_embeddings.cuda(), reduction='mean') * lambda3 L_sg_weak = F.l1_loss(text_features_weak, fixed_embeddings.cuda(), reduction='mean') * lambda3
L_zpg = F.kl_div( L_zlg = F.kl_div(
F.log_softmax(logits_final / 1, dim=1), F.log_softmax(logits_final / 1, dim=1),
F.log_softmax(zero_shot_logits / 1, dim=1), F.log_softmax(zero_shot_logits / 1, dim=1),
reduction='sum', reduction='sum',
log_target=True log_target=True
) * (1 * 1) / logits_final.numel() ) * (1 * 1) / logits_final.numel()
L_zg = (L_zpg + L_sg_strong + L_sg_weak + L_zvg) L_zg = (L_zlg + L_sg_strong + L_sg_weak + L_zvg)
loss = (loss_ce + L_zg) loss = (loss_ce + L_zg)
optim.zero_grad() optim.zero_grad()
loss.backward() loss.backward()
@@ -371,20 +371,22 @@ class DZGCoOp(TrainerX):
if (self.batch_idx + 1) == self.num_batches: if (self.batch_idx + 1) == self.num_batches:
self.update_lr() self.update_lr()
# Means one epoch is completed, perform GPA # Means one epoch is completed, perform EWA
self.step_counter = self.step_counter + 1 self.step_counter = self.step_counter + 1
current_epoch_weight = self.gauss[self.step_counter - 2] current_epoch_weight = self.normal_weights[self.step_counter - 2]
current_model_weights = copy.deepcopy(model.state_dict()) current_model_weights = copy.deepcopy(model.state_dict())
for key in current_model_weights:
current_model_weights[key] = current_model_weights[key].cpu()
weighted_state_dict = self.state_dict_weighting(current_model_weights, current_epoch_weight) weighted_state_dict = self.state_dict_weighting(current_model_weights, current_epoch_weight)
if self.previous_model_gpa is None: if self.previous_model_ewa is None:
self.previous_model_gpa = weighted_state_dict self.previous_model_ewa = weighted_state_dict
else: else:
self.previous_model_gpa = self.state_dict_add(weighted_state_dict, self.previous_model_gpa) self.previous_model_ewa = self.state_dict_add(weighted_state_dict, self.previous_model_ewa)
if self.step_counter == self.model.total_epochs + 1: if self.step_counter == self.model.total_epochs + 1:
print("Using GPA model for final inference...") print("Using EWA model for final inference...")
model.load_state_dict(self.previous_model_gpa) model.load_state_dict(self.previous_model_ewa)
self.model.load_state_dict(self.previous_model_gpa) self.model.load_state_dict(self.previous_model_ewa)
return loss_summary return loss_summary
def state_dict_weighting(self, main_dict, weightage, prompt_only=False): def state_dict_weighting(self, main_dict, weightage, prompt_only=False):
@@ -392,24 +394,24 @@ class DZGCoOp(TrainerX):
updated_dict = copy.deepcopy(main_dict) updated_dict = copy.deepcopy(main_dict)
if not prompt_only: if not prompt_only:
for key in main_dict: for key in main_dict:
updated_dict[key] = main_dict[key] * weightage updated_dict[key] = main_dict[key].cpu() * weightage
return updated_dict return updated_dict
else: else:
return main_dict * weightage return main_dict.cpu() * weightage
def state_dict_add(self, dict1, dict2, prompt_only=False): def state_dict_add(self, dict1, dict2, prompt_only=False):
# Average all parameters # Average all parameters
if not prompt_only: if not prompt_only:
modified_dict = dict2 modified_dict = dict2
for key in dict1: for key in dict1:
modified_dict[key] = (modified_dict[key] + dict1[key]) modified_dict[key] = modified_dict[key].cpu() + dict1[key].cpu()
return modified_dict return modified_dict
else: else:
return dict1 + dict2 return dict1.cpu() + dict2.cpu()
def get_gauss(self, mu, sigma): def get_normal(self, mu, sigma):
gauss = lambda x: (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2) normal = lambda x: (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
return gauss return normal
def parse_batch_train(self, batch): def parse_batch_train(self, batch):
input = batch["img"] input = batch["img"]
@@ -456,4 +458,4 @@ class DZGCoOp(TrainerX):
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False # set strict=False
self._models[name].load_state_dict(state_dict, strict=False) self._models[name].load_state_dict(state_dict, strict=False)