This commit is contained in:
2026-02-03 10:21:07 +08:00
parent e556f17ebc
commit 0c2ae25cf8
81 changed files with 572 additions and 76 deletions

View File

@@ -19,7 +19,6 @@ from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint, mkdir_if_missing
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.evaluation import Classification,EvaluatorBase
from pygrad.pcgrad import PCGrad
from datasets.data_manager import DataManager
from dassl.data.datasets import build_dataset
@@ -35,16 +34,6 @@ from .util import GradCAM,denorm
import cv2
_tokenizer = _Tokenizer()
BACKGROUND_CATEGORY = ['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',]
#['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
#'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',
#]
BACKGROUND_CATEGORY_FOOD = ['table','forks','tablecloth','hands','spoon','glasses','dishes']
def load_clip_to_cpu(cfg):
backbone_name = cfg.MODEL.BACKBONE.NAME
url = clip._MODELS[backbone_name]
@@ -159,28 +148,18 @@ class MultiModalPromptLearner(nn.Module):
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
###Introduce Background
bg_template = 'a clean origami {}.'
bg_classesnames = [bg_template.format(name) for name in BACKGROUND_CATEGORY +BACKGROUND_CATEGORY_FOOD ]
tokenized_bg_prompts = torch.cat([clip.tokenize(bg) for bg in bg_classesnames])
bg_num = len(BACKGROUND_CATEGORY) + len(BACKGROUND_CATEGORY_FOOD)
tokenized_prompts = torch.cat((tokenized_prompts,tokenized_bg_prompts),dim=0)
with torch.no_grad():
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
self.bg_embeding = embedding[-bg_num:]
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
# those computed using the current class names
self.register_buffer("token_prefix", embedding[:-bg_num, :1, :]) # SOS
self.register_buffer("token_suffix", embedding[:-bg_num, 1 + n_ctx:, :]) # CLS, EOS
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
self.n_cls = n_cls
self.n_ctx = n_ctx
self.tokenized_prompts = tokenized_prompts # torch.Tensor [class_num 77] [:-bg_num]
self.tokenized_prompts = tokenized_prompts # torch.Tensor [class_num, 77]
self.name_lens = name_lens
def construct_prompts(self, ctx, prefix, suffix, label=None):
@@ -204,8 +183,7 @@ class MultiModalPromptLearner(nn.Module):
dim=1,
)
final_prompts = torch.cat((prompts,self.bg_embeding.cuda()),dim=0)
return final_prompts
return prompts
def forward(self):
ctx = self.ctx
@@ -264,17 +242,44 @@ class CustomCLIP(nn.Module):
def cos_sim(self,a,b):
return F.cosine_similarity(a,b)
def contrastive_loss(self, anchor, positive, negative, temperature=0.07):
"""
InfoNCE contrastive loss for foreground-background discrimination
Args:
anchor: Complete image features [B, D]
positive: Foreground features [B, D]
negative: Background features [B, D]
temperature: Temperature parameter for softmax
Returns:
loss: Contrastive learning loss value
"""
# Calculate similarity
sim_pos = F.cosine_similarity(anchor, positive, dim=-1) # [B]
sim_neg = F.cosine_similarity(anchor, negative, dim=-1) # [B]
# Apply temperature scaling
sim_pos = sim_pos / temperature
sim_neg = sim_neg / temperature
# InfoNCE loss: -log(exp(sim_pos) / (exp(sim_pos) + exp(sim_neg)))
logits = torch.stack([sim_pos, sim_neg], dim=1) # [B, 2]
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
loss = F.cross_entropy(logits, labels)
return loss
def forward(self, image, label=None,record=False,cal_gradient=False,weight=None,epoch=None,index=None,cfg=None,mask=None):
tokenized_prompts = self.tokenized_prompts
logit_scale = self.logit_scale.exp()
prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()
text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
text_features_fg = text_features[:-len(BACKGROUND_CATEGORY)]
ori_image_input = image.type(self.dtype)
# text_features = text_features + self.get_learnable_noise(text_features.shape)
text_features_fg = text_features_fg / text_features_fg.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
image_features, visual_ctx, mask_similarity = self.image_encoder(ori_image_input, shared_ctx,
deep_compound_prompts_vision)
@@ -285,18 +290,15 @@ class CustomCLIP(nn.Module):
# if label is not None:
# image_features = image_features + self.get_uniform_ball_noise(image_features.shape)
logits = logit_scale * image_features @ text_features_fg.t()
logits = logit_scale * image_features @ text_features.t()
loss_re = torch.tensor(0.0, dtype=self.dtype, device=image.device)
loss_fg = torch.tensor(0.0, dtype=self.dtype, device=image.device)
loss_contrastive = torch.tensor(0.0, dtype=self.dtype, device=image.device)
if mask != None:
text_features_bg = text_features[-len(BACKGROUND_CATEGORY):]
text_features_bg = text_features_bg / text_features_bg.norm(dim=-1, keepdim=True)
image_features_fg,_,_ = self.image_encoder(ori_image_input*mask, shared_ctx, deep_compound_prompts_vision) #, shared_ctx, deep_compound_prompts_vision
@@ -305,37 +307,27 @@ class CustomCLIP(nn.Module):
image_features_bg = image_features_bg / image_features_bg.norm(dim=-1, keepdim=True)
loss_re1 = F.triplet_margin_loss(image_features,image_features_fg.detach(),image_features_bg.detach(),margin=1.5)
# image_features_fg_ori = self.image_encoder_ori(ori_image_input*mask_random)
# image_features_bg_ori = self.image_encoder_ori(ori_image_input*(1-mask_random))
# image_features_fg_ori = image_features_fg_ori / image_features_fg_ori.norm(dim=-1, keepdim=True)
# image_features_bg_ori = image_features_bg_ori / image_features_bg_ori.norm(dim=-1,keepdim=True)
# image_features_all_ori = image_features_fg_ori + image_features_bg_ori
# image_features_all_ori = image_features_all_ori / image_features_all_ori.norm(dim=-1,keepdim=True)
# loss_reo = torch.abs(image_features_all_ori.detach() - image_features).mean()
foreground_score = logit_scale*image_features_fg.detach()@text_features_fg.t()
pseudo_label = torch.argmax(image_features_bg @ text_features_bg.t(), dim=-1)
logits_bg = logit_scale*(image_features_bg) @ text_features_bg.t()
para_bg = 0.5
para_fg = 0.1
para_vd = 0.8
loss_contrastive = self.contrastive_loss(image_features, image_features_fg.detach(), image_features_bg.detach(), temperature=0.07)
loss_bg = F.cross_entropy(logits_bg,pseudo_label)
loss_fg = F.cross_entropy(foreground_score,label)
para_fg = 0.2
para_vd = 0.6
if epoch > 6: #Tunable parameters
loss_re = para_fg*loss_fg + para_bg*loss_bg
if label is not None:
loss_fg = F.cross_entropy(logit_scale*image_features_fg.detach()@text_features.t(), label)
else:
loss_re = para_vd*loss_re1 #loss_reo would be effective in base2novel setting
loss_fg = torch.tensor(0.0, dtype=self.dtype, device=image.device)
if epoch is not None and epoch > 6: #Tunable parameters
loss_re = para_fg*loss_fg
else:
loss_re = para_vd*loss_contrastive
if self.prompt_learner.training:
if weight is None:
return F.cross_entropy(logits,label)+loss_re,logits,{'loss_vd':loss_re1.item(),'loss_bg':loss_bg.item(),'loss_fg':loss_fg.item()}
return F.cross_entropy(logits,label)+loss_re,logits,{'loss_contrastive':loss_contrastive.item(),'loss_fg':loss_fg.item()}
else:
return F.cross_entropy(weight.unsqueeze(-1)*logits,label), logits
@@ -674,8 +666,8 @@ class MaPLe(TrainerX):
model_name="model-best.pth.tar"
)
# if meet_checkpoint_freq or last_epoch:
# self.save_model(self.epoch, self.output_dir)
if meet_checkpoint_freq or last_epoch:
self.save_model(self.epoch, self.output_dir)
print(f"Now generate the attentive masking in {self.cfg.TRAINER.DAPT_MODE} \n")

