Files
clip-symnets/main_coral_loss_mixup.py
2024-05-21 19:41:56 +08:00

374 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import time
from clip import clip
import torch.nn as nn
import numpy as np
import torch.optim
from opts import opts # The options for the project
# from trainer import validate # For the validate (test) process
from models.DomainClassifierTarget import DClassifierForTarget
from models.DomainClassifierSource import DClassifierForSource
from utils.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss
from utils.utils import prepare_directories, set_seed, get_dataset_loader, configure_clip_encoders, save_model, \
set_adapter_weights, get_text_feature, AverageMeter, accuracy, calculate_zeroshot_weights
from Adapter import Weight_Adapter
import logging
import torch.nn.functional as F
def mixup_data(x, y, labels, alpha=1.0):
'''对文本特征x和图像特征y执行Mixup操作然后融合它们'''
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size(0)
index = torch.randperm(batch_size).to(x.device)
mixed_x = lam * x + (1 - lam) * x[index, :]
mixed_y = lam * y + (1 - lam) * y[index, :]
# mixed_features=lam * x + (1 - lam) * y[index, :]
# alpha = np.random.beta(1.0, 1.0)
# pos_m = alpha * pos_1 + (1 - alpha) * pos_2[index, :]
# 假设您想将混合的文本特征和图像特征进行某种形式的融合
# 例如,通过简单的拼接(这里需要根据您的模型设计进行调整)
# mixed_features = torch.cat((mixed_x, mixed_y), dim=1)
mixed_labels = lam * labels + (1 - lam) * labels[index]
return mixed_x,mixed_y#, mixed_labels, index
def coral_loss(source_features, target_features):
"""
计算Deep CORAL损失。
:param source_features: 源域特征,维度为[batch_size, feature_dim]
:param target_features: 目标域特征,维度为[batch_size, feature_dim]
:return: CORAL损失
"""
d = source_features.data.shape[1] # 特征维度
source_mean = torch.mean(source_features, dim=0)
target_mean = torch.mean(target_features, dim=0)
source_cov = (source_features - source_mean).T @ (source_features - source_mean) / (source_features.shape[0] - 1)
target_cov = (target_features - target_mean).T @ (target_features - target_mean) / (target_features.shape[0] - 1)
coral_loss = torch.sum(torch.pow(source_cov - target_cov, 2))# / (4*d*d)
return coral_loss
def train(classnames, templates, source_train_loader, source_train_loader_batch, model,
adapter, optimizer,
epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder):
batch_time = AverageMeter()
data_time = AverageMeter()
losses_classifier = AverageMeter()
losses_G = AverageMeter()
losses_T = AverageMeter()
top1_source = AverageMeter()
top1_target = AverageMeter()
CLIP_Text.eval()
CLIP_Image.eval()
Text_Encoder.train()
Image_Encoder.train()
model.eval()
logit_scale = model.logit_scale.exp()
adapter.train()
new_epoch_flag = False
end = time.time()
concatenatedCELoss = ConcatenatedCELoss(num_classes=len(classnames)).cuda()
try:
(image, label, _) = source_train_loader_batch.__next__()[1]
except StopIteration:
epoch = epoch + 1
new_epoch_flag = True
source_train_loader_batch = enumerate(source_train_loader)
(image, label, _) = source_train_loader_batch.__next__()[1]
target_target = label.cuda()
# 自监督标签
label_self_supervised = label.cuda()
indices = torch.randperm(len(label))
target_source = label[indices].cuda()
# target_source = label.cuda()
input_target = image.cuda()
input_source = calculate_zeroshot_weights(classnames, target_source, templates, CLIP_Text, Text_Encoder)
self_input_source = calculate_zeroshot_weights(classnames, label_self_supervised, templates, CLIP_Text,
Text_Encoder)
data_time.update(time.time() - end)
# clip图片编码器
with torch.no_grad():
input_target_temp = CLIP_Image(input_target)
input_target_add = Image_Encoder(input_target_temp)
mixed_x,mixed_y= mixup_data(self_input_source, input_target_add, label_self_supervised, alpha=1.0)
# mix_pred = adapter(mixed_features) # 确保模型可以处理这种形式的融合特征
# mix_loss_1 = soft_label_criterion(mix_pred[:,:len(classnames)], mixed_labels)
# mix_loss_2 = soft_label_criterion(mix_pred[:,len(classnames):], mixed_labels)
target_target_temp = target_target + len(classnames)
target_source_temp = target_source + len(classnames)
target_target_temp = target_target_temp.cuda()
# 计算CORAL损失
# 总损失
# 文本直接输入全连接层
output_source = adapter(input_source) * logit_scale
# 图片直接输入全连接层
output_target = adapter(input_target_add) * logit_scale
coral_loss_value = coral_loss(self_input_source, input_target_add)
lambda_coral=50
loss_1=lambda_coral*coral_loss_value
# 自监督文本输入全连接层
# self_output_source = adapter(self_input_source)
# self_output_source = F.normalize(self_output_source[:,:len(classnames)])
# self_output_source = F.normalize(self_input_source)
#
# # 自监督图像特征
# # self_output_target = output_target / logit_scale
# # self_output_target = F.normalize(self_output_target[:,len(classnames):])
# self_output_target = F.normalize(input_target_add)
mixed_x, mixed_y
# # # 构造自监督标签0-255
self_supervised_labels = torch.arange(mixed_x.shape[0], device="cuda:0", dtype=torch.long)
logits_per_image = logit_scale * mixed_x @ mixed_y.T
logits_per_text = logit_scale * mixed_y @ mixed_x.T
loss_self_supervised_1 = (
F.cross_entropy(logits_per_image, self_supervised_labels) +
F.cross_entropy(logits_per_text, self_supervised_labels)
) / 2
# # 自监督文本输入全连接层
# self_output_source = adapter(self_input_source)
# self_output_source = F.normalize(self_output_source[:, :len(classnames)])
# # self_output_source = F.normalize(self_input_source)
#
# # 自监督图像特征
# self_output_target = output_target / logit_scale
# self_output_target = F.normalize(self_output_target[:, len(classnames):])
# # self_output_target = F.normalize(input_target_add)
# # # 构造自监督标签0-255
# self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long)
# logits_per_image = logit_scale * self_output_target @ self_output_source.T
# logits_per_text = logit_scale * self_output_source @ self_output_target.T
# loss_self_supervised_2 = (
# F.cross_entropy(logits_per_image, self_supervised_labels) +
# F.cross_entropy(logits_per_text, self_supervised_labels)
# ) / 2
#
loss_self_supervised = loss_self_supervised_1 # + loss_self_supervised_2
# 有监督分类的交叉熵损失
loss_task_s_Cs = criterion(output_source[:, :len(classnames)], target_source)
loss_task_s_Ct = criterion(output_target[:, len(classnames):], target_target)
# 对于源域数据,它希望让分类器上半部分所占的概率尽可能大,对于目标域数据,它希望让分类器下半部分所占的概率尽可能大。
loss_domain_st_Cst_part1 = criterion(output_source, target_source)
loss_domain_st_Cst_part2 = criterion(output_target, target_target_temp)
# 类级别混淆
loss_category_st_G = 0.5 * criterion(output_target, target_target) + 0.5 * criterion(output_source,
target_source_temp)
# 域级别混淆
# loss_domain_st_G = 0.5 * criterion_classifier_target(output_source) + 0.5 * criterion_classifier_source(
# output_target)
lam = 1 # 2 / (1 + math.exp(-1 * 10 * epoch / args.epochs)) - 1
# if(epoch<30):
# self_lam= 5
# else:
self_lam = 1
loss_confusion_target = concatenatedCELoss(output_target)
loss_classifier = loss_task_s_Cs + loss_task_s_Ct + loss_domain_st_Cst_part1 + loss_domain_st_Cst_part2
loss_G = loss_category_st_G + lam * loss_confusion_target
loss_T = loss_G + loss_classifier+lambda_coral * coral_loss_value + self_lam * loss_self_supervised
prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_source, topk=(1, 5))
prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5))
losses_classifier.update(loss_classifier.item(), input_source.size(0))
losses_G.update(loss_G.item(), input_source.size(0))
losses_T.update(loss_T.item(), input_source.size(0))
top1_source.update(prec1_source[0], input_source.size(0))
top1_target.update(prec1_target[0], input_source.size(0))
optimizer.zero_grad()
loss_T.backward()
optimizer.step()
scheduler.step()
batch_time.update(time.time() - end)
if (epoch + 1) % args.print_freq == 0 or epoch == 0:
print('Train: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t'
'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t'
'Loss@T {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
'top1T {top1T.val:.3f} ({top1T.avg:.3f})\t'.format(
epoch, args.epochs, batch_time=batch_time,
data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, loss_t=losses_T, top1S=top1_source,
top1T=top1_target))
return source_train_loader_batch, epoch, new_epoch_flag
def validate(best_epoch,classnames, templates, val_loader, model, adapter, epoch, args, criterion, best_prec, CLIP_Text,
Text_Encoder, CLIP_Image,
Image_Encoder):
batch_time = AverageMeter()
losses_source = AverageMeter()
losses_target = AverageMeter()
top1_source = AverageMeter()
top1_target = AverageMeter()
CLIP_Text.eval()
CLIP_Image.eval()
Text_Encoder.eval()
Image_Encoder.eval()
model.eval()
adapter.eval()
end = time.time()
logit_scale = model.logit_scale.exp()
for i, (image, label, _) in enumerate(val_loader):
image = image.cuda()
label = label.cuda()
input_source = calculate_zeroshot_weights(classnames, label, templates, CLIP_Text, Text_Encoder)
input_target = image.cuda()
target_target = label.cuda()
target_source = label.cuda()
# clip图片编码器
with torch.no_grad():
input_target_temp = CLIP_Image(input_target)
input_target_add = Image_Encoder(input_target_temp)
# output_source = adapter(input_source) * logit_scale
output_target = adapter(input_target_add) * logit_scale
output_source = output_target
# 3
loss_source = criterion(output_source[:, :len(classnames)], target_target)
loss_target = criterion(output_target[:, len(classnames):], target_target)
# measure accuracy and record loss
prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_target, topk=(1, 5))
prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5))
losses_source.update(loss_source.item(), image.size(0))
losses_target.update(loss_target.item(), image.size(0))
top1_source.update(prec1_source[0], image.size(0))
top1_target.update(prec1_target[0], image.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'LS {lossS.val:.4f} ({lossS.avg:.4f})\t'
'LT {lossT.val:.4f} ({lossT.avg:.4f})\t'
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
'top1T {top1T.val:.3f} ({top1T.avg:.3f})'.format(
epoch, i, len(val_loader), batch_time=batch_time, lossS=losses_source, lossT=losses_target,
top1S=top1_source, top1T=top1_target))
print(' * Top1@S {top1S.avg:.3f} Top1@T {top1T.avg:.3f}'
.format(top1S=top1_source, top1T=top1_target))
prec = max(top1_target.avg, top1_source.avg).item()
if prec > best_prec:
best_prec = max(top1_target.avg, top1_source.avg).item()
best_epoch = epoch
print('best_epoch', best_epoch, ' * Current_best_target@T:', best_prec)
return prec,best_epoch
def main():
args = opts()
set_seed(2023)
model, preprocess = clip.load(args.name)
model = model.cuda()
model.float( )
classnames, templates, loader, train_loader = get_dataset_loader(args, preprocess)
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder = configure_clip_encoders(args, model, 0, 1)
prepare_directories(args, CLIP_Text, CLIP_Image)
# 分类层
weights = set_adapter_weights(model, classnames, templates)
adapter = Weight_Adapter(args, classnames, weights).cuda()
# 损失函数
criterion = nn.CrossEntropyLoss().cuda()
criterion_classifier_target = DClassifierForTarget(nClass=len(classnames)).cuda()
criterion_classifier_source = DClassifierForSource(nClass=len(classnames)).cuda()
# 为模型的每个部分定义学习率和权重衰减
lr_adapter = 0.0001
lr_image_encoder = 0.00001
lr_text_encoder = 0.00001
weight_decay = 0.00001
# ADAM_BETAS 是用于控制移动平均衰减率的元组
ADAM_BETAS = (0.9, 0.999)
# 创建 AdamW 优化器实例
optimizer = torch.optim.AdamW([
{'params': adapter.parameters(), 'lr': lr_adapter, 'weight_decay': weight_decay, 'betas': ADAM_BETAS},
{'params': Image_Encoder.parameters(), 'lr': lr_image_encoder, 'weight_decay': weight_decay,
'betas': ADAM_BETAS},
{'params': Text_Encoder.parameters(), 'lr': lr_text_encoder, 'weight_decay': weight_decay, 'betas': ADAM_BETAS}
], eps=1e-4)
# 设置CosineAnnealingLR学习率调度器
# T_max设置为epochs的数量表示在每个epoch后更新学习率
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(train_loader))
source_train_loader_batch = enumerate(train_loader)
current_epoch = 0
best_prec = 0
best_epoch=0
while (current_epoch < args.epochs):
source_train_loader_batch, current_epoch, new_epoch_flag = train(classnames, templates,
train_loader,
source_train_loader_batch,
model,
adapter,
optimizer,
current_epoch,
args, scheduler, criterion, CLIP_Text,
Text_Encoder, CLIP_Image, Image_Encoder)
if new_epoch_flag:
if (current_epoch + 1) % args.test_freq == 0 or current_epoch == 0:
if current_epoch >= args.valepoch:
prec,best_epoch = validate(best_epoch,classnames, templates, loader, model, adapter, current_epoch, args, criterion,
best_prec,
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder)
is_best = prec > best_prec
if prec > args.valacc:
if is_best:
save_model(current_epoch, Text_Encoder, Image_Encoder, adapter,args, prec)
best_prec = max(prec, best_prec)
# 更新日志
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
logging.info(
f"Current Time: {current_time},Epoch: {current_epoch}, Accuracy: {prec}, Best: {best_prec}")
if __name__ == '__main__':
main()