This commit is contained in:
2024-05-21 19:41:56 +08:00
commit ca67205608
217 changed files with 201004 additions and 0 deletions

469
zero_test_imagenet.py Normal file
View File

@@ -0,0 +1,469 @@
import os
import time
from clip import clip
import torch.nn as nn
import numpy as np
import torch.optim
from opts import opts # The options for the project
# from trainer import validate # For the validate (test) process
from models.DomainClassifierTarget import DClassifierForTarget
from models.DomainClassifierSource import DClassifierForSource
from utils.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss
from utils.utils import prepare_directories, set_seed, get_dataset_loader, configure_clip_encoders, save_model, \
set_adapter_weights, get_text_feature, AverageMeter, accuracy, calculate_zeroshot_weights, gpt_clip_classifier, \
calculate_zeroshot_weights_GPT,calculate_zero,all_classifier_GPT,all_classifier_GPTWithCLIP
from Adapter import Weight_Adapter
import logging
import torch.nn.functional as F
import yaml
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import glob
from datasets.imagenet import ImageNet
class CustomCrossAttention(nn.Module):
def __init__(self, feature_dim):
super(CustomCrossAttention, self).__init__()
self.query_projection = nn.Linear(feature_dim, feature_dim)
self.key_projection = nn.Linear(feature_dim, feature_dim)
self.value_projection = nn.Linear(feature_dim, feature_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, text_features, image_features):
# 假设 text_features 的 batch_size < image_features 的 batch_size
text_batch_size = text_features.size(0)
image_batch_size = image_features.size(0)
# 重复 text_features 以匹配 image_features 的 batch_size
if text_batch_size < image_batch_size:
repeat_times = image_batch_size // text_batch_size
text_features = text_features.repeat(repeat_times, 1)
query = self.query_projection(text_features)
key = self.key_projection(image_features)
value = self.value_projection(image_features)
# 计算注意力分数
attention_scores = torch.matmul(query, key.transpose(-2, -1))
attention_scores = self.softmax(attention_scores)
# 应用注意力分数到 value 上
attended_features = torch.matmul(attention_scores, value)
return attended_features
def coral_loss(source_features, target_features):
"""
计算Deep CORAL损失。
:param source_features: 源域特征,维度为[batch_size, feature_dim]
:param target_features: 目标域特征,维度为[batch_size, feature_dim]
:return: CORAL损失
"""
d = source_features.data.shape[1] # 特征维度
source_mean = torch.mean(source_features, dim=0)
target_mean = torch.mean(target_features, dim=0)
source_cov = (source_features - source_mean).T @ (source_features - source_mean) / (source_features.shape[0] - 1)
target_cov = (target_features - target_mean).T @ (target_features - target_mean) / (target_features.shape[0] - 1)
coral_loss = torch.sum(torch.pow(source_cov - target_cov, 2)) # / (4*d*d)
return coral_loss
def coral_loss(source_features, target_features):
"""
计算Deep CORAL损失。
:param source_features: 源域特征,维度为[batch_size, feature_dim]
:param target_features: 目标域特征,维度为[batch_size, feature_dim]
:return: CORAL损失
"""
# 特征维度
d = source_features.data.shape[1]
# 计算均值
source_mean = torch.mean(source_features, dim=0)
target_mean = torch.mean(target_features, dim=0)
# 计算均值差异
mean_diff = torch.pow(source_mean - target_mean, 2).mean()
# 计算协方差矩阵
source_cov = (source_features - source_mean).T @ (source_features - source_mean) / (source_features.shape[0] - 1)
target_cov = (target_features - target_mean).T @ (target_features - target_mean) / (target_features.shape[0] - 1)
# 计算协方差矩阵差异的平均值
cov_diff = torch.pow(source_cov - target_cov, 2).mean()
# 返回均值差异和协方差矩阵差异的和
total_coral_loss = mean_diff + cov_diff
return total_coral_loss
def shuffle_data(weights, labels):
# 生成索引
indices = torch.randperm(len(weights))
# 使用索引来打乱数据和标签
shuffled_weights = weights[indices]
shuffled_labels = labels[indices]
return shuffled_weights, shuffled_labels
def compute_kernel(x, y):
"""
计算高斯核矩阵
"""
x_size = x.size(0)
y_size = y.size(0)
dim = x.size(1)
tiled_x = x.view(x_size, 1, dim).repeat(1, y_size, 1)
tiled_y = y.view(1, y_size, dim).repeat(x_size, 1, 1)
kernel_matrix = torch.exp(-torch.mean((tiled_x - tiled_y) ** 2, dim=2) / float(dim))
return kernel_matrix
def mmd_loss(source_features, target_features):
"""
计算源域和目标域特征之间的最大均值差异MMD损失
"""
source_kernel = compute_kernel(source_features, source_features)
target_kernel = compute_kernel(target_features, target_features)
cross_kernel = compute_kernel(source_features, target_features)
mmd = source_kernel.mean() + target_kernel.mean() - 2 * cross_kernel.mean()
return mmd
def train(classnames, templates, source_train_loader, source_train_loader_batch, model,
adapter, optimizer,
epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder, gpt_weight, gpt_label,
gpt3_prompt):
batch_time = AverageMeter()
data_time = AverageMeter()
losses_classifier = AverageMeter()
losses_G = AverageMeter()
losses_T = AverageMeter()
top1_source = AverageMeter()
top1_target = AverageMeter()
CLIP_Text.eval()
CLIP_Image.eval()
Text_Encoder.train()
Image_Encoder.train()
model.eval()
logit_scale = model.logit_scale.exp()
adapter.train()
new_epoch_flag = False
end = time.time()
concatenatedCELoss = ConcatenatedCELoss(num_classes=len(classnames)).cuda()
try:
(image, label, _) = source_train_loader_batch.__next__()[1]
except StopIteration:
epoch = epoch + 1
new_epoch_flag = True
source_train_loader_batch = enumerate(source_train_loader)
(image, label, _) = source_train_loader_batch.__next__()[1]
target_target = label.cuda()
# 自监督标签
label_self_supervised = label.cuda()
indices = torch.randperm(len(label))
target_source = label[indices].cuda()
# target_source = label.cuda()
input_target = image.cuda()
# input_source = calculate_zeroshot_weights(classnames, target_source, templates, CLIP_Text, Text_Encoder)
input_source = calculate_zeroshot_weights_GPT(classnames, target_source, templates, CLIP_Text, Text_Encoder,
gpt3_prompt)
gpt_weight, gpt_label = shuffle_data(gpt_weight, gpt_label)
# input_source = torch.cat((
# input_source, gpt_weight
# ), dim=0)
# target_source = torch.cat((
# target_source, gpt_label
# ), dim=0)
data_time.update(time.time() - end)
target_target_temp = target_target + len(classnames)
target_source_temp = target_source + len(classnames)
target_target_temp = target_target_temp.cuda()
# clip图片编码器
with torch.no_grad():
input_target_temp = CLIP_Image(input_target)
input_target_add = Image_Encoder(input_target_temp)
# 计算CORAL损失
# 总损失
# 文本直接输入全连接层
output_source = adapter(input_source) * logit_scale
# 图片直接输入全连接层
output_target = adapter(input_target_add) * logit_scale
# self_input_source = calculate_zeroshot_weights(classnames, label_self_supervised, templates, CLIP_Text,
# Text_Encoder)
self_input_source = calculate_zeroshot_weights_GPT(classnames, label_self_supervised, templates, CLIP_Text,
Text_Encoder, gpt3_prompt)
# input_source = calculate_zeroshot_weights_GPT(classnames, target_source, templates, CLIP_Text, Text_Encoder,gpt3_prompt)
# 计算MMD损失
# mmd_loss_val = mmd_loss(self_input_source, input_target_add)
# lambda_mmd=1000
# mmd_loss_val =lambda_mmd*mmd_loss_val
# 总损失
coral_loss_value = coral_loss(self_input_source, input_target_add)
lambda_coral = 50
loss_1 = lambda_coral * coral_loss_value
# 自监督文本输入全连接层
# self_output_source = adapter(self_input_source)
# self_output_source = F.normalize(self_output_source[:,:len(classnames)])
self_output_source = F.normalize(self_input_source)
# 自监督图像特征
# self_output_target = output_target / logit_scale
# self_output_target = F.normalize(self_output_target[:,len(classnames):])
self_output_target = F.normalize(input_target_add)
# # 构造自监督标签0-255
self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long)
logits_per_image = logit_scale * self_output_target @ self_output_source.T
logits_per_text = logit_scale * self_output_source @ self_output_target.T
loss_self_supervised_1 = (
F.cross_entropy(logits_per_image, self_supervised_labels) +
F.cross_entropy(logits_per_text, self_supervised_labels)
) / 2
# 自监督文本输入全连接层
self_output_source = adapter(self_input_source)
self_output_source = F.normalize(self_output_source[:, :len(classnames)])
# self_output_source = F.normalize(self_input_source)
# 自监督图像特征
self_output_target = output_target / logit_scale
self_output_target = F.normalize(self_output_target[:, len(classnames):])
# self_output_target = F.normalize(input_target_add)
# # 构造自监督标签0-255
self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long)
logits_per_image = logit_scale * self_output_target @ self_output_source.T
logits_per_text = logit_scale * self_output_source @ self_output_target.T
loss_self_supervised_2 = (
F.cross_entropy(logits_per_image, self_supervised_labels) +
F.cross_entropy(logits_per_text, self_supervised_labels)
) / 2
loss_self_supervised = loss_self_supervised_2 # loss_self_supervised_2 + loss_self_supervised_1
# 有监督分类的交叉熵损失
loss_task_s_Cs = criterion(output_source[:, :len(classnames)], target_source)
loss_task_s_Ct = criterion(output_target[:, len(classnames):], target_target)
# 对于源域数据,它希望让分类器上半部分所占的概率尽可能大,对于目标域数据,它希望让分类器下半部分所占的概率尽可能大。
loss_domain_st_Cst_part1 = criterion(output_source, target_source)
loss_domain_st_Cst_part2 = criterion(output_target, target_target_temp)
# 类级别混淆
loss_category_st_G = 0.5 * criterion(output_target, target_target) + 0.5 * criterion(output_source,
target_source_temp)
# 域级别混淆
# loss_domain_st_G = 0.5 * criterion_classifier_target(output_source) + 0.5 * criterion_classifier_source(
# output_target)
lam = 1 # 2 / (1 + math.exp(-1 * 10 * epoch / args.epochs)) - 1
# if(epoch<30):
# self_lam= 5
# else:
self_lam = 0
loss_confusion_target = concatenatedCELoss(output_target)
loss_classifier = loss_task_s_Cs + loss_task_s_Ct + loss_domain_st_Cst_part1 + loss_domain_st_Cst_part2
loss_G = loss_category_st_G + lam * loss_confusion_target
loss_T = loss_G + loss_classifier + self_lam * loss_self_supervised + lambda_coral * coral_loss_value # + mmd_loss_val
prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_source, topk=(1, 5))
prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5))
losses_classifier.update(loss_classifier.item(), input_source.size(0))
losses_G.update(loss_G.item(), input_source.size(0))
losses_T.update(loss_T.item(), input_source.size(0))
top1_source.update(prec1_source[0], input_source.size(0))
top1_target.update(prec1_target[0], input_source.size(0))
optimizer.zero_grad()
loss_T.backward()
optimizer.step()
scheduler.step()
batch_time.update(time.time() - end)
if (epoch + 1) % args.print_freq == 0 or epoch == 0:
print('Train: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t'
'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t'
'Loss@T {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
'top1T {top1T.val:.3f} ({top1T.avg:.3f})\t'.format(
epoch, args.epochs, batch_time=batch_time,
data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, loss_t=losses_T, top1S=top1_source,
top1T=top1_target))
return source_train_loader_batch, epoch, new_epoch_flag
def cls_acc(output, target, topk=1):
pred = output.topk(topk, 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
acc = 100 * acc / target.shape[0]
return acc
def validate(best_epoch, classnames, templates, val_loader, model, epoch, args, criterion, best_prec,
zero_weights,gpt_weight):
batch_time = AverageMeter()
losses_source = AverageMeter()
losses_target = AverageMeter()
top1_source = AverageMeter()
top1_target = AverageMeter()
model.eval()
end = time.time()
logit_scale = model.logit_scale.exp()
for i, (image, label) in enumerate(val_loader):
image = image.cuda()
label = label.cuda()
input_target = image.cuda()
target_target = label.cuda()
target_source = label.cuda()
# clip图片编码器
with torch.no_grad():
input_target_add= model.encode_image(input_target)
input_target_add /= input_target_add.norm(dim=-1, keepdim=True)
logit_1= 100. * input_target_add @ zero_weights
logit_2 = 100.* input_target_add @ gpt_weight
# measure accuracy and record loss
prec1_source = cls_acc(logit_1, target_target)
prec1_target = cls_acc(logit_2, target_target)
top1_source.update(prec1_source, image.size(0))
top1_target.update(prec1_target, image.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'LS {lossS.val:.4f} ({lossS.avg:.4f})\t'
'LT {lossT.val:.4f} ({lossT.avg:.4f})\t'
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
'top1T {top1T.val:.3f} ({top1T.avg:.3f})'.format(
epoch, i, len(val_loader), batch_time=batch_time, lossS=losses_source, lossT=losses_target,
top1S=top1_source, top1T=top1_target))
print(' * Top1@S {top1S.avg:.3f} Top1@T {top1T.avg:.3f}'
.format(top1S=top1_source, top1T=top1_target))
# prec = max(top1_target.avg, top1_source.avg).item()
# if prec > best_prec:
# best_prec = max(top1_target.avg, top1_source.avg).item()
# best_epoch = epoch
# print('best_epoch', best_epoch, ' * Current_best_target@T:', best_prec)
return 0,0#prec, best_epoch
def clip_classifier(classnames, template, clip_model):
with torch.no_grad():
clip_weights = []
for classname in classnames:
# Tokenize the prompts
classname = classname.replace('_', ' ')
texts = [t.format(classname) for t in template]
texts = clip.tokenize(texts).cuda()
# prompt ensemble for ImageNet
class_embeddings = clip_model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
clip_weights.append(class_embedding)
clip_weights = torch.stack(clip_weights, dim=1).cuda()
return clip_weights
def all_classifier(classnames, templates, model):
with torch.no_grad():
zeroshot_weights = []
for classname in classnames:
classname = classname.replace('_', ' ')
texts = [template.format(classname) for template in templates] # format with class
texts = clip.tokenize(texts).cuda() # tokenizeclip.tokenize向量化文字
class_embeddings = model.encode_text(texts) # embed with text encoder
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
return zeroshot_weights
def main():
args = opts()
set_seed(2023)
model, preprocess = clip.load(args.name)
model = model.cuda()
model.eval()
# model.float()
# classnames, templates, loader, train_loader,val_loader = get_dataset_loader(args, preprocess)
imagenet = ImageNet("/root/autodl-tmp/",16, preprocess)
test_loader = torch.utils.data.DataLoader(imagenet.test, batch_size=64, num_workers=8, shuffle=False)
classnames=imagenet.classnames
templates=imagenet.template
train_loader_cache = torch.utils.data.DataLoader(imagenet.train, batch_size=256, num_workers=8, shuffle=False)
train_loader_F = torch.utils.data.DataLoader(imagenet.train, batch_size=256, num_workers=8, shuffle=True)
loader=test_loader
# cfg = yaml.load(open(args.conf ig, 'r'), Loader=yaml.Loader)
# 获取'gpt_file'文件夹下所有的.yaml文件
json_files = glob.glob('gpt_file/imagenet_prompt.json')
for file_path in json_files:
# 打开并读取每个YAML文件
with open(file_path, 'r') as f:
gpt3_prompt = json.load(f)
# gpt_weight = all_classifier_GPTWithCLIP(classnames, gpt3_prompt,model,templates)
gpt_weight = all_classifier_GPT(classnames, gpt3_prompt,model)
gpt_label = torch.arange(len(classnames), device="cuda:0", dtype=torch.long)
gpt_weight, gpt_label
# 分类层
# 损失函数
criterion = nn.CrossEntropyLoss().cuda()
zero_weights = clip_classifier(classnames, templates, model)
current_epoch = 0
best_prec = 0
best_epoch = 0
while (current_epoch < 1):
prec, best_epoch = validate(best_epoch, classnames, templates, loader, model,
current_epoch, args, criterion,
best_prec,
zero_weights,gpt_weight)
current_epoch+=1
if __name__ == '__main__':
main()