This commit is contained in:
2026-02-05 12:12:11 +08:00
parent 1925ddfc86
commit 7fcf319dcf

View File

@@ -220,39 +220,6 @@ class CustomCLIP(nn.Module):
self.dtype = clip_model.dtype
self.total_epochs = cfg.OPTIM.MAX_EPOCH
self.n_cls = len(classnames)
confidence_type = getattr(cfg.TRAINER.PROMPTSRC, 'CONFIDENCE_TYPE', 'entropy')
self.confidence_type = confidence_type
self.temperature = getattr(cfg.TRAINER.PROMPTSRC, 'CONFIDENCE_TEMPERATURE', 1.0)
self.momentum = getattr(cfg.TRAINER.PROMPTSRC, 'CONFIDENCE_MOMENTUM', 0.9)
self.register_buffer('running_confidence', torch.tensor(0.0))
self.steps = 0
def compute_confidence(self, logits):
if self.confidence_type == 'entropy':
probs = F.softmax(logits / self.temperature, dim=1)
entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1)
normalized_entropy = entropy / torch.log(torch.tensor(self.n_cls, dtype=logits.dtype, device=logits.device))
confidence = 1 - normalized_entropy
elif self.confidence_type == 'max_prob':
probs = F.softmax(logits / self.temperature, dim=1)
confidence = probs.max(dim=1).values
elif self.confidence_type == 'margin':
probs = F.softmax(logits / self.temperature, dim=1)
top2_probs, _ = probs.topk(2, dim=1)
confidence = top2_probs[:, 0] - top2_probs[:, 1]
elif self.confidence_type == 'max_margin':
max_prob = F.softmax(logits / self.temperature, dim=1).max(dim=1).values
entropy = -(F.softmax(logits / self.temperature, dim=1) *
torch.log(F.softmax(logits / self.temperature, dim=1) + 1e-10)).sum(dim=1)
normalized_entropy = entropy / torch.log(torch.tensor(self.n_cls, dtype=logits.dtype, device=logits.device))
confidence = max_prob * (1 - normalized_entropy)
else:
probs = F.softmax(logits / self.temperature, dim=1)
confidence = probs.max(dim=1).values
confidence = confidence.clamp(0, 1)
return confidence
def forward(self, image, label=None):
tokenized_prompts = self.tokenized_prompts
@@ -281,17 +248,7 @@ class CustomCLIP(nn.Module):
logits_strong = logit_scale * image_features @ text_features_strong.t()
logits_weak = logit_scale * image_features @ text_features_weak.t()
confidence = self.compute_confidence(zero_shot_logits)
if self.training:
confidence_batch = confidence.mean().detach()
if self.steps == 0:
self.running_confidence = confidence_batch
else:
self.running_confidence = self.momentum * self.running_confidence + (1 - self.momentum) * confidence_batch
self.steps += 1
alpha = confidence.unsqueeze(1)
alpha = 0.5
logits_final = alpha * logits_strong + (1 - alpha) * logits_weak
@@ -498,4 +455,4 @@ class PromptSRC(TrainerX):
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False
self._models[name].load_state_dict(state_dict, strict=False)
self._models[name].load_state_dict(state_dict, strict=False)