rename ewa

This commit is contained in:
2026-02-25 17:36:27 +08:00
parent 61864e192a
commit f26f793937
7 changed files with 41 additions and 39 deletions

View File

@@ -37,8 +37,8 @@ TRAINER:
PREC: "fp16"
PROMPT_DEPTH_VISION: 9
PROMPT_DEPTH_TEXT: 9
IMAGE_LOSS_WEIGHT: 10
TEXT_LOSS_WEIGHT_STRONG: 10
TEXT_LOSS_WEIGHT_WEAK: 25
GPA_MEAN: 15
GPA_STD: 1
IMAGE_LOSS_WEIGHT: 8
TEXT_LOSS_WEIGHT_STRONG: 8
TEXT_LOSS_WEIGHT_WEAK: 24
EWA_MEAN: 15
EWA_STD: 1

View File

@@ -40,5 +40,5 @@ TRAINER:
PROMPT_DEPTH_TEXT: 3
TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10
GPA_MEAN: 6
GPA_STD: 10
EWA_MEAN: 6
EWA_STD: 10

View File

@@ -39,5 +39,5 @@ TRAINER:
PROMPT_DEPTH_TEXT: 3
TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10
GPA_MEAN: 6
GPA_STD: 10
EWA_MEAN: 6
EWA_STD: 10

View File

@@ -11,7 +11,7 @@ Training PromptSRC on ImageNet for 20 epochs takes around 6 hours for a single s
## PromptSRC
#### (1) Base-to-Novel class generalization setting
The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as GPA STD, GPA Mean, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file.
The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as EWA STD, EWA Mean, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file.
Run the commands below to train PromptSRC on ImageNet.

View File

@@ -9,7 +9,7 @@ datasets=(
"caltech101"
"fgvc_aircraft"
"stanford_cars"
# "sun397"
"sun397"
# "imagenet"
)

View File

@@ -121,8 +121,8 @@ def extend_cfg(cfg):
cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight
cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK = 10 # lambda3: weak text constraint weight
cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT = 10
cfg.TRAINER.DZGCOOP.GPA_MEAN = 15
cfg.TRAINER.DZGCOOP.GPA_STD = 1
cfg.TRAINER.DZGCOOP.EWA_MEAN = 15
cfg.TRAINER.DZGCOOP.EWA_STD = 1
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
# Config for independent Vision Language prompting (independent-vlp)

View File

