import random import time import numpy as np import torch import os import math import clip import ipdb import torch.nn.functional as F import torch.nn as nn from utils.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss def zeroshot_classifier(classname, templates, CLIP_Text): with torch.no_grad(): classname = classname.replace('_', ' ') str_prompts = [template.format(classname) for template in templates] prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda() features, eot_indices = CLIP_Text(prompts) return features, eot_indices def warm_train(classnames, templates, source_train_loader, source_train_loader_batch, model, adapter, criterion_classifier_source, criterion_classifier_target, optimizer, epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder): random.seed(1) batch_time = AverageMeter() data_time = AverageMeter() losses_classifier = AverageMeter() losses_G = AverageMeter() losses_T = AverageMeter() top1_source = AverageMeter() top1_target = AverageMeter() CLIP_Text.eval() CLIP_Image.eval() Text_Encoder.eval() Image_Encoder.eval() logit_scale = 4.60517 logit_scale = math.exp(logit_scale) model.eval() adapter.train() new_epoch_flag = False end = time.time() concatenatedCELoss = ConcatenatedCELoss(num_classes=len(classnames)).cuda() try: (image, label, _) = source_train_loader_batch.__next__()[1] except StopIteration: epoch = epoch + 1 new_epoch_flag = True source_train_loader_batch = enumerate(source_train_loader) (image, label, _) = source_train_loader_batch.__next__()[1] target_target = label.cuda() # 自监督标签 label_self_supervised = label.cuda() indices = torch.randperm(len(label)) target_source = label[indices].cuda() # target_source = label.cuda() input_target = image.cuda() zeroshot_weights = [] for i in range(len(target_source)): features, eot_indices = zeroshot_classifier(classnames[target_source[i]], templates, CLIP_Text) class_embeddings = Text_Encoder(features, eot_indices) class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding = class_embedding / class_embedding.norm() class_embedding = class_embedding / class_embedding.norm(dim=-1, keepdim=True) zeroshot_weights.append(class_embedding) zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() input_source = zeroshot_weights.T data_time.update(time.time() - end) target_target_temp = target_target + len(classnames) target_source_temp = target_source + len(classnames) target_target_temp = target_target_temp.cuda() # clip图片编码器 with torch.no_grad(): input_target_temp = CLIP_Image(input_target) input_target_add = Image_Encoder(input_target_temp) # 文本直接输入全连接层 output_source = adapter(input_source) * logit_scale # 输入编码图片 output_target = adapter(input_target_add) * logit_scale self_zeroshot_weights = [] for i in range(len(label_self_supervised)): features, eot_indices = zeroshot_classifier(classnames[label_self_supervised[i]], templates, CLIP_Text) class_embeddings = Text_Encoder(features, eot_indices) class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding = class_embedding / class_embedding.norm() self_zeroshot_weights.append(class_embedding) self_zeroshot_weights = torch.stack(self_zeroshot_weights, dim=1).cuda() self_input_source = self_zeroshot_weights.T # 自监督文本输入全连接层 self_output_source = adapter(self_input_source) self_output_source = F.normalize(self_output_source) # 自监督图像特征 self_output_target = output_target / logit_scale self_output_target = F.normalize(self_output_target) # # 构造自监督标签0-255 self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long) logits_per_image = logit_scale * self_output_target @ self_output_source.T logits_per_text = logit_scale * self_output_source @ self_output_target.T loss_self_supervised = ( F.cross_entropy(logits_per_image, self_supervised_labels) + F.cross_entropy(logits_per_text, self_supervised_labels) ) / 2 # 有监督分类的交叉熵损失 loss_task_s_Cs = criterion(output_source[:, :len(classnames)], target_source) loss_task_s_Ct = criterion(output_target[:, len(classnames):], target_target) # 对于源域数据,它希望让分类器上半部分所占的概率尽可能大,对于目标域数据,它希望让分类器下半部分所占的概率尽可能大。 loss_domain_st_Cst_part1 = criterion(output_source, target_source) loss_domain_st_Cst_part2 = criterion(output_target, target_target_temp) # 类级别混淆 loss_category_st_G = 0.5 * criterion(output_target, target_target) + 0.5 * criterion(output_source, target_source_temp) # 域级别混淆 # loss_domain_st_G = 0.5 * criterion_classifier_target(output_source) + 0.5 * criterion_classifier_source( # output_target) lam = 2 / (1 + math.exp(-1 * 10 * epoch / args.epochs)) - 1 if (epoch < 30): self_lam = 3 else: self_lam = 1 / 5 loss_confusion_target = concatenatedCELoss(output_target) loss_classifier = loss_task_s_Cs + loss_task_s_Ct + loss_domain_st_Cst_part1 + loss_domain_st_Cst_part2 loss_G = loss_category_st_G + lam * loss_confusion_target loss_T = loss_G + loss_classifier + self_lam * loss_self_supervised prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_source, topk=(1, 5)) prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5)) losses_classifier.update(loss_classifier.item(), input_source.size(0)) losses_G.update(loss_G.item(), input_source.size(0)) losses_T.update(loss_T.item(), input_source.size(0)) top1_source.update(prec1_source[0], input_source.size(0)) top1_target.update(prec1_target[0], input_source.size(0)) optimizer.zero_grad() # loss_classifier.backward(retain_graph=True) # optimizer.step() # # optimizer.zero_grad() # loss_G.backward() loss_T.backward() optimizer.step() scheduler.step() batch_time.update(time.time() - end) if (epoch + 1) % args.print_freq == 0 or epoch == 0: print('Train: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t' 'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t' 'Loss@T {loss_t.val:.4f} ({loss_t.avg:.4f})\t' 'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t' 'top1T {top1T.val:.3f} ({top1T.avg:.3f})\t'.format( epoch, args.epochs, batch_time=batch_time, data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, loss_t=losses_T, top1S=top1_source, top1T=top1_target)) if new_epoch_flag: log = open(os.path.join(args.log, 'log.txt'), 'a') log.write("\n") log.write("Train:epoch: %d, loss@min: %4f, loss@max: %4f, Top1S acc: %3f, Top1T acc: %3f" % ( epoch, losses_classifier.avg, losses_G.avg, top1_source.avg, top1_target.avg)) log.close() return source_train_loader_batch, epoch, new_epoch_flag def train(classnames, templates, source_train_loader, source_train_loader_batch, model, adapter, criterion_classifier_source, criterion_classifier_target, optimizer, epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder): random.seed(1) batch_time = AverageMeter() data_time = AverageMeter() losses_classifier = AverageMeter() losses_G = AverageMeter() losses_T = AverageMeter() top1_source = AverageMeter() top1_target = AverageMeter() CLIP_Text.eval() CLIP_Image.eval() Text_Encoder.eval() Image_Encoder.train() logit_scale = 4.60517 logit_scale = math.exp(logit_scale) model.eval() adapter.train() new_epoch_flag = False end = time.time() concatenatedCELoss = ConcatenatedCELoss(num_classes=len(classnames)).cuda() try: (image, label, _) = source_train_loader_batch.__next__()[1] except StopIteration: epoch = epoch + 1 new_epoch_flag = True source_train_loader_batch = enumerate(source_train_loader) (image, label, _) = source_train_loader_batch.__next__()[1] target_target = label.cuda() # 自监督标签 label_self_supervised = label.cuda() indices = torch.randperm(len(label)) target_source = label[indices].cuda() # target_source = label.cuda() input_target = image.cuda() zeroshot_weights = [] for i in range(len(target_source)): features, eot_indices = zeroshot_classifier(classnames[target_source[i]], templates, CLIP_Text) class_embeddings = Text_Encoder(features, eot_indices) class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding = class_embedding / class_embedding.norm() class_embedding = class_embedding / class_embedding.norm(dim=-1, keepdim=True) zeroshot_weights.append(class_embedding) zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() input_source = zeroshot_weights.T data_time.update(time.time() - end) target_target_temp = target_target + len(classnames) target_source_temp = target_source + len(classnames) target_target_temp = target_target_temp.cuda() # clip图片编码器 with torch.no_grad(): input_target_temp = CLIP_Image(input_target) input_target_add = Image_Encoder(input_target_temp) # 文本直接输入全连接层 output_source = adapter(input_source) * logit_scale # 输入编码图片 output_target = adapter(input_target_add) * logit_scale self_zeroshot_weights = [] for i in range(len(label_self_supervised)): features, eot_indices = zeroshot_classifier(classnames[label_self_supervised[i]], templates, CLIP_Text) class_embeddings = Text_Encoder(features, eot_indices) class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding = class_embedding / class_embedding.norm() self_zeroshot_weights.append(class_embedding) self_zeroshot_weights = torch.stack(self_zeroshot_weights, dim=1).cuda() self_input_source = self_zeroshot_weights.T # 自监督文本输入全连接层 # self_output_source = adapter(self_input_source) # self_output_source = F.normalize(self_output_source[:,:len(classnames)]) self_output_source = F.normalize(self_input_source) # 自监督图像特征 # self_output_target = output_target / logit_scale # self_output_target = F.normalize(self_output_target[:,len(classnames):]) self_output_target = F.normalize(input_target_add) # # 构造自监督标签0-255 self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long) logits_per_image = logit_scale * self_output_target @ self_output_source.T logits_per_text = logit_scale * self_output_source @ self_output_target.T loss_self_supervised_1 = ( F.cross_entropy(logits_per_image, self_supervised_labels) + F.cross_entropy(logits_per_text, self_supervised_labels) ) / 2 # 自监督文本输入全连接层 self_output_source = adapter(self_input_source) self_output_source = F.normalize(self_output_source[:, :len(classnames)]) # self_output_source = F.normalize(self_input_source) # 自监督图像特征 self_output_target = output_target / logit_scale self_output_target = F.normalize(self_output_target[:, len(classnames):]) # self_output_target = F.normalize(input_target_add) # # 构造自监督标签0-255 self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long) logits_per_image = logit_scale * self_output_target @ self_output_source.T logits_per_text = logit_scale * self_output_source @ self_output_target.T loss_self_supervised_2 = ( F.cross_entropy(logits_per_image, self_supervised_labels) + F.cross_entropy(logits_per_text, self_supervised_labels) ) / 2 loss_self_supervised = loss_self_supervised_2 + loss_self_supervised_1 # 有监督分类的交叉熵损失 loss_task_s_Cs = criterion(output_source[:, :len(classnames)], target_source) loss_task_s_Ct = criterion(output_target[:, len(classnames):], target_target) # 对于源域数据,它希望让分类器上半部分所占的概率尽可能大,对于目标域数据,它希望让分类器下半部分所占的概率尽可能大。 loss_domain_st_Cst_part1 = criterion(output_source, target_source) loss_domain_st_Cst_part2 = criterion(output_target, target_target_temp) # 类级别混淆 loss_category_st_G = 0.5 * criterion(output_target, target_target) + 0.5 * criterion(output_source, target_source_temp) # 域级别混淆 # loss_domain_st_G = 0.5 * criterion_classifier_target(output_source) + 0.5 * criterion_classifier_source( # output_target) lam = 2 / (1 + math.exp(-1 * 10 * epoch / args.epochs)) - 1 # if(epoch<30): # self_lam= 5 # else: self_lam = 0.6 loss_confusion_target = concatenatedCELoss(output_target) loss_classifier = loss_task_s_Cs + loss_task_s_Ct + loss_domain_st_Cst_part1 + loss_domain_st_Cst_part2 loss_G = loss_category_st_G + lam * loss_confusion_target loss_T = loss_G + loss_classifier + self_lam * loss_self_supervised prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_source, topk=(1, 5)) prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5)) losses_classifier.update(loss_classifier.item(), input_source.size(0)) losses_G.update(loss_G.item(), input_source.size(0)) losses_T.update(loss_T.item(), input_source.size(0)) top1_source.update(prec1_source[0], input_source.size(0)) top1_target.update(prec1_target[0], input_source.size(0)) optimizer.zero_grad() # loss_classifier.backward(retain_graph=True) # optimizer.step() # # optimizer.zero_grad() # loss_G.backward() loss_T.backward() optimizer.step() scheduler.step() batch_time.update(time.time() - end) if (epoch + 1) % args.print_freq == 0 or epoch == 0: print('Train: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t' 'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t' 'Loss@T {loss_t.val:.4f} ({loss_t.avg:.4f})\t' 'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t' 'top1T {top1T.val:.3f} ({top1T.avg:.3f})\t'.format( epoch, args.epochs, batch_time=batch_time, data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, loss_t=losses_T, top1S=top1_source, top1T=top1_target)) if new_epoch_flag: log = open(os.path.join(args.log, 'log.txt'), 'a') log.write("\n") log.write("Train:epoch: %d, loss@min: %4f, loss@max: %4f, Top1S acc: %3f, Top1T acc: %3f" % ( epoch, losses_classifier.avg, losses_G.avg, top1_source.avg, top1_target.avg)) log.close() return source_train_loader_batch, epoch, new_epoch_flag best_target_acc = 0 best_epoch = 0 def validate(classnames, templates, val_loader, model, adapter, epoch, args , criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder): global best_target_acc global best_epoch batch_time = AverageMeter() losses_source = AverageMeter() losses_target = AverageMeter() top1_source = AverageMeter() top1_target = AverageMeter() zero_acc_I_acc = AverageMeter() clip_acc_aver = AverageMeter() Compu4_acc = AverageMeter() # switch to evaluate mode CLIP_Text.eval() CLIP_Image.eval() Text_Encoder.eval() Image_Encoder.eval() model.eval() adapter.eval() end = time.time() logit_scale = 4.60517 logit_scale = math.exp(logit_scale) for i, (image, label, _) in enumerate(val_loader): image = image.cuda() label = label.cuda() zeroshot_weights = [] for j in range(len(label)): features, eot_indices = zeroshot_classifier(classnames[label[j]], templates, CLIP_Text) with torch.no_grad(): class_embeddings = Text_Encoder(features, eot_indices) class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding = class_embedding / class_embedding.norm() class_embedding = class_embedding / class_embedding.norm(dim=-1, keepdim=True) zeroshot_weights.append(class_embedding) zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() input_source = zeroshot_weights input_source = input_source.T input_target = image.cuda() target_target = label.cuda() target_source = label.cuda() # clip图片编码器 with torch.no_grad(): input_target_temp = CLIP_Image(input_target) input_target_add = Image_Encoder(input_target_temp) # output_source = adapter(input_source) * logit_scale output_target = adapter(input_target_add) * logit_scale output_source = output_target # 3 loss_source = criterion(output_source[:, :len(classnames)], target_target) loss_target = criterion(output_target[:, len(classnames):], target_target) # measure accuracy and record loss prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_target, topk=(1, 5)) prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5)) losses_source.update(loss_source.item(), image.size(0)) losses_target.update(loss_target.item(), image.size(0)) top1_source.update(prec1_source[0], image.size(0)) top1_target.update(prec1_target[0], image.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Test: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'LS {lossS.val:.4f} ({lossS.avg:.4f})\t' 'LT {lossT.val:.4f} ({lossT.avg:.4f})\t' 'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t' 'top1T {top1T.val:.3f} ({top1T.avg:.3f})'.format( epoch, i, len(val_loader), batch_time=batch_time, lossS=losses_source, lossT=losses_target, top1S=top1_source, top1T=top1_target)) print(' * Top1@S {top1S.avg:.3f} Top1@T {top1T.avg:.3f}' .format(top1S=top1_source, top1T=top1_target)) if max(top1_target.avg, top1_source.avg) > best_target_acc: best_target_acc = max(top1_target.avg, top1_source.avg) best_epoch = epoch print('best_epoch', best_epoch, ' * Current_best_target@T:', best_target_acc.item()) log = open(os.path.join(args.log, 'log.txt'), 'a') log.write("\n") log.write(" Test:epoch: %d, LS: %4f, LT: %4f, Top1S: %3f, Top1T: %3f" % \ (epoch, losses_source.avg, losses_target.avg, top1_source.avg, top1_target.avg)) log.close() return best_target_acc.item() class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() _2, pred2 = output.topk(1, 1, True, True) a = target.view(1, -1) correct = pred.eq(target.view(1, -1).expand_as(pred)) # print(correct) res = [] for k in topk: correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res