Files
clip-symnets/main_fusion_adapter.py
2024-05-21 19:41:56 +08:00

268 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import time
from clip import clip
import torch.nn as nn
import torch.optim
from opts import opts # The options for the project
# from trainer import validate # For the validate (test) process
from models.DomainClassifierTarget import DClassifierForTarget
from models.DomainClassifierSource import DClassifierForSource
from utils.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss
from utils.utils import prepare_directories, set_seed, get_dataset_loader, configure_clip_encoders, save_model, \
set_adapter_weights, get_text_feature, AverageMeter, accuracy, calculate_zeroshot_weights,set_adapter_weights_single
from Adapter import Weight_Adapter,Classifier,Res_Adapter
import logging
import torch.nn.functional as F
def train(classnames, templates, source_train_loader, source_train_loader_batch, model,
classifier, optimizer,
epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder,adapter):
batch_time = AverageMeter()
data_time = AverageMeter()
losses_classifier = AverageMeter()
losses_G = AverageMeter()
losses_T = AverageMeter()
top1_source = AverageMeter()
top1_target = AverageMeter()
CLIP_Text.eval()
CLIP_Image.eval()
Text_Encoder.train()
Image_Encoder.train()
adapter.train()
model.eval()
logit_scale = model.logit_scale.exp()
classifier.train()
new_epoch_flag = False
end = time.time()
concatenatedCELoss = ConcatenatedCELoss(num_classes=len(classnames)).cuda()
try:
(image, label, _) = source_train_loader_batch.__next__()[1]
except StopIteration:
epoch = epoch + 1
new_epoch_flag = True
source_train_loader_batch = enumerate(source_train_loader)
(image, label, _) = source_train_loader_batch.__next__()[1]
target_target = label.cuda()
# 自监督标签
label_self_supervised = label.cuda()
indices = torch.randperm(len(label))
target_source = label[indices].cuda()
# target_source = label.cuda()
input_target = image.cuda()
input_source = calculate_zeroshot_weights(classnames, target_source, templates, CLIP_Text, Text_Encoder)
self_input_source = calculate_zeroshot_weights(classnames, label_self_supervised, templates, CLIP_Text,
Text_Encoder)
data_time.update(time.time() - end)
# clip图片编码器
with torch.no_grad():
input_target_temp = CLIP_Image(input_target)
input_image = Image_Encoder(input_target_temp)
#特征1
fusion_feature=self_input_source+input_image
total_feature=torch.cat((fusion_feature,input_image,input_source),dim=0)
total_label=torch.cat((label_self_supervised,label_self_supervised,target_source),dim=0)
# 文本直接输入全连接层
output_source = classifier(adapter(total_feature)) * logit_scale
# 图片直接输入全连接层
# 有监督分类的交叉熵损失
loss = criterion(output_source, total_label)
loss_T = loss
prec1_source, _ = accuracy(output_source.data, total_label, topk=(1, 5))
losses_T.update(loss_T.item(), input_source.size(0))
top1_source.update(prec1_source[0], input_source.size(0))
optimizer.zero_grad()
loss_T.backward()
optimizer.step()
scheduler.step()
batch_time.update(time.time() - end)
if (epoch + 1) % args.print_freq == 0 or epoch == 0:
print('Train: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t'
'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t'
'Loss@T {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
'top1T {top1T.val:.3f} ({top1T.avg:.3f})\t'.format(
epoch, args.epochs, batch_time=batch_time,
data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, loss_t=losses_T, top1S=top1_source,
top1T=top1_target))
return source_train_loader_batch, epoch, new_epoch_flag
def validate(best_epoch,classnames, templates, val_loader, model, classifier, epoch, args, criterion, best_prec, CLIP_Text,
Text_Encoder, CLIP_Image,
Image_Encoder,adapter):
batch_time = AverageMeter()
losses_source = AverageMeter()
losses_target = AverageMeter()
top1_source = AverageMeter()
top1_target = AverageMeter()
CLIP_Text.eval()
CLIP_Image.eval()
Text_Encoder.eval()
Image_Encoder.eval()
adapter.eval()
model.eval()
classifier.eval()
end = time.time()
logit_scale = model.logit_scale.exp()
for i, (image, label, _) in enumerate(val_loader):
image = image.cuda()
label = label.cuda()
input_source = calculate_zeroshot_weights(classnames, label, templates, CLIP_Text, Text_Encoder)
input_target = image.cuda()
target_target = label.cuda()
target_source = label.cuda()
# clip图片编码器
with torch.no_grad():
input_target_temp = CLIP_Image(input_target)
input_target_add = Image_Encoder(input_target_temp)
# output_source = adapter(input_source) * logit_scale
output_target = classifier(adapter(input_target_add)) * logit_scale
output_source = output_target
# 3
loss_source = criterion(output_source, target_target)
# measure accuracy and record loss
prec1_source, _ = accuracy(output_source.data, target_target, topk=(1, 5))
losses_source.update(loss_source.item(), image.size(0))
top1_source.update(prec1_source[0], image.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'LS {lossS.val:.4f} ({lossS.avg:.4f})\t'
'LT {lossT.val:.4f} ({lossT.avg:.4f})\t'
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
'top1T {top1T.val:.3f} ({top1T.avg:.3f})'.format(
epoch, i, len(val_loader), batch_time=batch_time, lossS=losses_source, lossT=losses_target,
top1S=top1_source, top1T=top1_target))
print(' * Top1@S {top1S.avg:.3f} Top1@T {top1T.avg:.3f}'
.format(top1S=top1_source, top1T=top1_target))
prec = max(top1_target.avg, top1_source.avg).item()
if prec > best_prec:
best_prec = max(top1_target.avg, top1_source.avg).item()
best_epoch = epoch
print('best_epoch', best_epoch, ' * Current_best_target@T:', best_prec)
return prec,best_epoch
def main():
args = opts()
set_seed(2023)
model, preprocess = clip.load(args.name)
model = model.cuda()
model.float( )
classnames, templates, loader, train_loader = get_dataset_loader(args, preprocess)
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder = configure_clip_encoders(args, model, 0, 1)
prepare_directories(args, CLIP_Text, CLIP_Image)
# 分类层
weights = set_adapter_weights_single(model, classnames, templates)
classifier = Classifier(args, classnames, weights).cuda()
adapter=Res_Adapter(1024).cuda()
# 损失函数
criterion = nn.CrossEntropyLoss().cuda()
# 为模型的每个部分定义学习率和权重衰减
lr_adapter = 0.0001
lr_image_encoder = 0.0001
lr_text_encoder = 0.00001
weight_decay = 0.00001
# ADAM_BETAS 是用于控制移动平均衰减率的元组
ADAM_BETAS = (0.9, 0.999)
# 创建 AdamW 优化器实例
optimizer = torch.optim.AdamW([
{'params': adapter.parameters(), 'lr': lr_adapter, 'weight_decay': weight_decay, 'betas': ADAM_BETAS},
{'params': classifier.parameters(), 'lr': lr_adapter, 'weight_decay': weight_decay, 'betas': ADAM_BETAS},
{'params': Image_Encoder.parameters(), 'lr': lr_image_encoder, 'weight_decay': weight_decay,
'betas': ADAM_BETAS},
{'params': Text_Encoder.parameters(), 'lr': lr_text_encoder, 'weight_decay': weight_decay, 'betas': ADAM_BETAS}
], eps=1e-4)
# 设置CosineAnnealingLR学习率调度器
# T_max设置为epochs的数量表示在每个epoch后更新学习率
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(train_loader))
source_train_loader_batch = enumerate(train_loader)
current_epoch = 0
best_prec = 0
best_epoch=0
while (current_epoch < args.epochs):
source_train_loader_batch, current_epoch, new_epoch_flag = train(classnames, templates,
train_loader,
source_train_loader_batch,
model,
classifier,
optimizer,
current_epoch,
args, scheduler, criterion, CLIP_Text,
Text_Encoder, CLIP_Image, Image_Encoder,adapter)
if new_epoch_flag:
if (current_epoch + 1) % args.test_freq == 0 or current_epoch == 0:
if current_epoch >= args.valepoch:
prec,best_epoch = validate(best_epoch,classnames, templates, loader, model, classifier, current_epoch, args, criterion,
best_prec,
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder,adapter)
is_best = prec > best_prec
if prec > args.valacc:
if is_best:
save_model(current_epoch, Text_Encoder, Image_Encoder, classifier,args, prec)
best_prec = max(prec, best_prec)
# 更新日志
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
logging.info(
f"Current Time: {current_time},Epoch: {current_epoch}, Accuracy: {prec}, Best: {best_prec}")
if __name__ == '__main__':
main()