@@ -149,7 +149,7 @@ class VLPromptLearner(nn.Module):
template = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
for cls in classnames:
cls_descs = [template.format(cls)[:-1] + f", {desc}" for desc in all_desc[cls]]
cls_descs = [template.format(cls)[:-1] + f", features with {desc}" for desc in all_desc[cls]]
cls_token = torch.cat([clip.tokenize(cls_desc) for cls_desc in cls_descs]).cuda()
with torch.no_grad():
cls_feature = clip_model_temp.encode_text(cls_token)
@@ -312,11 +312,11 @@ class DZGCoOp(TrainerX):
self.total_epochs = cfg.OPTIM.MAX_EPOCH
self.step_counter = 1
N = cfg.OPTIM.MAX_EPOCH
mean = cfg.TRAINER.DZGCOOP.GPA_MEAN
stdev = cfg.TRAINER.DZGCOOP.GPA_STD
gauss = self.get_gauss(mean, stdev)
self.gauss = np.array([gauss(a) for a in range(1, N + 1)])
self.gauss = self.gauss / sum(self.gauss)
mean = cfg.TRAINER.DZGCOOP.EWA_MEAN
stdev = cfg.TRAINER.DZGCOOP.EWA_STD
normal = self.get_normal(mean, stdev)
self.normal_weights = np.array([normal(a) for a in range(1, N + 1)])
self.normal_weights = self.normal_weights / sum(self.normal_weights)
self.scaler = GradScaler() if cfg.TRAINER.DZGCOOP.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel
@@ -324,8 +324,8 @@ class DZGCoOp(TrainerX):
if device_count > 1:
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
self.model = nn.DataParallel(self.model)
# Keep model with GPA
self.previous_model_gpa = None
# Keep model with EWA
self.previous_model_ewa = None
def forward_backward(self, batch):
image, label = self.parse_batch_train(batch)
@@ -354,14 +354,14 @@ class DZGCoOp(TrainerX):
L_sg_strong = F.l1_loss(text_features_strong, fixed_embeddings.cuda(), reduction='mean') * lambda2
L_sg_weak = F.l1_loss(text_features_weak, fixed_embeddings.cuda(), reduction='mean') * lambda3
L_zpg = F.kl_div(
L_zlg = F.kl_div(
F.log_softmax(logits_final / 1, dim=1),
F.log_softmax(zero_shot_logits / 1, dim=1),
reduction='sum',
log_target=True
) * (1 * 1) / logits_final.numel()
L_zg = (L_zpg + L_sg_strong + L_sg_weak + L_zvg)
L_zg = (L_zlg + L_sg_strong + L_sg_weak + L_zvg)
loss = (loss_ce + L_zg)
optim.zero_grad()
loss.backward()
@@ -371,20 +371,22 @@ class DZGCoOp(TrainerX):
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
# Means one epoch is completed, perform GPA
# Means one epoch is completed, perform EWA
self.step_counter = self.step_counter + 1
current_epoch_weight = self.gauss[self.step_counter - 2]
current_epoch_weight = self.normal_weights[self.step_counter - 2]
current_model_weights = copy.deepcopy(model.state_dict())
for key in current_model_weights:
current_model_weights[key] = current_model_weights[key].cpu()
weighted_state_dict = self.state_dict_weighting(current_model_weights, current_epoch_weight)
if self.previous_model_gpa is None:
self.previous_model_gpa = weighted_state_dict
if self.previous_model_ewa is None:
self.previous_model_ewa = weighted_state_dict
else:
self.previous_model_gpa = self.state_dict_add(weighted_state_dict, self.previous_model_gpa)
self.previous_model_ewa = self.state_dict_add(weighted_state_dict, self.previous_model_ewa)
if self.step_counter == self.model.total_epochs + 1:
print("Using GPA model for final inference...")
model.load_state_dict(self.previous_model_gpa)
self.model.load_state_dict(self.previous_model_gpa)
print("Using EWA model for final inference...")
model.load_state_dict(self.previous_model_ewa)
self.model.load_state_dict(self.previous_model_ewa)
return loss_summary
def state_dict_weighting(self, main_dict, weightage, prompt_only=False):
@@ -392,24 +394,24 @@ class DZGCoOp(TrainerX):
updated_dict = copy.deepcopy(main_dict)
if not prompt_only:
for key in main_dict:
updated_dict[key] = main_dict[key] * weightage
updated_dict[key] = main_dict[key].cpu() * weightage
return updated_dict
else:
return main_dict * weightage
return main_dict.cpu() * weightage
def state_dict_add(self, dict1, dict2, prompt_only=False):
# Average all parameters
if not prompt_only:
modified_dict = dict2
for key in dict1:
modified_dict[key] = (modified_dict[key] + dict1[key])
modified_dict[key] = modified_dict[key].cpu() + dict1[key].cpu()
return modified_dict
else:
return dict1 + dict2
return dict1.cpu() + dict2.cpu()
def get_gauss(self, mu, sigma):
gauss = lambda x: (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
return gauss
def get_normal(self, mu, sigma):
normal = lambda x: (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
return normal
def parse_batch_train(self, batch):
input = batch["img"]
@@ -456,4 +458,4 @@ class DZGCoOp(TrainerX):
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False
self._models[name].load_state_dict(state_dict, strict=False)
self._models[name].load_state_dict(state_dict, strict=False)