rename distill variable
This commit is contained in:
@@ -142,22 +142,22 @@ class VLPromptLearner(nn.Module):
|
||||
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
|
||||
self.ZS_image_encoder = clip_model_temp_image.visual
|
||||
# Now pre-compute the frozen VL embeddings from LLM descriptions
|
||||
all_teacher_features = []
|
||||
semantic_guidance_features = []
|
||||
desc_file = f"./desc/{DESC_LLM}/descriptions_top{DESC_TOPK}/{cfg.DATASET.NAME}.json"
|
||||
with open(desc_file, "r") as f:
|
||||
all_desc = json.load(f)
|
||||
|
||||
template = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
|
||||
for cls in classnames:
|
||||
cls_descs = [template.format(cls)[:-1] + f", features with {desc}" for desc in all_desc[cls]]
|
||||
cls_descs = [template.format(cls)[:-1] + f", {desc}" for desc in all_desc[cls]]
|
||||
cls_token = torch.cat([clip.tokenize(cls_desc) for cls_desc in cls_descs]).cuda()
|
||||
with torch.no_grad():
|
||||
cls_feature = clip_model_temp.encode_text(cls_token)
|
||||
cls_feature = cls_feature / cls_feature.norm(dim=-1, keepdim=True)
|
||||
cls_feature = torch.mean(cls_feature, dim=0)
|
||||
all_teacher_features.append(cls_feature)
|
||||
semantic_guidance_features.append(cls_feature)
|
||||
|
||||
self.fixed_embeddings = torch.stack(all_teacher_features)
|
||||
self.semantic_embeddings = torch.stack(semantic_guidance_features)
|
||||
print(f"Using LLM descriptions from: {desc_file}")
|
||||
# These token vectors will be saved when in save_model(),
|
||||
# but they should be ignored in load_model() as we want to use
|
||||
@@ -238,10 +238,10 @@ class CustomCLIP(nn.Module):
|
||||
text_features_weak = self.text_encoder(prompts_weak, tokenized_prompts)
|
||||
text_features_weak = text_features_weak / text_features_weak.norm(dim=-1, keepdim=True)
|
||||
|
||||
fixed_embeddings = self.prompt_learner.fixed_embeddings
|
||||
fixed_embeddings = fixed_embeddings / fixed_embeddings.norm(dim=-1, keepdim=True)
|
||||
semantic_embeddings = self.prompt_learner.semantic_embeddings
|
||||
semantic_embeddings = semantic_embeddings / semantic_embeddings.norm(dim=-1, keepdim=True)
|
||||
|
||||
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ fixed_embeddings.half().cuda().t()
|
||||
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ semantic_embeddings.half().cuda().t()
|
||||
|
||||
logits_strong = logit_scale * image_features @ text_features_strong.t()
|
||||
logits_weak = logit_scale * image_features @ text_features_weak.t()
|
||||
@@ -255,7 +255,7 @@ class CustomCLIP(nn.Module):
|
||||
|
||||
if self.prompt_learner.training:
|
||||
loss_ce = F.cross_entropy(logits_final, label)
|
||||
return loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final
|
||||
return loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final
|
||||
else:
|
||||
return logits_final
|
||||
|
||||
@@ -343,7 +343,7 @@ class DZGCoOp(TrainerX):
|
||||
scaler.step(optim)
|
||||
scaler.update()
|
||||
else:
|
||||
loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zs_image_embedd, image_ft, \
|
||||
loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zs_image_embedd, image_ft, \
|
||||
zero_shot_logits, logits_strong, logits_weak, logits_final = model(image, label)
|
||||
|
||||
lambda1 = self.cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT
|
||||
@@ -351,17 +351,17 @@ class DZGCoOp(TrainerX):
|
||||
lambda3 = self.cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK
|
||||
|
||||
L_zvg = F.l1_loss(image_ft, zs_image_embedd.cuda(), reduction='mean') * lambda1
|
||||
L_sg_strong = F.l1_loss(text_features_strong, fixed_embeddings.cuda(), reduction='mean') * lambda2
|
||||
L_sg_weak = F.l1_loss(text_features_weak, fixed_embeddings.cuda(), reduction='mean') * lambda3
|
||||
L_sg_strong = F.l1_loss(text_features_strong, semantic_embeddings.cuda(), reduction='mean') * lambda2
|
||||
L_sg_weak = F.l1_loss(text_features_weak, semantic_embeddings.cuda(), reduction='mean') * lambda3
|
||||
|
||||
L_zlg = F.kl_div(
|
||||
L_zpg = F.kl_div(
|
||||
F.log_softmax(logits_final / 1, dim=1),
|
||||
F.log_softmax(zero_shot_logits / 1, dim=1),
|
||||
reduction='sum',
|
||||
log_target=True
|
||||
) * (1 * 1) / logits_final.numel()
|
||||
|
||||
L_zg = (L_zlg + L_sg_strong + L_sg_weak + L_zvg)
|
||||
L_zg = (L_zpg + L_sg_strong + L_sg_weak + L_zvg)
|
||||
loss = (loss_ce + L_zg)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
Reference in New Issue
Block a user