|
|
|
|
@@ -142,7 +142,7 @@ class VLPromptLearner(nn.Module):
|
|
|
|
|
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
|
|
|
|
|
self.ZS_image_encoder = clip_model_temp_image.visual
|
|
|
|
|
# Now pre-compute the frozen VL embeddings from LLM descriptions
|
|
|
|
|
all_teacher_features = []
|
|
|
|
|
semantic_guidance_features = []
|
|
|
|
|
desc_file = f"./desc/{DESC_LLM}/descriptions_top{DESC_TOPK}/{cfg.DATASET.NAME}.json"
|
|
|
|
|
with open(desc_file, "r") as f:
|
|
|
|
|
all_desc = json.load(f)
|
|
|
|
|
@@ -155,9 +155,9 @@ class VLPromptLearner(nn.Module):
|
|
|
|
|
cls_feature = clip_model_temp.encode_text(cls_token)
|
|
|
|
|
cls_feature = cls_feature / cls_feature.norm(dim=-1, keepdim=True)
|
|
|
|
|
cls_feature = torch.mean(cls_feature, dim=0)
|
|
|
|
|
all_teacher_features.append(cls_feature)
|
|
|
|
|
semantic_guidance_features.append(cls_feature)
|
|
|
|
|
|
|
|
|
|
self.fixed_embeddings = torch.stack(all_teacher_features)
|
|
|
|
|
self.semantic_embeddings = torch.stack(semantic_guidance_features)
|
|
|
|
|
print(f"Using LLM descriptions from: {desc_file}")
|
|
|
|
|
# These token vectors will be saved when in save_model(),
|
|
|
|
|
# but they should be ignored in load_model() as we want to use
|
|
|
|
|
@@ -238,10 +238,10 @@ class CustomCLIP(nn.Module):
|
|
|
|
|
text_features_weak = self.text_encoder(prompts_weak, tokenized_prompts)
|
|
|
|
|
text_features_weak = text_features_weak / text_features_weak.norm(dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
|
|
fixed_embeddings = self.prompt_learner.fixed_embeddings
|
|
|
|
|
fixed_embeddings = fixed_embeddings / fixed_embeddings.norm(dim=-1, keepdim=True)
|
|
|
|
|
semantic_embeddings = self.prompt_learner.semantic_embeddings
|
|
|
|
|
semantic_embeddings = semantic_embeddings / semantic_embeddings.norm(dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
|
|
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ fixed_embeddings.half().cuda().t()
|
|
|
|
|
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ semantic_embeddings.half().cuda().t()
|
|
|
|
|
|
|
|
|
|
logits_strong = logit_scale * image_features @ text_features_strong.t()
|
|
|
|
|
logits_weak = logit_scale * image_features @ text_features_weak.t()
|
|
|
|
|
@@ -255,7 +255,7 @@ class CustomCLIP(nn.Module):
|
|
|
|
|
|
|
|
|
|
if self.prompt_learner.training:
|
|
|
|
|
loss_ce = F.cross_entropy(logits_final, label)
|
|
|
|
|
return loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final
|
|
|
|
|
return loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final
|
|
|
|
|
else:
|
|
|
|
|
return logits_final
|
|
|
|
|
|
|
|
|
|
@@ -312,11 +312,11 @@ class DZGCoOp(TrainerX):
|
|
|
|
|
self.total_epochs = cfg.OPTIM.MAX_EPOCH
|
|
|
|
|
self.step_counter = 1
|
|
|
|
|
N = cfg.OPTIM.MAX_EPOCH
|
|
|
|
|
mean = cfg.TRAINER.DZGCOOP.GPA_MEAN
|
|
|
|
|
stdev = cfg.TRAINER.DZGCOOP.GPA_STD
|
|
|
|
|
gauss = self.get_gauss(mean, stdev)
|
|
|
|
|
self.gauss = np.array([gauss(a) for a in range(1, N + 1)])
|
|
|
|
|
self.gauss = self.gauss / sum(self.gauss)
|
|
|
|
|
mean = cfg.TRAINER.DZGCOOP.EWA_MEAN
|
|
|
|
|
stdev = cfg.TRAINER.DZGCOOP.EWA_STD
|
|
|
|
|
normal = self.get_normal(mean, stdev)
|
|
|
|
|
self.normal_weights = np.array([normal(a) for a in range(1, N + 1)])
|
|
|
|
|
self.normal_weights = self.normal_weights / sum(self.normal_weights)
|
|
|
|
|
self.scaler = GradScaler() if cfg.TRAINER.DZGCOOP.PREC == "amp" else None
|
|
|
|
|
# Note that multi-gpu training could be slow because CLIP's size is
|
|
|
|
|
# big, which slows down the copy operation in DataParallel
|
|
|
|
|
@@ -324,8 +324,8 @@ class DZGCoOp(TrainerX):
|
|
|
|
|
if device_count > 1:
|
|
|
|
|
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
|
|
|
|
|
self.model = nn.DataParallel(self.model)
|
|
|
|
|
# Keep model with GPA
|
|
|
|
|
self.previous_model_gpa = None
|
|
|
|
|
# Keep model with EWA
|
|
|
|
|
self.previous_model_ewa = None
|
|
|
|
|
|
|
|
|
|
def forward_backward(self, batch):
|
|
|
|
|
image, label = self.parse_batch_train(batch)
|
|
|
|
|
@@ -343,7 +343,7 @@ class DZGCoOp(TrainerX):
|
|
|
|
|
scaler.step(optim)
|
|
|
|
|
scaler.update()
|
|
|
|
|
else:
|
|
|
|
|
loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zs_image_embedd, image_ft, \
|
|
|
|
|
loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zs_image_embedd, image_ft, \
|
|
|
|
|
zero_shot_logits, logits_strong, logits_weak, logits_final = model(image, label)
|
|
|
|
|
|
|
|
|
|
lambda1 = self.cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT
|
|
|
|
|
@@ -351,17 +351,17 @@ class DZGCoOp(TrainerX):
|
|
|
|
|
lambda3 = self.cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK
|
|
|
|
|
|
|
|
|
|
L_zvg = F.l1_loss(image_ft, zs_image_embedd.cuda(), reduction='mean') * lambda1
|
|
|
|
|
L_sg_strong = F.l1_loss(text_features_strong, fixed_embeddings.cuda(), reduction='mean') * lambda2
|
|
|
|
|
L_sg_weak = F.l1_loss(text_features_weak, fixed_embeddings.cuda(), reduction='mean') * lambda3
|
|
|
|
|
L_sg_strong = F.l1_loss(text_features_strong, semantic_embeddings.cuda(), reduction='mean') * lambda2
|
|
|
|
|
L_sg_weak = F.l1_loss(text_features_weak, semantic_embeddings.cuda(), reduction='mean') * lambda3
|
|
|
|
|
|
|
|
|
|
L_zpg = F.kl_div(
|
|
|
|
|
L_zlg = F.kl_div(
|
|
|
|
|
F.log_softmax(logits_final / 1, dim=1),
|
|
|
|
|
F.log_softmax(zero_shot_logits / 1, dim=1),
|
|
|
|
|
reduction='sum',
|
|
|
|
|
log_target=True
|
|
|
|
|
) * (1 * 1) / logits_final.numel()
|
|
|
|
|
|
|
|
|
|
L_zg = (L_zpg + L_sg_strong + L_sg_weak + L_zvg)
|
|
|
|
|
L_zg = (L_zlg + L_sg_strong + L_sg_weak + L_zvg)
|
|
|
|
|
loss = (loss_ce + L_zg)
|
|
|
|
|
optim.zero_grad()
|
|
|
|
|
loss.backward()
|
|
|
|
|
@@ -371,20 +371,22 @@ class DZGCoOp(TrainerX):
|
|
|
|
|
|
|
|
|
|
if (self.batch_idx + 1) == self.num_batches:
|
|
|
|
|
self.update_lr()
|
|
|
|
|
# Means one epoch is completed, perform GPA
|
|
|
|
|
# Means one epoch is completed, perform EWA
|
|
|
|
|
self.step_counter = self.step_counter + 1
|
|
|
|
|
current_epoch_weight = self.gauss[self.step_counter - 2]
|
|
|
|
|
current_epoch_weight = self.normal_weights[self.step_counter - 2]
|
|
|
|
|
current_model_weights = copy.deepcopy(model.state_dict())
|
|
|
|
|
for key in current_model_weights:
|
|
|
|
|
current_model_weights[key] = current_model_weights[key].cpu()
|
|
|
|
|
weighted_state_dict = self.state_dict_weighting(current_model_weights, current_epoch_weight)
|
|
|
|
|
if self.previous_model_gpa is None:
|
|
|
|
|
self.previous_model_gpa = weighted_state_dict
|
|
|
|
|
if self.previous_model_ewa is None:
|
|
|
|
|
self.previous_model_ewa = weighted_state_dict
|
|
|
|
|
else:
|
|
|
|
|
self.previous_model_gpa = self.state_dict_add(weighted_state_dict, self.previous_model_gpa)
|
|
|
|
|
self.previous_model_ewa = self.state_dict_add(weighted_state_dict, self.previous_model_ewa)
|
|
|
|
|
|
|
|
|
|
if self.step_counter == self.model.total_epochs + 1:
|
|
|
|
|
print("Using GPA model for final inference...")
|
|
|
|
|
model.load_state_dict(self.previous_model_gpa)
|
|
|
|
|
self.model.load_state_dict(self.previous_model_gpa)
|
|
|
|
|
print("Using EWA model for final inference...")
|
|
|
|
|
model.load_state_dict(self.previous_model_ewa)
|
|
|
|
|
self.model.load_state_dict(self.previous_model_ewa)
|
|
|
|
|
return loss_summary
|
|
|
|
|
|
|
|
|
|
def state_dict_weighting(self, main_dict, weightage, prompt_only=False):
|
|
|
|
|
@@ -392,24 +394,24 @@ class DZGCoOp(TrainerX):
|
|
|
|
|
updated_dict = copy.deepcopy(main_dict)
|
|
|
|
|
if not prompt_only:
|
|
|
|
|
for key in main_dict:
|
|
|
|
|
updated_dict[key] = main_dict[key] * weightage
|
|
|
|
|
updated_dict[key] = main_dict[key].cpu() * weightage
|
|
|
|
|
return updated_dict
|
|
|
|
|
else:
|
|
|
|
|
return main_dict * weightage
|
|
|
|
|
return main_dict.cpu() * weightage
|
|
|
|
|
|
|
|
|
|
def state_dict_add(self, dict1, dict2, prompt_only=False):
|
|
|
|
|
# Average all parameters
|
|
|
|
|
if not prompt_only:
|
|
|
|
|
modified_dict = dict2
|
|
|
|
|
for key in dict1:
|
|
|
|
|
modified_dict[key] = (modified_dict[key] + dict1[key])
|
|
|
|
|
modified_dict[key] = modified_dict[key].cpu() + dict1[key].cpu()
|
|
|
|
|
return modified_dict
|
|
|
|
|
else:
|
|
|
|
|
return dict1 + dict2
|
|
|
|
|
return dict1.cpu() + dict2.cpu()
|
|
|
|
|
|
|
|
|
|
def get_gauss(self, mu, sigma):
|
|
|
|
|
gauss = lambda x: (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
|
|
|
|
|
return gauss
|
|
|
|
|
def get_normal(self, mu, sigma):
|
|
|
|
|
normal = lambda x: (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
|
|
|
|
|
return normal
|
|
|
|
|
|
|
|
|
|
def parse_batch_train(self, batch):
|
|
|
|
|
input = batch["img"]
|
|
|
|
|
@@ -456,4 +458,4 @@ class DZGCoOp(TrainerX):
|
|
|
|
|
|
|
|
|
|
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
|
|
|
|
|
# set strict=False
|
|
|
|
|
self._models[name].load_state_dict(state_dict, strict=False)
|
|
|
|
|
self._models[name].load_state_dict(state_dict, strict=False)
|
|
|
|
|
|