fix conf
This commit is contained in:
@@ -220,39 +220,6 @@ class CustomCLIP(nn.Module):
|
||||
self.dtype = clip_model.dtype
|
||||
self.total_epochs = cfg.OPTIM.MAX_EPOCH
|
||||
self.n_cls = len(classnames)
|
||||
|
||||
confidence_type = getattr(cfg.TRAINER.PROMPTSRC, 'CONFIDENCE_TYPE', 'entropy')
|
||||
self.confidence_type = confidence_type
|
||||
self.temperature = getattr(cfg.TRAINER.PROMPTSRC, 'CONFIDENCE_TEMPERATURE', 1.0)
|
||||
self.momentum = getattr(cfg.TRAINER.PROMPTSRC, 'CONFIDENCE_MOMENTUM', 0.9)
|
||||
self.register_buffer('running_confidence', torch.tensor(0.0))
|
||||
self.steps = 0
|
||||
|
||||
def compute_confidence(self, logits):
|
||||
if self.confidence_type == 'entropy':
|
||||
probs = F.softmax(logits / self.temperature, dim=1)
|
||||
entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1)
|
||||
normalized_entropy = entropy / torch.log(torch.tensor(self.n_cls, dtype=logits.dtype, device=logits.device))
|
||||
confidence = 1 - normalized_entropy
|
||||
elif self.confidence_type == 'max_prob':
|
||||
probs = F.softmax(logits / self.temperature, dim=1)
|
||||
confidence = probs.max(dim=1).values
|
||||
elif self.confidence_type == 'margin':
|
||||
probs = F.softmax(logits / self.temperature, dim=1)
|
||||
top2_probs, _ = probs.topk(2, dim=1)
|
||||
confidence = top2_probs[:, 0] - top2_probs[:, 1]
|
||||
elif self.confidence_type == 'max_margin':
|
||||
max_prob = F.softmax(logits / self.temperature, dim=1).max(dim=1).values
|
||||
entropy = -(F.softmax(logits / self.temperature, dim=1) *
|
||||
torch.log(F.softmax(logits / self.temperature, dim=1) + 1e-10)).sum(dim=1)
|
||||
normalized_entropy = entropy / torch.log(torch.tensor(self.n_cls, dtype=logits.dtype, device=logits.device))
|
||||
confidence = max_prob * (1 - normalized_entropy)
|
||||
else:
|
||||
probs = F.softmax(logits / self.temperature, dim=1)
|
||||
confidence = probs.max(dim=1).values
|
||||
|
||||
confidence = confidence.clamp(0, 1)
|
||||
return confidence
|
||||
|
||||
def forward(self, image, label=None):
|
||||
tokenized_prompts = self.tokenized_prompts
|
||||
@@ -281,17 +248,7 @@ class CustomCLIP(nn.Module):
|
||||
logits_strong = logit_scale * image_features @ text_features_strong.t()
|
||||
logits_weak = logit_scale * image_features @ text_features_weak.t()
|
||||
|
||||
confidence = self.compute_confidence(zero_shot_logits)
|
||||
|
||||
if self.training:
|
||||
confidence_batch = confidence.mean().detach()
|
||||
if self.steps == 0:
|
||||
self.running_confidence = confidence_batch
|
||||
else:
|
||||
self.running_confidence = self.momentum * self.running_confidence + (1 - self.momentum) * confidence_batch
|
||||
self.steps += 1
|
||||
|
||||
alpha = confidence.unsqueeze(1)
|
||||
alpha = 0.5
|
||||
|
||||
logits_final = alpha * logits_strong + (1 - alpha) * logits_weak
|
||||
|
||||
@@ -498,4 +455,4 @@ class PromptSRC(TrainerX):
|
||||
|
||||
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
|
||||
# set strict=False
|
||||
self._models[name].load_state_dict(state_dict, strict=False)
|
||||
self._models[name].load_state_dict(state_dict, strict=False)
|
||||
|
||||
Reference in New Issue
Block a user