import os import time from clip import clip import torch.nn as nn import numpy as np import torch.optim from opts import opts # The options for the project # from trainer import validate # For the validate (test) process from models.DomainClassifierTarget import DClassifierForTarget from models.DomainClassifierSource import DClassifierForSource from utils.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss from utils.utils import prepare_directories, set_seed, get_dataset_loader, configure_clip_encoders, save_model, \ set_adapter_weights, get_text_feature, AverageMeter, accuracy, calculate_zeroshot_weights, gpt_clip_classifier, \ calculate_zeroshot_weights_GPT,calculate_zero,all_classifier_GPT,all_classifier_GPTWithCLIP from Adapter import Weight_Adapter import logging import torch.nn.functional as F import yaml import json import torch import torch.nn as nn import torch.nn.functional as F import glob class CustomCrossAttention(nn.Module): def __init__(self, feature_dim): super(CustomCrossAttention, self).__init__() self.query_projection = nn.Linear(feature_dim, feature_dim) self.key_projection = nn.Linear(feature_dim, feature_dim) self.value_projection = nn.Linear(feature_dim, feature_dim) self.softmax = nn.Softmax(dim=-1) def forward(self, text_features, image_features): # 假设 text_features 的 batch_size < image_features 的 batch_size text_batch_size = text_features.size(0) image_batch_size = image_features.size(0) # 重复 text_features 以匹配 image_features 的 batch_size if text_batch_size < image_batch_size: repeat_times = image_batch_size // text_batch_size text_features = text_features.repeat(repeat_times, 1) query = self.query_projection(text_features) key = self.key_projection(image_features) value = self.value_projection(image_features) # 计算注意力分数 attention_scores = torch.matmul(query, key.transpose(-2, -1)) attention_scores = self.softmax(attention_scores) # 应用注意力分数到 value 上 attended_features = torch.matmul(attention_scores, value) return attended_features def coral_loss(source_features, target_features): """ 计算Deep CORAL损失。 :param source_features: 源域特征,维度为[batch_size, feature_dim] :param target_features: 目标域特征,维度为[batch_size, feature_dim] :return: CORAL损失 """ d = source_features.data.shape[1] # 特征维度 source_mean = torch.mean(source_features, dim=0) target_mean = torch.mean(target_features, dim=0) source_cov = (source_features - source_mean).T @ (source_features - source_mean) / (source_features.shape[0] - 1) target_cov = (target_features - target_mean).T @ (target_features - target_mean) / (target_features.shape[0] - 1) coral_loss = torch.sum(torch.pow(source_cov - target_cov, 2)) # / (4*d*d) return coral_loss def coral_loss(source_features, target_features): """ 计算Deep CORAL损失。 :param source_features: 源域特征,维度为[batch_size, feature_dim] :param target_features: 目标域特征,维度为[batch_size, feature_dim] :return: CORAL损失 """ # 特征维度 d = source_features.data.shape[1] # 计算均值 source_mean = torch.mean(source_features, dim=0) target_mean = torch.mean(target_features, dim=0) # 计算均值差异 mean_diff = torch.pow(source_mean - target_mean, 2).mean() # 计算协方差矩阵 source_cov = (source_features - source_mean).T @ (source_features - source_mean) / (source_features.shape[0] - 1) target_cov = (target_features - target_mean).T @ (target_features - target_mean) / (target_features.shape[0] - 1) # 计算协方差矩阵差异的平均值 cov_diff = torch.pow(source_cov - target_cov, 2).mean() # 返回均值差异和协方差矩阵差异的和 total_coral_loss = mean_diff + cov_diff return total_coral_loss def shuffle_data(weights, labels): # 生成索引 indices = torch.randperm(len(weights)) # 使用索引来打乱数据和标签 shuffled_weights = weights[indices] shuffled_labels = labels[indices] return shuffled_weights, shuffled_labels def compute_kernel(x, y): """ 计算高斯核矩阵 """ x_size = x.size(0) y_size = y.size(0) dim = x.size(1) tiled_x = x.view(x_size, 1, dim).repeat(1, y_size, 1) tiled_y = y.view(1, y_size, dim).repeat(x_size, 1, 1) kernel_matrix = torch.exp(-torch.mean((tiled_x - tiled_y) ** 2, dim=2) / float(dim)) return kernel_matrix def mmd_loss(source_features, target_features): """ 计算源域和目标域特征之间的最大均值差异(MMD)损失 """ source_kernel = compute_kernel(source_features, source_features) target_kernel = compute_kernel(target_features, target_features) cross_kernel = compute_kernel(source_features, target_features) mmd = source_kernel.mean() + target_kernel.mean() - 2 * cross_kernel.mean() return mmd def train(classnames, templates, source_train_loader, source_train_loader_batch, model, adapter, optimizer, epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder, gpt_weight, gpt_label, gpt3_prompt): batch_time = AverageMeter() data_time = AverageMeter() losses_classifier = AverageMeter() losses_G = AverageMeter() losses_T = AverageMeter() top1_source = AverageMeter() top1_target = AverageMeter() CLIP_Text.eval() CLIP_Image.eval() Text_Encoder.train() Image_Encoder.train() model.eval() logit_scale = model.logit_scale.exp() adapter.train() new_epoch_flag = False end = time.time() concatenatedCELoss = ConcatenatedCELoss(num_classes=len(classnames)).cuda() try: (image, label, _) = source_train_loader_batch.__next__()[1] except StopIteration: epoch = epoch + 1 new_epoch_flag = True source_train_loader_batch = enumerate(source_train_loader) (image, label, _) = source_train_loader_batch.__next__()[1] target_target = label.cuda() # 自监督标签 label_self_supervised = label.cuda() indices = torch.randperm(len(label)) target_source = label[indices].cuda() # target_source = label.cuda() input_target = image.cuda() # input_source = calculate_zeroshot_weights(classnames, target_source, templates, CLIP_Text, Text_Encoder) input_source = calculate_zeroshot_weights_GPT(classnames, target_source, templates, CLIP_Text, Text_Encoder, gpt3_prompt) gpt_weight, gpt_label = shuffle_data(gpt_weight, gpt_label) # input_source = torch.cat(( # input_source, gpt_weight # ), dim=0) # target_source = torch.cat(( # target_source, gpt_label # ), dim=0) data_time.update(time.time() - end) target_target_temp = target_target + len(classnames) target_source_temp = target_source + len(classnames) target_target_temp = target_target_temp.cuda() # clip图片编码器 with torch.no_grad(): input_target_temp = CLIP_Image(input_target) input_target_add = Image_Encoder(input_target_temp) # 计算CORAL损失 # 总损失 # 文本直接输入全连接层 output_source = adapter(input_source) * logit_scale # 图片直接输入全连接层 output_target = adapter(input_target_add) * logit_scale # self_input_source = calculate_zeroshot_weights(classnames, label_self_supervised, templates, CLIP_Text, # Text_Encoder) self_input_source = calculate_zeroshot_weights_GPT(classnames, label_self_supervised, templates, CLIP_Text, Text_Encoder, gpt3_prompt) # input_source = calculate_zeroshot_weights_GPT(classnames, target_source, templates, CLIP_Text, Text_Encoder,gpt3_prompt) # 计算MMD损失 # mmd_loss_val = mmd_loss(self_input_source, input_target_add) # lambda_mmd=1000 # mmd_loss_val =lambda_mmd*mmd_loss_val # 总损失 coral_loss_value = coral_loss(self_input_source, input_target_add) lambda_coral = 50 loss_1 = lambda_coral * coral_loss_value # 自监督文本输入全连接层 # self_output_source = adapter(self_input_source) # self_output_source = F.normalize(self_output_source[:,:len(classnames)]) self_output_source = F.normalize(self_input_source) # 自监督图像特征 # self_output_target = output_target / logit_scale # self_output_target = F.normalize(self_output_target[:,len(classnames):]) self_output_target = F.normalize(input_target_add) # # 构造自监督标签0-255 self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long) logits_per_image = logit_scale * self_output_target @ self_output_source.T logits_per_text = logit_scale * self_output_source @ self_output_target.T loss_self_supervised_1 = ( F.cross_entropy(logits_per_image, self_supervised_labels) + F.cross_entropy(logits_per_text, self_supervised_labels) ) / 2 # 自监督文本输入全连接层 self_output_source = adapter(self_input_source) self_output_source = F.normalize(self_output_source[:, :len(classnames)]) # self_output_source = F.normalize(self_input_source) # 自监督图像特征 self_output_target = output_target / logit_scale self_output_target = F.normalize(self_output_target[:, len(classnames):]) # self_output_target = F.normalize(input_target_add) # # 构造自监督标签0-255 self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long) logits_per_image = logit_scale * self_output_target @ self_output_source.T logits_per_text = logit_scale * self_output_source @ self_output_target.T loss_self_supervised_2 = ( F.cross_entropy(logits_per_image, self_supervised_labels) + F.cross_entropy(logits_per_text, self_supervised_labels) ) / 2 loss_self_supervised = loss_self_supervised_2 # loss_self_supervised_2 + loss_self_supervised_1 # 有监督分类的交叉熵损失 loss_task_s_Cs = criterion(output_source[:, :len(classnames)], target_source) loss_task_s_Ct = criterion(output_target[:, len(classnames):], target_target) # 对于源域数据,它希望让分类器上半部分所占的概率尽可能大,对于目标域数据,它希望让分类器下半部分所占的概率尽可能大。 loss_domain_st_Cst_part1 = criterion(output_source, target_source) loss_domain_st_Cst_part2 = criterion(output_target, target_target_temp) # 类级别混淆 loss_category_st_G = 0.5 * criterion(output_target, target_target) + 0.5 * criterion(output_source, target_source_temp) # 域级别混淆 # loss_domain_st_G = 0.5 * criterion_classifier_target(output_source) + 0.5 * criterion_classifier_source( # output_target) lam = 1 # 2 / (1 + math.exp(-1 * 10 * epoch / args.epochs)) - 1 # if(epoch<30): # self_lam= 5 # else: self_lam = 0 loss_confusion_target = concatenatedCELoss(output_target) loss_classifier = loss_task_s_Cs + loss_task_s_Ct + loss_domain_st_Cst_part1 + loss_domain_st_Cst_part2 loss_G = loss_category_st_G + lam * loss_confusion_target loss_T = loss_G + loss_classifier + self_lam * loss_self_supervised + lambda_coral * coral_loss_value # + mmd_loss_val prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_source, topk=(1, 5)) prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5)) losses_classifier.update(loss_classifier.item(), input_source.size(0)) losses_G.update(loss_G.item(), input_source.size(0)) losses_T.update(loss_T.item(), input_source.size(0)) top1_source.update(prec1_source[0], input_source.size(0)) top1_target.update(prec1_target[0], input_source.size(0)) optimizer.zero_grad() loss_T.backward() optimizer.step() scheduler.step() batch_time.update(time.time() - end) if (epoch + 1) % args.print_freq == 0 or epoch == 0: print('Train: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t' 'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t' 'Loss@T {loss_t.val:.4f} ({loss_t.avg:.4f})\t' 'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t' 'top1T {top1T.val:.3f} ({top1T.avg:.3f})\t'.format( epoch, args.epochs, batch_time=batch_time, data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, loss_t=losses_T, top1S=top1_source, top1T=top1_target)) return source_train_loader_batch, epoch, new_epoch_flag def cls_acc(output, target, topk=1): pred = output.topk(topk, 1, True, True)[1].t() correct = pred.eq(target.view(1, -1).expand_as(pred)) acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) acc = 100 * acc / target.shape[0] return acc def validate(best_epoch, classnames, templates, val_loader, model, epoch, args, criterion, best_prec, zero_weights,gpt_weight): batch_time = AverageMeter() losses_source = AverageMeter() losses_target = AverageMeter() top1_source = AverageMeter() top1_target = AverageMeter() model.eval() end = time.time() logit_scale = model.logit_scale.exp() for i, (image, label,_) in enumerate(val_loader): image = image.cuda() label = label.cuda() input_target = image.cuda() target_target = label.cuda() target_source = label.cuda() # clip图片编码器 with torch.no_grad(): input_target_add= model.encode_image(input_target) input_target_add /= input_target_add.norm(dim=-1, keepdim=True) logit_1= 100. * input_target_add @ zero_weights logit_2 = 100.* input_target_add @ gpt_weight # measure accuracy and record loss prec1_source = cls_acc(logit_1, target_target) prec1_target = cls_acc(logit_2, target_target) top1_source.update(prec1_source, image.size(0)) top1_target.update(prec1_target, image.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Test: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'LS {lossS.val:.4f} ({lossS.avg:.4f})\t' 'LT {lossT.val:.4f} ({lossT.avg:.4f})\t' 'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t' 'top1T {top1T.val:.3f} ({top1T.avg:.3f})'.format( epoch, i, len(val_loader), batch_time=batch_time, lossS=losses_source, lossT=losses_target, top1S=top1_source, top1T=top1_target)) print(' * Top1@S {top1S.avg:.3f} Top1@T {top1T.avg:.3f}' .format(top1S=top1_source, top1T=top1_target)) # prec = max(top1_target.avg, top1_source.avg).item() # if prec > best_prec: # best_prec = max(top1_target.avg, top1_source.avg).item() # best_epoch = epoch # print('best_epoch', best_epoch, ' * Current_best_target@T:', best_prec) return 0,0#prec, best_epoch def clip_classifier(classnames, template, clip_model): with torch.no_grad(): clip_weights = [] for classname in classnames: # Tokenize the prompts classname = classname.replace('_', ' ') texts = [t.format(classname) for t in template] texts = clip.tokenize(texts).cuda() # prompt ensemble for ImageNet class_embeddings = clip_model.encode_text(texts) class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding /= class_embedding.norm() clip_weights.append(class_embedding) clip_weights = torch.stack(clip_weights, dim=1).cuda() return clip_weights def all_classifier(classnames, templates, model): with torch.no_grad(): zeroshot_weights = [] for classname in classnames: classname = classname.replace('_', ' ') texts = [template.format(classname) for template in templates] # format with class texts = clip.tokenize(texts).cuda() # tokenizeclip.tokenize向量化文字 class_embeddings = model.encode_text(texts) # embed with text encoder class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding /= class_embedding.norm() zeroshot_weights.append(class_embedding) zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() return zeroshot_weights def main(): args = opts() set_seed(2023) model, preprocess = clip.load(args.name) model = model.cuda() model.eval() # model.float() classnames, templates, loader, train_loader,val_loader = get_dataset_loader(args, preprocess) loader=loader # cfg = yaml.load(open(args.conf ig, 'r'), Loader=yaml.Loader) # 获取'gpt_file'文件夹下所有的.yaml文件 json_files = glob.glob('gpt_file/sun397_prompt.json') for file_path in json_files: # 打开并读取每个YAML文件 with open(file_path, 'r') as f: gpt3_prompt = json.load(f) # gpt_weight = all_classifier_GPTWithCLIP(classnames, gpt3_prompt,model,templates) gpt_weight = all_classifier_GPT(classnames, gpt3_prompt,model) gpt_label = torch.arange(len(classnames), device="cuda:0", dtype=torch.long) gpt_weight, gpt_label # 分类层 # 损失函数 criterion = nn.CrossEntropyLoss().cuda() zero_weights = clip_classifier(classnames, templates, model) current_epoch = 0 best_prec = 0 best_epoch = 0 while (current_epoch < 1): prec, best_epoch = validate(best_epoch, classnames, templates, loader, model, current_epoch, args, criterion, best_prec, zero_weights,gpt_weight) current_epoch+=1 if __name__ == '__main__': main()