View File

@@ -5,10 +5,6 @@ import cv2
from PIL import Image
import os
BACKGROUND_CATEGORY = ['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',
]
class GradCAM(object):
def __init__(self,model_dict):
layer_name = model_dict['layer_name']
@@ -80,7 +76,7 @@ class GradCAM(object):
else:
logit = self.model_arch.forward_test(input,labels,cfg=cfg)
pred_label = torch.argmax(logit[:,:-len(BACKGROUND_CATEGORY)])
pred_label = torch.argmax(logit)
sign = pred_label == labels
# if (split == 'true' and sign == False) or (split == 'wrong' and sign == True):
# print(f'Ignore the not {split} sample')
@@ -88,11 +84,10 @@ class GradCAM(object):
# if attn_mask:
# return final_cls_mask
pred = logit[:,:-len(BACKGROUND_CATEGORY)].argmax(dim=-1)
background_logit = logit[:,-len(BACKGROUND_CATEGORY):]
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]-len(BACKGROUND_CATEGORY)).to(torch.float16)
pred = logit.argmax(dim=-1)
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]).to(torch.float16)
loss = (F.softmax(logit[:,:-len(BACKGROUND_CATEGORY)])*one_hot_labels).mean() #+ background_logit.mean() #(logit[:,:-len(BACKGROUND_CATEGORY)]*one_hot_labels).mean() #F.cross_entropy(logit.requires_grad_(True), labels)
loss = (F.softmax(logit)*one_hot_labels).mean()
# score = logit[:,labels]
self.model_arch.zero_grad()
@@ -186,10 +181,8 @@ class GradCAM(object):
# if attn_mask:
# return final_cls_mask
# pred = logit[:,-len(BACKGROUND_CATEGORY):].argmax(dim=-1)
# background_logit = logit[:,-len(BACKGROUND_CATEGORY):]
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]).to(torch.float16)
loss = (logit*one_hot_labels).mean() #+ background_logit.mean() #(logit[:,:-len(BACKGROUND_CATEGORY)]*one_hot_labels).mean() #F.cross_entropy(logit.requires_grad_(True), labels)
loss = (logit*one_hot_labels).mean()
# score = logit[:,labels]
self.model_arch.zero_grad()
loss.backward(retain_graph=retain_graph)