Files
DAPT/trainers/util.py
2025-10-07 22:42:55 +08:00

266 lines
13 KiB
Python

import torch
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
import os
BACKGROUND_CATEGORY = ['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',
]
class GradCAM(object):
def __init__(self,model_dict):
layer_name = model_dict['layer_name']
self.model_arch = model_dict['arch']
self.gradient = dict()
self.activation = dict()
self.gradient_t = dict()
self.activation_t = dict()
def backward_hook(module,grad_input,grad_output):
self.gradient['value'] = grad_output[0]
return None
def forward_hook(module,input,output):
self.activation['value'] = output
return None
def backward_hook_t(module,grad_input,grad_output):
self.gradient_t['value'] = grad_output[0]
return None
def forward_hook_t(module,input,output):
self.activation_t['value'] = output
return None
target_layer = self.model_arch.image_encoder.transformer.resblocks[-1].ln_1
# target_layer_t = self.model_arch.image_encoder.transformer.resblocks[-2].mlp.c_proj
target_layer.register_forward_hook(forward_hook)
target_layer.register_backward_hook(backward_hook)
# target_layer_t.register_forward_hook(forward_hook_t)
# target_layer_t.register_backward_hook(backward_hook_t)
def forward(self,input,labels,cfg=None,retain_graph=False,split=None,attn_mask=False):
b,c,h,w = input.shape
patch_num,ori_size = self.model_arch.image_encoder.patch_num, self.model_arch.image_encoder.input_resolution
if attn_mask:
logit,mask = self.model_arch.forward_test(input,labels,cfg=cfg,attn_mask=attn_mask)
cls_mask = mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].reshape(b,-1,patch_num,patch_num) #+ mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].permute(0,2,1)
aff = mask[:,1:-self.model_arch.prompt_learner.n_ctx, 1:-self.model_arch.prompt_learner.n_ctx]
# aff = (aff + aff.permute(0,2,1)) / 2
aff = aff / (aff.sum(dim=1,keepdim=True) + 1e-6)
# aff = aff / (aff.sum(dim=1,keepdim=True) + 1e-6)
# aff = (aff + aff.permute(0,2,1)) / 2
# aff = torch.bmm(aff,aff)
# aff = F.softmax(aff,dim=1)
# cls_mask = torch.bmm(cls_mask, aff).reshape(b,-1,patch_num,patch_num)
# cls_mask = mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].permute(0,2,1).reshape(b,-1,patch_num,patch_num)
# # cls_mask = mask[:,-self.model_arch.prompt_learner.n_ctx:,1:-self.model_arch.prompt_learner.n_ctx].reshape(b,-1,patch_num,patch_num).mean(dim=1,keepdim=True)
# final_cls_mask = F.upsample(cls_mask, size=(ori_size, ori_size), mode='bilinear',
# align_corners=True)
# final_cls_feature_min, final_cls_feature_max = final_cls_mask.min(), final_cls_mask.max()
# final_cls_mask = (final_cls_mask - final_cls_feature_min) / (
# final_cls_feature_max - final_cls_feature_min + 1e-6)
# final_cls_mask = final_cls_mask / (final_cls_mask.max() + 1e-6)
else:
logit = self.model_arch.forward_test(input,labels,cfg=cfg)
pred_label = torch.argmax(logit[:,:-len(BACKGROUND_CATEGORY)])
sign = pred_label == labels
# if (split == 'true' and sign == False) or (split == 'wrong' and sign == True):
# print(f'Ignore the not {split} sample')
# return None
# if attn_mask:
# return final_cls_mask
pred = logit[:,:-len(BACKGROUND_CATEGORY)].argmax(dim=-1)
background_logit = logit[:,-len(BACKGROUND_CATEGORY):]
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]-len(BACKGROUND_CATEGORY)).to(torch.float16)
loss = (F.softmax(logit[:,:-len(BACKGROUND_CATEGORY)])*one_hot_labels).mean() #+ background_logit.mean() #(logit[:,:-len(BACKGROUND_CATEGORY)]*one_hot_labels).mean() #F.cross_entropy(logit.requires_grad_(True), labels)
# score = logit[:,labels]
self.model_arch.zero_grad()
loss.backward(retain_graph=retain_graph)
gradients = self.gradient['value']
activations = self.activation['value']
# gradients_t = self.gradient_t['value']
# activations_t = self.activation_t['value']
visual_feature = activations[1:-self.model_arch.prompt_learner.n_ctx]
# visual_feature = activations[1:-self.model_arch.prompt_learner.n_ctx]
# cls = gradients[1:-self.model_arch.prompt_learner.n_ctx,:,:]
# cls_token_gradient = gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].mean(dim=0,keepdim=True)#gradients[:1,:,:]
cls_token_gradient,prompt_gradient = gradients[:1,:,:], gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].mean(keepdim=True,dim=0)
visual_gradient = torch.mean(gradients[1:-self.model_arch.prompt_learner.n_ctx],keepdim=True,dim=0)
lam = 0.5
# cls_token_gradient = cls_token_gradient / (cls_token_gradient.max(dim=-1,keepdim=True)[0] + 1e-6)
# prompt_gradient = prompt_gradient / (prompt_gradient.max(dim=-1,keepdim=True)[0] + 1e-6)
# sim = F.cosine_similarity(prompt_gradient.mean(dim=0,keepdim=True),cls_token_gradient,dim=-1)
# print(sim)
# cls_token_gradient = gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].max(dim=0,keepdim=True)[0]#gradients[:1,:,:]
# token_gradient = cls_token_gradient
# token_gradient = cls_token_gradient#*(prompt_gradient.mean(dim=0,keepdim=True))
# propmt_mean = prompt_gradient.mean(dim=0,keepdim=True)
token_gradient = visual_gradient
final_visual_feature = torch.bmm(visual_feature.permute(1,0,2),token_gradient.permute(1,2,0))
final_visual_feature = F.relu(final_visual_feature).permute(0,2,1)
# if attn_mask:
# final_visual_feature = torch.bmm(final_visual_feature, aff)
final_visual_feature = final_visual_feature.reshape(final_visual_feature.shape[0],1, patch_num, patch_num)
final_visual_feature = F.upsample(final_visual_feature,size=(ori_size,ori_size),mode='bilinear',align_corners=True)
# saliency_map = final_visual_feature / final_visual_feature.max()
final_visual_feature_min, final_visual_feature_max = final_visual_feature.min(), final_visual_feature.max()
saliency_map = final_visual_feature / (final_visual_feature_max + 1e-6)#(final_visual_feature-final_visual_feature_min) / (final_visual_feature_max - final_visual_feature_min + 1e-6)
threshold = 0.5
# saliency_map[saliency_map >= threshold] = 1
saliency_map[saliency_map < threshold] = 0
return saliency_map
def forward_train(self,input,labels,cfg=None,retain_graph=False,split=None,attn_mask=False):
b,c,h,w = input.shape
patch_num,ori_size = self.model_arch.image_encoder.patch_num, self.model_arch.image_encoder.input_resolution
if attn_mask:
logit,mask = self.model_arch.forward_test(input,labels,cfg=cfg,attn_mask=attn_mask)
cls_mask = mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].reshape(b,-1,patch_num,patch_num) #+ mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].permute(0,2,1)
aff = mask[:,1:-self.model_arch.prompt_learner.n_ctx, 1:-self.model_arch.prompt_learner.n_ctx]
# aff = (aff + aff.permute(0,2,1)) / 2
aff = aff / (aff.sum(dim=1,keepdim=True) + 1e-6)
# aff = aff / (aff.sum(dim=1,keepdim=True) + 1e-6)
# aff = (aff + aff.permute(0,2,1)) / 2
# aff = torch.bmm(aff,aff)
# aff = F.softmax(aff,dim=1)
# cls_mask = torch.bmm(cls_mask, aff).reshape(b,-1,patch_num,patch_num)
# cls_mask = mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].permute(0,2,1).reshape(b,-1,patch_num,patch_num)
# # cls_mask = mask[:,-self.model_arch.prompt_learner.n_ctx:,1:-self.model_arch.prompt_learner.n_ctx].reshape(b,-1,patch_num,patch_num).mean(dim=1,keepdim=True)
# final_cls_mask = F.upsample(cls_mask, size=(ori_size, ori_size), mode='bilinear',
# align_corners=True)
# final_cls_feature_min, final_cls_feature_max = final_cls_mask.min(), final_cls_mask.max()
# final_cls_mask = (final_cls_mask - final_cls_feature_min) / (
# final_cls_feature_max - final_cls_feature_min + 1e-6)
# final_cls_mask = final_cls_mask / (final_cls_mask.max() + 1e-6)
else:
logit = self.model_arch.forward_test(input,labels,cfg=cfg)
pred_label = torch.argmax(logit)
sign = pred_label == labels
# if (split == 'true' and sign == False) or (split == 'wrong' and sign == True):
# print(f'Ignore the not {split} sample')
# return None
# if attn_mask:
# return final_cls_mask
# pred = logit[:,-len(BACKGROUND_CATEGORY):].argmax(dim=-1)
# background_logit = logit[:,-len(BACKGROUND_CATEGORY):]
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]).to(torch.float16)
loss = (logit*one_hot_labels).mean() #+ background_logit.mean() #(logit[:,:-len(BACKGROUND_CATEGORY)]*one_hot_labels).mean() #F.cross_entropy(logit.requires_grad_(True), labels)
# score = logit[:,labels]
self.model_arch.zero_grad()
loss.backward(retain_graph=retain_graph)
gradients = self.gradient['value']
activations = self.activation['value']
# gradients_t = self.gradient_t['value']
# activations_t = self.activation_t['value']
visual_feature = activations[1:-self.model_arch.prompt_learner.n_ctx]
# visual_feature = activations[1:-self.model_arch.prompt_learner.n_ctx]
# cls = gradients[1:-self.model_arch.prompt_learner.n_ctx,:,:]
# cls_token_gradient = gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].mean(dim=0,keepdim=True)#gradients[:1,:,:]
cls_token_gradient,prompt_gradient = gradients[:1,:,:], gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].mean(keepdim=True,dim=0)
visual_gradient = torch.mean(gradients[1:-self.model_arch.prompt_learner.n_ctx],keepdim=True,dim=0)
lam = 0.5
# cls_token_gradient = cls_token_gradient / (cls_token_gradient.max(dim=-1,keepdim=True)[0] + 1e-6)
# prompt_gradient = prompt_gradient / (prompt_gradient.max(dim=-1,keepdim=True)[0] + 1e-6)
# sim = F.cosine_similarity(prompt_gradient.mean(dim=0,keepdim=True),cls_token_gradient,dim=-1)
# print(sim)
# cls_token_gradient = gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].max(dim=0,keepdim=True)[0]#gradients[:1,:,:]
# token_gradient = cls_token_gradient
# token_gradient = cls_token_gradient#*(prompt_gradient.mean(dim=0,keepdim=True))
# propmt_mean = prompt_gradient.mean(dim=0,keepdim=True)
token_gradient = visual_gradient
final_visual_feature = torch.bmm(visual_feature.permute(1,0,2),token_gradient.permute(1,2,0))
final_visual_feature = F.relu(final_visual_feature).permute(0,2,1)
# if attn_mask:
# final_visual_feature = torch.bmm(final_visual_feature, aff)
final_visual_feature = final_visual_feature.reshape(final_visual_feature.shape[0],1, patch_num, patch_num)
final_visual_feature = F.upsample(final_visual_feature,size=(ori_size,ori_size),mode='bilinear',align_corners=True)
# saliency_map = final_visual_feature / final_visual_feature.max()
final_visual_feature_min, final_visual_feature_max = final_visual_feature.min(), final_visual_feature.max()
saliency_map = final_visual_feature / (final_visual_feature_max + 1e-6)#(final_visual_feature-final_visual_feature_min) / (final_visual_feature_max - final_visual_feature_min + 1e-6)
threshold = 0.5
saliency_map[saliency_map >= threshold] = 1
saliency_map[saliency_map < threshold] = 0
return saliency_map
def show_cam(self,img,mask,save_path=None):
heat_map = cv2.applyColorMap(np.uint8(255*mask.squeeze()), cv2.COLORMAP_JET)
heatmap = torch.from_numpy(heat_map).permute(2,0,1).float().div(255)
b,g,r = heatmap.split(1)
heatmap = torch.cat([r,g,b])
rate = 0.5
res = rate*heatmap + (1-rate)*img
res = res.div(res.max()).squeeze()
res = np.transpose(np.uint8(255*res),(1,2,0))
pil_image = Image.fromarray(res)
# pil_image.save('test1.jpg')
pil_image.save(save_path)
return pil_image
def denorm(img,mean,std):
mean,std = np.array(mean),np.array(std)
img = img*std[:, None, None] + mean[:, None, None]
# img = np.clip(img*255, 0, 255) #.clamp(0,255)
# img = img / 255
return img