Files
clip-symnets/main_1_17.py
2024-05-21 19:41:56 +08:00

405 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import os
import random
import shutil
import time
from clip import clip
import numpy as np
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim
from data.prepare_data_shot2 import generate_dataloader # Prepare the data and dataloader
from opts import opts # The options for the project
from trainer_1_17 import train # For the training process
from trainer_1_17 import validate # For the validate (test) process
from models.DomainClassifierTarget import DClassifierForTarget
from models.DomainClassifierSource import DClassifierForSource
from engine import partial_model
from clip.model import ModifiedResNet, VisionTransformer
from datasets import build_dataset
from datasets.utils import build_data_loader
import torchvision.transforms as transforms
import math
import shutil
best_prec1 = 0
class Weight_Adapter(nn.Module):
def __init__(self, n_input, n_output,adapter_weights):
super().__init__()
self.linear1 = nn.Linear(n_input, n_output, bias=False)
self.linear1.weight.data = adapter_weights # Initialize linear layer weights
def forward(self, x):
x = self.linear1(x.float())
return x
class Adapter(nn.Module):
def __init__(self, n_input,n_output):
super().__init__()
self.residual_ratio = 0.2
self.linear1 = nn.Linear(n_input, n_output, bias=False)
# self.linear1.weight.data = adapter_weights # Initialize linear layer weights
self.relu=nn.ReLU()
def forward(self, x):
a=x
x = self.linear1(x.float())
x=self.relu(x)
# x = self.residual_ratio * x + (1 - self.residual_ratio) * a
return x
def zeroshot_classifier(classname, templates, CLIP_Text):
with torch.no_grad():
classname = classname.replace('_', ' ')
str_prompts = [template.format(classname) for template in templates]
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
features, eot_indices = CLIP_Text(prompts)
return features, eot_indices
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
_2, pred2 = output.topk(1, 1, True, True)
a = target.view(1, -1)
correct = pred.eq(target.view(1, -1).expand_as(pred))
# print(correct)
res = []
for k in topk:
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def all_classifier(classnames, templates, model):
with torch.no_grad():
zeroshot_weights = []
for classname in classnames:
classname = classname.replace('_', ' ')
texts = [template.format(classname) for template in templates] # format with class
texts = clip.tokenize(texts).cuda() # tokenizeclip.tokenize向量化文字
class_embeddings = model.encode_text(texts) # embed with text encoder
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
return zeroshot_weights
def validate_train(classnames, templates,val_loader, model, args, zero_shots, criterion,
optimizer, scheduler, alpha, beta, gama, CLIP_Text, CLIP_Image,Image_Encoder,Text_Encoder,adapter):
global best_target_acc
Compu1_acc = AverageMeter()
losses = AverageMeter()
CLIP_Text.eval()
CLIP_Image.eval()
Image_Encoder.eval()
Text_Encoder.eval()
adapter.eval()
logit_scale = 4.60517
logit_scale = math.exp(logit_scale)
# switch to evaluate mode
for i, (image, label) in enumerate(val_loader):
image = image.cuda()
label = label.cuda()
zeroshot_weights = []
for j in range(len(label)):
features, eot_indices = zeroshot_classifier(classnames[label[j]], templates, CLIP_Text)
with torch.no_grad():
class_embeddings = Text_Encoder(features, eot_indices)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding = class_embedding / class_embedding.norm()
class_embedding = class_embedding / class_embedding.norm(dim=-1, keepdim=True)
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
input_source = zeroshot_weights
input_source = input_source.T
input_target = image.cuda()
target_target = label.cuda()
target_source = label.cuda()
input_target_clip = model.encode_image(input_target)
# clip图片编码器
with torch.no_grad():
input_target_temp = CLIP_Image(input_target)
input_target_add = Image_Encoder(input_target_temp)
output_source = adapter(input_source) * logit_scale
output_target = adapter(input_target_add) * logit_scale
# 3
loss_source = criterion(output_source[:, :len(classnames)], target_source)
loss_target = criterion(output_target[:, len(classnames):], target_target)
# measure accuracy and record loss
prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_source, topk=(1, 5))
prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5))
# 2
logits2 =100.* input_target_clip.float() @ zero_shots.float()
# 3
logits3 = output_target[:, len(classnames):]
# compu1:1-2+3:
compu1 = beta*logits2 + gama * logits3
compu1_acc = accuracy(compu1, target_target, topk=(1, 5))
loss = criterion(compu1, target_target)
Compu1_acc.update(compu1_acc[0].item(), image.size(0))
losses.update(loss.item(), image.size(0))
print('loss:', loss.item())
print(i, '/', len(val_loader))
print('Compu1_acc:', Compu1_acc.val, 'alpha:', alpha.item(), 'beta:', beta.item(), 'gama:', gama.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
print('Compu1_acc.avg', Compu1_acc.avg, 'alpha:', alpha.item(), 'beta:', beta.item(), 'gama:', gama.item(),
'losses.avg', losses.avg)
return Compu1_acc.avg, alpha.item(), beta.item(), gama.item()
def main():
seed = 2023
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
global args, best_prec1
current_epoch = 0
args = opts()
clip.available_models()
model, preprocess = clip.load(args.name)
# model = model.cuda()
model.float()
if os.path.exists(args.filename_dir):
print('exist')
else:
os.makedirs(args.filename_dir)
filename=args.filename_dir+args.dataset_name+'.txt'
if os.path.exists(filename):
print(filename + " exist!")
else:
print("create " + filename)
f = open(filename, "w")
f.close()
epx_dir=args.savedir+args.dataset_name+'_epx/'+str(args.shot)+'shot'+'/'
if os.path.exists(epx_dir):
print('epx_dir exist')
else:
os.makedirs(epx_dir)
dataset = build_dataset(args.dataset_name, args.dataset_dir, args.shot)
classnames=dataset.classnames
templates=dataset.template
# loader = build_data_loader(data_source=dataset.val, batch_size=64, is_train=False, tfm=preprocess,
# shuffle=False)
loader = build_data_loader(data_source=dataset.test, batch_size=64, is_train=False, tfm=preprocess,
shuffle=False)
train_tranform = transforms.Compose([
transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
])
#
# train_loader_cache = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform,
# is_train=True, shuffle=False)
train_loader_shuffle = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True,
shuffle=True)
criterion = nn.CrossEntropyLoss().cuda()
# if not os.path.isdir(args.log):
# os.makedirs(args.log)
# log = open(os.path.join(args.log, 'log.txt'), 'a')
# state = {k: v for k, v in args._get_kwargs()}
# log.write(json.dumps(state) + '\n')
# log.close()
#
# cudnn.benchmark = True # Benchmark模式会提升计算速度但是由于计算中有随机性每次网络前馈结果略有差异
#
# log = open(os.path.join(args.log, 'log.txt'), 'a')
# log.write('\n-------------------------------------------\n')
# log.write(time.asctime(time.localtime(time.time())))
# log.write('\n-------------------------------------------')
# log.close()
# process the data and prepare the dataloaders.
# train_loader_shuffle, loader = generate_dataloader(args, preprocess)
#拆分CLIP图像编码器
if args.name =="ViT-B/16":
CLIP_Text,Text_Encoder=partial_model.get_text(model,text_layer_idx=0)
assert type(model.visual) == VisionTransformer
CLIP_Image,Image_Encoder=partial_model.get_image_vit(model.visual, image_layer_idx=0)
elif args.name =="ViT-B/32":
CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=0)
assert type(model.visual) == VisionTransformer
CLIP_Image, Image_Encoder = partial_model.get_image_vit(model.visual, image_layer_idx=0)
elif args.name == "RN50":
CLIP_Text,Text_Encoder =partial_model.get_text(model,text_layer_idx=0)
assert type(model.visual) == ModifiedResNet
CLIP_Image,Image_Encoder=partial_model.get_image_resnet(model.visual, image_layer_idx=1)
elif args.name == "RN101":
CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=0)
assert type(model.visual) == ModifiedResNet
CLIP_Image, Image_Encoder = partial_model.get_image_resnet(model.visual, image_layer_idx=0)
elif args.name == "RN50x16":
CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=0)
assert type(model.visual) == ModifiedResNet
CLIP_Image, Image_Encoder = partial_model.get_image_resnet(model.visual, image_layer_idx=0)
# 1000类标签经过clip
model=model.cuda()
zero_weights = all_classifier(classnames, templates, model)
CLIP_Text,Text_Encoder,CLIP_Image,Image_Encoder=CLIP_Text.cuda(),Text_Encoder.cuda(),CLIP_Image.cuda(),Image_Encoder.cuda()
weights_path = None
best_epoch=0
best_init_acc=0
criterion_classifier_target = DClassifierForTarget(nClass=len(classnames)).cuda()
criterion_classifier_source = DClassifierForSource(nClass=len(classnames)).cuda()
text_weights=zero_weights
adapter_weights=torch.cat([text_weights,text_weights],dim=1).T
adapter = Weight_Adapter(1024, 2 * len(classnames),adapter_weights).cuda()
ADAM_BETAS = (0.9, 0.999)
if args.shot>=18:
optimizer = torch.optim.AdamW([{'params': adapter.parameters(), 'lr': 0.001},
{'params': Image_Encoder.parameters(), 'lr':0.00001},
{'params': Text_Encoder.parameters(), 'lr': 0.00001}],
eps=1e-5)
else:
# optimizer = torch.optim.AdamW([{'params': adapter.parameters(), 'lr': 0.0001},
# {'params': Image_Encoder.parameters(), 'lr':0.00001},
# {'params': Text_Encoder.parameters(), 'lr': 0.00001}],
# eps=1e-5)
# optimizer = torch.optim.AdamW([{'params': adapter.parameters()},
# {'params': Image_Encoder.parameters()},
# {'params': Text_Encoder.parameters()}],
# eps=1e-5,lr=0.0001,weight_decay=0.0001)
##caltech101
# optimizer = torch.optim.AdamW(
# [
# {'params': adapter.parameters(), 'lr': 0.0001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS},
# {'params': Image_Encoder.parameters(), 'lr': 0.00001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS},
# {'params': Text_Encoder.parameters(), 'lr': 0.00001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS}]
# , eps=1e-4
# )
optimizer = torch.optim.AdamW(
[
{'params': adapter.parameters(), 'lr': 0.0001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS},
{'params': Image_Encoder.parameters(), 'lr': 0.00001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS},
{'params': Text_Encoder.parameters(), 'lr': 0.00001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS}]
, eps=1e-4
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(train_loader_shuffle))
source_train_loader_batch = enumerate(train_loader_shuffle)
dir=args.savedir+args.dataset_name+'_epx/'+str(args.shot)+'shot'+'/'
torch.save(CLIP_Text, dir + '/CLIP_Text.pth')
torch.save(CLIP_Image, dir + '/CLIP_Image.pth')
while (current_epoch < args.epochs):
source_train_loader_batch, current_epoch, new_epoch_flag = train(classnames, templates,
train_loader_shuffle,
source_train_loader_batch,
model,
adapter,
criterion_classifier_source,
criterion_classifier_target,
optimizer,
current_epoch,
args, scheduler, criterion, CLIP_Text,Text_Encoder,CLIP_Image,Image_Encoder)
# evaluate on the val data
if new_epoch_flag:
if (current_epoch + 1) % args.test_freq == 0 or current_epoch == 0:
if current_epoch >=args.valepoch:
prec1 = validate(classnames, templates,loader, model, adapter, current_epoch, args, zero_weights, criterion,
CLIP_Text,Text_Encoder,CLIP_Image,Image_Encoder)
# record the best prec1 and save checkpoint
is_best = prec1 > best_prec1
if prec1 > args.valacc:
save_dir = dir+'/epoch_' + str(current_epoch) + '_' + str(
prec1)
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
torch.save(adapter, save_dir + '/_adapter_extractor.pth')
torch.save(Text_Encoder, save_dir + '/Text_Encoder.pth')
torch.save(Image_Encoder, save_dir + '/Image_Encoder.pth')
best_prec1 = max(prec1, best_prec1)
if is_best:
weights_path=save_dir
best_init_acc=best_prec1
best_epoch=current_epoch
# log = open(os.path.join(args.log, 'log.txt'), 'a')
# log.write('Best acc: %3f' % (best_prec1))
# log.close()
filename=args.filename_dir+args.dataset_name+'.txt'
strr=str(args.shot)+'shot'+' '+'best_epoch'+' '+str(best_epoch)+' '+'best_init_acc'+' '+str(best_init_acc)
with open(filename, 'a') as f:
f.write(strr+ '\n')
f.close()
# log = open(os.path.join(args.log, 'log.txt'), 'a')
# log.write('\n-------------------------------------------\n')
# log.write(time.asctime(time.localtime(time.time())))
# log.write('\n-------------------------------------------\n')
# log.close()
if __name__ == '__main__':
main()