dual and softmax conf

This commit is contained in:
2026-02-05 18:46:37 +08:00
parent ea5e9f17ba
commit 91e873c365
2 changed files with 79 additions and 48 deletions

View File

@@ -119,6 +119,8 @@ def extend_cfg(cfg):
cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1)
cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1)
cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT = 25
cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight
cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_WEAK = 2.5 # lambda3: weak text constraint weight
cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT = 10
cfg.TRAINER.PROMPTSRC.GPA_MEAN = 15
cfg.TRAINER.PROMPTSRC.GPA_STD = 1

View File

@@ -107,28 +107,32 @@ class VLPromptLearner(nn.Module):
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
if ctx_init and n_ctx <= 4:
# use given words to initialize context vectors
ctx_init = ctx_init.replace("_", " ")
n_ctx = n_ctx
prompt = clip.tokenize(ctx_init)
with torch.no_grad():
embedding = clip_model.token_embedding(prompt).type(dtype)
ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
prompt_prefix = ctx_init
ctx_vectors_strong = embedding[0, 1: 1 + n_ctx, :]
prompt_prefix_strong = ctx_init
else:
# random initialization
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors, std=0.02)
prompt_prefix = " ".join(["X"] * n_ctx)
print(f"Independent V-L design")
print(f'Initial text context: "{prompt_prefix}"')
ctx_vectors_strong = torch.empty(n_ctx, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors_strong, std=0.02)
prompt_prefix_strong = " ".join(["X"] * n_ctx)
ctx_vectors_weak = torch.empty(n_ctx, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors_weak, std=0.02)
prompt_prefix_weak = " ".join(["X"] * n_ctx)
print(f"Independent V-L design with Dual Prompt Branches")
print(f'Strong branch initial text context: "{prompt_prefix_strong}"')
print(f'Weak branch initial text context: "{prompt_prefix_weak}"')
print(f"Number of context words (tokens) for Language prompting: {n_ctx}")
print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.PROMPTSRC.N_CTX_VISION}")
self.ctx = nn.Parameter(ctx_vectors)
self.ctx_strong = nn.Parameter(ctx_vectors_strong)
self.ctx_weak = nn.Parameter(ctx_vectors_weak)
classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
prompts = [prompt_prefix + " " + name + "." for name in classnames]
prompts = [prompt_prefix_strong + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
# Also create frozen CLIP
@@ -188,15 +192,19 @@ class VLPromptLearner(nn.Module):
return prompts
def forward(self):
ctx = self.ctx
if ctx.dim() == 2:
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
ctx_strong = self.ctx_strong
ctx_weak = self.ctx_weak
if ctx_strong.dim() == 2:
ctx_strong = ctx_strong.unsqueeze(0).expand(self.n_cls, -1, -1)
ctx_weak = ctx_weak.unsqueeze(0).expand(self.n_cls, -1, -1)
prefix = self.token_prefix
suffix = self.token_suffix
prompts = self.construct_prompts(ctx, prefix, suffix)
prompts_strong = self.construct_prompts(ctx_strong, prefix, suffix)
prompts_weak = self.construct_prompts(ctx_weak, prefix, suffix)
return prompts
return prompts_strong, prompts_weak
class CustomCLIP(nn.Module):
@@ -215,29 +223,41 @@ class CustomCLIP(nn.Module):
tokenized_prompts = self.tokenized_prompts
logit_scale = self.logit_scale.exp()
prompts = self.prompt_learner()
# Compute the prompted image and text features
text_features = self.text_encoder(prompts, tokenized_prompts)
prompts_strong, prompts_weak = self.prompt_learner()
with torch.no_grad():
zero_shot_features = self.prompt_learner.ZS_image_encoder(image.type(self.dtype))
zero_shot_features = zero_shot_features / zero_shot_features.norm(dim=-1, keepdim=True)
image_features = self.image_encoder(image.type(self.dtype))
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Compute the prompted logits
logits = logit_scale * image_features @ text_features.t()
if self.prompt_learner.training:
# Now calculate the frozen pre-trained features
fixed_embeddings = self.prompt_learner.fixed_embeddings # precomputed pre-trained frozen textual features
fixed_embeddings = fixed_embeddings / fixed_embeddings.norm(dim=-1, keepdim=True)
with torch.no_grad():
zero_shot_features = self.prompt_learner.ZS_image_encoder(image.type(self.dtype))
zero_shot_features = zero_shot_features / zero_shot_features.norm(dim=-1, keepdim=True)
# Compute pre-trained frozen visual features
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ fixed_embeddings.half().cuda().t()
return F.cross_entropy(logits,
label), text_features, fixed_embeddings, zero_shot_features, \
image_features, zero_shot_logits, logits
text_features_strong = self.text_encoder(prompts_strong, tokenized_prompts)
text_features_strong = text_features_strong / text_features_strong.norm(dim=-1, keepdim=True)
text_features_weak = self.text_encoder(prompts_weak, tokenized_prompts)
text_features_weak = text_features_weak / text_features_weak.norm(dim=-1, keepdim=True)
fixed_embeddings = self.prompt_learner.fixed_embeddings
fixed_embeddings = fixed_embeddings / fixed_embeddings.norm(dim=-1, keepdim=True)
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ fixed_embeddings.half().cuda().t()
logits_strong = logit_scale * image_features @ text_features_strong.t()
logits_weak = logit_scale * image_features @ text_features_weak.t()
zs_probs = F.softmax(zero_shot_logits, dim=1)
confidence = zs_probs.max(dim=1).values
alpha = confidence.unsqueeze(1)
logits_final = alpha * logits_strong + (1 - alpha) * logits_weak
if self.prompt_learner.training:
loss_ce = F.cross_entropy(logits_final, label)
return loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final
else:
return logits
return logits_final
@TRAINER_REGISTRY.register()
@@ -323,22 +343,25 @@ class PromptSRC(TrainerX):
scaler.step(optim)
scaler.update()
else:
loss_ce, normalized_text_features, zs_clip_text_embeddings, zs_image_embedd, image_ft, \
zero_shot_logits, logits = model(image, label)
# Calculate the L_SCL_text loss
loss_scl_text = F.l1_loss(normalized_text_features, zs_clip_text_embeddings.cuda(),
reduction='mean') * self.cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT
# Calculate the L_SCL_image loss
loss_scl_image = F.l1_loss(image_ft, zs_image_embedd.cuda(),
reduction='mean') * self.cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT
# Now calculate L_SCL_logits
loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zs_image_embedd, image_ft, \
zero_shot_logits, logits_strong, logits_weak, logits_final = model(image, label)
lambda1 = self.cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT
lambda2 = self.cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_STRONG
lambda3 = self.cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_WEAK
loss_scl_image = F.l1_loss(image_ft, zs_image_embedd.cuda(), reduction='mean') * lambda1
loss_scl_text_strong = F.l1_loss(text_features_strong, fixed_embeddings.cuda(), reduction='mean') * lambda2
loss_scl_text_weak = F.l1_loss(text_features_weak, fixed_embeddings.cuda(), reduction='mean') * lambda3
L_SCL_logits = F.kl_div(
F.log_softmax(logits / 1, dim=1),
F.log_softmax(logits_final / 1, dim=1),
F.log_softmax(zero_shot_logits / 1, dim=1),
reduction='sum',
log_target=True
) * (1 * 1) / logits.numel()
L_SCL = (L_SCL_logits + loss_scl_text + loss_scl_image)
) * (1 * 1) / logits_final.numel()
L_SCL = (L_SCL_logits + loss_scl_text_strong + loss_scl_text_weak + loss_scl_image)
loss = (loss_ce + L_SCL)
optim.zero_grad()
loss.backward()
@@ -425,6 +448,12 @@ class PromptSRC(TrainerX):
if "prompt_learner.token_suffix" in state_dict:
del state_dict["prompt_learner.token_suffix"]
# Handle backward compatibility: if old checkpoint has ctx, initialize both ctx_strong and ctx_weak
if "prompt_learner.ctx" in state_dict:
ctx = state_dict.pop("prompt_learner.ctx")
state_dict["prompt_learner.ctx_strong"] = ctx.clone()
state_dict["prompt_learner.ctx_weak"] = ctx.clone()
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False
self._models[name].load_state_dict(state_dict, strict=False)