Compare commits

...

2 Commits

Author SHA1 Message Date
984ce9f4bb rename distill variable 2026-02-25 21:02:56 +08:00
f26f793937 rename ewa 2026-02-25 17:36:27 +08:00
7 changed files with 50 additions and 48 deletions

View File

@@ -37,8 +37,8 @@ TRAINER:
PREC: "fp16"
PROMPT_DEPTH_VISION: 9
PROMPT_DEPTH_TEXT: 9
IMAGE_LOSS_WEIGHT: 10
TEXT_LOSS_WEIGHT_STRONG: 10
TEXT_LOSS_WEIGHT_WEAK: 25
GPA_MEAN: 15
GPA_STD: 1
IMAGE_LOSS_WEIGHT: 8
TEXT_LOSS_WEIGHT_STRONG: 8
TEXT_LOSS_WEIGHT_WEAK: 24
EWA_MEAN: 15
EWA_STD: 1

View File

@@ -40,5 +40,5 @@ TRAINER:
PROMPT_DEPTH_TEXT: 3
TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10
GPA_MEAN: 6
GPA_STD: 10
EWA_MEAN: 6
EWA_STD: 10

View File

@@ -39,5 +39,5 @@ TRAINER:
PROMPT_DEPTH_TEXT: 3
TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10
GPA_MEAN: 6
GPA_STD: 10
EWA_MEAN: 6
EWA_STD: 10

View File

@@ -11,7 +11,7 @@ Training PromptSRC on ImageNet for 20 epochs takes around 6 hours for a single s
## PromptSRC
#### (1) Base-to-Novel class generalization setting
The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as GPA STD, GPA Mean, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file.
The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as EWA STD, EWA Mean, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file.
Run the commands below to train PromptSRC on ImageNet.

View File

@@ -9,7 +9,7 @@ datasets=(
"caltech101"
"fgvc_aircraft"
"stanford_cars"
# "sun397"
"sun397"
# "imagenet"
)

View File

@@ -121,8 +121,8 @@ def extend_cfg(cfg):
cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight
cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK = 10 # lambda3: weak text constraint weight
cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT = 10
cfg.TRAINER.DZGCOOP.GPA_MEAN = 15
cfg.TRAINER.DZGCOOP.GPA_STD = 1
cfg.TRAINER.DZGCOOP.EWA_MEAN = 15
cfg.TRAINER.DZGCOOP.EWA_STD = 1
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
# Config for independent Vision Language prompting (independent-vlp)

View File

