Files
clip-symnets/main_supervised_warm.py
2024-05-21 19:41:56 +08:00

422 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import os
import random
import shutil
import time
from clip import clip
import numpy as np
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim
from data.prepare_data_shot2 import generate_dataloader # Prepare the data and dataloader
from opts import opts # The options for the project
from trainer_self_supervised import train # For the training process
from trainer_supervised_warm import warm_train # For the training process
from trainer_supervised_warm import validate # For the validate (test) process
from trainer_supervised_warm import warm_validate # For the validate (test) process
from models.DomainClassifierTarget import DClassifierForTarget
from models.DomainClassifierSource import DClassifierForSource
from engine import partial_model
from clip.model import ModifiedResNet, VisionTransformer
from datasets import build_dataset
from datasets.utils import build_data_loader
import torchvision.transforms as transforms
import math
import shutil
best_prec1 = 0
# adapter 0.0001 text_encoder=0 89.6146011352539
class Weight_Adapter(nn.Module):
def __init__(self, n_input, n_output, adapter_weights):
super().__init__()
self.linear1 = nn.Linear(n_input, n_output)
self.linear1.weight.data = adapter_weights # Initialize linear layer weights
def forward(self, x):
x = self.linear1(x.float())
return x
# class Res_Adapter(nn.Module):
# def __init__(self, n_input, ):
# super().__init__()
# self.residual_ratio = 0.2
# self.fc = nn.Sequential(
# nn.Linear(n_input, n_input // 4, bias=False),
# nn.ReLU(inplace=True),
# nn.Linear(n_input // 4, n_input, bias=False),
# nn.ReLU(inplace=True)
# )
#
# def forward(self, x):
# a = self.fc(x)
# x = self.residual_ratio * a + (1 - self.residual_ratio) * x
#
# return x
def zeroshot_classifier(classname, templates, CLIP_Text):
with torch.no_grad():
classname = classname.replace('_', ' ')
str_prompts = [template.format(classname) for template in templates]
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
features, eot_indices = CLIP_Text(prompts)
return features, eot_indices
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
_2, pred2 = output.topk(1, 1, True, True)
a = target.view(1, -1)
correct = pred.eq(target.view(1, -1).expand_as(pred))
# print(correct)
res = []
for k in topk:
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def all_classifier(classnames, templates, model):
with torch.no_grad():
zeroshot_weights = []
for classname in classnames:
classname = classname.replace('_', ' ')
texts = [template.format(classname) for template in templates] # format with class
texts = clip.tokenize(texts).cuda() # tokenizeclip.tokenize向量化文字
class_embeddings = model.encode_text(texts) # embed with text encoder
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
return zeroshot_weights
def main():
seed = 2023
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
global args, best_prec1
current_epoch = 0
args = opts()
clip.available_models()
model, preprocess = clip.load(args.name)
# model = model.cuda()
model.float()
if os.path.exists(args.filename_dir):
print('exist')
else:
os.makedirs(args.filename_dir)
filename = args.filename_dir + args.dataset_name + '.txt'
if os.path.exists(filename):
print(filename + " exist!")
else:
print("create " + filename)
f = open(filename, "w")
f.close()
epx_dir = args.savedir + args.dataset_name + '_epx/' + str(args.shot) + 'shot' + '/'
if os.path.exists(epx_dir):
print('epx_dir exist')
else:
os.makedirs(epx_dir)
dataset = build_dataset(args.dataset_name, args.dataset_dir, args.shot)
classnames = dataset.classnames
templates = dataset.template
# loader = build_data_loader(data_source=dataset.val, batch_size=64, is_train=False, tfm=preprocess,
# shuffle=False)
loader = build_data_loader(data_source=dataset.test, batch_size=64, is_train=False, tfm=preprocess,
shuffle=False)
train_tranform = transforms.Compose([
transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
])
train_loader_shuffle = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform,
is_train=True,
shuffle=True)
criterion = nn.CrossEntropyLoss().cuda()
if not os.path.isdir(args.log):
os.makedirs(args.log)
log = open(os.path.join(args.log, 'log.txt'), 'a')
state = {k: v for k, v in args._get_kwargs()}
log.write(json.dumps(state) + '\n')
log.close()
cudnn.benchmark = True # Benchmark模式会提升计算速度但是由于计算中有随机性每次网络前馈结果略有差异
log = open(os.path.join(args.log, 'log.txt'), 'a')
log.write('\n-------------------------------------------\n')
log.write(time.asctime(time.localtime(time.time())))
log.write('\n-------------------------------------------')
log.close()
# process the data and prepare the dataloaders.
# train_loader_shuffle, loader = generate_dataloader(args, preprocess)
# 拆分CLIP图像编码器
if args.name == "ViT-B/16":
CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=1)
assert type(model.visual) == VisionTransformer
CLIP_Image, Image_Encoder = partial_model.get_image_vit(model.visual, image_layer_idx=0)
elif args.name == "ViT-B/32":
CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=1)
assert type(model.visual) == VisionTransformer
CLIP_Image, Image_Encoder = partial_model.get_image_vit(model.visual, image_layer_idx=0)
elif args.name == "RN50":
CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=0)
assert type(model.visual) == ModifiedResNet
CLIP_Image, Image_Encoder = partial_model.get_image_resnet(model.visual, image_layer_idx=1)
elif args.name == "RN101":
CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=0)
assert type(model.visual) == ModifiedResNet
CLIP_Image, Image_Encoder = partial_model.get_image_resnet(model.visual, image_layer_idx=0)
elif args.name == "RN50x16":
CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=1)
assert type(model.visual) == ModifiedResNet
CLIP_Image, Image_Encoder = partial_model.get_image_resnet(model.visual, image_layer_idx=0)
# 1000类标签经过clip
model = model.cuda()
zero_weights = all_classifier(classnames, templates, model)
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder = CLIP_Text.cuda(), Text_Encoder.cuda(), CLIP_Image.cuda(), Image_Encoder.cuda()
Init_Image_Encoder = Image_Encoder
best_epoch = 0
best_init_acc = 0
criterion_classifier_target = DClassifierForTarget(nClass=len(classnames)).cuda()
criterion_classifier_source = DClassifierForSource(nClass=len(classnames)).cuda()
text_weights = zero_weights
adapter_weights = torch.cat([text_weights, text_weights], dim=1).T
adapter = Weight_Adapter(1024, 2 * len(classnames), adapter_weights).cuda()
ADAM_BETAS = (0.9, 0.999)
if args.shot >= 18:
optimizer = torch.optim.RMSprop([{'params': adapter.parameters(), 'lr': 0.0001},
{'params': Image_Encoder.parameters(), 'lr': 0.00001},
{'params': Text_Encoder.parameters(), 'lr': 0.00001}],
eps=1e-5)
warm_optimizer = torch.optim.AdamW(
[
{'params': adapter.parameters(), 'lr': 0.0001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS},
{'params': Image_Encoder.parameters(), 'lr': 0.00001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS},
{'params': Text_Encoder.parameters(), 'lr': 0.00001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS}]
, eps=1e-4
)
else:
optimizer = torch.optim.AdamW(
[
{'params': adapter.parameters(), 'lr': 0.0001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS},
{'params': Image_Encoder.parameters(), 'lr': 0.00001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS},
{'params': Text_Encoder.parameters(), 'lr': 0.00001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS}]
, eps=1e-4
)
warm_optimizer = torch.optim.AdamW(
[
{'params': adapter.parameters(), 'lr': 0.0001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS},
{'params': Image_Encoder.parameters(), 'lr': 0.00001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS},
{'params': Text_Encoder.parameters(), 'lr': 0.00001, 'weight_decay': 0.00001, 'betas': ADAM_BETAS}]
, eps=1e-4
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(train_loader_shuffle))
source_train_loader_batch = enumerate(train_loader_shuffle)
warm_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(train_loader_shuffle))
source_train_loader_batch_warm = enumerate(train_loader_shuffle)
dir = args.savedir + args.dataset_name + '_epx/' + str(args.shot) + 'shot' + '/'
torch.save(CLIP_Text, dir + '/CLIP_Text.pth')
torch.save(CLIP_Image, dir + '/CLIP_Image.pth')
while (current_epoch < 80):
if (current_epoch < 80):
source_train_loader_batch, current_epoch, new_epoch_flag = warm_train(classnames, templates,
train_loader_shuffle,
source_train_loader_batch_warm,
model,
adapter,
criterion_classifier_source,
criterion_classifier_target,
warm_optimizer,
current_epoch,
args, warm_scheduler, criterion, CLIP_Text,
Text_Encoder, CLIP_Image,
Image_Encoder,
zero_weights)
else:
if (current_epoch + 1) % args.test_freq == 0 or current_epoch == 0:
if current_epoch >= args.valepoch:
prec1 = warm_validate(classnames, templates, loader, model, adapter, current_epoch, args,
zero_weights,
criterion,
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder)
# record the best prec1 and save checkpoint
is_best = prec1 > best_prec1
if prec1 > args.valacc:
save_dir = dir + '/epoch_' + str(current_epoch) + '_' + str(
prec1)
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
torch.save(adapter, save_dir + '/_adapter_extractor.pth')
torch.save(Text_Encoder, save_dir + '/Text_Encoder.pth')
torch.save(Image_Encoder, save_dir + '/Image_Encoder.pth')
best_prec1 = max(prec1, best_prec1)
if is_best:
save_dir = dir + '/epoch_' + str(current_epoch) + '_' + str(
prec1)
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
weights_path = save_dir
best_init_acc = best_prec1
best_epoch = current_epoch
log = open(os.path.join(args.log, 'log.txt'), 'a')
log.write('Best acc: %3f' % (best_prec1))
log.close()
current_epoch = 0
while (current_epoch < args.epochs):
if (current_epoch <0):
source_train_loader_batch, current_epoch, new_epoch_flag = warm_train(classnames, templates,
train_loader_shuffle,
source_train_loader_batch,
model,
adapter,
criterion_classifier_source,
criterion_classifier_target,
optimizer,
current_epoch,
args, scheduler, criterion, CLIP_Text,
Text_Encoder, CLIP_Image,
Image_Encoder,
zero_weights)
else:
source_train_loader_batch, current_epoch, new_epoch_flag = train(classnames, templates,
train_loader_shuffle,
source_train_loader_batch,
model,
adapter,
criterion_classifier_source,
criterion_classifier_target,
optimizer,
current_epoch,
args, scheduler, criterion, CLIP_Text,
Text_Encoder, CLIP_Image, Image_Encoder,
zero_weights)
if (current_epoch + 1) % args.test_freq == 0 or current_epoch == 0:
if current_epoch >= args.valepoch:
prec1 = validate(classnames, templates, loader, model, adapter, current_epoch, args,
zero_weights,
criterion,
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder)
# record the best prec1 and save checkpoint
is_best = prec1 > best_prec1
if prec1 > args.valacc:
save_dir = dir + '/epoch_' + str(current_epoch) + '_' + str(
prec1)
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
torch.save(adapter, save_dir + '/_adapter_extractor.pth')
torch.save(Text_Encoder, save_dir + '/Text_Encoder.pth')
torch.save(Image_Encoder, save_dir + '/Image_Encoder.pth')
best_prec1 = max(prec1, best_prec1)
if is_best:
save_dir = dir + '/epoch_' + str(current_epoch) + '_' + str(
prec1)
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
weights_path = save_dir
best_init_acc = best_prec1
best_epoch = current_epoch
log = open(os.path.join(args.log, 'log.txt'), 'a')
log.write('Best acc: %3f' % (best_prec1))
log.close()
# if new_epoch_flag:
# if (current_epoch + 1) % args.test_freq == 0 or current_epoch == 0:
# if current_epoch >= args.valepoch:
# prec1 = validate(classnames, templates, train_loader_shuffle, model, adapter, current_epoch, args,
# zero_weights,
# criterion,
# CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder, res_adapter)
# # record the best prec1 and save checkpoint
# is_best = prec1 > best_prec1
# if prec1 > args.valacc:
# save_dir = dir + '/epoch_' + str(current_epoch) + '_' + str(
# prec1)
# if not os.path.isdir(save_dir):
# os.mkdir(save_dir)
# torch.save(adapter, save_dir + '/_adapter_extractor.pth')
# torch.save(Text_Encoder, save_dir + '/Text_Encoder.pth')
# torch.save(Image_Encoder, save_dir + '/Image_Encoder.pth')
# best_prec1 = max(prec1, best_prec1)
# if is_best:
# save_dir = dir + '/epoch_' + str(current_epoch) + '_' + str(
# prec1)
# if not os.path.isdir(save_dir):
# os.mkdir(save_dir)
# weights_path = save_dir
# best_init_acc = best_prec1
# best_epoch = current_epoch
# log = open(os.path.join(args.log, 'log.txt'), 'a')
# log.write('Best acc: %3f' % (best_prec1))
# log.close()
# evaluate on the val data
filename = args.filename_dir + args.dataset_name + '.txt'
strr = str(args.shot) + 'shot' + ' ' + 'best_epoch' + ' ' + str(best_epoch) + ' ' + 'best_init_acc' + ' ' + str(
best_init_acc)
with open(filename, 'a') as f:
f.write(strr + '\n')
f.close()
log = open(os.path.join(args.log, 'log.txt'), 'a')
log.write('\n-------------------------------------------\n')
log.write(time.asctime(time.localtime(time.time())))
log.write('\n-------------------------------------------\n')
log.close()
if __name__ == '__main__':
main()