diff --git a/trainers/promptsrc.py b/trainers/promptsrc.py index 71891d5..38076f7 100644 --- a/trainers/promptsrc.py +++ b/trainers/promptsrc.py @@ -220,39 +220,6 @@ class CustomCLIP(nn.Module): self.dtype = clip_model.dtype self.total_epochs = cfg.OPTIM.MAX_EPOCH self.n_cls = len(classnames) - - confidence_type = getattr(cfg.TRAINER.PROMPTSRC, 'CONFIDENCE_TYPE', 'entropy') - self.confidence_type = confidence_type - self.temperature = getattr(cfg.TRAINER.PROMPTSRC, 'CONFIDENCE_TEMPERATURE', 1.0) - self.momentum = getattr(cfg.TRAINER.PROMPTSRC, 'CONFIDENCE_MOMENTUM', 0.9) - self.register_buffer('running_confidence', torch.tensor(0.0)) - self.steps = 0 - - def compute_confidence(self, logits): - if self.confidence_type == 'entropy': - probs = F.softmax(logits / self.temperature, dim=1) - entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1) - normalized_entropy = entropy / torch.log(torch.tensor(self.n_cls, dtype=logits.dtype, device=logits.device)) - confidence = 1 - normalized_entropy - elif self.confidence_type == 'max_prob': - probs = F.softmax(logits / self.temperature, dim=1) - confidence = probs.max(dim=1).values - elif self.confidence_type == 'margin': - probs = F.softmax(logits / self.temperature, dim=1) - top2_probs, _ = probs.topk(2, dim=1) - confidence = top2_probs[:, 0] - top2_probs[:, 1] - elif self.confidence_type == 'max_margin': - max_prob = F.softmax(logits / self.temperature, dim=1).max(dim=1).values - entropy = -(F.softmax(logits / self.temperature, dim=1) * - torch.log(F.softmax(logits / self.temperature, dim=1) + 1e-10)).sum(dim=1) - normalized_entropy = entropy / torch.log(torch.tensor(self.n_cls, dtype=logits.dtype, device=logits.device)) - confidence = max_prob * (1 - normalized_entropy) - else: - probs = F.softmax(logits / self.temperature, dim=1) - confidence = probs.max(dim=1).values - - confidence = confidence.clamp(0, 1) - return confidence def forward(self, image, label=None): tokenized_prompts = self.tokenized_prompts @@ -281,17 +248,7 @@ class CustomCLIP(nn.Module): logits_strong = logit_scale * image_features @ text_features_strong.t() logits_weak = logit_scale * image_features @ text_features_weak.t() - confidence = self.compute_confidence(zero_shot_logits) - - if self.training: - confidence_batch = confidence.mean().detach() - if self.steps == 0: - self.running_confidence = confidence_batch - else: - self.running_confidence = self.momentum * self.running_confidence + (1 - self.momentum) * confidence_batch - self.steps += 1 - - alpha = confidence.unsqueeze(1) + alpha = 0.5 logits_final = alpha * logits_strong + (1 - alpha) * logits_weak @@ -498,4 +455,4 @@ class PromptSRC(TrainerX): print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) # set strict=False - self._models[name].load_state_dict(state_dict, strict=False) \ No newline at end of file + self._models[name].load_state_dict(state_dict, strict=False)