fix conf
This commit is contained in:
@@ -221,39 +221,6 @@ class CustomCLIP(nn.Module):
|
|||||||
self.total_epochs = cfg.OPTIM.MAX_EPOCH
|
self.total_epochs = cfg.OPTIM.MAX_EPOCH
|
||||||
self.n_cls = len(classnames)
|
self.n_cls = len(classnames)
|
||||||
|
|
||||||
confidence_type = getattr(cfg.TRAINER.PROMPTSRC, 'CONFIDENCE_TYPE', 'entropy')
|
|
||||||
self.confidence_type = confidence_type
|
|
||||||
self.temperature = getattr(cfg.TRAINER.PROMPTSRC, 'CONFIDENCE_TEMPERATURE', 1.0)
|
|
||||||
self.momentum = getattr(cfg.TRAINER.PROMPTSRC, 'CONFIDENCE_MOMENTUM', 0.9)
|
|
||||||
self.register_buffer('running_confidence', torch.tensor(0.0))
|
|
||||||
self.steps = 0
|
|
||||||
|
|
||||||
def compute_confidence(self, logits):
|
|
||||||
if self.confidence_type == 'entropy':
|
|
||||||
probs = F.softmax(logits / self.temperature, dim=1)
|
|
||||||
entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1)
|
|
||||||
normalized_entropy = entropy / torch.log(torch.tensor(self.n_cls, dtype=logits.dtype, device=logits.device))
|
|
||||||
confidence = 1 - normalized_entropy
|
|
||||||
elif self.confidence_type == 'max_prob':
|
|
||||||
probs = F.softmax(logits / self.temperature, dim=1)
|
|
||||||
confidence = probs.max(dim=1).values
|
|
||||||
elif self.confidence_type == 'margin':
|
|
||||||
probs = F.softmax(logits / self.temperature, dim=1)
|
|
||||||
top2_probs, _ = probs.topk(2, dim=1)
|
|
||||||
confidence = top2_probs[:, 0] - top2_probs[:, 1]
|
|
||||||
elif self.confidence_type == 'max_margin':
|
|
||||||
max_prob = F.softmax(logits / self.temperature, dim=1).max(dim=1).values
|
|
||||||
entropy = -(F.softmax(logits / self.temperature, dim=1) *
|
|
||||||
torch.log(F.softmax(logits / self.temperature, dim=1) + 1e-10)).sum(dim=1)
|
|
||||||
normalized_entropy = entropy / torch.log(torch.tensor(self.n_cls, dtype=logits.dtype, device=logits.device))
|
|
||||||
confidence = max_prob * (1 - normalized_entropy)
|
|
||||||
else:
|
|
||||||
probs = F.softmax(logits / self.temperature, dim=1)
|
|
||||||
confidence = probs.max(dim=1).values
|
|
||||||
|
|
||||||
confidence = confidence.clamp(0, 1)
|
|
||||||
return confidence
|
|
||||||
|
|
||||||
def forward(self, image, label=None):
|
def forward(self, image, label=None):
|
||||||
tokenized_prompts = self.tokenized_prompts
|
tokenized_prompts = self.tokenized_prompts
|
||||||
logit_scale = self.logit_scale.exp()
|
logit_scale = self.logit_scale.exp()
|
||||||
@@ -281,17 +248,7 @@ class CustomCLIP(nn.Module):
|
|||||||
logits_strong = logit_scale * image_features @ text_features_strong.t()
|
logits_strong = logit_scale * image_features @ text_features_strong.t()
|
||||||
logits_weak = logit_scale * image_features @ text_features_weak.t()
|
logits_weak = logit_scale * image_features @ text_features_weak.t()
|
||||||
|
|
||||||
confidence = self.compute_confidence(zero_shot_logits)
|
alpha = 0.5
|
||||||
|
|
||||||
if self.training:
|
|
||||||
confidence_batch = confidence.mean().detach()
|
|
||||||
if self.steps == 0:
|
|
||||||
self.running_confidence = confidence_batch
|
|
||||||
else:
|
|
||||||
self.running_confidence = self.momentum * self.running_confidence + (1 - self.momentum) * confidence_batch
|
|
||||||
self.steps += 1
|
|
||||||
|
|
||||||
alpha = confidence.unsqueeze(1)
|
|
||||||
|
|
||||||
logits_final = alpha * logits_strong + (1 - alpha) * logits_weak
|
logits_final = alpha * logits_strong + (1 - alpha) * logits_weak
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user