432 lines
18 KiB
Python
432 lines
18 KiB
Python
import torch
|
||
import numpy as np
|
||
from datasets import build_dataset
|
||
from datasets.utils import build_data_loader
|
||
from engine import partial_model
|
||
from clip.model import ModifiedResNet, VisionTransformer
|
||
import torchvision.transforms as transforms
|
||
import os
|
||
import random
|
||
import logging
|
||
from clip import clip
|
||
def to_cuda(x):
|
||
if torch.cuda.is_available():
|
||
x = x.cuda()
|
||
return x
|
||
|
||
def to_cpu(x):
|
||
return x.cpu()
|
||
|
||
def to_numpy(x):
|
||
if torch.cuda.is_available():
|
||
x = x.cpu()
|
||
return x.data.numpy()
|
||
|
||
def to_onehot(label, num_classes):
|
||
identity = torch.eye(num_classes).to(label.device)
|
||
onehot = torch.index_select(identity, 0, label)
|
||
return onehot
|
||
|
||
def accuracy(output, target):
|
||
"""Computes the precision"""
|
||
batch_size = target.size(0)
|
||
_, pred = output.topk(1, 1, True, True)
|
||
pred = pred.t()
|
||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||
|
||
correct = correct[:1].view(-1).float().sum(0, keepdim=True)
|
||
res = correct.mul_(100.0 / batch_size)
|
||
return res
|
||
|
||
|
||
def accuracy_for_each_class(output, target, total_vector, correct_vector):
|
||
"""Computes the precision for each class"""
|
||
batch_size = target.size(0)
|
||
_, pred = output.topk(1, 1, True, True)
|
||
pred = pred.t()
|
||
correct = pred.eq(target.view(1, -1)).float().cpu().squeeze()
|
||
for i in range(batch_size):
|
||
total_vector[target[i]] += 1
|
||
correct_vector[torch.LongTensor([target[i]])] += correct[i]
|
||
|
||
return total_vector, correct_vector
|
||
|
||
def recall_for_each_class(output, target, total_vector, correct_vector):
|
||
"""Computes the recall for each class"""
|
||
batch_size = target.size(0)
|
||
_, pred = output.topk(1, 1, True, True)
|
||
pred = pred.t()
|
||
correct = pred.eq(target.view(1, -1)).float().cpu().squeeze()
|
||
for i in range(batch_size):
|
||
total_vector[pred[0][i]] += 1
|
||
correct_vector[torch.LongTensor([pred[0][i]])] += correct[i]
|
||
|
||
return total_vector, correct_vector
|
||
|
||
def process_one_values(tensor):
|
||
if (tensor == 1).sum() != 0:
|
||
eps = torch.FloatTensor(tensor.size()).fill_(0)
|
||
eps[tensor.data.cpu() == 1] = 1e-6
|
||
tensor = tensor - eps.cuda()
|
||
return tensor
|
||
|
||
def process_zero_values(tensor):
|
||
if (tensor == 0).sum() != 0:
|
||
eps = torch.FloatTensor(tensor.size()).fill_(0)
|
||
eps[tensor.data.cpu() == 0] = 1e-6
|
||
tensor = tensor + eps.cuda()
|
||
return tensor
|
||
|
||
|
||
class AverageMeter(object):
|
||
"""Computes and stores the average and current value"""
|
||
def __init__(self):
|
||
self.reset()
|
||
|
||
def reset(self):
|
||
self.val = 0
|
||
self.avg = 0
|
||
self.sum = 0
|
||
self.count = 0
|
||
|
||
def update(self, val, n=1):
|
||
self.val = val
|
||
self.sum += val * n
|
||
self.count += n
|
||
self.avg = self.sum / self.count
|
||
|
||
def prepare_directories(args,CLIP_Text,CLIP_Image):
|
||
"""检查并创建必要的目录和文件。"""
|
||
# 创建目录(如果不存在)
|
||
for dir_path in [args.filename_dir, os.path.join(args.savedir, f"{args.dataset_name}_epx", f"{args.shot}shot")]:
|
||
os.makedirs(dir_path, exist_ok=True)
|
||
print(f"{dir_path} directory is ready.")
|
||
|
||
# 创建文件(如果不存在)
|
||
filename = os.path.join(args.filename_dir, f"{args.dataset_name}.txt")
|
||
if not os.path.exists(filename):
|
||
open(filename, 'a').close() # 'a' 模式会创建文件(如果不存在)
|
||
print(f"Created {filename}")
|
||
else:
|
||
print(f"{filename} already exists.")
|
||
#保存未解冻clip部分模型
|
||
dir=args.savedir+args.dataset_name+'_epx/'+str(args.shot)+'shot'+'/'
|
||
torch.save(CLIP_Text, dir + '/CLIP_Text.pth')
|
||
torch.save(CLIP_Image, dir + '/CLIP_Image.pth')
|
||
logging.basicConfig(filename=filename, level=logging.INFO)
|
||
def set_seed(seed_value):
|
||
"""设置所有随机种子以确保可重复性."""
|
||
random.seed(seed_value) # Python random module.
|
||
np.random.seed(seed_value) # Numpy module.
|
||
torch.manual_seed(seed_value) # PyTorch for CPU operations.
|
||
|
||
# 如果您使用的是CUDA操作,请确保设置以下两项
|
||
if torch.cuda.is_available():
|
||
torch.cuda.manual_seed(seed_value) # PyTorch for current GPU.
|
||
torch.cuda.manual_seed_all(seed_value) # PyTorch for all GPUs.
|
||
|
||
# 这些设置有助于提高可重复性,但可能会影响性能
|
||
torch.backends.cudnn.deterministic = True
|
||
torch.backends.cudnn.benchmark = False
|
||
|
||
# 设置PYTHONHASHSEED环境变量,影响Python的hash-based操作
|
||
os.environ['PYTHONHASHSEED'] = str(seed_value)
|
||
|
||
#准备dataloader
|
||
def get_dataset_loader(args,preprocess):
|
||
dataset = build_dataset(args.dataset_name, args.dataset_dir, args.shot)
|
||
classnames=dataset.classnames
|
||
templates=dataset.template
|
||
# 加载测试数据集
|
||
loader = build_data_loader(data_source=dataset.test, batch_size=64, is_train=False, tfm=preprocess,
|
||
shuffle=False)
|
||
val_loader = build_data_loader(data_source=dataset.val, batch_size=64, is_train=False, tfm=preprocess,
|
||
shuffle=False)
|
||
|
||
|
||
# 加载训练数据集(可选,如果需要)
|
||
train_tranform = transforms.Compose([
|
||
transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC),
|
||
transforms.RandomHorizontalFlip(p=0.5),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
||
])
|
||
train_loader = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform,
|
||
is_train=True,
|
||
shuffle=True)
|
||
|
||
return classnames,templates, loader, train_loader,val_loader
|
||
|
||
def configure_clip_encoders(args,model,text_layer_idx,image_layer_idx):
|
||
"""
|
||
根据模型名称配置并返回CLIP的文本和图像编码器。"""
|
||
if args.name =="ViT-B/16":
|
||
CLIP_Text,Text_Encoder=partial_model.get_text(model,text_layer_idx)
|
||
assert type(model.visual) == VisionTransformer
|
||
CLIP_Image,Image_Encoder=partial_model.get_image_vit(model.visual, image_layer_idx)
|
||
elif args.name =="ViT-B/32":
|
||
CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx)
|
||
assert type(model.visual) == VisionTransformer
|
||
CLIP_Image, Image_Encoder = partial_model.get_image_vit(model.visual, image_layer_idx)
|
||
elif args.name == "RN50":
|
||
CLIP_Text,Text_Encoder =partial_model.get_text(model,text_layer_idx)
|
||
assert type(model.visual) == ModifiedResNet
|
||
CLIP_Image,Image_Encoder=partial_model.get_image_resnet(model.visual, image_layer_idx)
|
||
elif args.name == "RN101":
|
||
CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx)
|
||
assert type(model.visual) == ModifiedResNet
|
||
CLIP_Image, Image_Encoder = partial_model.get_image_resnet(model.visual, image_layer_idx)
|
||
elif args.name == "RN50x16":
|
||
CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx)
|
||
assert type(model.visual) == ModifiedResNet
|
||
CLIP_Image, Image_Encoder = partial_model.get_image_resnet(model.visual, image_layer_idx)
|
||
else:
|
||
raise ValueError(f"Unsupported model name: {args.name}")
|
||
|
||
|
||
|
||
return CLIP_Text.cuda(), Text_Encoder.cuda(), CLIP_Image.cuda(), Image_Encoder.cuda()
|
||
|
||
def save_model(epoch,Text_Encoder,Image_Encoder,adapter, args,prec):
|
||
"""保存模型和训练状态"""
|
||
dir = args.savedir + args.dataset_name + '_epx/' + str(args.shot) + 'shot' + '/'
|
||
save_dir = dir + '/epoch_' + str(epoch) + '_' + str(prec)
|
||
if not os.path.isdir(save_dir):
|
||
os.mkdir(save_dir)
|
||
torch.save(Text_Encoder, os.path.join(save_dir,"Text_Encoder.pth"))
|
||
torch.save(Image_Encoder, os.path.join(save_dir, "Image_Encoder.pth"))
|
||
torch.save(adapter, os.path.join(save_dir, "adapter.pth"))
|
||
|
||
|
||
|
||
def set_adapter_weights(model,classnames,templates):
|
||
zeroshot_weights = []
|
||
for classname in classnames:
|
||
classname = classname.replace('_', ' ')
|
||
texts = [template.format(classname) for template in templates]
|
||
texts = clip.tokenize(texts).cuda()
|
||
class_embeddings = model.encode_text(texts)
|
||
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
||
class_embedding = class_embeddings.mean(dim=0)
|
||
class_embedding /= class_embedding.norm()
|
||
zeroshot_weights.append(class_embedding)
|
||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||
init_weight = torch.cat([zeroshot_weights, zeroshot_weights], dim=1).T
|
||
# init_weight = torch.cat([zeroshot_weights, zeroshot_weights], dim=1).T
|
||
|
||
return init_weight
|
||
|
||
def set_adapter_weights_single(model,classnames,templates):
|
||
zeroshot_weights = []
|
||
for classname in classnames:
|
||
classname = classname.replace('_', ' ')
|
||
texts = [template.format(classname) for template in templates]
|
||
texts = clip.tokenize(texts).cuda()
|
||
class_embeddings = model.encode_text(texts)
|
||
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
||
class_embedding = class_embeddings.mean(dim=0)
|
||
class_embedding /= class_embedding.norm()
|
||
zeroshot_weights.append(class_embedding)
|
||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||
init_weight = zeroshot_weights.T
|
||
# init_weight = torch.cat([zeroshot_weights, zeroshot_weights], dim=1).T
|
||
|
||
return init_weight
|
||
|
||
#文本模板特征
|
||
def get_text_feature(classname, templates, CLIP_Text):
|
||
with torch.no_grad():
|
||
classname = classname.replace('_', ' ')
|
||
str_prompts = [template.format(classname) for template in templates]
|
||
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
|
||
features, eot_indices = CLIP_Text(prompts)
|
||
return features, eot_indices
|
||
def get_text_feature_GPT(classname, templates, CLIP_Text,gpt3_prompt):
|
||
with torch.no_grad():
|
||
classname = classname.replace('_', ' ')
|
||
# str_prompts = [template.format(classname) for template in templates]
|
||
texts=[]
|
||
for t in gpt3_prompt[classname]:
|
||
texts.append(t)
|
||
# str_prompts =str_prompts+texts
|
||
str_prompts = texts
|
||
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
|
||
features, eot_indices = CLIP_Text(prompts)
|
||
return features, eot_indices
|
||
def text_feature(str_prompts, CLIP_Text):
|
||
with torch.no_grad():
|
||
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
|
||
features, eot_indices = CLIP_Text(prompts)
|
||
return features, eot_indices
|
||
|
||
def calculate_zeroshot_weights(classnames,label, templates, CLIP_Text, Text_Encoder):
|
||
zeroshot_weights = []
|
||
for i in range(len(label)):
|
||
features, eot_indices = get_text_feature(classnames[label[i]], templates, CLIP_Text)
|
||
class_embeddings = Text_Encoder(features, eot_indices)
|
||
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
||
class_embedding = class_embeddings.mean(dim=0)
|
||
class_embedding = class_embedding / class_embedding.norm()
|
||
zeroshot_weights.append(class_embedding)
|
||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||
return zeroshot_weights.T
|
||
def calculate_zero(classnames,label, templates, model):
|
||
zeroshot_weights = []
|
||
for i in range(len(label)):
|
||
|
||
with torch.no_grad():
|
||
classname = classnames[label[i]].replace('_', ' ')
|
||
str_prompts = [template.format(classname) for template in templates]
|
||
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
|
||
class_embeddings = model(prompts)
|
||
|
||
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
||
class_embedding = class_embeddings.mean(dim=0)
|
||
class_embedding = class_embedding / class_embedding.norm()
|
||
zeroshot_weights.append(class_embedding)
|
||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||
return zeroshot_weights.T
|
||
def calculate_zeroshot_weights_GPT(classnames,label, templates, CLIP_Text, Text_Encoder,gpt3_prompt):
|
||
zeroshot_weights = []
|
||
for i in range(len(label)):
|
||
features, eot_indices = get_text_feature_GPT(classnames[label[i]], templates, CLIP_Text,gpt3_prompt)
|
||
class_embeddings = Text_Encoder(features, eot_indices)
|
||
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
||
class_embedding = class_embeddings.mean(dim=0)
|
||
class_embedding = class_embedding / class_embedding.norm()
|
||
zeroshot_weights.append(class_embedding)
|
||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||
return zeroshot_weights.T
|
||
|
||
def gpt_clip_classifier(classnames,gpt3_prompt, CLIP_Text, Text_Encoder):
|
||
zeroshot_weights = []
|
||
with torch.no_grad():
|
||
for classname in classnames:
|
||
# Tokenize the prompts
|
||
classname = classname.replace('_', ' ')
|
||
texts = []
|
||
for t in gpt3_prompt[classname]:
|
||
texts.append(t)
|
||
texts = clip.tokenize(texts).cuda()
|
||
features, eot_indices = CLIP_Text(texts)
|
||
class_embeddings = Text_Encoder(features, eot_indices)
|
||
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
||
class_embedding = class_embeddings.mean(dim=0)
|
||
class_embedding = class_embedding / class_embedding.norm()
|
||
zeroshot_weights.append(class_embedding)
|
||
# prompt ensemble for ImageNet
|
||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||
|
||
return zeroshot_weights.T
|
||
|
||
def all_classifier(classnames, templates, model):
|
||
with torch.no_grad():
|
||
zeroshot_weights = []
|
||
for classname in classnames:
|
||
classname = classname.replace('_', ' ')
|
||
texts = [template.format(classname) for template in templates] # format with class
|
||
texts = clip.tokenize(texts).cuda() # tokenizeclip.tokenize向量化文字
|
||
class_embeddings = model.encode_text(texts) # embed with text encoder
|
||
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
||
class_embedding = class_embeddings.mean(dim=0)
|
||
class_embedding /= class_embedding.norm()
|
||
zeroshot_weights.append(class_embedding)
|
||
|
||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||
return zeroshot_weights
|
||
def all_classifier_GPT(classnames, gpt3_prompt,model):
|
||
with torch.no_grad():
|
||
zeroshot_weights = []
|
||
for classname in classnames:
|
||
classname = classname.replace('_', ' ')
|
||
texts = []
|
||
for t in gpt3_prompt[classname]:
|
||
texts.append(t)
|
||
# str_prompts =str_prompts+texts
|
||
str_prompts = texts
|
||
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
|
||
class_embeddings = model.encode_text(prompts) # embed with text encoder
|
||
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
||
class_embedding = class_embeddings.mean(dim=0)
|
||
class_embedding /= class_embedding.norm()
|
||
zeroshot_weights.append(class_embedding)
|
||
|
||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||
return zeroshot_weights
|
||
def all_classifier_GPTWithCLIP(classnames, gpt3_prompt,model, templates):
|
||
with torch.no_grad():
|
||
zeroshot_weights = []
|
||
for classname in classnames:
|
||
classname = classname.replace('_', ' ')
|
||
texts1 = [template.format(classname) for template in templates]
|
||
texts2 = []
|
||
for t in gpt3_prompt[classname]:
|
||
texts2.append(t)
|
||
# str_prompts =str_prompts+texts
|
||
str_prompts = texts1+texts2
|
||
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
|
||
class_embeddings = model.encode_text(prompts) # embed with text encoder
|
||
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
||
class_embedding = class_embeddings.mean(dim=0)
|
||
class_embedding /= class_embedding.norm()
|
||
zeroshot_weights.append(class_embedding)
|
||
|
||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||
return zeroshot_weights
|
||
|
||
# def calculate_zeroshot_weights(classnames, label, templates, CLIP_Text, Text_Encoder):
|
||
# zeroshot_weights = None
|
||
# labels = [] # 初始化labels为一个空列表,稍后将其转换为Tensor
|
||
# for i in range(len(label)):
|
||
# features, eot_indices = get_text_feature(classnames[label[i]], templates, CLIP_Text)
|
||
# class_embeddings = Text_Encoder(features, eot_indices)
|
||
#
|
||
# # 如果是第一次迭代,直接赋值;否则,进行拼接
|
||
# if zeroshot_weights is None:
|
||
# zeroshot_weights = class_embeddings
|
||
# else:
|
||
# zeroshot_weights = torch.cat((zeroshot_weights, class_embeddings), dim=0)
|
||
#
|
||
# # 对于每个label,重复len(templates)次,然后将结果添加到labels列表中
|
||
# labels.extend([label[i].item()] * len(templates))
|
||
#
|
||
# # 将labels列表转换为Tensor
|
||
# labels = torch.tensor(labels, dtype=torch.long).cuda()
|
||
#
|
||
# return zeroshot_weights, labels
|
||
|
||
|
||
class AverageMeter(object):
|
||
"""Computes and stores the average and current value"""
|
||
|
||
def __init__(self):
|
||
self.reset()
|
||
|
||
def reset(self):
|
||
self.val = 0
|
||
self.avg = 0
|
||
self.sum = 0
|
||
self.count = 0
|
||
|
||
def update(self, val, n=1):
|
||
self.val = val
|
||
self.sum += val * n
|
||
self.count += n
|
||
self.avg = self.sum / self.count
|
||
|
||
|
||
def accuracy(output, target, topk=(1,)):
|
||
"""Computes the precision@k for the specified values of k"""
|
||
maxk = max(topk)
|
||
batch_size = target.size(0)
|
||
_, pred = output.topk(maxk, 1, True, True)
|
||
pred = pred.t()
|
||
_2, pred2 = output.topk(1, 1, True, True)
|
||
a = target.view(1, -1)
|
||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||
# print(correct)
|
||
res = []
|
||
for k in topk:
|
||
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
|
||
res.append(correct_k.mul_(100.0 / batch_size))
|
||
return res |