rename distill variable
This commit is contained in:
@@ -142,22 +142,22 @@ class VLPromptLearner(nn.Module):
|
|||||||
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
|
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
|
||||||
self.ZS_image_encoder = clip_model_temp_image.visual
|
self.ZS_image_encoder = clip_model_temp_image.visual
|
||||||
# Now pre-compute the frozen VL embeddings from LLM descriptions
|
# Now pre-compute the frozen VL embeddings from LLM descriptions
|
||||||
all_teacher_features = []
|
semantic_guidance_features = []
|
||||||
desc_file = f"./desc/{DESC_LLM}/descriptions_top{DESC_TOPK}/{cfg.DATASET.NAME}.json"
|
desc_file = f"./desc/{DESC_LLM}/descriptions_top{DESC_TOPK}/{cfg.DATASET.NAME}.json"
|
||||||
with open(desc_file, "r") as f:
|
with open(desc_file, "r") as f:
|
||||||
all_desc = json.load(f)
|
all_desc = json.load(f)
|
||||||
|
|
||||||
template = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
|
template = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
|
||||||
for cls in classnames:
|
for cls in classnames:
|
||||||
cls_descs = [template.format(cls)[:-1] + f", features with {desc}" for desc in all_desc[cls]]
|
cls_descs = [template.format(cls)[:-1] + f", {desc}" for desc in all_desc[cls]]
|
||||||
cls_token = torch.cat([clip.tokenize(cls_desc) for cls_desc in cls_descs]).cuda()
|
cls_token = torch.cat([clip.tokenize(cls_desc) for cls_desc in cls_descs]).cuda()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
cls_feature = clip_model_temp.encode_text(cls_token)
|
cls_feature = clip_model_temp.encode_text(cls_token)
|
||||||
cls_feature = cls_feature / cls_feature.norm(dim=-1, keepdim=True)
|
cls_feature = cls_feature / cls_feature.norm(dim=-1, keepdim=True)
|
||||||
cls_feature = torch.mean(cls_feature, dim=0)
|
cls_feature = torch.mean(cls_feature, dim=0)
|
||||||
all_teacher_features.append(cls_feature)
|
semantic_guidance_features.append(cls_feature)
|
||||||
|
|
||||||
self.fixed_embeddings = torch.stack(all_teacher_features)
|
self.semantic_embeddings = torch.stack(semantic_guidance_features)
|
||||||
print(f"Using LLM descriptions from: {desc_file}")
|
print(f"Using LLM descriptions from: {desc_file}")
|
||||||
# These token vectors will be saved when in save_model(),
|
# These token vectors will be saved when in save_model(),
|
||||||
# but they should be ignored in load_model() as we want to use
|
# but they should be ignored in load_model() as we want to use
|
||||||
@@ -238,10 +238,10 @@ class CustomCLIP(nn.Module):
|
|||||||
text_features_weak = self.text_encoder(prompts_weak, tokenized_prompts)
|
text_features_weak = self.text_encoder(prompts_weak, tokenized_prompts)
|
||||||
text_features_weak = text_features_weak / text_features_weak.norm(dim=-1, keepdim=True)
|
text_features_weak = text_features_weak / text_features_weak.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
fixed_embeddings = self.prompt_learner.fixed_embeddings
|
semantic_embeddings = self.prompt_learner.semantic_embeddings
|
||||||
fixed_embeddings = fixed_embeddings / fixed_embeddings.norm(dim=-1, keepdim=True)
|
semantic_embeddings = semantic_embeddings / semantic_embeddings.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ fixed_embeddings.half().cuda().t()
|
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ semantic_embeddings.half().cuda().t()
|
||||||
|
|
||||||
logits_strong = logit_scale * image_features @ text_features_strong.t()
|
logits_strong = logit_scale * image_features @ text_features_strong.t()
|
||||||
logits_weak = logit_scale * image_features @ text_features_weak.t()
|
logits_weak = logit_scale * image_features @ text_features_weak.t()
|
||||||
@@ -255,7 +255,7 @@ class CustomCLIP(nn.Module):
|
|||||||
|
|
||||||
if self.prompt_learner.training:
|
if self.prompt_learner.training:
|
||||||
loss_ce = F.cross_entropy(logits_final, label)
|
loss_ce = F.cross_entropy(logits_final, label)
|
||||||
return loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final
|
return loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final
|
||||||
else:
|
else:
|
||||||
return logits_final
|
return logits_final
|
||||||
|
|
||||||
@@ -343,7 +343,7 @@ class DZGCoOp(TrainerX):
|
|||||||
scaler.step(optim)
|
scaler.step(optim)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
else:
|
else:
|
||||||
loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zs_image_embedd, image_ft, \
|
loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zs_image_embedd, image_ft, \
|
||||||
zero_shot_logits, logits_strong, logits_weak, logits_final = model(image, label)
|
zero_shot_logits, logits_strong, logits_weak, logits_final = model(image, label)
|
||||||
|
|
||||||
lambda1 = self.cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT
|
lambda1 = self.cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT
|
||||||
@@ -351,17 +351,17 @@ class DZGCoOp(TrainerX):
|
|||||||
lambda3 = self.cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK
|
lambda3 = self.cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK
|
||||||
|
|
||||||
L_zvg = F.l1_loss(image_ft, zs_image_embedd.cuda(), reduction='mean') * lambda1
|
L_zvg = F.l1_loss(image_ft, zs_image_embedd.cuda(), reduction='mean') * lambda1
|
||||||
L_sg_strong = F.l1_loss(text_features_strong, fixed_embeddings.cuda(), reduction='mean') * lambda2
|
L_sg_strong = F.l1_loss(text_features_strong, semantic_embeddings.cuda(), reduction='mean') * lambda2
|
||||||
L_sg_weak = F.l1_loss(text_features_weak, fixed_embeddings.cuda(), reduction='mean') * lambda3
|
L_sg_weak = F.l1_loss(text_features_weak, semantic_embeddings.cuda(), reduction='mean') * lambda3
|
||||||
|
|
||||||
L_zlg = F.kl_div(
|
L_zpg = F.kl_div(
|
||||||
F.log_softmax(logits_final / 1, dim=1),
|
F.log_softmax(logits_final / 1, dim=1),
|
||||||
F.log_softmax(zero_shot_logits / 1, dim=1),
|
F.log_softmax(zero_shot_logits / 1, dim=1),
|
||||||
reduction='sum',
|
reduction='sum',
|
||||||
log_target=True
|
log_target=True
|
||||||
) * (1 * 1) / logits_final.numel()
|
) * (1 * 1) / logits_final.numel()
|
||||||
|
|
||||||
L_zg = (L_zlg + L_sg_strong + L_sg_weak + L_zvg)
|
L_zg = (L_zpg + L_sg_strong + L_sg_weak + L_zvg)
|
||||||
loss = (loss_ce + L_zg)
|
loss = (loss_ce + L_zg)
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|||||||
Reference in New Issue
Block a user