This commit is contained in:
2026-02-03 10:21:07 +08:00
parent e556f17ebc
commit 0c2ae25cf8
81 changed files with 572 additions and 76 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+52 -60
View File
@@ -19,7 +19,6 @@ from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint, mkdir_if_missing
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.evaluation import Classification,EvaluatorBase
from pygrad.pcgrad import PCGrad
from datasets.data_manager import DataManager
from dassl.data.datasets import build_dataset
@@ -35,16 +34,6 @@ from .util import GradCAM,denorm
import cv2
_tokenizer = _Tokenizer()
BACKGROUND_CATEGORY = ['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',]
#['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
#'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',
#]
BACKGROUND_CATEGORY_FOOD = ['table','forks','tablecloth','hands','spoon','glasses','dishes']
def load_clip_to_cpu(cfg):
backbone_name = cfg.MODEL.BACKBONE.NAME
url = clip._MODELS[backbone_name]
@@ -159,28 +148,18 @@ class MultiModalPromptLearner(nn.Module):
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
###Introduce Background
bg_template = 'a clean origami {}.'
bg_classesnames = [bg_template.format(name) for name in BACKGROUND_CATEGORY +BACKGROUND_CATEGORY_FOOD ]
tokenized_bg_prompts = torch.cat([clip.tokenize(bg) for bg in bg_classesnames])
bg_num = len(BACKGROUND_CATEGORY) + len(BACKGROUND_CATEGORY_FOOD)
tokenized_prompts = torch.cat((tokenized_prompts,tokenized_bg_prompts),dim=0)
with torch.no_grad():
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
self.bg_embeding = embedding[-bg_num:]
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
# those computed using the current class names
self.register_buffer("token_prefix", embedding[:-bg_num, :1, :]) # SOS
self.register_buffer("token_suffix", embedding[:-bg_num, 1 + n_ctx:, :]) # CLS, EOS
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
self.n_cls = n_cls
self.n_ctx = n_ctx
self.tokenized_prompts = tokenized_prompts # torch.Tensor [class_num 77] [:-bg_num]
self.tokenized_prompts = tokenized_prompts # torch.Tensor [class_num, 77]
self.name_lens = name_lens
def construct_prompts(self, ctx, prefix, suffix, label=None):
@@ -204,8 +183,7 @@ class MultiModalPromptLearner(nn.Module):
dim=1,
)
final_prompts = torch.cat((prompts,self.bg_embeding.cuda()),dim=0)
return final_prompts
return prompts
def forward(self):
ctx = self.ctx
@@ -264,17 +242,44 @@ class CustomCLIP(nn.Module):
def cos_sim(self,a,b):
return F.cosine_similarity(a,b)
def contrastive_loss(self, anchor, positive, negative, temperature=0.07):
"""
InfoNCE contrastive loss for foreground-background discrimination
Args:
anchor: Complete image features [B, D]
positive: Foreground features [B, D]
negative: Background features [B, D]
temperature: Temperature parameter for softmax
Returns:
loss: Contrastive learning loss value
"""
# Calculate similarity
sim_pos = F.cosine_similarity(anchor, positive, dim=-1) # [B]
sim_neg = F.cosine_similarity(anchor, negative, dim=-1) # [B]
# Apply temperature scaling
sim_pos = sim_pos / temperature
sim_neg = sim_neg / temperature
# InfoNCE loss: -log(exp(sim_pos) / (exp(sim_pos) + exp(sim_neg)))
logits = torch.stack([sim_pos, sim_neg], dim=1) # [B, 2]
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
loss = F.cross_entropy(logits, labels)
return loss
def forward(self, image, label=None,record=False,cal_gradient=False,weight=None,epoch=None,index=None,cfg=None,mask=None):
tokenized_prompts = self.tokenized_prompts
logit_scale = self.logit_scale.exp()
prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()
text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
text_features_fg = text_features[:-len(BACKGROUND_CATEGORY)]
ori_image_input = image.type(self.dtype)
# text_features = text_features + self.get_learnable_noise(text_features.shape)
text_features_fg = text_features_fg / text_features_fg.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
image_features, visual_ctx, mask_similarity = self.image_encoder(ori_image_input, shared_ctx,
deep_compound_prompts_vision)
@@ -285,18 +290,15 @@ class CustomCLIP(nn.Module):
# if label is not None:
# image_features = image_features + self.get_uniform_ball_noise(image_features.shape)
logits = logit_scale * image_features @ text_features_fg.t()
logits = logit_scale * image_features @ text_features.t()
loss_re = torch.tensor(0.0, dtype=self.dtype, device=image.device)
loss_fg = torch.tensor(0.0, dtype=self.dtype, device=image.device)
loss_contrastive = torch.tensor(0.0, dtype=self.dtype, device=image.device)
if mask != None:
text_features_bg = text_features[-len(BACKGROUND_CATEGORY):]
text_features_bg = text_features_bg / text_features_bg.norm(dim=-1, keepdim=True)
image_features_fg,_,_ = self.image_encoder(ori_image_input*mask, shared_ctx, deep_compound_prompts_vision) #, shared_ctx, deep_compound_prompts_vision
@@ -305,37 +307,27 @@ class CustomCLIP(nn.Module):
image_features_bg = image_features_bg / image_features_bg.norm(dim=-1, keepdim=True)
loss_re1 = F.triplet_margin_loss(image_features,image_features_fg.detach(),image_features_bg.detach(),margin=1.5)
# image_features_fg_ori = self.image_encoder_ori(ori_image_input*mask_random)
# image_features_bg_ori = self.image_encoder_ori(ori_image_input*(1-mask_random))
# image_features_fg_ori = image_features_fg_ori / image_features_fg_ori.norm(dim=-1, keepdim=True)
# image_features_bg_ori = image_features_bg_ori / image_features_bg_ori.norm(dim=-1,keepdim=True)
# image_features_all_ori = image_features_fg_ori + image_features_bg_ori
# image_features_all_ori = image_features_all_ori / image_features_all_ori.norm(dim=-1,keepdim=True)
# loss_reo = torch.abs(image_features_all_ori.detach() - image_features).mean()
foreground_score = logit_scale*image_features_fg.detach()@text_features_fg.t()
pseudo_label = torch.argmax(image_features_bg @ text_features_bg.t(), dim=-1)
logits_bg = logit_scale*(image_features_bg) @ text_features_bg.t()
para_bg = 0.5
para_fg = 0.1
para_vd = 0.8
loss_contrastive = self.contrastive_loss(image_features, image_features_fg.detach(), image_features_bg.detach(), temperature=0.07)
loss_bg = F.cross_entropy(logits_bg,pseudo_label)
loss_fg = F.cross_entropy(foreground_score,label)
para_fg = 0.2
para_vd = 0.6
if epoch > 6: #Tunable parameters
loss_re = para_fg*loss_fg + para_bg*loss_bg
if label is not None:
loss_fg = F.cross_entropy(logit_scale*image_features_fg.detach()@text_features.t(), label)
else:
loss_re = para_vd*loss_re1 #loss_reo would be effective in base2novel setting
loss_fg = torch.tensor(0.0, dtype=self.dtype, device=image.device)
if epoch is not None and epoch > 6: #Tunable parameters
loss_re = para_fg*loss_fg
else:
loss_re = para_vd*loss_contrastive
if self.prompt_learner.training:
if weight is None:
return F.cross_entropy(logits,label)+loss_re,logits,{'loss_vd':loss_re1.item(),'loss_bg':loss_bg.item(),'loss_fg':loss_fg.item()}
return F.cross_entropy(logits,label)+loss_re,logits,{'loss_contrastive':loss_contrastive.item(),'loss_fg':loss_fg.item()}
else:
return F.cross_entropy(weight.unsqueeze(-1)*logits,label), logits
@@ -674,8 +666,8 @@ class MaPLe(TrainerX):
model_name="model-best.pth.tar"
)
# if meet_checkpoint_freq or last_epoch:
# self.save_model(self.epoch, self.output_dir)
if meet_checkpoint_freq or last_epoch:
self.save_model(self.epoch, self.output_dir)
print(f"Now generate the attentive masking in {self.cfg.TRAINER.DAPT_MODE} \n")
+5 -12
View File
@@ -5,10 +5,6 @@ import cv2
from PIL import Image
import os
BACKGROUND_CATEGORY = ['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',
]
class GradCAM(object):
def __init__(self,model_dict):
layer_name = model_dict['layer_name']
@@ -80,7 +76,7 @@ class GradCAM(object):
else:
logit = self.model_arch.forward_test(input,labels,cfg=cfg)
pred_label = torch.argmax(logit[:,:-len(BACKGROUND_CATEGORY)])
pred_label = torch.argmax(logit)
sign = pred_label == labels
# if (split == 'true' and sign == False) or (split == 'wrong' and sign == True):
# print(f'Ignore the not {split} sample')
@@ -88,11 +84,10 @@ class GradCAM(object):
# if attn_mask:
# return final_cls_mask
pred = logit[:,:-len(BACKGROUND_CATEGORY)].argmax(dim=-1)
background_logit = logit[:,-len(BACKGROUND_CATEGORY):]
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]-len(BACKGROUND_CATEGORY)).to(torch.float16)
pred = logit.argmax(dim=-1)
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]).to(torch.float16)
loss = (F.softmax(logit[:,:-len(BACKGROUND_CATEGORY)])*one_hot_labels).mean() #+ background_logit.mean() #(logit[:,:-len(BACKGROUND_CATEGORY)]*one_hot_labels).mean() #F.cross_entropy(logit.requires_grad_(True), labels)
loss = (F.softmax(logit)*one_hot_labels).mean()
# score = logit[:,labels]
self.model_arch.zero_grad()
@@ -186,10 +181,8 @@ class GradCAM(object):
# if attn_mask:
# return final_cls_mask
# pred = logit[:,-len(BACKGROUND_CATEGORY):].argmax(dim=-1)
# background_logit = logit[:,-len(BACKGROUND_CATEGORY):]
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]).to(torch.float16)
loss = (logit*one_hot_labels).mean() #+ background_logit.mean() #(logit[:,:-len(BACKGROUND_CATEGORY)]*one_hot_labels).mean() #F.cross_entropy(logit.requires_grad_(True), labels)
loss = (logit*one_hot_labels).mean()
# score = logit[:,labels]
self.model_arch.zero_grad()
loss.backward(retain_graph=retain_graph)