@@ -142,7 +142,7 @@ class VLPromptLearner(nn.Module):
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
self.ZS_image_encoder = clip_model_temp_image.visual
# Now pre-compute the frozen VL embeddings from LLM descriptions
all_teacher_features = []
semantic_guidance_features = []
desc_file = f"./desc/{DESC_LLM}/descriptions_top{DESC_TOPK}/{cfg.DATASET.NAME}.json"
with open(desc_file, "r") as f:
all_desc = json.load(f)
@@ -155,9 +155,9 @@ class VLPromptLearner(nn.Module):
cls_feature = clip_model_temp.encode_text(cls_token)
cls_feature = cls_feature / cls_feature.norm(dim=-1, keepdim=True)
cls_feature = torch.mean(cls_feature, dim=0)
all_teacher_features.append(cls_feature)
semantic_guidance_features.append(cls_feature)
self.fixed_embeddings = torch.stack(all_teacher_features)
self.semantic_embeddings = torch.stack(semantic_guidance_features)
print(f"Using LLM descriptions from: {desc_file}")
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
@@ -238,10 +238,10 @@ class CustomCLIP(nn.Module):
text_features_weak = self.text_encoder(prompts_weak, tokenized_prompts)
text_features_weak = text_features_weak / text_features_weak.norm(dim=-1, keepdim=True)
fixed_embeddings = self.prompt_learner.fixed_embeddings
fixed_embeddings = fixed_embeddings / fixed_embeddings.norm(dim=-1, keepdim=True)
semantic_embeddings = self.prompt_learner.semantic_embeddings
semantic_embeddings = semantic_embeddings / semantic_embeddings.norm(dim=-1, keepdim=True)
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ fixed_embeddings.half().cuda().t()
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ semantic_embeddings.half().cuda().t()
logits_strong = logit_scale * image_features @ text_features_strong.t()
logits_weak = logit_scale * image_features @ text_features_weak.t()
@@ -255,7 +255,7 @@ class CustomCLIP(nn.Module):
if self.prompt_learner.training:
loss_ce = F.cross_entropy(logits_final, label)
return loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final
return loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final
else:
return logits_final
@@ -312,11 +312,11 @@ class DZGCoOp(TrainerX):
self.total_epochs = cfg.OPTIM.MAX_EPOCH
self.step_counter = 1
N = cfg.OPTIM.MAX_EPOCH
mean = cfg.TRAINER.DZGCOOP.GPA_MEAN
stdev = cfg.TRAINER.DZGCOOP.GPA_STD
gauss = self.get_gauss(mean, stdev)
self.gauss = np.array([gauss(a) for a in range(1, N + 1)])
self.gauss = self.gauss / sum(self.gauss)
mean = cfg.TRAINER.DZGCOOP.EWA_MEAN
stdev = cfg.TRAINER.DZGCOOP.EWA_STD
normal = self.get_normal(mean, stdev)
self.normal_weights = np.array([normal(a) for a in range(1, N + 1)])
self.normal_weights = self.normal_weights / sum(self.normal_weights)
self.scaler = GradScaler() if cfg.TRAINER.DZGCOOP.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel
@@ -324,8 +324,8 @@ class DZGCoOp(TrainerX):
if device_count > 1:
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
self.model = nn.DataParallel(self.model)
# Keep model with GPA
self.previous_model_gpa = None
# Keep model with EWA
self.previous_model_ewa = None
def forward_backward(self, batch):
image, label = self.parse_batch_train(batch)
@@ -343,7 +343,7 @@ class DZGCoOp(TrainerX):
scaler.step(optim)
scaler.update()
else:
loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zs_image_embedd, image_ft, \
loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zs_image_embedd, image_ft, \
zero_shot_logits, logits_strong, logits_weak, logits_final = model(image, label)
lambda1 = self.cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT
@@ -351,17 +351,17 @@ class DZGCoOp(TrainerX):
lambda3 = self.cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK
L_zvg = F.l1_loss(image_ft, zs_image_embedd.cuda(), reduction='mean') * lambda1
L_sg_strong = F.l1_loss(text_features_strong, fixed_embeddings.cuda(), reduction='mean') * lambda2
L_sg_weak = F.l1_loss(text_features_weak, fixed_embeddings.cuda(), reduction='mean') * lambda3
L_sg_strong = F.l1_loss(text_features_strong, semantic_embeddings.cuda(), reduction='mean') * lambda2
L_sg_weak = F.l1_loss(text_features_weak, semantic_embeddings.cuda(), reduction='mean') * lambda3
L_zpg = F.kl_div(
L_zlg = F.kl_div(
F.log_softmax(logits_final / 1, dim=1),
F.log_softmax(zero_shot_logits / 1, dim=1),
reduction='sum',
log_target=True
) * (1 * 1) / logits_final.numel()
L_zg = (L_zpg + L_sg_strong + L_sg_weak + L_zvg)
L_zg = (L_zlg + L_sg_strong + L_sg_weak + L_zvg)
loss = (loss_ce + L_zg)
optim.zero_grad()
loss.backward()
@@ -371,20 +371,22 @@ class DZGCoOp(TrainerX):
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
# Means one epoch is completed, perform GPA
# Means one epoch is completed, perform EWA
self.step_counter = self.step_counter + 1
current_epoch_weight = self.gauss[self.step_counter - 2]
current_epoch_weight = self.normal_weights[self.step_counter - 2]
current_model_weights = copy.deepcopy(model.state_dict())
for key in current_model_weights:
current_model_weights[key] = current_model_weights[key].cpu()
weighted_state_dict = self.state_dict_weighting(current_model_weights, current_epoch_weight)
if self.previous_model_gpa is None:
self.previous_model_gpa = weighted_state_dict
if self.previous_model_ewa is None:
self.previous_model_ewa = weighted_state_dict
else:
self.previous_model_gpa = self.state_dict_add(weighted_state_dict, self.previous_model_gpa)
self.previous_model_ewa = self.state_dict_add(weighted_state_dict, self.previous_model_ewa)
if self.step_counter == self.model.total_epochs + 1:
print("Using GPA model for final inference...")
model.load_state_dict(self.previous_model_gpa)
self.model.load_state_dict(self.previous_model_gpa)
print("Using EWA model for final inference...")
model.load_state_dict(self.previous_model_ewa)
self.model.load_state_dict(self.previous_model_ewa)
return loss_summary
def state_dict_weighting(self, main_dict, weightage, prompt_only=False):
@@ -392,24 +394,24 @@ class DZGCoOp(TrainerX):
updated_dict = copy.deepcopy(main_dict)
if not prompt_only:
for key in main_dict:
updated_dict[key] = main_dict[key] * weightage
updated_dict[key] = main_dict[key].cpu() * weightage
return updated_dict
else:
return main_dict * weightage
return main_dict.cpu() * weightage
def state_dict_add(self, dict1, dict2, prompt_only=False):
# Average all parameters
if not prompt_only:
modified_dict = dict2
for key in dict1:
modified_dict[key] = (modified_dict[key] + dict1[key])
modified_dict[key] = modified_dict[key].cpu() + dict1[key].cpu()
return modified_dict
else:
return dict1 + dict2
return dict1.cpu() + dict2.cpu()
def get_gauss(self, mu, sigma):
gauss = lambda x: (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
return gauss
def get_normal(self, mu, sigma):
normal = lambda x: (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
return normal
def parse_batch_train(self, batch):
input = batch["img"]
@@ -456,4 +458,4 @@ class DZGCoOp(TrainerX):
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False
self._models[name].load_state_dict(state_dict, strict=False)
self._models[name].load_state_dict(state_dict, strict=False)