import time from clip import clip import torch.nn as nn import torch.optim from opts import opts # The options for the project # from trainer import validate # For the validate (test) process from models.DomainClassifierTarget import DClassifierForTarget from models.DomainClassifierSource import DClassifierForSource from utils.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss from utils.utils import prepare_directories, set_seed, get_dataset_loader, configure_clip_encoders, save_model, \ set_adapter_weights_single, get_text_feature, AverageMeter, accuracy, calculate_zeroshot_weights from Adapter import Adapter import logging import torch.nn.functional as F import torch import torch.nn as nn import torch.nn.functional as F class CompactBilinearPooling(nn.Module): def __init__(self, input_dim, output_dim): super(CompactBilinearPooling, self).__init__() self.input_dim = input_dim self.output_dim = output_dim # 随机生成哈希参数 self.hashcode = torch.randint(0, output_dim, (input_dim,), dtype=torch.long).cuda() self.sign = (torch.randint(0, 2, (input_dim,)) * 2 - 1).cuda() self.sign = self.sign.float().cuda() def count_sketch(self, x, hashcode, sign): batch_size, input_dim = x.shape output = x.new_zeros(batch_size, self.output_dim) for i in range(input_dim): output[:, hashcode[i]] += sign[i] * x[:, i] return output def forward(self, x1, x2): # 应用Count Sketch sketch_x1 = self.count_sketch(x1, self.hashcode, self.sign).cuda() sketch_x2 = self.count_sketch(x2, self.hashcode, self.sign).cuda() # 应用FFT变换 fft_x1 = torch.fft.rfft(sketch_x1, n=self.output_dim, dim=1).cuda() fft_x2 = torch.fft.rfft(sketch_x2, n=self.output_dim, dim=1).cuda() # 计算FFT的点积并应用逆FFT变换 ifft_result = torch.fft.irfft(fft_x1 * fft_x2, n=self.output_dim, dim=1).cuda() return ifft_result def train(classnames, templates, source_train_loader, source_train_loader_batch, model, adapter, optimizer, epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder,cbp_layer): batch_time = AverageMeter() data_time = AverageMeter() losses_classifier = AverageMeter() losses_G = AverageMeter() losses_T = AverageMeter() top1_source = AverageMeter() top1_target = AverageMeter() logit_scale = model.logit_scale.exp() model.eval() CLIP_Text.eval() CLIP_Image.eval() Text_Encoder.train() Image_Encoder.train() cbp_layer.train() adapter.train() new_epoch_flag = False end = time.time() try: (image, label, _) = source_train_loader_batch.__next__()[1] except StopIteration: epoch = epoch + 1 new_epoch_flag = True source_train_loader_batch = enumerate(source_train_loader) (image, label, _) = source_train_loader_batch.__next__()[1] target_target = label.cuda() label=label.cuda() # target_source = label.cuda() input_target = image.cuda() input_source = calculate_zeroshot_weights(classnames, label, templates, CLIP_Text, Text_Encoder) data_time.update(time.time() - end) # clip图片编码器 with torch.no_grad(): input_target_temp = CLIP_Image(input_target) input_target_add = Image_Encoder(input_target_temp) # 使用CBP层融合图像和文本特征 fused_features = cbp_layer(input_target_add, input_source) # 文本直接输入全连接层 output = adapter(fused_features) * logit_scale # 有监督分类的交叉熵损失 loss= criterion(output, label) prec, _ = accuracy(output, label, topk=(1, 5)) losses_T.update(loss.item(), input_source.size(0)) top1_source.update(prec[0], input_source.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() batch_time.update(time.time() - end) if (epoch + 1) % args.print_freq == 0 or epoch == 0: print('Train: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t' 'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t' 'Loss@T {loss_t.val:.4f} ({loss_t.avg:.4f})\t' 'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t' 'top1T {top1T.val:.3f} ({top1T.avg:.3f})\t'.format( epoch, args.epochs, batch_time=batch_time, data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, loss_t=losses_T, top1S=top1_source, top1T=top1_target)) return source_train_loader_batch, epoch, new_epoch_flag def validate(best_epoch,classnames, templates, val_loader, model, adapter, epoch, args, criterion, best_prec, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder,cbp_layer): batch_time = AverageMeter() losses_source = AverageMeter() losses_target = AverageMeter() top1_source = AverageMeter() top1_target = AverageMeter() CLIP_Text.eval() CLIP_Image.eval() Text_Encoder.eval() Image_Encoder.eval() cbp_layer.eval() model.eval() adapter.eval() end = time.time() logit_scale = model.logit_scale.exp() for i, (image, label, _) in enumerate(val_loader): image = image.cuda() label = label.cuda() input_source = calculate_zeroshot_weights(classnames, label, templates, CLIP_Text, Text_Encoder) input_target = image.cuda() target_target = label.cuda() target_source = label.cuda() # clip图片编码器 with torch.no_grad(): input_target_temp = CLIP_Image(input_target) input_target_add = Image_Encoder(input_target_temp) fused_features=cbp_layer(input_target_add,input_target_add) # output_source = adapter(input_source) * logit_scale output = adapter(fused_features) * logit_scale # 3 loss_source = criterion(output, target_target) # measure accuracy and record loss prec, _ = accuracy(output, target_target, topk=(1, 5)) losses_source.update(loss_source.item(), image.size(0)) top1_source.update(prec[0], image.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Test: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'LS {lossS.val:.4f} ({lossS.avg:.4f})\t' 'LT {lossT.val:.4f} ({lossT.avg:.4f})\t' 'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t' 'top1T {top1T.val:.3f} ({top1T.avg:.3f})'.format( epoch, i, len(val_loader), batch_time=batch_time, lossS=losses_source, lossT=losses_target, top1S=top1_source, top1T=top1_target)) print(' * Top1@S {top1S.avg:.3f} Top1@T {top1T.avg:.3f}' .format(top1S=top1_source, top1T=top1_target)) prec = max(top1_target.avg, top1_source.avg).item() if prec > best_prec: best_prec = max(top1_target.avg, top1_source.avg).item() best_epoch = epoch print('best_epoch', best_epoch, ' * Current_best_target@T:', best_prec) return prec,best_epoch def main(): args = opts() set_seed(2023) model, preprocess = clip.load(args.name) model = model.cuda() model.float( ) classnames, templates, loader, train_loader = get_dataset_loader(args, preprocess) CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder = configure_clip_encoders(args, model, 0, 1) prepare_directories(args, CLIP_Text, CLIP_Image) # 分类层 weights = set_adapter_weights_single(model, classnames, templates) adapter = Adapter(args, classnames, weights).cuda() #cbp # 定义模型参数 input_dim = 1024 output_dim = 1024 # 或者其他根据需要设置的值 # 实例化CBP模块 cbp_layer = CompactBilinearPooling(input_dim, output_dim) # 损失函数 criterion = nn.CrossEntropyLoss().cuda() criterion_classifier_target = DClassifierForTarget(nClass=len(classnames)).cuda() criterion_classifier_source = DClassifierForSource(nClass=len(classnames)).cuda() # 为模型的每个部分定义学习率和权重衰减 lr_adapter = 0.0001 lr_image_encoder = 0.00001 lr_text_encoder = 0.00001 weight_decay = 0.00001 # ADAM_BETAS 是用于控制移动平均衰减率的元组 ADAM_BETAS = (0.9, 0.999) # 创建 AdamW 优化器实例 optimizer = torch.optim.AdamW([ {'params': adapter.parameters(), 'lr': lr_adapter, 'weight_decay': weight_decay, 'betas': ADAM_BETAS}, {'params': Image_Encoder.parameters(), 'lr': lr_image_encoder, 'weight_decay': weight_decay, 'betas': ADAM_BETAS}, {'params': Text_Encoder.parameters(), 'lr': lr_text_encoder, 'weight_decay': weight_decay, 'betas': ADAM_BETAS} ], eps=1e-4) # 设置CosineAnnealingLR学习率调度器 # T_max设置为epochs的数量,表示在每个epoch后更新学习率 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(train_loader)) source_train_loader_batch = enumerate(train_loader) current_epoch = 0 best_prec = 0 best_epoch=0 while (current_epoch < args.epochs): source_train_loader_batch, current_epoch, new_epoch_flag = train(classnames, templates, train_loader, source_train_loader_batch, model, adapter, optimizer, current_epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder,cbp_layer) if new_epoch_flag: if (current_epoch + 1) % args.test_freq == 0 or current_epoch == 0: if current_epoch >= args.valepoch: prec,best_epoch = validate(best_epoch,classnames, templates, loader, model, adapter, current_epoch, args, criterion, best_prec, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder,cbp_layer) is_best = prec > best_prec if prec > args.valacc: if is_best: save_model(current_epoch, Text_Encoder, Image_Encoder, adapter,args, prec) best_prec = max(prec, best_prec) # 更新日志 current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) logging.info( f"Current Time: {current_time},Epoch: {current_epoch}, Accuracy: {prec}, Best: {best_prec}") if __name__ == '__main__': main()