rename distill variable

This commit is contained in:
2026-02-25 21:02:56 +08:00
parent f26f793937
commit fa3afbcae1

View File

@@ -142,22 +142,22 @@ class VLPromptLearner(nn.Module):
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
self.ZS_image_encoder = clip_model_temp_image.visual
# Now pre-compute the frozen VL embeddings from LLM descriptions
all_teacher_features = []
semantic_guidance_features = []
desc_file = f"./desc/{DESC_LLM}/descriptions_top{DESC_TOPK}/{cfg.DATASET.NAME}.json"
with open(desc_file, "r") as f:
all_desc = json.load(f)
template = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
for cls in classnames:
cls_descs = [template.format(cls)[:-1] + f", features with {desc}" for desc in all_desc[cls]]
cls_descs = [template.format(cls)[:-1] + f", {desc}" for desc in all_desc[cls]]
cls_token = torch.cat([clip.tokenize(cls_desc) for cls_desc in cls_descs]).cuda()
with torch.no_grad():
cls_feature = clip_model_temp.encode_text(cls_token)
cls_feature = cls_feature / cls_feature.norm(dim=-1, keepdim=True)
cls_feature = torch.mean(cls_feature, dim=0)
all_teacher_features.append(cls_feature)
semantic_guidance_features.append(cls_feature)
self.fixed_embeddings = torch.stack(all_teacher_features)
self.semantic_embeddings = torch.stack(semantic_guidance_features)
print(f"Using LLM descriptions from: {desc_file}")
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
@@ -238,10 +238,10 @@ class CustomCLIP(nn.Module):
text_features_weak = self.text_encoder(prompts_weak, tokenized_prompts)
text_features_weak = text_features_weak / text_features_weak.norm(dim=-1, keepdim=True)
fixed_embeddings = self.prompt_learner.fixed_embeddings
fixed_embeddings = fixed_embeddings / fixed_embeddings.norm(dim=-1, keepdim=True)
semantic_embeddings = self.prompt_learner.semantic_embeddings
semantic_embeddings = semantic_embeddings / semantic_embeddings.norm(dim=-1, keepdim=True)
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ fixed_embeddings.half().cuda().t()
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ semantic_embeddings.half().cuda().t()
logits_strong = logit_scale * image_features @ text_features_strong.t()
logits_weak = logit_scale * image_features @ text_features_weak.t()
@@ -255,7 +255,7 @@ class CustomCLIP(nn.Module):
if self.prompt_learner.training:
loss_ce = F.cross_entropy(logits_final, label)
return loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final
return loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final
else:
return logits_final
@@ -343,7 +343,7 @@ class DZGCoOp(TrainerX):
scaler.step(optim)
scaler.update()
else:
loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zs_image_embedd, image_ft, \
loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zs_image_embedd, image_ft, \
zero_shot_logits, logits_strong, logits_weak, logits_final = model(image, label)
lambda1 = self.cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT
@@ -351,17 +351,17 @@ class DZGCoOp(TrainerX):
lambda3 = self.cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK
L_zvg = F.l1_loss(image_ft, zs_image_embedd.cuda(), reduction='mean') * lambda1
L_sg_strong = F.l1_loss(text_features_strong, fixed_embeddings.cuda(), reduction='mean') * lambda2
L_sg_weak = F.l1_loss(text_features_weak, fixed_embeddings.cuda(), reduction='mean') * lambda3
L_sg_strong = F.l1_loss(text_features_strong, semantic_embeddings.cuda(), reduction='mean') * lambda2
L_sg_weak = F.l1_loss(text_features_weak, semantic_embeddings.cuda(), reduction='mean') * lambda3
L_zlg = F.kl_div(
L_zpg = F.kl_div(
F.log_softmax(logits_final / 1, dim=1),
F.log_softmax(zero_shot_logits / 1, dim=1),
reduction='sum',
log_target=True
) * (1 * 1) / logits_final.numel()
L_zg = (L_zlg + L_sg_strong + L_sg_weak + L_zvg)
L_zg = (L_zpg + L_sg_strong + L_sg_weak + L_zvg)
loss = (loss_ce + L_zg)
optim.zero_grad()
loss.backward()