From fa3afbcae169399ea91d97b4fbb0dd85469bb3ea Mon Sep 17 00:00:00 2001 From: rain-bus Date: Wed, 25 Feb 2026 21:02:56 +0800 Subject: [PATCH] rename distill variable --- trainers/dzgcoop.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/trainers/dzgcoop.py b/trainers/dzgcoop.py index 9d6aa47..d3784c8 100644 --- a/trainers/dzgcoop.py +++ b/trainers/dzgcoop.py @@ -142,22 +142,22 @@ class VLPromptLearner(nn.Module): embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) self.ZS_image_encoder = clip_model_temp_image.visual # Now pre-compute the frozen VL embeddings from LLM descriptions - all_teacher_features = [] + semantic_guidance_features = [] desc_file = f"./desc/{DESC_LLM}/descriptions_top{DESC_TOPK}/{cfg.DATASET.NAME}.json" with open(desc_file, "r") as f: all_desc = json.load(f) template = CUSTOM_TEMPLATES[cfg.DATASET.NAME] for cls in classnames: - cls_descs = [template.format(cls)[:-1] + f", features with {desc}" for desc in all_desc[cls]] + cls_descs = [template.format(cls)[:-1] + f", {desc}" for desc in all_desc[cls]] cls_token = torch.cat([clip.tokenize(cls_desc) for cls_desc in cls_descs]).cuda() with torch.no_grad(): cls_feature = clip_model_temp.encode_text(cls_token) cls_feature = cls_feature / cls_feature.norm(dim=-1, keepdim=True) cls_feature = torch.mean(cls_feature, dim=0) - all_teacher_features.append(cls_feature) + semantic_guidance_features.append(cls_feature) - self.fixed_embeddings = torch.stack(all_teacher_features) + self.semantic_embeddings = torch.stack(semantic_guidance_features) print(f"Using LLM descriptions from: {desc_file}") # These token vectors will be saved when in save_model(), # but they should be ignored in load_model() as we want to use @@ -238,10 +238,10 @@ class CustomCLIP(nn.Module): text_features_weak = self.text_encoder(prompts_weak, tokenized_prompts) text_features_weak = text_features_weak / text_features_weak.norm(dim=-1, keepdim=True) - fixed_embeddings = self.prompt_learner.fixed_embeddings - fixed_embeddings = fixed_embeddings / fixed_embeddings.norm(dim=-1, keepdim=True) + semantic_embeddings = self.prompt_learner.semantic_embeddings + semantic_embeddings = semantic_embeddings / semantic_embeddings.norm(dim=-1, keepdim=True) - zero_shot_logits = logit_scale * zero_shot_features.cuda() @ fixed_embeddings.half().cuda().t() + zero_shot_logits = logit_scale * zero_shot_features.cuda() @ semantic_embeddings.half().cuda().t() logits_strong = logit_scale * image_features @ text_features_strong.t() logits_weak = logit_scale * image_features @ text_features_weak.t() @@ -255,7 +255,7 @@ class CustomCLIP(nn.Module): if self.prompt_learner.training: loss_ce = F.cross_entropy(logits_final, label) - return loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final + return loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final else: return logits_final @@ -343,7 +343,7 @@ class DZGCoOp(TrainerX): scaler.step(optim) scaler.update() else: - loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zs_image_embedd, image_ft, \ + loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zs_image_embedd, image_ft, \ zero_shot_logits, logits_strong, logits_weak, logits_final = model(image, label) lambda1 = self.cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT @@ -351,17 +351,17 @@ class DZGCoOp(TrainerX): lambda3 = self.cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK L_zvg = F.l1_loss(image_ft, zs_image_embedd.cuda(), reduction='mean') * lambda1 - L_sg_strong = F.l1_loss(text_features_strong, fixed_embeddings.cuda(), reduction='mean') * lambda2 - L_sg_weak = F.l1_loss(text_features_weak, fixed_embeddings.cuda(), reduction='mean') * lambda3 + L_sg_strong = F.l1_loss(text_features_strong, semantic_embeddings.cuda(), reduction='mean') * lambda2 + L_sg_weak = F.l1_loss(text_features_weak, semantic_embeddings.cuda(), reduction='mean') * lambda3 - L_zlg = F.kl_div( + L_zpg = F.kl_div( F.log_softmax(logits_final / 1, dim=1), F.log_softmax(zero_shot_logits / 1, dim=1), reduction='sum', log_target=True ) * (1 * 1) / logits_final.numel() - L_zg = (L_zlg + L_sg_strong + L_sg_weak + L_zvg) + L_zg = (L_zpg + L_sg_strong + L_sg_weak + L_zvg) loss = (loss_ce + L_zg) optim.zero_grad() loss.backward()