init
This commit is contained in:
457
main_coral_loss_gpt3_final.py
Normal file
457
main_coral_loss_gpt3_final.py
Normal file
@@ -0,0 +1,457 @@
|
||||
import os
|
||||
import time
|
||||
from clip import clip
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import torch.optim
|
||||
from opts import opts # The options for the project
|
||||
# from trainer import validate # For the validate (test) process
|
||||
from models.DomainClassifierTarget import DClassifierForTarget
|
||||
from models.DomainClassifierSource import DClassifierForSource
|
||||
from utils.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss
|
||||
from utils.utils import prepare_directories, set_seed, get_dataset_loader, configure_clip_encoders, save_model, \
|
||||
set_adapter_weights, get_text_feature, AverageMeter, accuracy, calculate_zeroshot_weights, gpt_clip_classifier,calculate_zeroshot_weights_GPT
|
||||
from Adapter import Weight_Adapter
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
import yaml
|
||||
import json
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import glob
|
||||
|
||||
|
||||
class CustomCrossAttention(nn.Module):
|
||||
def __init__(self, feature_dim):
|
||||
super(CustomCrossAttention, self).__init__()
|
||||
self.query_projection = nn.Linear(feature_dim, feature_dim)
|
||||
self.key_projection = nn.Linear(feature_dim, feature_dim)
|
||||
self.value_projection = nn.Linear(feature_dim, feature_dim)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, text_features, image_features):
|
||||
# 假设 text_features 的 batch_size < image_features 的 batch_size
|
||||
text_batch_size = text_features.size(0)
|
||||
image_batch_size = image_features.size(0)
|
||||
|
||||
# 重复 text_features 以匹配 image_features 的 batch_size
|
||||
if text_batch_size < image_batch_size:
|
||||
repeat_times = image_batch_size // text_batch_size
|
||||
text_features = text_features.repeat(repeat_times, 1)
|
||||
|
||||
query = self.query_projection(text_features)
|
||||
key = self.key_projection(image_features)
|
||||
value = self.value_projection(image_features)
|
||||
|
||||
# 计算注意力分数
|
||||
attention_scores = torch.matmul(query, key.transpose(-2, -1))
|
||||
attention_scores = self.softmax(attention_scores)
|
||||
|
||||
# 应用注意力分数到 value 上
|
||||
attended_features = torch.matmul(attention_scores, value)
|
||||
|
||||
return attended_features
|
||||
|
||||
|
||||
# def coral_loss(source_features, target_features):
|
||||
# """
|
||||
# 计算Deep CORAL损失。
|
||||
# :param source_features: 源域特征,维度为[batch_size, feature_dim]
|
||||
# :param target_features: 目标域特征,维度为[batch_size, feature_dim]
|
||||
# :return: CORAL损失
|
||||
# """
|
||||
# d = source_features.data.shape[1] # 特征维度
|
||||
# source_mean = torch.mean(source_features, dim=0)
|
||||
# target_mean = torch.mean(target_features, dim=0)
|
||||
# source_cov = (source_features - source_mean).T @ (source_features - source_mean) / (source_features.shape[0] - 1)
|
||||
# target_cov = (target_features - target_mean).T @ (target_features - target_mean) / (target_features.shape[0] - 1)
|
||||
# coral_loss = torch.sum(torch.pow(source_cov - target_cov, 2)) # / (4*d*d)
|
||||
# return coral_loss
|
||||
def coral_loss(source_features, target_features):
|
||||
"""
|
||||
计算Deep CORAL损失。
|
||||
:param source_features: 源域特征,维度为[batch_size, feature_dim]
|
||||
:param target_features: 目标域特征,维度为[batch_size, feature_dim]
|
||||
:return: CORAL损失
|
||||
"""
|
||||
# 特征维度
|
||||
d = source_features.data.shape[1]
|
||||
|
||||
# 计算均值
|
||||
source_mean = torch.mean(source_features, dim=0)
|
||||
target_mean = torch.mean(target_features, dim=0)
|
||||
|
||||
# 计算均值差异
|
||||
# mean_diff = torch.pow(source_mean - target_mean, 2).sum().sqrt().mean()
|
||||
mean_diff = torch.pow(source_mean - target_mean, 2).mean()
|
||||
|
||||
# 计算协方差矩阵
|
||||
source_cov = (source_features - source_mean).T @ (source_features - source_mean) / (source_features.shape[0] - 1)
|
||||
target_cov = (target_features - target_mean).T @ (target_features - target_mean) / (target_features.shape[0] - 1)
|
||||
|
||||
# 计算协方差矩阵差异的平均值
|
||||
# cov_diff = torch.pow(source_cov - target_cov, 2).sum().sqrt().mean()
|
||||
cov_diff = torch.pow(source_cov - target_cov, 2).mean()
|
||||
# 返回均值差异和协方差矩阵差异的和
|
||||
total_coral_loss = mean_diff + cov_diff
|
||||
return total_coral_loss
|
||||
def shuffle_data(weights, labels):
|
||||
# 生成索引
|
||||
indices = torch.randperm(len(weights))
|
||||
# 使用索引来打乱数据和标签
|
||||
shuffled_weights = weights[indices]
|
||||
shuffled_labels = labels[indices]
|
||||
return shuffled_weights, shuffled_labels
|
||||
|
||||
|
||||
|
||||
def train(classnames, templates, source_train_loader, source_train_loader_batch, model,
|
||||
adapter, optimizer,
|
||||
epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder, gpt_weight, gpt_label,gpt3_prompt):
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
losses_classifier = AverageMeter()
|
||||
losses_G = AverageMeter()
|
||||
losses_CR = AverageMeter()
|
||||
losses_T = AverageMeter()
|
||||
top1_source = AverageMeter()
|
||||
top1_target = AverageMeter()
|
||||
|
||||
CLIP_Text.eval()
|
||||
CLIP_Image.eval()
|
||||
Text_Encoder.train()
|
||||
Image_Encoder.train()
|
||||
model.eval()
|
||||
logit_scale = model.logit_scale.exp()
|
||||
|
||||
adapter.train()
|
||||
new_epoch_flag = False
|
||||
end = time.time()
|
||||
concatenatedCELoss = ConcatenatedCELoss(num_classes=len(classnames)).cuda()
|
||||
try:
|
||||
(image, label, _) = source_train_loader_batch.__next__()[1]
|
||||
except StopIteration:
|
||||
epoch = epoch + 1
|
||||
new_epoch_flag = True
|
||||
source_train_loader_batch = enumerate(source_train_loader)
|
||||
(image, label, _) = source_train_loader_batch.__next__()[1]
|
||||
target_target = label.cuda()
|
||||
|
||||
# 自监督标签
|
||||
label_self_supervised = label.cuda()
|
||||
indices = torch.randperm(len(label))
|
||||
target_source = label[indices].cuda()
|
||||
|
||||
|
||||
|
||||
# target_source = label.cuda()
|
||||
input_target = image.cuda()
|
||||
# input_source = calculate_zeroshot_weights(classnames, target_source, templates, CLIP_Text, Text_Encoder)
|
||||
input_source = calculate_zeroshot_weights_GPT(classnames, target_source, templates, CLIP_Text, Text_Encoder,gpt3_prompt)
|
||||
gpt_weight, gpt_label=shuffle_data(gpt_weight, gpt_label)
|
||||
# input_source = torch.cat((
|
||||
# input_source, gpt_weight
|
||||
# ), dim=0)
|
||||
# target_source = torch.cat((
|
||||
# target_source, gpt_label
|
||||
# ), dim=0)
|
||||
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
target_target_temp = target_target + len(classnames)
|
||||
target_source_temp = target_source + len(classnames)
|
||||
target_target_temp = target_target_temp.cuda()
|
||||
|
||||
# clip图片编码器
|
||||
with torch.no_grad():
|
||||
input_target_temp = CLIP_Image(input_target)
|
||||
|
||||
input_target_add = Image_Encoder(input_target_temp)
|
||||
|
||||
# 计算CORAL损失
|
||||
|
||||
# 总损失
|
||||
# 文本直接输入全连接层
|
||||
output_source = adapter(input_source) * logit_scale
|
||||
# 图片直接输入全连接层
|
||||
output_target = adapter(input_target_add) * logit_scale
|
||||
|
||||
# self_input_source = calculate_zeroshot_weights(classnames, label_self_supervised, templates, CLIP_Text,
|
||||
# Text_Encoder)
|
||||
self_input_source = calculate_zeroshot_weights_GPT(classnames, label_self_supervised, templates, CLIP_Text,
|
||||
Text_Encoder,gpt3_prompt)
|
||||
# input_source = calculate_zeroshot_weights_GPT(classnames, target_source, templates, CLIP_Text, Text_Encoder,gpt3_prompt)
|
||||
|
||||
|
||||
# 计算MMD损失
|
||||
# mmd_loss_val = mmd_loss(self_input_source, input_target_add)
|
||||
# lambda_mmd=1000
|
||||
# mmd_loss_val =lambda_mmd*mmd_loss_val
|
||||
# 总损失
|
||||
coral_loss_value = coral_loss(self_input_source, input_target_add)
|
||||
lambda_coral = 1
|
||||
loss_1 = lambda_coral * coral_loss_value
|
||||
# 自监督文本输入全连接层
|
||||
# self_output_source = adapter(self_input_source)
|
||||
# self_output_source = F.normalize(self_output_source[:,:len(classnames)])
|
||||
self_output_source = F.normalize(self_input_source)
|
||||
|
||||
# 自监督图像特征
|
||||
# self_output_target = output_target / logit_scale
|
||||
# self_output_target = F.normalize(self_output_target[:,len(classnames):])
|
||||
self_output_target = F.normalize(input_target_add)
|
||||
# # 构造自监督标签0-255
|
||||
self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long)
|
||||
logits_per_image = logit_scale * self_output_target @ self_output_source.T
|
||||
logits_per_text = logit_scale * self_output_source @ self_output_target.T
|
||||
loss_self_supervised_1 = (
|
||||
F.cross_entropy(logits_per_image, self_supervised_labels) +
|
||||
F.cross_entropy(logits_per_text, self_supervised_labels)
|
||||
) / 2
|
||||
# 自监督文本输入全连接层
|
||||
self_output_source = adapter(self_input_source)
|
||||
self_output_source = F.normalize(self_output_source[:, :len(classnames)])
|
||||
# self_output_source = F.normalize(self_input_source)
|
||||
|
||||
# 自监督图像特征
|
||||
self_output_target = output_target / logit_scale
|
||||
self_output_target = F.normalize(self_output_target[:, len(classnames):])
|
||||
# self_output_target = F.normalize(input_target_add)
|
||||
# # 构造自监督标签0-255
|
||||
self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long)
|
||||
logits_per_image = logit_scale * self_output_target @ self_output_source.T
|
||||
logits_per_text = logit_scale * self_output_source @ self_output_target.T
|
||||
loss_self_supervised_2 = (
|
||||
F.cross_entropy(logits_per_image, self_supervised_labels) +
|
||||
F.cross_entropy(logits_per_text, self_supervised_labels)
|
||||
) / 2
|
||||
|
||||
loss_self_supervised =loss_self_supervised_2 #loss_self_supervised_2 + loss_self_supervised_1
|
||||
|
||||
# 有监督分类的交叉熵损失
|
||||
loss_task_s_Cs = criterion(output_source[:, :len(classnames)], target_source)
|
||||
loss_task_s_Ct = criterion(output_target[:, len(classnames):], target_target)
|
||||
|
||||
# 对于源域数据,它希望让分类器上半部分所占的概率尽可能大,对于目标域数据,它希望让分类器下半部分所占的概率尽可能大。
|
||||
loss_domain_st_Cst_part1 = criterion(output_source, target_source)
|
||||
loss_domain_st_Cst_part2 = criterion(output_target, target_target_temp)
|
||||
|
||||
# 类级别混淆
|
||||
loss_category_st_G = 0.5 * criterion(output_target, target_target) + 0.5 * criterion(output_source,
|
||||
target_source_temp)
|
||||
# 域级别混淆
|
||||
# loss_domain_st_G = 0.5 * criterion_classifier_target(output_source) + 0.5 * criterion_classifier_source(
|
||||
# output_target)
|
||||
|
||||
lam = 1 # 2 / (1 + math.exp(-1 * 10 * epoch / args.epochs)) - 1
|
||||
# if(epoch<30):
|
||||
# self_lam= 5
|
||||
# else:
|
||||
self_lam = 0
|
||||
|
||||
loss_confusion_target = concatenatedCELoss(output_target)
|
||||
loss_classifier = loss_task_s_Cs + loss_task_s_Ct + loss_domain_st_Cst_part1 + loss_domain_st_Cst_part2
|
||||
loss_G = loss_category_st_G + lam * loss_confusion_target
|
||||
loss_T = loss_G + loss_classifier + self_lam * loss_self_supervised + lambda_coral * coral_loss_value#+ mmd_loss_val
|
||||
|
||||
prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_source, topk=(1, 5))
|
||||
prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5))
|
||||
losses_CR.update(loss_G.item(), input_source.size(0))
|
||||
losses_classifier.update(loss_classifier.item(), input_source.size(0))
|
||||
losses_G.update(loss_G.item(), input_source.size(0))
|
||||
losses_CR.update(coral_loss_value.item(), input_source.size(0))
|
||||
losses_T.update(loss_T.item(), input_source.size(0))
|
||||
top1_source.update(prec1_source[0], input_source.size(0))
|
||||
top1_target.update(prec1_target[0], input_source.size(0))
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss_T.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
batch_time.update(time.time() - end)
|
||||
if (epoch + 1) % args.print_freq == 0 or epoch == 0:
|
||||
print('Train: [{0}/{1}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
||||
'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t'
|
||||
'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t'
|
||||
'Loss@CR {loss_cr.val:.4f} ({loss_cr.avg:.4f})\t'
|
||||
'Loss@T {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
|
||||
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
|
||||
'top1T {top1T.val:.3f} ({top1T.avg:.3f})\t'.format(
|
||||
epoch, args.epochs, batch_time=batch_time,
|
||||
data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, loss_cr=losses_CR,loss_t=losses_T, top1S=top1_source,
|
||||
top1T=top1_target))
|
||||
return source_train_loader_batch, epoch, new_epoch_flag
|
||||
|
||||
|
||||
def validate(best_epoch, classnames, templates, val_loader, model, adapter, epoch, args, criterion, best_prec,
|
||||
CLIP_Text,
|
||||
Text_Encoder, CLIP_Image,
|
||||
Image_Encoder):
|
||||
batch_time = AverageMeter()
|
||||
losses_source = AverageMeter()
|
||||
losses_target = AverageMeter()
|
||||
top1_source = AverageMeter()
|
||||
top1_target = AverageMeter()
|
||||
|
||||
CLIP_Text.eval()
|
||||
CLIP_Image.eval()
|
||||
Text_Encoder.eval()
|
||||
Image_Encoder.eval()
|
||||
|
||||
model.eval()
|
||||
adapter.eval()
|
||||
end = time.time()
|
||||
logit_scale = model.logit_scale.exp()
|
||||
|
||||
for i, (image, label, _) in enumerate(val_loader):
|
||||
image = image.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
input_source = calculate_zeroshot_weights(classnames, label, templates, CLIP_Text, Text_Encoder)
|
||||
input_target = image.cuda()
|
||||
target_target = label.cuda()
|
||||
target_source = label.cuda()
|
||||
|
||||
# clip图片编码器
|
||||
with torch.no_grad():
|
||||
input_target_temp = CLIP_Image(input_target)
|
||||
input_target_add = Image_Encoder(input_target_temp)
|
||||
# output_source = adapter(input_source) * logit_scale
|
||||
output_target = adapter(input_target_add) * logit_scale
|
||||
output_source = output_target
|
||||
|
||||
# 3
|
||||
loss_source = criterion(output_source[:, :len(classnames)], target_target)
|
||||
loss_target = criterion(output_target[:, len(classnames):], target_target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_target, topk=(1, 5))
|
||||
prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5))
|
||||
|
||||
losses_source.update(loss_source.item(), image.size(0))
|
||||
losses_target.update(loss_target.item(), image.size(0))
|
||||
|
||||
top1_source.update(prec1_source[0], image.size(0))
|
||||
top1_target.update(prec1_target[0], image.size(0))
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print('Test: [{0}][{1}/{2}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||
'LS {lossS.val:.4f} ({lossS.avg:.4f})\t'
|
||||
'LT {lossT.val:.4f} ({lossT.avg:.4f})\t'
|
||||
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
|
||||
'top1T {top1T.val:.3f} ({top1T.avg:.3f})'.format(
|
||||
epoch, i, len(val_loader), batch_time=batch_time, lossS=losses_source, lossT=losses_target,
|
||||
top1S=top1_source, top1T=top1_target))
|
||||
print(' * Top1@S {top1S.avg:.3f} Top1@T {top1T.avg:.3f}'
|
||||
.format(top1S=top1_source, top1T=top1_target))
|
||||
prec = max(top1_target.avg, top1_source.avg).item()
|
||||
if prec > best_prec:
|
||||
best_prec = max(top1_target.avg, top1_source.avg).item()
|
||||
best_epoch = epoch
|
||||
print('best_epoch', best_epoch, ' * Current_best_target@T:', best_prec)
|
||||
|
||||
return prec, best_epoch
|
||||
|
||||
|
||||
def main():
|
||||
args = opts()
|
||||
|
||||
set_seed(2023)
|
||||
model, preprocess = clip.load(args.name)
|
||||
model = model.cuda()
|
||||
model.float()
|
||||
classnames, templates, loader, train_loader,va_loader = get_dataset_loader(args, preprocess)
|
||||
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder = configure_clip_encoders(args, model, 0, 1)
|
||||
|
||||
prepare_directories(args, CLIP_Text, CLIP_Image)
|
||||
# cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
||||
# 获取'gpt_file'文件夹下所有的.yaml文件
|
||||
|
||||
json_files = glob.glob('gpt_file/caltech_prompt.json')
|
||||
|
||||
for file_path in json_files:
|
||||
# 打开并读取每个YAML文件
|
||||
with open(file_path, 'r') as f:
|
||||
gpt3_prompt = json.load(f)
|
||||
gpt_weight = gpt_clip_classifier(classnames, gpt3_prompt, CLIP_Text, Text_Encoder)
|
||||
gpt_label = torch.arange(len(classnames), device="cuda:0", dtype=torch.long)
|
||||
gpt_weight, gpt_label
|
||||
# 分类层
|
||||
weights = set_adapter_weights(model, classnames, templates)
|
||||
init_weight = torch.cat([gpt_weight.T ,gpt_weight.T], dim=1).T
|
||||
# adapter = Weight_Adapter(args, classnames, weights).cuda()
|
||||
adapter = Weight_Adapter(args, classnames, init_weight).cuda()
|
||||
|
||||
# 损失函数
|
||||
criterion = nn.CrossEntropyLoss().cuda()
|
||||
criterion_classifier_target = DClassifierForTarget(nClass=len(classnames)).cuda()
|
||||
criterion_classifier_source = DClassifierForSource(nClass=len(classnames)).cuda()
|
||||
|
||||
# 为模型的每个部分定义学习率和权重衰减
|
||||
# lr_adapter = 0.0001
|
||||
# lr_image_encoder = 0.00001
|
||||
# lr_text_encoder = 0.00001
|
||||
# weight_decay = 0.00001
|
||||
#caltech
|
||||
lr_adapter = 0.0001
|
||||
lr_image_encoder = 0.00001
|
||||
lr_text_encoder = 0.000001
|
||||
weight_decay = 0.00001
|
||||
# ADAM_BETAS 是用于控制移动平均衰减率的元组
|
||||
ADAM_BETAS = (0.9, 0.999)
|
||||
|
||||
# 创建 AdamW 优化器实例
|
||||
optimizer = torch.optim.AdamW([
|
||||
{'params': adapter.parameters(), 'lr': lr_adapter, 'weight_decay': weight_decay, 'betas': ADAM_BETAS},
|
||||
{'params': Image_Encoder.parameters(), 'lr': lr_image_encoder, 'weight_decay': weight_decay,
|
||||
'betas': ADAM_BETAS},
|
||||
{'params': Text_Encoder.parameters(), 'lr': lr_text_encoder, 'weight_decay': weight_decay, 'betas': ADAM_BETAS}
|
||||
], eps=1e-4)
|
||||
|
||||
# 设置CosineAnnealingLR学习率调度器
|
||||
# T_max设置为epochs的数量,表示在每个epoch后更新学习率
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(train_loader))
|
||||
source_train_loader_batch = enumerate(train_loader)
|
||||
|
||||
current_epoch = 0
|
||||
best_prec = 0
|
||||
best_epoch = 0
|
||||
while (current_epoch < args.epochs):
|
||||
source_train_loader_batch, current_epoch, new_epoch_flag = train(classnames, templates,
|
||||
train_loader,
|
||||
source_train_loader_batch,
|
||||
model,
|
||||
adapter,
|
||||
optimizer,
|
||||
current_epoch,
|
||||
args, scheduler, criterion, CLIP_Text,
|
||||
Text_Encoder, CLIP_Image, Image_Encoder,
|
||||
gpt_weight, gpt_label,gpt3_prompt)
|
||||
if new_epoch_flag:
|
||||
if (current_epoch + 1) % args.test_freq == 0 or current_epoch == 0:
|
||||
if current_epoch >= args.valepoch:
|
||||
prec, best_epoch = validate(best_epoch, classnames, templates, loader, model, adapter,
|
||||
current_epoch, args, criterion,
|
||||
best_prec,
|
||||
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder)
|
||||
is_best = prec > best_prec
|
||||
if prec > args.valacc:
|
||||
if is_best:
|
||||
save_model(current_epoch, Text_Encoder, Image_Encoder, adapter, args, prec)
|
||||
best_prec = max(prec, best_prec)
|
||||
# 更新日志
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
logging.info(
|
||||
f"Current Time: {current_time},Epoch: {current_epoch}, Accuracy: {prec}, Best: {best_prec}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user