import json import os import random import shutil import time from clip import clip import numpy as np import torch.backends.cudnn as cudnn import torch.nn as nn import torch.optim from data.prepare_data_shot2 import generate_dataloader # Prepare the data and dataloader from opts import opts # The options for the project from trainer_1_17 import train # For the training process from trainer_1_17 import validate # For the validate (test) process from models.DomainClassifierTarget import DClassifierForTarget from models.DomainClassifierSource import DClassifierForSource from engine import partial_model from clip.model import ModifiedResNet, VisionTransformer from datasets import build_dataset from datasets.utils import build_data_loader import torchvision.transforms as transforms import math import shutil best_prec1 = 0 class Weight_Adapter(nn.Module): def __init__(self, n_input, n_output,adapter_weights): super().__init__() self.linear1 = nn.Linear(n_input, n_output, bias=False) self.linear1.weight.data = adapter_weights # Initialize linear layer weights def forward(self, x): x = self.linear1(x.float()) return x class Adapter(nn.Module): def __init__(self, n_input,n_output): super().__init__() self.residual_ratio = 0.2 self.linear1 = nn.Linear(n_input, n_output, bias=False) # self.linear1.weight.data = adapter_weights # Initialize linear layer weights self.relu=nn.ReLU() def forward(self, x): a=x x = self.linear1(x.float()) x=self.relu(x) # x = self.residual_ratio * x + (1 - self.residual_ratio) * a return x def zeroshot_classifier(classname, templates, CLIP_Text): with torch.no_grad(): classname = classname.replace('_', ' ') str_prompts = [template.format(classname) for template in templates] prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda() features, eot_indices = CLIP_Text(prompts) return features, eot_indices class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() _2, pred2 = output.topk(1, 1, True, True) a = target.view(1, -1) correct = pred.eq(target.view(1, -1).expand_as(pred)) # print(correct) res = [] for k in topk: correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res def all_classifier(classnames, templates, model): with torch.no_grad(): zeroshot_weights = [] for classname in classnames: classname = classname.replace('_', ' ') texts = [template.format(classname) for template in templates] # format with class texts = clip.tokenize(texts).cuda() # tokenizeclip.tokenize向量化文字 class_embeddings = model.encode_text(texts) # embed with text encoder class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding /= class_embedding.norm() zeroshot_weights.append(class_embedding) zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() return zeroshot_weights def validate_train(classnames, templates,val_loader, model, args, zero_shots, criterion, optimizer, scheduler, alpha, beta, gama, CLIP_Text, CLIP_Image,Image_Encoder,Text_Encoder,adapter): global best_target_acc Compu1_acc = AverageMeter() losses = AverageMeter() CLIP_Text.eval() CLIP_Image.eval() Image_Encoder.eval() Text_Encoder.eval() adapter.eval() logit_scale = 4.60517 logit_scale = math.exp(logit_scale) # switch to evaluate mode for i, (image, label) in enumerate(val_loader): image = image.cuda() label = label.cuda() zeroshot_weights = [] for j in range(len(label)): features, eot_indices = zeroshot_classifier(classnames[label[j]], templates, CLIP_Text) with torch.no_grad(): class_embeddings = Text_Encoder(features, eot_indices) class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding = class_embedding / class_embedding.norm() class_embedding = class_embedding / class_embedding.norm(dim=-1, keepdim=True) zeroshot_weights.append(class_embedding) zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() input_source = zeroshot_weights input_source = input_source.T input_target = image.cuda() target_target = label.cuda() target_source = label.cuda() input_target_clip = model.encode_image(input_target) # clip图片编码器 with torch.no_grad(): input_target_temp = CLIP_Image(input_target) input_target_add = Image_Encoder(input_target_temp) output_source = adapter(input_source) * logit_scale output_target = adapter(input_target_add) * logit_scale # 3 loss_source = criterion(output_source[:, :len(classnames)], target_source) loss_target = criterion(output_target[:, len(classnames):], target_target) # measure accuracy and record loss prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_source, topk=(1, 5)) prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5)) # 2 logits2 =100.* input_target_clip.float() @ zero_shots.float() # 3 logits3 = output_target[:, len(classnames):] # compu1:1-2+3: compu1 = beta*logits2 + gama * logits3 compu1_acc = accuracy(compu1, target_target, topk=(1, 5)) loss = criterion(compu1, target_target) Compu1_acc.update(compu1_acc[0].item(), image.size(0)) losses.update(loss.item(), image.size(0)) print('loss:', loss.item()) print(i, '/', len(val_loader)) print('Compu1_acc:', Compu1_acc.val, 'alpha:', alpha.item(), 'beta:', beta.item(), 'gama:', gama.item()) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() print('Compu1_acc.avg', Compu1_acc.avg, 'alpha:', alpha.item(), 'beta:', beta.item(), 'gama:', gama.item(), 'losses.avg', losses.avg) return Compu1_acc.avg, alpha.item(), beta.item(), gama.item() def main(): seed = 2023 random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) global args, best_prec1 current_epoch = 0 args = opts() clip.available_models() model, preprocess = clip.load(args.name) # model = model.cuda() model.float() if os.path.exists(args.filename_dir): print('exist') else: os.makedirs(args.filename_dir) filename=args.filename_dir+args.dataset_name+'.txt' if os.path.exists(filename): print(filename + " exist!") else: print("create " + filename) f = open(filename, "w") f.close() epx_dir=args.savedir+args.dataset_name+'_epx/'+str(args.shot)+'shot'+'/' if os.path.exists(epx_dir): print('epx_dir exist') else: os.makedirs(epx_dir) dataset = build_dataset(args.dataset_name, args.dataset_dir, args.shot) classnames=dataset.classnames templates=dataset.template # loader = build_data_loader(data_source=dataset.val, batch_size=64, is_train=False, tfm=preprocess, # shuffle=False) loader = build_data_loader(data_source=dataset.test, batch_size=64, is_train=False, tfm=preprocess, shuffle=False) train_tranform = transforms.Compose([ transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ]) # # train_loader_cache = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform, # is_train=True, shuffle=False) train_loader_shuffle = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=True) criterion = nn.CrossEntropyLoss().cuda() # if not os.path.isdir(args.log): # os.makedirs(args.log) # log = open(os.path.join(args.log, 'log.txt'), 'a') # state = {k: v for k, v in args._get_kwargs()} # log.write(json.dumps(state) + '\n') # log.close() # # cudnn.benchmark = True # Benchmark模式会提升计算速度,但是由于计算中有随机性,每次网络前馈结果略有差异 # # log = open(os.path.join(args.log, 'log.txt'), 'a') # log.write('\n-------------------------------------------\n') # log.write(time.asctime(time.localtime(time.time()))) # log.write('\n-------------------------------------------') # log.close() # process the data and prepare the dataloaders. # train_loader_shuffle, loader = generate_dataloader(args, preprocess) #拆分CLIP图像编码器 if args.name =="ViT-B/16": CLIP_Text,Text_Encoder=partial_model.get_text(model,text_layer_idx=0) assert type(model.visual) == VisionTransformer CLIP_Image,Image_Encoder=partial_model.get_image_vit(model.visual, image_layer_idx=0) elif args.name =="ViT-B/32": CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=0) assert type(model.visual) == VisionTransformer CLIP_Image, Image_Encoder = partial_model.get_image_vit(model.visual, image_layer_idx=0) elif args.name == "RN50": CLIP_Text,Text_Encoder =partial_model.get_text(model,text_layer_idx=0) assert type(model.visual) == ModifiedResNet CLIP_Image,Image_Encoder=partial_model.get_image_resnet(model.visual, image_layer_idx=1) elif args.name == "RN101": CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=0) assert type(model.visual) == ModifiedResNet CLIP_Image, Image_Encoder = partial_model.get_image_resnet(model.visual, image_layer_idx=0) elif args.name == "RN50x16": CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=0) assert type(model.visual) == ModifiedResNet CLIP_Image, Image_Encoder = partial_model.get_image_resnet(model.visual, image_layer_idx=0) # 1000类标签经过clip model=model.cuda() zero_weights = all_classifier(classnames, templates, model) CLIP_Text,Text_Encoder,CLIP_Image,Image_Encoder=CLIP_Text.cuda(),Text_Encoder.cuda(),CLIP_Image.cuda(),Image_Encoder.cuda() weights_path = None best_epoch=0 best_init_acc=0 criterion_classifier_target = DClassifierForTarget(nClass=len(classnames)).cuda() criterion_classifier_source = DClassifierForSource(nClass=len(classnames)).cuda() text_weights=zero_weights adapter_weights=torch.cat([text_weights,text_weights],dim=1).T adapter = Weight_Adapter(1024, 2 * len(classnames),adapter_weights).cuda() ADAM_BETAS = (0.9, 0.999) if args.shot>=18: optimizer = torch.optim.AdamW([{'params': adapter.parameters(), 'lr': 0.001}, {'params': Image_Encoder.parameters(), 'lr':0.00001}, {'params': Text_Encoder.parameters(), 'lr': 0.00001}], eps=1e-5) else: # optimizer = torch.optim.AdamW([{'params': adapter.parameters(), 'lr': 0.0001}, # {'params': Image_Encoder.parameters(), 'lr':0.00001}, # {'params': Text_Encoder.parameters(), 'lr': 0.00001}], # eps=1e-5) # optimizer = torch.optim.AdamW([{'params': adapter.parameters()}, # {'params': Image_Encoder.parameters()}, # {'params': Text_Encoder.parameters()}], # eps=1e-5,lr=0.0001,weight_decay=0.0001) ##caltech101 # optimizer = torch.optim.AdamW( # [ # {'params': adapter.parameters(), 'lr': 0.0001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS}, # {'params': Image_Encoder.parameters(), 'lr': 0.00001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS}, # {'params': Text_Encoder.parameters(), 'lr': 0.00001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS}] # , eps=1e-4 # ) optimizer = torch.optim.AdamW( [ {'params': adapter.parameters(), 'lr': 0.0001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS}, {'params': Image_Encoder.parameters(), 'lr': 0.00001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS}, {'params': Text_Encoder.parameters(), 'lr': 0.00001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS}] , eps=1e-4 ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(train_loader_shuffle)) source_train_loader_batch = enumerate(train_loader_shuffle) dir=args.savedir+args.dataset_name+'_epx/'+str(args.shot)+'shot'+'/' torch.save(CLIP_Text, dir + '/CLIP_Text.pth') torch.save(CLIP_Image, dir + '/CLIP_Image.pth') while (current_epoch < args.epochs): source_train_loader_batch, current_epoch, new_epoch_flag = train(classnames, templates, train_loader_shuffle, source_train_loader_batch, model, adapter, criterion_classifier_source, criterion_classifier_target, optimizer, current_epoch, args, scheduler, criterion, CLIP_Text,Text_Encoder,CLIP_Image,Image_Encoder) # evaluate on the val data if new_epoch_flag: if (current_epoch + 1) % args.test_freq == 0 or current_epoch == 0: if current_epoch >=args.valepoch: prec1 = validate(classnames, templates,loader, model, adapter, current_epoch, args, zero_weights, criterion, CLIP_Text,Text_Encoder,CLIP_Image,Image_Encoder) # record the best prec1 and save checkpoint is_best = prec1 > best_prec1 if prec1 > args.valacc: save_dir = dir+'/epoch_' + str(current_epoch) + '_' + str( prec1) if not os.path.isdir(save_dir): os.mkdir(save_dir) torch.save(adapter, save_dir + '/_adapter_extractor.pth') torch.save(Text_Encoder, save_dir + '/Text_Encoder.pth') torch.save(Image_Encoder, save_dir + '/Image_Encoder.pth') best_prec1 = max(prec1, best_prec1) if is_best: weights_path=save_dir best_init_acc=best_prec1 best_epoch=current_epoch # log = open(os.path.join(args.log, 'log.txt'), 'a') # log.write('Best acc: %3f' % (best_prec1)) # log.close() filename=args.filename_dir+args.dataset_name+'.txt' strr=str(args.shot)+'shot'+' '+'best_epoch'+' '+str(best_epoch)+' '+'best_init_acc'+' '+str(best_init_acc) with open(filename, 'a') as f: f.write(strr+ '\n') f.close() # log = open(os.path.join(args.log, 'log.txt'), 'a') # log.write('\n-------------------------------------------\n') # log.write(time.asctime(time.localtime(time.time()))) # log.write('\n-------------------------------------------\n') # log.close() if __name__ == '__main__': main()