Files
clip-symnets/zero_test_imagenet.py
2024-05-21 19:41:56 +08:00

470 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import time
from clip import clip
import torch.nn as nn
import numpy as np
import torch.optim
from opts import opts # The options for the project
# from trainer import validate # For the validate (test) process
from models.DomainClassifierTarget import DClassifierForTarget
from models.DomainClassifierSource import DClassifierForSource
from utils.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss
from utils.utils import prepare_directories, set_seed, get_dataset_loader, configure_clip_encoders, save_model, \
set_adapter_weights, get_text_feature, AverageMeter, accuracy, calculate_zeroshot_weights, gpt_clip_classifier, \
calculate_zeroshot_weights_GPT,calculate_zero,all_classifier_GPT,all_classifier_GPTWithCLIP
from Adapter import Weight_Adapter
import logging
import torch.nn.functional as F
import yaml
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import glob
from datasets.imagenet import ImageNet
class CustomCrossAttention(nn.Module):
def __init__(self, feature_dim):
super(CustomCrossAttention, self).__init__()
self.query_projection = nn.Linear(feature_dim, feature_dim)
self.key_projection = nn.Linear(feature_dim, feature_dim)
self.value_projection = nn.Linear(feature_dim, feature_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, text_features, image_features):
# 假设 text_features 的 batch_size < image_features 的 batch_size
text_batch_size = text_features.size(0)
image_batch_size = image_features.size(0)
# 重复 text_features 以匹配 image_features 的 batch_size
if text_batch_size < image_batch_size:
repeat_times = image_batch_size // text_batch_size
text_features = text_features.repeat(repeat_times, 1)
query = self.query_projection(text_features)
key = self.key_projection(image_features)
value = self.value_projection(image_features)
# 计算注意力分数
attention_scores = torch.matmul(query, key.transpose(-2, -1))
attention_scores = self.softmax(attention_scores)
# 应用注意力分数到 value 上
attended_features = torch.matmul(attention_scores, value)
return attended_features
def coral_loss(source_features, target_features):
"""
计算Deep CORAL损失。
:param source_features: 源域特征,维度为[batch_size, feature_dim]
:param target_features: 目标域特征,维度为[batch_size, feature_dim]
:return: CORAL损失
"""
d = source_features.data.shape[1] # 特征维度
source_mean = torch.mean(source_features, dim=0)
target_mean = torch.mean(target_features, dim=0)
source_cov = (source_features - source_mean).T @ (source_features - source_mean) / (source_features.shape[0] - 1)
target_cov = (target_features - target_mean).T @ (target_features - target_mean) / (target_features.shape[0] - 1)
coral_loss = torch.sum(torch.pow(source_cov - target_cov, 2)) # / (4*d*d)
return coral_loss
def coral_loss(source_features, target_features):
"""
计算Deep CORAL损失。
:param source_features: 源域特征,维度为[batch_size, feature_dim]
:param target_features: 目标域特征,维度为[batch_size, feature_dim]
:return: CORAL损失
"""
# 特征维度
d = source_features.data.shape[1]
# 计算均值
source_mean = torch.mean(source_features, dim=0)
target_mean = torch.mean(target_features, dim=0)
# 计算均值差异
mean_diff = torch.pow(source_mean - target_mean, 2).mean()
# 计算协方差矩阵
source_cov = (source_features - source_mean).T @ (source_features - source_mean) / (source_features.shape[0] - 1)
target_cov = (target_features - target_mean).T @ (target_features - target_mean) / (target_features.shape[0] - 1)
# 计算协方差矩阵差异的平均值
cov_diff = torch.pow(source_cov - target_cov, 2).mean()
# 返回均值差异和协方差矩阵差异的和
total_coral_loss = mean_diff + cov_diff
return total_coral_loss
def shuffle_data(weights, labels):
# 生成索引
indices = torch.randperm(len(weights))
# 使用索引来打乱数据和标签
shuffled_weights = weights[indices]
shuffled_labels = labels[indices]
return shuffled_weights, shuffled_labels
def compute_kernel(x, y):
"""
计算高斯核矩阵
"""
x_size = x.size(0)
y_size = y.size(0)
dim = x.size(1)
tiled_x = x.view(x_size, 1, dim).repeat(1, y_size, 1)
tiled_y = y.view(1, y_size, dim).repeat(x_size, 1, 1)
kernel_matrix = torch.exp(-torch.mean((tiled_x - tiled_y) ** 2, dim=2) / float(dim))
return kernel_matrix
def mmd_loss(source_features, target_features):
"""
计算源域和目标域特征之间的最大均值差异MMD损失
"""
source_kernel = compute_kernel(source_features, source_features)
target_kernel = compute_kernel(target_features, target_features)
cross_kernel = compute_kernel(source_features, target_features)
mmd = source_kernel.mean() + target_kernel.mean() - 2 * cross_kernel.mean()
return mmd
def train(classnames, templates, source_train_loader, source_train_loader_batch, model,
adapter, optimizer,
epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder, gpt_weight, gpt_label,
gpt3_prompt):
batch_time = AverageMeter()
data_time = AverageMeter()
losses_classifier = AverageMeter()
losses_G = AverageMeter()
losses_T = AverageMeter()
top1_source = AverageMeter()
top1_target = AverageMeter()
CLIP_Text.eval()
CLIP_Image.eval()
Text_Encoder.train()
Image_Encoder.train()
model.eval()
logit_scale = model.logit_scale.exp()
adapter.train()
new_epoch_flag = False
end = time.time()
concatenatedCELoss = ConcatenatedCELoss(num_classes=len(classnames)).cuda()
try:
(image, label, _) = source_train_loader_batch.__next__()[1]
except StopIteration:
epoch = epoch + 1
new_epoch_flag = True
source_train_loader_batch = enumerate(source_train_loader)
(image, label, _) = source_train_loader_batch.__next__()[1]
target_target = label.cuda()
# 自监督标签
label_self_supervised = label.cuda()
indices = torch.randperm(len(label))
target_source = label[indices].cuda()
# target_source = label.cuda()
input_target = image.cuda()
# input_source = calculate_zeroshot_weights(classnames, target_source, templates, CLIP_Text, Text_Encoder)
input_source = calculate_zeroshot_weights_GPT(classnames, target_source, templates, CLIP_Text, Text_Encoder,
gpt3_prompt)
gpt_weight, gpt_label = shuffle_data(gpt_weight, gpt_label)
# input_source = torch.cat((
# input_source, gpt_weight
# ), dim=0)
# target_source = torch.cat((
# target_source, gpt_label
# ), dim=0)
data_time.update(time.time() - end)
target_target_temp = target_target + len(classnames)
target_source_temp = target_source + len(classnames)
target_target_temp = target_target_temp.cuda()
# clip图片编码器
with torch.no_grad():
input_target_temp = CLIP_Image(input_target)
input_target_add = Image_Encoder(input_target_temp)
# 计算CORAL损失
# 总损失
# 文本直接输入全连接层
output_source = adapter(input_source) * logit_scale
# 图片直接输入全连接层
output_target = adapter(input_target_add) * logit_scale
# self_input_source = calculate_zeroshot_weights(classnames, label_self_supervised, templates, CLIP_Text,
# Text_Encoder)
self_input_source = calculate_zeroshot_weights_GPT(classnames, label_self_supervised, templates, CLIP_Text,
Text_Encoder, gpt3_prompt)
# input_source = calculate_zeroshot_weights_GPT(classnames, target_source, templates, CLIP_Text, Text_Encoder,gpt3_prompt)
# 计算MMD损失
# mmd_loss_val = mmd_loss(self_input_source, input_target_add)
# lambda_mmd=1000
# mmd_loss_val =lambda_mmd*mmd_loss_val
# 总损失
coral_loss_value = coral_loss(self_input_source, input_target_add)
lambda_coral = 50
loss_1 = lambda_coral * coral_loss_value
# 自监督文本输入全连接层
# self_output_source = adapter(self_input_source)
# self_output_source = F.normalize(self_output_source[:,:len(classnames)])
self_output_source = F.normalize(self_input_source)
# 自监督图像特征
# self_output_target = output_target / logit_scale
# self_output_target = F.normalize(self_output_target[:,len(classnames):])
self_output_target = F.normalize(input_target_add)
# # 构造自监督标签0-255
self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long)
logits_per_image = logit_scale * self_output_target @ self_output_source.T
logits_per_text = logit_scale * self_output_source @ self_output_target.T
loss_self_supervised_1 = (
F.cross_entropy(logits_per_image, self_supervised_labels) +
F.cross_entropy(logits_per_text, self_supervised_labels)
) / 2
# 自监督文本输入全连接层
self_output_source = adapter(self_input_source)
self_output_source = F.normalize(self_output_source[:, :len(classnames)])
# self_output_source = F.normalize(self_input_source)
# 自监督图像特征
self_output_target = output_target / logit_scale
self_output_target = F.normalize(self_output_target[:, len(classnames):])
# self_output_target = F.normalize(input_target_add)
# # 构造自监督标签0-255
self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long)
logits_per_image = logit_scale * self_output_target @ self_output_source.T
logits_per_text = logit_scale * self_output_source @ self_output_target.T
loss_self_supervised_2 = (
F.cross_entropy(logits_per_image, self_supervised_labels) +
F.cross_entropy(logits_per_text, self_supervised_labels)
) / 2
loss_self_supervised = loss_self_supervised_2 # loss_self_supervised_2 + loss_self_supervised_1
# 有监督分类的交叉熵损失
loss_task_s_Cs = criterion(output_source[:, :len(classnames)], target_source)
loss_task_s_Ct = criterion(output_target[:, len(classnames):], target_target)
# 对于源域数据,它希望让分类器上半部分所占的概率尽可能大,对于目标域数据,它希望让分类器下半部分所占的概率尽可能大。
loss_domain_st_Cst_part1 = criterion(output_source, target_source)
loss_domain_st_Cst_part2 = criterion(output_target, target_target_temp)
# 类级别混淆
loss_category_st_G = 0.5 * criterion(output_target, target_target) + 0.5 * criterion(output_source,
target_source_temp)
# 域级别混淆
# loss_domain_st_G = 0.5 * criterion_classifier_target(output_source) + 0.5 * criterion_classifier_source(
# output_target)
lam = 1 # 2 / (1 + math.exp(-1 * 10 * epoch / args.epochs)) - 1
# if(epoch<30):
# self_lam= 5
# else:
self_lam = 0
loss_confusion_target = concatenatedCELoss(output_target)
loss_classifier = loss_task_s_Cs + loss_task_s_Ct + loss_domain_st_Cst_part1 + loss_domain_st_Cst_part2
loss_G = loss_category_st_G + lam * loss_confusion_target
loss_T = loss_G + loss_classifier + self_lam * loss_self_supervised + lambda_coral * coral_loss_value # + mmd_loss_val
prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_source, topk=(1, 5))
prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5))
losses_classifier.update(loss_classifier.item(), input_source.size(0))
losses_G.update(loss_G.item(), input_source.size(0))
losses_T.update(loss_T.item(), input_source.size(0))
top1_source.update(prec1_source[0], input_source.size(0))
top1_target.update(prec1_target[0], input_source.size(0))
optimizer.zero_grad()
loss_T.backward()
optimizer.step()
scheduler.step()
batch_time.update(time.time() - end)
if (epoch + 1) % args.print_freq == 0 or epoch == 0:
print('Train: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t'
'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t'
'Loss@T {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
'top1T {top1T.val:.3f} ({top1T.avg:.3f})\t'.format(
epoch, args.epochs, batch_time=batch_time,
data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, loss_t=losses_T, top1S=top1_source,
top1T=top1_target))
return source_train_loader_batch, epoch, new_epoch_flag
def cls_acc(output, target, topk=1):
pred = output.topk(topk, 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
acc = 100 * acc / target.shape[0]
return acc
def validate(best_epoch, classnames, templates, val_loader, model, epoch, args, criterion, best_prec,
zero_weights,gpt_weight):
batch_time = AverageMeter()
losses_source = AverageMeter()
losses_target = AverageMeter()
top1_source = AverageMeter()
top1_target = AverageMeter()
model.eval()
end = time.time()
logit_scale = model.logit_scale.exp()
for i, (image, label) in enumerate(val_loader):
image = image.cuda()
label = label.cuda()
input_target = image.cuda()
target_target = label.cuda()
target_source = label.cuda()
# clip图片编码器
with torch.no_grad():
input_target_add= model.encode_image(input_target)
input_target_add /= input_target_add.norm(dim=-1, keepdim=True)
logit_1= 100. * input_target_add @ zero_weights
logit_2 = 100.* input_target_add @ gpt_weight
# measure accuracy and record loss
prec1_source = cls_acc(logit_1, target_target)
prec1_target = cls_acc(logit_2, target_target)
top1_source.update(prec1_source, image.size(0))
top1_target.update(prec1_target, image.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'LS {lossS.val:.4f} ({lossS.avg:.4f})\t'
'LT {lossT.val:.4f} ({lossT.avg:.4f})\t'
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
'top1T {top1T.val:.3f} ({top1T.avg:.3f})'.format(
epoch, i, len(val_loader), batch_time=batch_time, lossS=losses_source, lossT=losses_target,
top1S=top1_source, top1T=top1_target))
print(' * Top1@S {top1S.avg:.3f} Top1@T {top1T.avg:.3f}'
.format(top1S=top1_source, top1T=top1_target))
# prec = max(top1_target.avg, top1_source.avg).item()
# if prec > best_prec:
# best_prec = max(top1_target.avg, top1_source.avg).item()
# best_epoch = epoch
# print('best_epoch', best_epoch, ' * Current_best_target@T:', best_prec)
return 0,0#prec, best_epoch
def clip_classifier(classnames, template, clip_model):
with torch.no_grad():
clip_weights = []
for classname in classnames:
# Tokenize the prompts
classname = classname.replace('_', ' ')
texts = [t.format(classname) for t in template]
texts = clip.tokenize(texts).cuda()
# prompt ensemble for ImageNet
class_embeddings = clip_model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
clip_weights.append(class_embedding)
clip_weights = torch.stack(clip_weights, dim=1).cuda()
return clip_weights
def all_classifier(classnames, templates, model):
with torch.no_grad():
zeroshot_weights = []
for classname in classnames:
classname = classname.replace('_', ' ')
texts = [template.format(classname) for template in templates] # format with class
texts = clip.tokenize(texts).cuda() # tokenizeclip.tokenize向量化文字
class_embeddings = model.encode_text(texts) # embed with text encoder
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
return zeroshot_weights
def main():
args = opts()
set_seed(2023)
model, preprocess = clip.load(args.name)
model = model.cuda()
model.eval()
# model.float()
# classnames, templates, loader, train_loader,val_loader = get_dataset_loader(args, preprocess)
imagenet = ImageNet("/root/autodl-tmp/",16, preprocess)
test_loader = torch.utils.data.DataLoader(imagenet.test, batch_size=64, num_workers=8, shuffle=False)
classnames=imagenet.classnames
templates=imagenet.template
train_loader_cache = torch.utils.data.DataLoader(imagenet.train, batch_size=256, num_workers=8, shuffle=False)
train_loader_F = torch.utils.data.DataLoader(imagenet.train, batch_size=256, num_workers=8, shuffle=True)
loader=test_loader
# cfg = yaml.load(open(args.conf ig, 'r'), Loader=yaml.Loader)
# 获取'gpt_file'文件夹下所有的.yaml文件
json_files = glob.glob('gpt_file/imagenet_prompt.json')
for file_path in json_files:
# 打开并读取每个YAML文件
with open(file_path, 'r') as f:
gpt3_prompt = json.load(f)
# gpt_weight = all_classifier_GPTWithCLIP(classnames, gpt3_prompt,model,templates)
gpt_weight = all_classifier_GPT(classnames, gpt3_prompt,model)
gpt_label = torch.arange(len(classnames), device="cuda:0", dtype=torch.long)
gpt_weight, gpt_label
# 分类层
# 损失函数
criterion = nn.CrossEntropyLoss().cuda()
zero_weights = clip_classifier(classnames, templates, model)
current_epoch = 0
best_prec = 0
best_epoch = 0
while (current_epoch < 1):
prec, best_epoch = validate(best_epoch, classnames, templates, loader, model,
current_epoch, args, criterion,
best_prec,
zero_weights,gpt_weight)
current_epoch+=1
if __name__ == '__main__':
main()