temp
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -19,7 +19,6 @@ from dassl.metrics import compute_accuracy
|
||||
from dassl.utils import load_pretrained_weights, load_checkpoint, mkdir_if_missing
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.evaluation import Classification,EvaluatorBase
|
||||
from pygrad.pcgrad import PCGrad
|
||||
from datasets.data_manager import DataManager
|
||||
from dassl.data.datasets import build_dataset
|
||||
|
||||
@@ -35,16 +34,6 @@ from .util import GradCAM,denorm
|
||||
import cv2
|
||||
_tokenizer = _Tokenizer()
|
||||
|
||||
BACKGROUND_CATEGORY = ['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
|
||||
'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',]
|
||||
|
||||
|
||||
#['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
|
||||
#'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',
|
||||
#]
|
||||
|
||||
BACKGROUND_CATEGORY_FOOD = ['table','forks','tablecloth','hands','spoon','glasses','dishes']
|
||||
|
||||
def load_clip_to_cpu(cfg):
|
||||
backbone_name = cfg.MODEL.BACKBONE.NAME
|
||||
url = clip._MODELS[backbone_name]
|
||||
@@ -159,28 +148,18 @@ class MultiModalPromptLearner(nn.Module):
|
||||
prompts = [prompt_prefix + " " + name + "." for name in classnames]
|
||||
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
|
||||
|
||||
|
||||
###Introduce Background
|
||||
bg_template = 'a clean origami {}.'
|
||||
|
||||
bg_classesnames = [bg_template.format(name) for name in BACKGROUND_CATEGORY +BACKGROUND_CATEGORY_FOOD ]
|
||||
tokenized_bg_prompts = torch.cat([clip.tokenize(bg) for bg in bg_classesnames])
|
||||
bg_num = len(BACKGROUND_CATEGORY) + len(BACKGROUND_CATEGORY_FOOD)
|
||||
tokenized_prompts = torch.cat((tokenized_prompts,tokenized_bg_prompts),dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
|
||||
self.bg_embeding = embedding[-bg_num:]
|
||||
|
||||
# These token vectors will be saved when in save_model(),
|
||||
# but they should be ignored in load_model() as we want to use
|
||||
# those computed using the current class names
|
||||
self.register_buffer("token_prefix", embedding[:-bg_num, :1, :]) # SOS
|
||||
self.register_buffer("token_suffix", embedding[:-bg_num, 1 + n_ctx:, :]) # CLS, EOS
|
||||
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
|
||||
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
|
||||
|
||||
self.n_cls = n_cls
|
||||
self.n_ctx = n_ctx
|
||||
self.tokenized_prompts = tokenized_prompts # torch.Tensor [class_num 77] [:-bg_num]
|
||||
self.tokenized_prompts = tokenized_prompts # torch.Tensor [class_num, 77]
|
||||
self.name_lens = name_lens
|
||||
|
||||
def construct_prompts(self, ctx, prefix, suffix, label=None):
|
||||
@@ -204,8 +183,7 @@ class MultiModalPromptLearner(nn.Module):
|
||||
dim=1,
|
||||
)
|
||||
|
||||
final_prompts = torch.cat((prompts,self.bg_embeding.cuda()),dim=0)
|
||||
return final_prompts
|
||||
return prompts
|
||||
|
||||
def forward(self):
|
||||
ctx = self.ctx
|
||||
@@ -264,17 +242,44 @@ class CustomCLIP(nn.Module):
|
||||
def cos_sim(self,a,b):
|
||||
return F.cosine_similarity(a,b)
|
||||
|
||||
def contrastive_loss(self, anchor, positive, negative, temperature=0.07):
|
||||
"""
|
||||
InfoNCE contrastive loss for foreground-background discrimination
|
||||
|
||||
Args:
|
||||
anchor: Complete image features [B, D]
|
||||
positive: Foreground features [B, D]
|
||||
negative: Background features [B, D]
|
||||
temperature: Temperature parameter for softmax
|
||||
|
||||
Returns:
|
||||
loss: Contrastive learning loss value
|
||||
"""
|
||||
# Calculate similarity
|
||||
sim_pos = F.cosine_similarity(anchor, positive, dim=-1) # [B]
|
||||
sim_neg = F.cosine_similarity(anchor, negative, dim=-1) # [B]
|
||||
|
||||
# Apply temperature scaling
|
||||
sim_pos = sim_pos / temperature
|
||||
sim_neg = sim_neg / temperature
|
||||
|
||||
# InfoNCE loss: -log(exp(sim_pos) / (exp(sim_pos) + exp(sim_neg)))
|
||||
logits = torch.stack([sim_pos, sim_neg], dim=1) # [B, 2]
|
||||
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
|
||||
|
||||
loss = F.cross_entropy(logits, labels)
|
||||
return loss
|
||||
|
||||
def forward(self, image, label=None,record=False,cal_gradient=False,weight=None,epoch=None,index=None,cfg=None,mask=None):
|
||||
tokenized_prompts = self.tokenized_prompts
|
||||
logit_scale = self.logit_scale.exp()
|
||||
|
||||
prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()
|
||||
text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
|
||||
text_features_fg = text_features[:-len(BACKGROUND_CATEGORY)]
|
||||
ori_image_input = image.type(self.dtype)
|
||||
# text_features = text_features + self.get_learnable_noise(text_features.shape)
|
||||
|
||||
text_features_fg = text_features_fg / text_features_fg.norm(dim=-1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
image_features, visual_ctx, mask_similarity = self.image_encoder(ori_image_input, shared_ctx,
|
||||
deep_compound_prompts_vision)
|
||||
@@ -285,18 +290,15 @@ class CustomCLIP(nn.Module):
|
||||
# if label is not None:
|
||||
# image_features = image_features + self.get_uniform_ball_noise(image_features.shape)
|
||||
|
||||
logits = logit_scale * image_features @ text_features_fg.t()
|
||||
|
||||
|
||||
|
||||
logits = logit_scale * image_features @ text_features.t()
|
||||
|
||||
loss_re = torch.tensor(0.0, dtype=self.dtype, device=image.device)
|
||||
loss_fg = torch.tensor(0.0, dtype=self.dtype, device=image.device)
|
||||
loss_contrastive = torch.tensor(0.0, dtype=self.dtype, device=image.device)
|
||||
|
||||
if mask != None:
|
||||
|
||||
|
||||
|
||||
text_features_bg = text_features[-len(BACKGROUND_CATEGORY):]
|
||||
text_features_bg = text_features_bg / text_features_bg.norm(dim=-1, keepdim=True)
|
||||
image_features_fg,_,_ = self.image_encoder(ori_image_input*mask, shared_ctx, deep_compound_prompts_vision) #, shared_ctx, deep_compound_prompts_vision
|
||||
|
||||
|
||||
@@ -305,37 +307,27 @@ class CustomCLIP(nn.Module):
|
||||
image_features_bg = image_features_bg / image_features_bg.norm(dim=-1, keepdim=True)
|
||||
|
||||
|
||||
loss_re1 = F.triplet_margin_loss(image_features,image_features_fg.detach(),image_features_bg.detach(),margin=1.5)
|
||||
|
||||
# image_features_fg_ori = self.image_encoder_ori(ori_image_input*mask_random)
|
||||
# image_features_bg_ori = self.image_encoder_ori(ori_image_input*(1-mask_random))
|
||||
# image_features_fg_ori = image_features_fg_ori / image_features_fg_ori.norm(dim=-1, keepdim=True)
|
||||
# image_features_bg_ori = image_features_bg_ori / image_features_bg_ori.norm(dim=-1,keepdim=True)
|
||||
# image_features_all_ori = image_features_fg_ori + image_features_bg_ori
|
||||
# image_features_all_ori = image_features_all_ori / image_features_all_ori.norm(dim=-1,keepdim=True)
|
||||
# loss_reo = torch.abs(image_features_all_ori.detach() - image_features).mean()
|
||||
|
||||
foreground_score = logit_scale*image_features_fg.detach()@text_features_fg.t()
|
||||
pseudo_label = torch.argmax(image_features_bg @ text_features_bg.t(), dim=-1)
|
||||
logits_bg = logit_scale*(image_features_bg) @ text_features_bg.t()
|
||||
|
||||
para_bg = 0.5
|
||||
para_fg = 0.1
|
||||
para_vd = 0.8
|
||||
loss_contrastive = self.contrastive_loss(image_features, image_features_fg.detach(), image_features_bg.detach(), temperature=0.07)
|
||||
|
||||
|
||||
loss_bg = F.cross_entropy(logits_bg,pseudo_label)
|
||||
loss_fg = F.cross_entropy(foreground_score,label)
|
||||
para_fg = 0.2
|
||||
para_vd = 0.6
|
||||
|
||||
if epoch > 6: #Tunable parameters
|
||||
loss_re = para_fg*loss_fg + para_bg*loss_bg
|
||||
|
||||
if label is not None:
|
||||
loss_fg = F.cross_entropy(logit_scale*image_features_fg.detach()@text_features.t(), label)
|
||||
else:
|
||||
loss_re = para_vd*loss_re1 #loss_reo would be effective in base2novel setting
|
||||
loss_fg = torch.tensor(0.0, dtype=self.dtype, device=image.device)
|
||||
|
||||
if epoch is not None and epoch > 6: #Tunable parameters
|
||||
loss_re = para_fg*loss_fg
|
||||
else:
|
||||
loss_re = para_vd*loss_contrastive
|
||||
|
||||
|
||||
if self.prompt_learner.training:
|
||||
if weight is None:
|
||||
return F.cross_entropy(logits,label)+loss_re,logits,{'loss_vd':loss_re1.item(),'loss_bg':loss_bg.item(),'loss_fg':loss_fg.item()}
|
||||
return F.cross_entropy(logits,label)+loss_re,logits,{'loss_contrastive':loss_contrastive.item(),'loss_fg':loss_fg.item()}
|
||||
else:
|
||||
return F.cross_entropy(weight.unsqueeze(-1)*logits,label), logits
|
||||
|
||||
@@ -674,8 +666,8 @@ class MaPLe(TrainerX):
|
||||
model_name="model-best.pth.tar"
|
||||
)
|
||||
|
||||
# if meet_checkpoint_freq or last_epoch:
|
||||
# self.save_model(self.epoch, self.output_dir)
|
||||
if meet_checkpoint_freq or last_epoch:
|
||||
self.save_model(self.epoch, self.output_dir)
|
||||
|
||||
print(f"Now generate the attentive masking in {self.cfg.TRAINER.DAPT_MODE} \n")
|
||||
|
||||
|
||||
@@ -5,10 +5,6 @@ import cv2
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
BACKGROUND_CATEGORY = ['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
|
||||
'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',
|
||||
]
|
||||
|
||||
class GradCAM(object):
|
||||
def __init__(self,model_dict):
|
||||
layer_name = model_dict['layer_name']
|
||||
@@ -80,7 +76,7 @@ class GradCAM(object):
|
||||
|
||||
else:
|
||||
logit = self.model_arch.forward_test(input,labels,cfg=cfg)
|
||||
pred_label = torch.argmax(logit[:,:-len(BACKGROUND_CATEGORY)])
|
||||
pred_label = torch.argmax(logit)
|
||||
sign = pred_label == labels
|
||||
# if (split == 'true' and sign == False) or (split == 'wrong' and sign == True):
|
||||
# print(f'Ignore the not {split} sample')
|
||||
@@ -88,11 +84,10 @@ class GradCAM(object):
|
||||
|
||||
# if attn_mask:
|
||||
# return final_cls_mask
|
||||
pred = logit[:,:-len(BACKGROUND_CATEGORY)].argmax(dim=-1)
|
||||
background_logit = logit[:,-len(BACKGROUND_CATEGORY):]
|
||||
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]-len(BACKGROUND_CATEGORY)).to(torch.float16)
|
||||
pred = logit.argmax(dim=-1)
|
||||
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]).to(torch.float16)
|
||||
|
||||
loss = (F.softmax(logit[:,:-len(BACKGROUND_CATEGORY)])*one_hot_labels).mean() #+ background_logit.mean() #(logit[:,:-len(BACKGROUND_CATEGORY)]*one_hot_labels).mean() #F.cross_entropy(logit.requires_grad_(True), labels)
|
||||
loss = (F.softmax(logit)*one_hot_labels).mean()
|
||||
|
||||
# score = logit[:,labels]
|
||||
self.model_arch.zero_grad()
|
||||
@@ -186,10 +181,8 @@ class GradCAM(object):
|
||||
|
||||
# if attn_mask:
|
||||
# return final_cls_mask
|
||||
# pred = logit[:,-len(BACKGROUND_CATEGORY):].argmax(dim=-1)
|
||||
# background_logit = logit[:,-len(BACKGROUND_CATEGORY):]
|
||||
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]).to(torch.float16)
|
||||
loss = (logit*one_hot_labels).mean() #+ background_logit.mean() #(logit[:,:-len(BACKGROUND_CATEGORY)]*one_hot_labels).mean() #F.cross_entropy(logit.requires_grad_(True), labels)
|
||||
loss = (logit*one_hot_labels).mean()
|
||||
# score = logit[:,labels]
|
||||
self.model_arch.zero_grad()
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
Reference in New Issue
Block a user