Files
clip-symnets/main_DALN.py
2024-05-21 19:41:56 +08:00

284 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import time
from clip import clip
import torch.nn as nn
import torch.optim
from opts import opts # The options for the project
# from trainer import validate # For the validate (test) process
from models.DomainClassifierTarget import DClassifierForTarget
from models.DomainClassifierSource import DClassifierForSource
from utils.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss
from utils.utils import prepare_directories, set_seed, get_dataset_loader, configure_clip_encoders, save_model, \
set_adapter_weights,set_adapter_weights_single, get_text_feature, AverageMeter, accuracy, calculate_zeroshot_weights
from Adapter import Weight_Adapter,Classifier,Res_Adapter
import logging
import torch.nn.functional as F
from daln.nwd import NuclearWassersteinDiscrepancy
def train(classnames, templates, source_train_loader, source_train_loader_batch, model,
adapter, optimizer,
epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder,discrepancy,res_adapter):
batch_time = AverageMeter()
data_time = AverageMeter()
losses_classifier = AverageMeter()
losses_G = AverageMeter()
losses_T = AverageMeter()
top1_source = AverageMeter()
top1_target = AverageMeter()
CLIP_Text.eval()
CLIP_Image.eval()
Text_Encoder.eval()
Image_Encoder.eval()
model.eval()
logit_scale = model.logit_scale.exp()
res_adapter.train()
adapter.train()
new_epoch_flag = False
end = time.time()
concatenatedCELoss = ConcatenatedCELoss(num_classes=len(classnames)).cuda()
try:
(image, label, _) = source_train_loader_batch.__next__()[1]
except StopIteration:
epoch = epoch + 1
new_epoch_flag = True
source_train_loader_batch = enumerate(source_train_loader)
(image, label, _) = source_train_loader_batch.__next__()[1]
target_target = label.cuda()
# 自监督标签
label_self_supervised = label.cuda()
indices = torch.randperm(len(label))
target_source = label[indices].cuda()
# target_source = label.cuda()
input_target = image.cuda()
input_source = calculate_zeroshot_weights(classnames, target_source, templates, CLIP_Text, Text_Encoder)
input_source =res_adapter(input_source)
data_time.update(time.time() - end)
target_target_temp = target_target + len(classnames)
target_source_temp = target_source + len(classnames)
target_target_temp = target_target_temp.cuda()
# clip图片编码器
with torch.no_grad():
input_target_temp = CLIP_Image(input_target)
input_target_add = Image_Encoder(input_target_temp)
input_target_add =res_adapter(input_target_add )
# compute output
x = torch.cat((input_source, input_target_add), dim=0)
y = adapter(x)* logit_scale
y_s, y_t = y.chunk(2, dim=0)
labels_s=target_source
labels_t=label_self_supervised
cls_loss_1 = criterion(y_s, labels_s)
cls_loss_2 = criterion(y_t, labels_t)
discrepancy_loss = -discrepancy(x)
trade_off_lambda=-100
transfer_loss = discrepancy_loss * trade_off_lambda # multiply the lambda to trade off the loss term
loss = cls_loss_1+cls_loss_2 + transfer_loss
prec1_source, _ = accuracy(y_s.data, target_source, topk=(1, 5))
prec1_target, _ = accuracy(y_t.data, target_target, topk=(1, 5))
losses_G.update((cls_loss_1+cls_loss_2).item(), input_source.size(0))
losses_T.update(transfer_loss.item(), input_source.size(0))
top1_source.update(prec1_source[0], input_source.size(0))
top1_target.update(prec1_target[0], input_source.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
batch_time.update(time.time() - end)
if (epoch + 1) % args.print_freq == 0 or epoch == 0:
print('Train: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t'
'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t'
'Loss@T {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
'top1T {top1T.val:.3f} ({top1T.avg:.3f})\t'.format(
epoch, args.epochs, batch_time=batch_time,
data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, loss_t=losses_T, top1S=top1_source,
top1T=top1_target))
return source_train_loader_batch, epoch, new_epoch_flag
def validate(best_epoch,classnames, templates, val_loader, model, adapter, epoch, args, criterion, best_prec, CLIP_Text,
Text_Encoder, CLIP_Image,
Image_Encoder,res_adapter):
batch_time = AverageMeter()
losses_source = AverageMeter()
losses_target = AverageMeter()
top1_source = AverageMeter()
top1_target = AverageMeter()
CLIP_Text.eval()
CLIP_Image.eval()
Text_Encoder.eval()
Image_Encoder.eval()
res_adapter.eval()
model.eval()
adapter.eval()
end = time.time()
logit_scale = model.logit_scale.exp()
for i, (image, label, _) in enumerate(val_loader):
image = image.cuda()
label = label.cuda()
input_source = calculate_zeroshot_weights(classnames, label, templates, CLIP_Text, Text_Encoder)
input_target = image.cuda()
target_target = label.cuda()
target_source = label.cuda()
# clip图片编码器
with torch.no_grad():
input_target_temp = CLIP_Image(input_target)
input_target_add = Image_Encoder(input_target_temp)
input_target_add =res_adapter(input_target_add)
# output_source = adapter(input_source) * logit_scale
output_target = adapter(input_target_add) * logit_scale
output_source = output_target
# 3
loss_source = criterion(output_source, target_target)
loss_target = criterion(output_target, target_target)
# measure accuracy and record loss
prec1_source, _ = accuracy(output_source.data, target_target, topk=(1, 5))
prec1_target, _ = accuracy(output_target.data, target_target, topk=(1, 5))
losses_source.update(loss_source.item(), image.size(0))
losses_target.update(loss_target.item(), image.size(0))
top1_source.update(prec1_source[0], image.size(0))
top1_target.update(prec1_target[0], image.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'LS {lossS.val:.4f} ({lossS.avg:.4f})\t'
'LT {lossT.val:.4f} ({lossT.avg:.4f})\t'
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
'top1T {top1T.val:.3f} ({top1T.avg:.3f})'.format(
epoch, i, len(val_loader), batch_time=batch_time, lossS=losses_source, lossT=losses_target,
top1S=top1_source, top1T=top1_target))
print(' * Top1@S {top1S.avg:.3f} Top1@T {top1T.avg:.3f}'
.format(top1S=top1_source, top1T=top1_target))
prec = max(top1_target.avg, top1_source.avg).item()
if prec > best_prec:
best_prec = max(top1_target.avg, top1_source.avg).item()
best_epoch = epoch
print('best_epoch', best_epoch, ' * Current_best_target@T:', best_prec)
return prec,best_epoch
def main():
args = opts()
set_seed(2023)
model, preprocess = clip.load(args.name)
model = model.cuda()
model.float( )
classnames, templates, loader, train_loader = get_dataset_loader(args, preprocess)
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder = configure_clip_encoders(args, model, 0, 0)
prepare_directories(args, CLIP_Text, CLIP_Image)
# 分类层
weights = set_adapter_weights_single(model, classnames, templates)
# res_adapter = Weight_Adapter(args, classnames, weights).cuda()
res_adapter = Res_Adapter(1024).cuda()
adapter = Classifier(args, classnames, weights).cuda()
# instantiate NWD
discrepancy = NuclearWassersteinDiscrepancy(adapter)
# 损失函数
criterion = nn.CrossEntropyLoss().cuda()
criterion_classifier_target = DClassifierForTarget(nClass=len(classnames)).cuda()
criterion_classifier_source = DClassifierForSource(nClass=len(classnames)).cuda()
# 为模型的每个部分定义学习率和权重衰减
# lr_adapter = 0.0001
# lr_image_encoder = 0.00001
# lr_text_encoder = 0.00001
# weight_decay = 0.00001
lr_adapter = 0.0001
lr_image_encoder = 0.0001
lr_text_encoder = 0.00001
weight_decay = 0.00001
# ADAM_BETAS 是用于控制移动平均衰减率的元组
ADAM_BETAS = (0.9, 0.999)
# 创建 AdamW 优化器实例
# optimizer = torch.optim.AdamW([
# {'params': adapter.parameters(), 'lr': lr_adapter, 'weight_decay': weight_decay, 'betas': ADAM_BETAS},
# {'params': res_adapter.parameters(), 'lr': lr_image_encoder, 'weight_decay': weight_decay,
# 'betas': ADAM_BETAS},
# {'params': Text_Encoder.parameters(), 'lr': lr_text_encoder, 'weight_decay': weight_decay, 'betas': ADAM_BETAS}
# ], eps=1e-5)
optimizer = torch.optim.AdamW([
{'params': adapter.parameters(), 'lr': lr_adapter},
{'params': res_adapter.parameters(), 'lr': lr_image_encoder},
{'params': Text_Encoder.parameters(), 'lr': lr_text_encoder}
], eps=1e-5)
# 设置CosineAnnealingLR学习率调度器
# T_max设置为epochs的数量表示在每个epoch后更新学习率
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(train_loader))
source_train_loader_batch = enumerate(train_loader)
current_epoch = 0
best_prec = 0
best_epoch=0
while (current_epoch < args.epochs):
source_train_loader_batch, current_epoch, new_epoch_flag = train(classnames, templates,
train_loader,
source_train_loader_batch,
model,
adapter,
optimizer,
current_epoch,
args, scheduler, criterion, CLIP_Text,
Text_Encoder, CLIP_Image, Image_Encoder,discrepancy,res_adapter)
if new_epoch_flag:
if (current_epoch + 1) % args.test_freq == 0 or current_epoch == 0:
if current_epoch >= args.valepoch:
prec,best_epoch = validate(best_epoch,classnames, templates, loader, model, adapter, current_epoch, args, criterion,
best_prec,
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder,res_adapter)
is_best = prec > best_prec
if prec > args.valacc:
if is_best:
save_model(current_epoch, Text_Encoder, Image_Encoder, adapter,args, prec)
best_prec = max(prec, best_prec)
# 更新日志
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
logging.info(
f"Current Time: {current_time},Epoch: {current_epoch}, Accuracy: {prec}, Best: {best_prec}")
if __name__ == '__main__':
main()