init
This commit is contained in:
174
utils/loss_utils.py
Normal file
174
utils/loss_utils.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from utils.utils import process_zero_values
|
||||
import ipdb
|
||||
|
||||
|
||||
def _assert_no_grad(variable):
|
||||
assert not variable.requires_grad, \
|
||||
"nn criterions don't compute the gradient w.r.t. targets - please " \
|
||||
"mark these variables as volatile or not requiring gradients"
|
||||
|
||||
|
||||
class _Loss(nn.Module):
|
||||
def __init__(self, size_average=True):
|
||||
super(_Loss, self).__init__()
|
||||
self.size_average = size_average
|
||||
|
||||
|
||||
class _WeightedLoss(_Loss):
|
||||
def __init__(self, weight=None, size_average=True):
|
||||
super(_WeightedLoss, self).__init__(size_average)
|
||||
self.register_buffer('weight', weight)
|
||||
|
||||
|
||||
class CrossEntropyClassWeighted(_Loss):
|
||||
|
||||
def __init__(self, size_average=True, ignore_index=-100, reduce=None, reduction='elementwise_mean'):
|
||||
super(CrossEntropyClassWeighted, self).__init__(size_average)
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, input, target, weight=None):
|
||||
return F.cross_entropy(input, target, weight, ignore_index=self.ignore_index, reduction=self.reduction)
|
||||
|
||||
|
||||
### clone this function from: https://github.com/krumo/swd_pytorch/blob/master/swd_pytorch.py. [Unofficial]
|
||||
def discrepancy_slice_wasserstein(p1, p2):
|
||||
s = p1.shape
|
||||
if s[1] > 1:
|
||||
proj = torch.randn(s[1], 128).cuda()
|
||||
proj *= torch.rsqrt(torch.sum(torch.mul(proj, proj), 0, keepdim=True))
|
||||
p1 = torch.matmul(p1, proj)
|
||||
p2 = torch.matmul(p2, proj)
|
||||
p1 = torch.topk(p1, s[0], dim=0)[0]
|
||||
p2 = torch.topk(p2, s[0], dim=0)[0]
|
||||
dist = p1 - p2
|
||||
wdist = torch.mean(torch.mul(dist, dist))
|
||||
|
||||
return wdist
|
||||
|
||||
|
||||
class McDalNetLoss(_WeightedLoss):
|
||||
|
||||
def __init__(self, weight=None, size_average=True):
|
||||
super(McDalNetLoss, self).__init__(weight, size_average)
|
||||
|
||||
def forward(self, input1, input2, dis_type='L1'):
|
||||
|
||||
if dis_type == 'L1':
|
||||
prob_s = F.softmax(input1, dim=1)
|
||||
prob_t = F.softmax(input2, dim=1)
|
||||
loss = torch.mean(torch.abs(prob_s - prob_t)) ### element-wise
|
||||
elif dis_type == 'CE': ## Cross entropy
|
||||
loss = - ((F.log_softmax(input2, dim=1)).mul(F.softmax(input1, dim=1))).mean() - (
|
||||
(F.log_softmax(input1, dim=1)).mul(F.softmax(input2, dim=1))).mean()
|
||||
loss = loss * 0.5
|
||||
elif dis_type == 'KL': ##### averaged over elements, not the real KL div (summed over elements of instance, and averaged over instance)
|
||||
############# nn.KLDivLoss(size_average=False) Vs F.kl_div()
|
||||
loss = (F.kl_div(F.log_softmax(input1), F.softmax(input2))) + (
|
||||
F.kl_div(F.log_softmax(input2), F.softmax(input1)))
|
||||
loss = loss * 0.5
|
||||
############# the following two distances are not evaluated in our paper, and need further investigation
|
||||
elif dis_type == 'L2':
|
||||
nClass = input1.size()[1]
|
||||
prob_s = F.softmax(input1, dim=1)
|
||||
prob_t = F.softmax(input2, dim=1)
|
||||
loss = torch.norm(prob_s - prob_t, p=2, dim=1).mean() / nClass ### element-wise
|
||||
elif dis_type == 'Wasse': ## distance proposed in Sliced wasserstein discrepancy for unsupervised domain adaptation,
|
||||
prob_s = F.softmax(input1, dim=1)
|
||||
prob_t = F.softmax(input2, dim=1)
|
||||
loss = discrepancy_slice_wasserstein(prob_s, prob_t)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class TargetDiscrimLoss(_WeightedLoss):
|
||||
def __init__(self, weight=None, size_average=True, num_classes=31):
|
||||
super(TargetDiscrimLoss, self).__init__(weight, size_average)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, input):
|
||||
batch_size = input.size(0)
|
||||
prob = F.softmax(input, dim=1)
|
||||
|
||||
if (prob.data[:, self.num_classes:].sum(1) == 0).sum() != 0: ########### in case of log(0)
|
||||
soft_weight = torch.FloatTensor(batch_size).fill_(0)
|
||||
soft_weight[prob[:, self.num_classes:].sum(1).data.cpu() == 0] = 1e-6
|
||||
soft_weight_var = soft_weight.cuda()
|
||||
loss = -((prob[:, self.num_classes:].sum(1) + soft_weight_var).log().mean())
|
||||
else:
|
||||
loss = -(prob[:, self.num_classes:].sum(1).log().mean())
|
||||
return loss
|
||||
|
||||
class SourceDiscrimLoss(_WeightedLoss):
|
||||
def __init__(self, weight=None, size_average=True, num_classes=31):
|
||||
super(SourceDiscrimLoss, self).__init__(weight, size_average)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, input):
|
||||
batch_size = input.size(0)
|
||||
prob = F.softmax(input, dim=1)
|
||||
|
||||
if (prob.data[:, :self.num_classes].sum(1) == 0).sum() != 0: ########### in case of log(0)
|
||||
soft_weight = torch.FloatTensor(batch_size).fill_(0)
|
||||
soft_weight[prob[:, :self.num_classes].sum(1).data.cpu() == 0] = 1e-6
|
||||
soft_weight_var = soft_weight.cuda()
|
||||
loss = -((prob[:, :self.num_classes].sum(1) + soft_weight_var).log().mean())
|
||||
else:
|
||||
loss = -(prob[:, :self.num_classes].sum(1).log().mean())
|
||||
return loss
|
||||
|
||||
|
||||
class ConcatenatedCELoss(_WeightedLoss):
|
||||
def __init__(self, weight=None, size_average=True, num_classes=31):
|
||||
super(ConcatenatedCELoss, self).__init__(weight, size_average)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, input):
|
||||
prob = F.softmax(input, dim=1)
|
||||
prob_s = prob[:, :self.num_classes]
|
||||
prob_t = prob[:, self.num_classes:]
|
||||
|
||||
prob_s = process_zero_values(prob_s)
|
||||
prob_t = process_zero_values(prob_t)
|
||||
loss = - (prob_s.log().mul(prob_t)).sum(1).mean() - (prob_t.log().mul(prob_s)).sum(1).mean()
|
||||
loss = loss * 0.5
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
class ConcatenatedEMLoss(_WeightedLoss):
|
||||
def __init__(self, weight=None, size_average=True, num_classes=31):
|
||||
super(ConcatenatedEMLoss, self).__init__(weight, size_average)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, input):
|
||||
prob = F.softmax(input, dim=1)
|
||||
prob_s = prob[:, :self.num_classes]
|
||||
prob_t = prob[:, self.num_classes:]
|
||||
prob_sum = prob_s + prob_t
|
||||
prob_sum = process_zero_values(prob_sum)
|
||||
loss = - prob_sum.log().mul(prob_sum).sum(1).mean()
|
||||
|
||||
return loss
|
||||
|
||||
class MinEntropyConsensusLoss(nn.Module):
|
||||
def __init__(self, num_classes):
|
||||
super(MinEntropyConsensusLoss, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, x, y):
|
||||
i = torch.eye(self.num_classes).unsqueeze(0).cuda()
|
||||
x = F.log_softmax(x, dim=1)
|
||||
y = F.log_softmax(y, dim=1)
|
||||
x = x.unsqueeze(-1)
|
||||
y = y.unsqueeze(-1)
|
||||
|
||||
ce_x = (- 1.0 * i * x).sum(1)
|
||||
ce_y = (- 1.0 * i * y).sum(1)
|
||||
|
||||
ce = 0.5 * (ce_x + ce_y).min(1)[0].mean()
|
||||
|
||||
return ce
|
||||
432
utils/utils.py
Normal file
432
utils/utils.py
Normal file
@@ -0,0 +1,432 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from datasets import build_dataset
|
||||
from datasets.utils import build_data_loader
|
||||
from engine import partial_model
|
||||
from clip.model import ModifiedResNet, VisionTransformer
|
||||
import torchvision.transforms as transforms
|
||||
import os
|
||||
import random
|
||||
import logging
|
||||
from clip import clip
|
||||
def to_cuda(x):
|
||||
if torch.cuda.is_available():
|
||||
x = x.cuda()
|
||||
return x
|
||||
|
||||
def to_cpu(x):
|
||||
return x.cpu()
|
||||
|
||||
def to_numpy(x):
|
||||
if torch.cuda.is_available():
|
||||
x = x.cpu()
|
||||
return x.data.numpy()
|
||||
|
||||
def to_onehot(label, num_classes):
|
||||
identity = torch.eye(num_classes).to(label.device)
|
||||
onehot = torch.index_select(identity, 0, label)
|
||||
return onehot
|
||||
|
||||
def accuracy(output, target):
|
||||
"""Computes the precision"""
|
||||
batch_size = target.size(0)
|
||||
_, pred = output.topk(1, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
correct = correct[:1].view(-1).float().sum(0, keepdim=True)
|
||||
res = correct.mul_(100.0 / batch_size)
|
||||
return res
|
||||
|
||||
|
||||
def accuracy_for_each_class(output, target, total_vector, correct_vector):
|
||||
"""Computes the precision for each class"""
|
||||
batch_size = target.size(0)
|
||||
_, pred = output.topk(1, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1)).float().cpu().squeeze()
|
||||
for i in range(batch_size):
|
||||
total_vector[target[i]] += 1
|
||||
correct_vector[torch.LongTensor([target[i]])] += correct[i]
|
||||
|
||||
return total_vector, correct_vector
|
||||
|
||||
def recall_for_each_class(output, target, total_vector, correct_vector):
|
||||
"""Computes the recall for each class"""
|
||||
batch_size = target.size(0)
|
||||
_, pred = output.topk(1, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1)).float().cpu().squeeze()
|
||||
for i in range(batch_size):
|
||||
total_vector[pred[0][i]] += 1
|
||||
correct_vector[torch.LongTensor([pred[0][i]])] += correct[i]
|
||||
|
||||
return total_vector, correct_vector
|
||||
|
||||
def process_one_values(tensor):
|
||||
if (tensor == 1).sum() != 0:
|
||||
eps = torch.FloatTensor(tensor.size()).fill_(0)
|
||||
eps[tensor.data.cpu() == 1] = 1e-6
|
||||
tensor = tensor - eps.cuda()
|
||||
return tensor
|
||||
|
||||
def process_zero_values(tensor):
|
||||
if (tensor == 0).sum() != 0:
|
||||
eps = torch.FloatTensor(tensor.size()).fill_(0)
|
||||
eps[tensor.data.cpu() == 0] = 1e-6
|
||||
tensor = tensor + eps.cuda()
|
||||
return tensor
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def prepare_directories(args,CLIP_Text,CLIP_Image):
|
||||
"""检查并创建必要的目录和文件。"""
|
||||
# 创建目录(如果不存在)
|
||||
for dir_path in [args.filename_dir, os.path.join(args.savedir, f"{args.dataset_name}_epx", f"{args.shot}shot")]:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
print(f"{dir_path} directory is ready.")
|
||||
|
||||
# 创建文件(如果不存在)
|
||||
filename = os.path.join(args.filename_dir, f"{args.dataset_name}.txt")
|
||||
if not os.path.exists(filename):
|
||||
open(filename, 'a').close() # 'a' 模式会创建文件(如果不存在)
|
||||
print(f"Created {filename}")
|
||||
else:
|
||||
print(f"{filename} already exists.")
|
||||
#保存未解冻clip部分模型
|
||||
dir=args.savedir+args.dataset_name+'_epx/'+str(args.shot)+'shot'+'/'
|
||||
torch.save(CLIP_Text, dir + '/CLIP_Text.pth')
|
||||
torch.save(CLIP_Image, dir + '/CLIP_Image.pth')
|
||||
logging.basicConfig(filename=filename, level=logging.INFO)
|
||||
def set_seed(seed_value):
|
||||
"""设置所有随机种子以确保可重复性."""
|
||||
random.seed(seed_value) # Python random module.
|
||||
np.random.seed(seed_value) # Numpy module.
|
||||
torch.manual_seed(seed_value) # PyTorch for CPU operations.
|
||||
|
||||
# 如果您使用的是CUDA操作,请确保设置以下两项
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed_value) # PyTorch for current GPU.
|
||||
torch.cuda.manual_seed_all(seed_value) # PyTorch for all GPUs.
|
||||
|
||||
# 这些设置有助于提高可重复性,但可能会影响性能
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
# 设置PYTHONHASHSEED环境变量,影响Python的hash-based操作
|
||||
os.environ['PYTHONHASHSEED'] = str(seed_value)
|
||||
|
||||
#准备dataloader
|
||||
def get_dataset_loader(args,preprocess):
|
||||
dataset = build_dataset(args.dataset_name, args.dataset_dir, args.shot)
|
||||
classnames=dataset.classnames
|
||||
templates=dataset.template
|
||||
# 加载测试数据集
|
||||
loader = build_data_loader(data_source=dataset.test, batch_size=64, is_train=False, tfm=preprocess,
|
||||
shuffle=False)
|
||||
val_loader = build_data_loader(data_source=dataset.val, batch_size=64, is_train=False, tfm=preprocess,
|
||||
shuffle=False)
|
||||
|
||||
|
||||
# 加载训练数据集(可选,如果需要)
|
||||
train_tranform = transforms.Compose([
|
||||
transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
||||
])
|
||||
train_loader = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform,
|
||||
is_train=True,
|
||||
shuffle=True)
|
||||
|
||||
return classnames,templates, loader, train_loader,val_loader
|
||||
|
||||
def configure_clip_encoders(args,model,text_layer_idx,image_layer_idx):
|
||||
"""
|
||||
根据模型名称配置并返回CLIP的文本和图像编码器。"""
|
||||
if args.name =="ViT-B/16":
|
||||
CLIP_Text,Text_Encoder=partial_model.get_text(model,text_layer_idx)
|
||||
assert type(model.visual) == VisionTransformer
|
||||
CLIP_Image,Image_Encoder=partial_model.get_image_vit(model.visual, image_layer_idx)
|
||||
elif args.name =="ViT-B/32":
|
||||
CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx)
|
||||
assert type(model.visual) == VisionTransformer
|
||||
CLIP_Image, Image_Encoder = partial_model.get_image_vit(model.visual, image_layer_idx)
|
||||
elif args.name == "RN50":
|
||||
CLIP_Text,Text_Encoder =partial_model.get_text(model,text_layer_idx)
|
||||
assert type(model.visual) == ModifiedResNet
|
||||
CLIP_Image,Image_Encoder=partial_model.get_image_resnet(model.visual, image_layer_idx)
|
||||
elif args.name == "RN101":
|
||||
CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx)
|
||||
assert type(model.visual) == ModifiedResNet
|
||||
CLIP_Image, Image_Encoder = partial_model.get_image_resnet(model.visual, image_layer_idx)
|
||||
elif args.name == "RN50x16":
|
||||
CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx)
|
||||
assert type(model.visual) == ModifiedResNet
|
||||
CLIP_Image, Image_Encoder = partial_model.get_image_resnet(model.visual, image_layer_idx)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model name: {args.name}")
|
||||
|
||||
|
||||
|
||||
return CLIP_Text.cuda(), Text_Encoder.cuda(), CLIP_Image.cuda(), Image_Encoder.cuda()
|
||||
|
||||
def save_model(epoch,Text_Encoder,Image_Encoder,adapter, args,prec):
|
||||
"""保存模型和训练状态"""
|
||||
dir = args.savedir + args.dataset_name + '_epx/' + str(args.shot) + 'shot' + '/'
|
||||
save_dir = dir + '/epoch_' + str(epoch) + '_' + str(prec)
|
||||
if not os.path.isdir(save_dir):
|
||||
os.mkdir(save_dir)
|
||||
torch.save(Text_Encoder, os.path.join(save_dir,"Text_Encoder.pth"))
|
||||
torch.save(Image_Encoder, os.path.join(save_dir, "Image_Encoder.pth"))
|
||||
torch.save(adapter, os.path.join(save_dir, "adapter.pth"))
|
||||
|
||||
|
||||
|
||||
def set_adapter_weights(model,classnames,templates):
|
||||
zeroshot_weights = []
|
||||
for classname in classnames:
|
||||
classname = classname.replace('_', ' ')
|
||||
texts = [template.format(classname) for template in templates]
|
||||
texts = clip.tokenize(texts).cuda()
|
||||
class_embeddings = model.encode_text(texts)
|
||||
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
||||
class_embedding = class_embeddings.mean(dim=0)
|
||||
class_embedding /= class_embedding.norm()
|
||||
zeroshot_weights.append(class_embedding)
|
||||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||||
init_weight = torch.cat([zeroshot_weights, zeroshot_weights], dim=1).T
|
||||
# init_weight = torch.cat([zeroshot_weights, zeroshot_weights], dim=1).T
|
||||
|
||||
return init_weight
|
||||
|
||||
def set_adapter_weights_single(model,classnames,templates):
|
||||
zeroshot_weights = []
|
||||
for classname in classnames:
|
||||
classname = classname.replace('_', ' ')
|
||||
texts = [template.format(classname) for template in templates]
|
||||
texts = clip.tokenize(texts).cuda()
|
||||
class_embeddings = model.encode_text(texts)
|
||||
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
||||
class_embedding = class_embeddings.mean(dim=0)
|
||||
class_embedding /= class_embedding.norm()
|
||||
zeroshot_weights.append(class_embedding)
|
||||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||||
init_weight = zeroshot_weights.T
|
||||
# init_weight = torch.cat([zeroshot_weights, zeroshot_weights], dim=1).T
|
||||
|
||||
return init_weight
|
||||
|
||||
#文本模板特征
|
||||
def get_text_feature(classname, templates, CLIP_Text):
|
||||
with torch.no_grad():
|
||||
classname = classname.replace('_', ' ')
|
||||
str_prompts = [template.format(classname) for template in templates]
|
||||
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
|
||||
features, eot_indices = CLIP_Text(prompts)
|
||||
return features, eot_indices
|
||||
def get_text_feature_GPT(classname, templates, CLIP_Text,gpt3_prompt):
|
||||
with torch.no_grad():
|
||||
classname = classname.replace('_', ' ')
|
||||
# str_prompts = [template.format(classname) for template in templates]
|
||||
texts=[]
|
||||
for t in gpt3_prompt[classname]:
|
||||
texts.append(t)
|
||||
# str_prompts =str_prompts+texts
|
||||
str_prompts = texts
|
||||
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
|
||||
features, eot_indices = CLIP_Text(prompts)
|
||||
return features, eot_indices
|
||||
def text_feature(str_prompts, CLIP_Text):
|
||||
with torch.no_grad():
|
||||
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
|
||||
features, eot_indices = CLIP_Text(prompts)
|
||||
return features, eot_indices
|
||||
|
||||
def calculate_zeroshot_weights(classnames,label, templates, CLIP_Text, Text_Encoder):
|
||||
zeroshot_weights = []
|
||||
for i in range(len(label)):
|
||||
features, eot_indices = get_text_feature(classnames[label[i]], templates, CLIP_Text)
|
||||
class_embeddings = Text_Encoder(features, eot_indices)
|
||||
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
||||
class_embedding = class_embeddings.mean(dim=0)
|
||||
class_embedding = class_embedding / class_embedding.norm()
|
||||
zeroshot_weights.append(class_embedding)
|
||||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||||
return zeroshot_weights.T
|
||||
def calculate_zero(classnames,label, templates, model):
|
||||
zeroshot_weights = []
|
||||
for i in range(len(label)):
|
||||
|
||||
with torch.no_grad():
|
||||
classname = classnames[label[i]].replace('_', ' ')
|
||||
str_prompts = [template.format(classname) for template in templates]
|
||||
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
|
||||
class_embeddings = model(prompts)
|
||||
|
||||
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
||||
class_embedding = class_embeddings.mean(dim=0)
|
||||
class_embedding = class_embedding / class_embedding.norm()
|
||||
zeroshot_weights.append(class_embedding)
|
||||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||||
return zeroshot_weights.T
|
||||
def calculate_zeroshot_weights_GPT(classnames,label, templates, CLIP_Text, Text_Encoder,gpt3_prompt):
|
||||
zeroshot_weights = []
|
||||
for i in range(len(label)):
|
||||
features, eot_indices = get_text_feature_GPT(classnames[label[i]], templates, CLIP_Text,gpt3_prompt)
|
||||
class_embeddings = Text_Encoder(features, eot_indices)
|
||||
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
||||
class_embedding = class_embeddings.mean(dim=0)
|
||||
class_embedding = class_embedding / class_embedding.norm()
|
||||
zeroshot_weights.append(class_embedding)
|
||||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||||
return zeroshot_weights.T
|
||||
|
||||
def gpt_clip_classifier(classnames,gpt3_prompt, CLIP_Text, Text_Encoder):
|
||||
zeroshot_weights = []
|
||||
with torch.no_grad():
|
||||
for classname in classnames:
|
||||
# Tokenize the prompts
|
||||
classname = classname.replace('_', ' ')
|
||||
texts = []
|
||||
for t in gpt3_prompt[classname]:
|
||||
texts.append(t)
|
||||
texts = clip.tokenize(texts).cuda()
|
||||
features, eot_indices = CLIP_Text(texts)
|
||||
class_embeddings = Text_Encoder(features, eot_indices)
|
||||
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
||||
class_embedding = class_embeddings.mean(dim=0)
|
||||
class_embedding = class_embedding / class_embedding.norm()
|
||||
zeroshot_weights.append(class_embedding)
|
||||
# prompt ensemble for ImageNet
|
||||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||||
|
||||
return zeroshot_weights.T
|
||||
|
||||
def all_classifier(classnames, templates, model):
|
||||
with torch.no_grad():
|
||||
zeroshot_weights = []
|
||||
for classname in classnames:
|
||||
classname = classname.replace('_', ' ')
|
||||
texts = [template.format(classname) for template in templates] # format with class
|
||||
texts = clip.tokenize(texts).cuda() # tokenizeclip.tokenize向量化文字
|
||||
class_embeddings = model.encode_text(texts) # embed with text encoder
|
||||
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
||||
class_embedding = class_embeddings.mean(dim=0)
|
||||
class_embedding /= class_embedding.norm()
|
||||
zeroshot_weights.append(class_embedding)
|
||||
|
||||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||||
return zeroshot_weights
|
||||
def all_classifier_GPT(classnames, gpt3_prompt,model):
|
||||
with torch.no_grad():
|
||||
zeroshot_weights = []
|
||||
for classname in classnames:
|
||||
classname = classname.replace('_', ' ')
|
||||
texts = []
|
||||
for t in gpt3_prompt[classname]:
|
||||
texts.append(t)
|
||||
# str_prompts =str_prompts+texts
|
||||
str_prompts = texts
|
||||
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
|
||||
class_embeddings = model.encode_text(prompts) # embed with text encoder
|
||||
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
||||
class_embedding = class_embeddings.mean(dim=0)
|
||||
class_embedding /= class_embedding.norm()
|
||||
zeroshot_weights.append(class_embedding)
|
||||
|
||||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||||
return zeroshot_weights
|
||||
def all_classifier_GPTWithCLIP(classnames, gpt3_prompt,model, templates):
|
||||
with torch.no_grad():
|
||||
zeroshot_weights = []
|
||||
for classname in classnames:
|
||||
classname = classname.replace('_', ' ')
|
||||
texts1 = [template.format(classname) for template in templates]
|
||||
texts2 = []
|
||||
for t in gpt3_prompt[classname]:
|
||||
texts2.append(t)
|
||||
# str_prompts =str_prompts+texts
|
||||
str_prompts = texts1+texts2
|
||||
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
|
||||
class_embeddings = model.encode_text(prompts) # embed with text encoder
|
||||
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
||||
class_embedding = class_embeddings.mean(dim=0)
|
||||
class_embedding /= class_embedding.norm()
|
||||
zeroshot_weights.append(class_embedding)
|
||||
|
||||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||||
return zeroshot_weights
|
||||
|
||||
# def calculate_zeroshot_weights(classnames, label, templates, CLIP_Text, Text_Encoder):
|
||||
# zeroshot_weights = None
|
||||
# labels = [] # 初始化labels为一个空列表,稍后将其转换为Tensor
|
||||
# for i in range(len(label)):
|
||||
# features, eot_indices = get_text_feature(classnames[label[i]], templates, CLIP_Text)
|
||||
# class_embeddings = Text_Encoder(features, eot_indices)
|
||||
#
|
||||
# # 如果是第一次迭代,直接赋值;否则,进行拼接
|
||||
# if zeroshot_weights is None:
|
||||
# zeroshot_weights = class_embeddings
|
||||
# else:
|
||||
# zeroshot_weights = torch.cat((zeroshot_weights, class_embeddings), dim=0)
|
||||
#
|
||||
# # 对于每个label,重复len(templates)次,然后将结果添加到labels列表中
|
||||
# labels.extend([label[i].item()] * len(templates))
|
||||
#
|
||||
# # 将labels列表转换为Tensor
|
||||
# labels = torch.tensor(labels, dtype=torch.long).cuda()
|
||||
#
|
||||
# return zeroshot_weights, labels
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
_2, pred2 = output.topk(1, 1, True, True)
|
||||
a = target.view(1, -1)
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
# print(correct)
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
Reference in New Issue
Block a user