Files
clip-symnets/main_coral_loss_gpt3_final.py
2024-05-21 19:41:56 +08:00

458 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import time
from clip import clip
import torch.nn as nn
import numpy as np
import torch.optim
from opts import opts # The options for the project
# from trainer import validate # For the validate (test) process
from models.DomainClassifierTarget import DClassifierForTarget
from models.DomainClassifierSource import DClassifierForSource
from utils.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss
from utils.utils import prepare_directories, set_seed, get_dataset_loader, configure_clip_encoders, save_model, \
set_adapter_weights, get_text_feature, AverageMeter, accuracy, calculate_zeroshot_weights, gpt_clip_classifier,calculate_zeroshot_weights_GPT
from Adapter import Weight_Adapter
import logging
import torch.nn.functional as F
import yaml
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import glob
class CustomCrossAttention(nn.Module):
def __init__(self, feature_dim):
super(CustomCrossAttention, self).__init__()
self.query_projection = nn.Linear(feature_dim, feature_dim)
self.key_projection = nn.Linear(feature_dim, feature_dim)
self.value_projection = nn.Linear(feature_dim, feature_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, text_features, image_features):
# 假设 text_features 的 batch_size < image_features 的 batch_size
text_batch_size = text_features.size(0)
image_batch_size = image_features.size(0)
# 重复 text_features 以匹配 image_features 的 batch_size
if text_batch_size < image_batch_size:
repeat_times = image_batch_size // text_batch_size
text_features = text_features.repeat(repeat_times, 1)
query = self.query_projection(text_features)
key = self.key_projection(image_features)
value = self.value_projection(image_features)
# 计算注意力分数
attention_scores = torch.matmul(query, key.transpose(-2, -1))
attention_scores = self.softmax(attention_scores)
# 应用注意力分数到 value 上
attended_features = torch.matmul(attention_scores, value)
return attended_features
# def coral_loss(source_features, target_features):
# """
# 计算Deep CORAL损失。
# :param source_features: 源域特征,维度为[batch_size, feature_dim]
# :param target_features: 目标域特征,维度为[batch_size, feature_dim]
# :return: CORAL损失
# """
# d = source_features.data.shape[1] # 特征维度
# source_mean = torch.mean(source_features, dim=0)
# target_mean = torch.mean(target_features, dim=0)
# source_cov = (source_features - source_mean).T @ (source_features - source_mean) / (source_features.shape[0] - 1)
# target_cov = (target_features - target_mean).T @ (target_features - target_mean) / (target_features.shape[0] - 1)
# coral_loss = torch.sum(torch.pow(source_cov - target_cov, 2)) # / (4*d*d)
# return coral_loss
def coral_loss(source_features, target_features):
"""
计算Deep CORAL损失。
:param source_features: 源域特征,维度为[batch_size, feature_dim]
:param target_features: 目标域特征,维度为[batch_size, feature_dim]
:return: CORAL损失
"""
# 特征维度
d = source_features.data.shape[1]
# 计算均值
source_mean = torch.mean(source_features, dim=0)
target_mean = torch.mean(target_features, dim=0)
# 计算均值差异
# mean_diff = torch.pow(source_mean - target_mean, 2).sum().sqrt().mean()
mean_diff = torch.pow(source_mean - target_mean, 2).mean()
# 计算协方差矩阵
source_cov = (source_features - source_mean).T @ (source_features - source_mean) / (source_features.shape[0] - 1)
target_cov = (target_features - target_mean).T @ (target_features - target_mean) / (target_features.shape[0] - 1)
# 计算协方差矩阵差异的平均值
# cov_diff = torch.pow(source_cov - target_cov, 2).sum().sqrt().mean()
cov_diff = torch.pow(source_cov - target_cov, 2).mean()
# 返回均值差异和协方差矩阵差异的和
total_coral_loss = mean_diff + cov_diff
return total_coral_loss
def shuffle_data(weights, labels):
# 生成索引
indices = torch.randperm(len(weights))
# 使用索引来打乱数据和标签
shuffled_weights = weights[indices]
shuffled_labels = labels[indices]
return shuffled_weights, shuffled_labels
def train(classnames, templates, source_train_loader, source_train_loader_batch, model,
adapter, optimizer,
epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder, gpt_weight, gpt_label,gpt3_prompt):
batch_time = AverageMeter()
data_time = AverageMeter()
losses_classifier = AverageMeter()
losses_G = AverageMeter()
losses_CR = AverageMeter()
losses_T = AverageMeter()
top1_source = AverageMeter()
top1_target = AverageMeter()
CLIP_Text.eval()
CLIP_Image.eval()
Text_Encoder.train()
Image_Encoder.train()
model.eval()
logit_scale = model.logit_scale.exp()
adapter.train()
new_epoch_flag = False
end = time.time()
concatenatedCELoss = ConcatenatedCELoss(num_classes=len(classnames)).cuda()
try:
(image, label, _) = source_train_loader_batch.__next__()[1]
except StopIteration:
epoch = epoch + 1
new_epoch_flag = True
source_train_loader_batch = enumerate(source_train_loader)
(image, label, _) = source_train_loader_batch.__next__()[1]
target_target = label.cuda()
# 自监督标签
label_self_supervised = label.cuda()
indices = torch.randperm(len(label))
target_source = label[indices].cuda()
# target_source = label.cuda()
input_target = image.cuda()
# input_source = calculate_zeroshot_weights(classnames, target_source, templates, CLIP_Text, Text_Encoder)
input_source = calculate_zeroshot_weights_GPT(classnames, target_source, templates, CLIP_Text, Text_Encoder,gpt3_prompt)
gpt_weight, gpt_label=shuffle_data(gpt_weight, gpt_label)
# input_source = torch.cat((
# input_source, gpt_weight
# ), dim=0)
# target_source = torch.cat((
# target_source, gpt_label
# ), dim=0)
data_time.update(time.time() - end)
target_target_temp = target_target + len(classnames)
target_source_temp = target_source + len(classnames)
target_target_temp = target_target_temp.cuda()
# clip图片编码器
with torch.no_grad():
input_target_temp = CLIP_Image(input_target)
input_target_add = Image_Encoder(input_target_temp)
# 计算CORAL损失
# 总损失
# 文本直接输入全连接层
output_source = adapter(input_source) * logit_scale
# 图片直接输入全连接层
output_target = adapter(input_target_add) * logit_scale
# self_input_source = calculate_zeroshot_weights(classnames, label_self_supervised, templates, CLIP_Text,
# Text_Encoder)
self_input_source = calculate_zeroshot_weights_GPT(classnames, label_self_supervised, templates, CLIP_Text,
Text_Encoder,gpt3_prompt)
# input_source = calculate_zeroshot_weights_GPT(classnames, target_source, templates, CLIP_Text, Text_Encoder,gpt3_prompt)
# 计算MMD损失
# mmd_loss_val = mmd_loss(self_input_source, input_target_add)
# lambda_mmd=1000
# mmd_loss_val =lambda_mmd*mmd_loss_val
# 总损失
coral_loss_value = coral_loss(self_input_source, input_target_add)
lambda_coral = 1
loss_1 = lambda_coral * coral_loss_value
# 自监督文本输入全连接层
# self_output_source = adapter(self_input_source)
# self_output_source = F.normalize(self_output_source[:,:len(classnames)])
self_output_source = F.normalize(self_input_source)
# 自监督图像特征
# self_output_target = output_target / logit_scale
# self_output_target = F.normalize(self_output_target[:,len(classnames):])
self_output_target = F.normalize(input_target_add)
# # 构造自监督标签0-255
self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long)
logits_per_image = logit_scale * self_output_target @ self_output_source.T
logits_per_text = logit_scale * self_output_source @ self_output_target.T
loss_self_supervised_1 = (
F.cross_entropy(logits_per_image, self_supervised_labels) +
F.cross_entropy(logits_per_text, self_supervised_labels)
) / 2
# 自监督文本输入全连接层
self_output_source = adapter(self_input_source)
self_output_source = F.normalize(self_output_source[:, :len(classnames)])
# self_output_source = F.normalize(self_input_source)
# 自监督图像特征
self_output_target = output_target / logit_scale
self_output_target = F.normalize(self_output_target[:, len(classnames):])
# self_output_target = F.normalize(input_target_add)
# # 构造自监督标签0-255
self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long)
logits_per_image = logit_scale * self_output_target @ self_output_source.T
logits_per_text = logit_scale * self_output_source @ self_output_target.T
loss_self_supervised_2 = (
F.cross_entropy(logits_per_image, self_supervised_labels) +
F.cross_entropy(logits_per_text, self_supervised_labels)
) / 2
loss_self_supervised =loss_self_supervised_2 #loss_self_supervised_2 + loss_self_supervised_1
# 有监督分类的交叉熵损失
loss_task_s_Cs = criterion(output_source[:, :len(classnames)], target_source)
loss_task_s_Ct = criterion(output_target[:, len(classnames):], target_target)
# 对于源域数据,它希望让分类器上半部分所占的概率尽可能大,对于目标域数据,它希望让分类器下半部分所占的概率尽可能大。
loss_domain_st_Cst_part1 = criterion(output_source, target_source)
loss_domain_st_Cst_part2 = criterion(output_target, target_target_temp)
# 类级别混淆
loss_category_st_G = 0.5 * criterion(output_target, target_target) + 0.5 * criterion(output_source,
target_source_temp)
# 域级别混淆
# loss_domain_st_G = 0.5 * criterion_classifier_target(output_source) + 0.5 * criterion_classifier_source(
# output_target)
lam = 1 # 2 / (1 + math.exp(-1 * 10 * epoch / args.epochs)) - 1
# if(epoch<30):
# self_lam= 5
# else:
self_lam = 0
loss_confusion_target = concatenatedCELoss(output_target)
loss_classifier = loss_task_s_Cs + loss_task_s_Ct + loss_domain_st_Cst_part1 + loss_domain_st_Cst_part2
loss_G = loss_category_st_G + lam * loss_confusion_target
loss_T = loss_G + loss_classifier + self_lam * loss_self_supervised + lambda_coral * coral_loss_value#+ mmd_loss_val
prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_source, topk=(1, 5))
prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5))
losses_CR.update(loss_G.item(), input_source.size(0))
losses_classifier.update(loss_classifier.item(), input_source.size(0))
losses_G.update(loss_G.item(), input_source.size(0))
losses_CR.update(coral_loss_value.item(), input_source.size(0))
losses_T.update(loss_T.item(), input_source.size(0))
top1_source.update(prec1_source[0], input_source.size(0))
top1_target.update(prec1_target[0], input_source.size(0))
optimizer.zero_grad()
loss_T.backward()
optimizer.step()
scheduler.step()
batch_time.update(time.time() - end)
if (epoch + 1) % args.print_freq == 0 or epoch == 0:
print('Train: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t'
'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t'
'Loss@CR {loss_cr.val:.4f} ({loss_cr.avg:.4f})\t'
'Loss@T {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
'top1T {top1T.val:.3f} ({top1T.avg:.3f})\t'.format(
epoch, args.epochs, batch_time=batch_time,
data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, loss_cr=losses_CR,loss_t=losses_T, top1S=top1_source,
top1T=top1_target))
return source_train_loader_batch, epoch, new_epoch_flag
def validate(best_epoch, classnames, templates, val_loader, model, adapter, epoch, args, criterion, best_prec,
CLIP_Text,
Text_Encoder, CLIP_Image,
Image_Encoder):
batch_time = AverageMeter()
losses_source = AverageMeter()
losses_target = AverageMeter()
top1_source = AverageMeter()
top1_target = AverageMeter()
CLIP_Text.eval()
CLIP_Image.eval()
Text_Encoder.eval()
Image_Encoder.eval()
model.eval()
adapter.eval()
end = time.time()
logit_scale = model.logit_scale.exp()
for i, (image, label, _) in enumerate(val_loader):
image = image.cuda()
label = label.cuda()
input_source = calculate_zeroshot_weights(classnames, label, templates, CLIP_Text, Text_Encoder)
input_target = image.cuda()
target_target = label.cuda()
target_source = label.cuda()
# clip图片编码器
with torch.no_grad():
input_target_temp = CLIP_Image(input_target)
input_target_add = Image_Encoder(input_target_temp)
# output_source = adapter(input_source) * logit_scale
output_target = adapter(input_target_add) * logit_scale
output_source = output_target
# 3
loss_source = criterion(output_source[:, :len(classnames)], target_target)
loss_target = criterion(output_target[:, len(classnames):], target_target)
# measure accuracy and record loss
prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_target, topk=(1, 5))
prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5))
losses_source.update(loss_source.item(), image.size(0))
losses_target.update(loss_target.item(), image.size(0))
top1_source.update(prec1_source[0], image.size(0))
top1_target.update(prec1_target[0], image.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'LS {lossS.val:.4f} ({lossS.avg:.4f})\t'
'LT {lossT.val:.4f} ({lossT.avg:.4f})\t'
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
'top1T {top1T.val:.3f} ({top1T.avg:.3f})'.format(
epoch, i, len(val_loader), batch_time=batch_time, lossS=losses_source, lossT=losses_target,
top1S=top1_source, top1T=top1_target))
print(' * Top1@S {top1S.avg:.3f} Top1@T {top1T.avg:.3f}'
.format(top1S=top1_source, top1T=top1_target))
prec = max(top1_target.avg, top1_source.avg).item()
if prec > best_prec:
best_prec = max(top1_target.avg, top1_source.avg).item()
best_epoch = epoch
print('best_epoch', best_epoch, ' * Current_best_target@T:', best_prec)
return prec, best_epoch
def main():
args = opts()
set_seed(2023)
model, preprocess = clip.load(args.name)
model = model.cuda()
model.float()
classnames, templates, loader, train_loader,va_loader = get_dataset_loader(args, preprocess)
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder = configure_clip_encoders(args, model, 0, 1)
prepare_directories(args, CLIP_Text, CLIP_Image)
# cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
# 获取'gpt_file'文件夹下所有的.yaml文件
json_files = glob.glob('gpt_file/caltech_prompt.json')
for file_path in json_files:
# 打开并读取每个YAML文件
with open(file_path, 'r') as f:
gpt3_prompt = json.load(f)
gpt_weight = gpt_clip_classifier(classnames, gpt3_prompt, CLIP_Text, Text_Encoder)
gpt_label = torch.arange(len(classnames), device="cuda:0", dtype=torch.long)
gpt_weight, gpt_label
# 分类层
weights = set_adapter_weights(model, classnames, templates)
init_weight = torch.cat([gpt_weight.T ,gpt_weight.T], dim=1).T
# adapter = Weight_Adapter(args, classnames, weights).cuda()
adapter = Weight_Adapter(args, classnames, init_weight).cuda()
# 损失函数
criterion = nn.CrossEntropyLoss().cuda()
criterion_classifier_target = DClassifierForTarget(nClass=len(classnames)).cuda()
criterion_classifier_source = DClassifierForSource(nClass=len(classnames)).cuda()
# 为模型的每个部分定义学习率和权重衰减
# lr_adapter = 0.0001
# lr_image_encoder = 0.00001
# lr_text_encoder = 0.00001
# weight_decay = 0.00001
#caltech
lr_adapter = 0.0001
lr_image_encoder = 0.00001
lr_text_encoder = 0.000001
weight_decay = 0.00001
# ADAM_BETAS 是用于控制移动平均衰减率的元组
ADAM_BETAS = (0.9, 0.999)
# 创建 AdamW 优化器实例
optimizer = torch.optim.AdamW([
{'params': adapter.parameters(), 'lr': lr_adapter, 'weight_decay': weight_decay, 'betas': ADAM_BETAS},
{'params': Image_Encoder.parameters(), 'lr': lr_image_encoder, 'weight_decay': weight_decay,
'betas': ADAM_BETAS},
{'params': Text_Encoder.parameters(), 'lr': lr_text_encoder, 'weight_decay': weight_decay, 'betas': ADAM_BETAS}
], eps=1e-4)
# 设置CosineAnnealingLR学习率调度器
# T_max设置为epochs的数量表示在每个epoch后更新学习率
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(train_loader))
source_train_loader_batch = enumerate(train_loader)
current_epoch = 0
best_prec = 0
best_epoch = 0
while (current_epoch < args.epochs):
source_train_loader_batch, current_epoch, new_epoch_flag = train(classnames, templates,
train_loader,
source_train_loader_batch,
model,
adapter,
optimizer,
current_epoch,
args, scheduler, criterion, CLIP_Text,
Text_Encoder, CLIP_Image, Image_Encoder,
gpt_weight, gpt_label,gpt3_prompt)
if new_epoch_flag:
if (current_epoch + 1) % args.test_freq == 0 or current_epoch == 0:
if current_epoch >= args.valepoch:
prec, best_epoch = validate(best_epoch, classnames, templates, loader, model, adapter,
current_epoch, args, criterion,
best_prec,
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder)
is_best = prec > best_prec
if prec > args.valacc:
if is_best:
save_model(current_epoch, Text_Encoder, Image_Encoder, adapter, args, prec)
best_prec = max(prec, best_prec)
# 更新日志
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
logging.info(
f"Current Time: {current_time},Epoch: {current_epoch}, Accuracy: {prec}, Best: {best_prec}")
if __name__ == '__main__':
main()