295 lines
11 KiB
Python
295 lines
11 KiB
Python
import time
|
||
from clip import clip
|
||
import torch.nn as nn
|
||
import torch.optim
|
||
from opts import opts # The options for the project
|
||
# from trainer import validate # For the validate (test) process
|
||
from models.DomainClassifierTarget import DClassifierForTarget
|
||
from models.DomainClassifierSource import DClassifierForSource
|
||
from utils.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss
|
||
from utils.utils import prepare_directories, set_seed, get_dataset_loader, configure_clip_encoders, save_model, \
|
||
set_adapter_weights_single, get_text_feature, AverageMeter, accuracy, calculate_zeroshot_weights
|
||
from Adapter import Adapter
|
||
import logging
|
||
import torch.nn.functional as F
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
|
||
class CompactBilinearPooling(nn.Module):
|
||
def __init__(self, input_dim, output_dim):
|
||
super(CompactBilinearPooling, self).__init__()
|
||
self.input_dim = input_dim
|
||
self.output_dim = output_dim
|
||
# 随机生成哈希参数
|
||
self.hashcode = torch.randint(0, output_dim, (input_dim,), dtype=torch.long).cuda()
|
||
self.sign = (torch.randint(0, 2, (input_dim,)) * 2 - 1).cuda()
|
||
self.sign = self.sign.float().cuda()
|
||
|
||
def count_sketch(self, x, hashcode, sign):
|
||
batch_size, input_dim = x.shape
|
||
output = x.new_zeros(batch_size, self.output_dim)
|
||
for i in range(input_dim):
|
||
output[:, hashcode[i]] += sign[i] * x[:, i]
|
||
return output
|
||
|
||
def forward(self, x1, x2):
|
||
# 应用Count Sketch
|
||
sketch_x1 = self.count_sketch(x1, self.hashcode, self.sign).cuda()
|
||
sketch_x2 = self.count_sketch(x2, self.hashcode, self.sign).cuda()
|
||
|
||
# 应用FFT变换
|
||
fft_x1 = torch.fft.rfft(sketch_x1, n=self.output_dim, dim=1).cuda()
|
||
fft_x2 = torch.fft.rfft(sketch_x2, n=self.output_dim, dim=1).cuda()
|
||
|
||
# 计算FFT的点积并应用逆FFT变换
|
||
ifft_result = torch.fft.irfft(fft_x1 * fft_x2, n=self.output_dim, dim=1).cuda()
|
||
|
||
return ifft_result
|
||
|
||
|
||
|
||
def train(classnames, templates, source_train_loader, source_train_loader_batch, model,
|
||
adapter, optimizer,
|
||
epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder,cbp_layer):
|
||
batch_time = AverageMeter()
|
||
data_time = AverageMeter()
|
||
losses_classifier = AverageMeter()
|
||
losses_G = AverageMeter()
|
||
losses_T = AverageMeter()
|
||
top1_source = AverageMeter()
|
||
top1_target = AverageMeter()
|
||
logit_scale = model.logit_scale.exp()
|
||
model.eval()
|
||
CLIP_Text.eval()
|
||
CLIP_Image.eval()
|
||
Text_Encoder.train()
|
||
Image_Encoder.train()
|
||
cbp_layer.train()
|
||
adapter.train()
|
||
|
||
new_epoch_flag = False
|
||
end = time.time()
|
||
try:
|
||
(image, label, _) = source_train_loader_batch.__next__()[1]
|
||
except StopIteration:
|
||
epoch = epoch + 1
|
||
new_epoch_flag = True
|
||
source_train_loader_batch = enumerate(source_train_loader)
|
||
(image, label, _) = source_train_loader_batch.__next__()[1]
|
||
target_target = label.cuda()
|
||
label=label.cuda()
|
||
|
||
|
||
# target_source = label.cuda()
|
||
input_target = image.cuda()
|
||
input_source = calculate_zeroshot_weights(classnames, label, templates, CLIP_Text, Text_Encoder)
|
||
data_time.update(time.time() - end)
|
||
|
||
# clip图片编码器
|
||
with torch.no_grad():
|
||
input_target_temp = CLIP_Image(input_target)
|
||
|
||
input_target_add = Image_Encoder(input_target_temp)
|
||
|
||
# 使用CBP层融合图像和文本特征
|
||
fused_features = cbp_layer(input_target_add, input_source)
|
||
|
||
# 文本直接输入全连接层
|
||
output = adapter(fused_features) * logit_scale
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
# 有监督分类的交叉熵损失
|
||
loss= criterion(output, label)
|
||
|
||
prec, _ = accuracy(output, label, topk=(1, 5))
|
||
|
||
|
||
|
||
losses_T.update(loss.item(), input_source.size(0))
|
||
top1_source.update(prec[0], input_source.size(0))
|
||
|
||
optimizer.zero_grad()
|
||
loss.backward()
|
||
optimizer.step()
|
||
scheduler.step()
|
||
|
||
batch_time.update(time.time() - end)
|
||
if (epoch + 1) % args.print_freq == 0 or epoch == 0:
|
||
print('Train: [{0}/{1}]\t'
|
||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
||
'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t'
|
||
'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t'
|
||
'Loss@T {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
|
||
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
|
||
'top1T {top1T.val:.3f} ({top1T.avg:.3f})\t'.format(
|
||
epoch, args.epochs, batch_time=batch_time,
|
||
data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, loss_t=losses_T, top1S=top1_source,
|
||
top1T=top1_target))
|
||
return source_train_loader_batch, epoch, new_epoch_flag
|
||
|
||
|
||
def validate(best_epoch,classnames, templates, val_loader, model, adapter, epoch, args, criterion, best_prec, CLIP_Text,
|
||
Text_Encoder, CLIP_Image,
|
||
Image_Encoder,cbp_layer):
|
||
batch_time = AverageMeter()
|
||
losses_source = AverageMeter()
|
||
losses_target = AverageMeter()
|
||
top1_source = AverageMeter()
|
||
top1_target = AverageMeter()
|
||
|
||
CLIP_Text.eval()
|
||
CLIP_Image.eval()
|
||
Text_Encoder.eval()
|
||
Image_Encoder.eval()
|
||
cbp_layer.eval()
|
||
|
||
model.eval()
|
||
adapter.eval()
|
||
end = time.time()
|
||
logit_scale = model.logit_scale.exp()
|
||
|
||
for i, (image, label, _) in enumerate(val_loader):
|
||
image = image.cuda()
|
||
label = label.cuda()
|
||
|
||
input_source = calculate_zeroshot_weights(classnames, label, templates, CLIP_Text, Text_Encoder)
|
||
input_target = image.cuda()
|
||
target_target = label.cuda()
|
||
target_source = label.cuda()
|
||
|
||
# clip图片编码器
|
||
with torch.no_grad():
|
||
input_target_temp = CLIP_Image(input_target)
|
||
input_target_add = Image_Encoder(input_target_temp)
|
||
fused_features=cbp_layer(input_target_add,input_target_add)
|
||
# output_source = adapter(input_source) * logit_scale
|
||
output = adapter(fused_features) * logit_scale
|
||
# 3
|
||
loss_source = criterion(output, target_target)
|
||
|
||
|
||
# measure accuracy and record loss
|
||
prec, _ = accuracy(output, target_target, topk=(1, 5))
|
||
|
||
|
||
losses_source.update(loss_source.item(), image.size(0))
|
||
|
||
|
||
top1_source.update(prec[0], image.size(0))
|
||
# measure elapsed time
|
||
batch_time.update(time.time() - end)
|
||
end = time.time()
|
||
|
||
if i % args.print_freq == 0:
|
||
print('Test: [{0}][{1}/{2}]\t'
|
||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||
'LS {lossS.val:.4f} ({lossS.avg:.4f})\t'
|
||
'LT {lossT.val:.4f} ({lossT.avg:.4f})\t'
|
||
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
|
||
'top1T {top1T.val:.3f} ({top1T.avg:.3f})'.format(
|
||
epoch, i, len(val_loader), batch_time=batch_time, lossS=losses_source, lossT=losses_target,
|
||
top1S=top1_source, top1T=top1_target))
|
||
print(' * Top1@S {top1S.avg:.3f} Top1@T {top1T.avg:.3f}'
|
||
.format(top1S=top1_source, top1T=top1_target))
|
||
prec = max(top1_target.avg, top1_source.avg).item()
|
||
if prec > best_prec:
|
||
best_prec = max(top1_target.avg, top1_source.avg).item()
|
||
best_epoch = epoch
|
||
print('best_epoch', best_epoch, ' * Current_best_target@T:', best_prec)
|
||
|
||
return prec,best_epoch
|
||
|
||
|
||
def main():
|
||
args = opts()
|
||
|
||
set_seed(2023)
|
||
model, preprocess = clip.load(args.name)
|
||
model = model.cuda()
|
||
model.float( )
|
||
classnames, templates, loader, train_loader = get_dataset_loader(args, preprocess)
|
||
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder = configure_clip_encoders(args, model, 0, 1)
|
||
|
||
prepare_directories(args, CLIP_Text, CLIP_Image)
|
||
|
||
# 分类层
|
||
weights = set_adapter_weights_single(model, classnames, templates)
|
||
adapter = Adapter(args, classnames, weights).cuda()
|
||
|
||
#cbp
|
||
# 定义模型参数
|
||
input_dim = 1024
|
||
output_dim = 1024 # 或者其他根据需要设置的值
|
||
|
||
# 实例化CBP模块
|
||
cbp_layer = CompactBilinearPooling(input_dim, output_dim)
|
||
|
||
# 损失函数
|
||
criterion = nn.CrossEntropyLoss().cuda()
|
||
criterion_classifier_target = DClassifierForTarget(nClass=len(classnames)).cuda()
|
||
criterion_classifier_source = DClassifierForSource(nClass=len(classnames)).cuda()
|
||
|
||
# 为模型的每个部分定义学习率和权重衰减
|
||
lr_adapter = 0.0001
|
||
lr_image_encoder = 0.00001
|
||
lr_text_encoder = 0.00001
|
||
weight_decay = 0.00001
|
||
# ADAM_BETAS 是用于控制移动平均衰减率的元组
|
||
ADAM_BETAS = (0.9, 0.999)
|
||
|
||
|
||
# 创建 AdamW 优化器实例
|
||
optimizer = torch.optim.AdamW([
|
||
{'params': adapter.parameters(), 'lr': lr_adapter, 'weight_decay': weight_decay, 'betas': ADAM_BETAS},
|
||
{'params': Image_Encoder.parameters(), 'lr': lr_image_encoder, 'weight_decay': weight_decay,
|
||
'betas': ADAM_BETAS},
|
||
{'params': Text_Encoder.parameters(), 'lr': lr_text_encoder, 'weight_decay': weight_decay, 'betas': ADAM_BETAS}
|
||
], eps=1e-4)
|
||
|
||
# 设置CosineAnnealingLR学习率调度器
|
||
# T_max设置为epochs的数量,表示在每个epoch后更新学习率
|
||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(train_loader))
|
||
source_train_loader_batch = enumerate(train_loader)
|
||
|
||
current_epoch = 0
|
||
best_prec = 0
|
||
best_epoch=0
|
||
while (current_epoch < args.epochs):
|
||
source_train_loader_batch, current_epoch, new_epoch_flag = train(classnames, templates,
|
||
train_loader,
|
||
source_train_loader_batch,
|
||
model,
|
||
adapter,
|
||
optimizer,
|
||
current_epoch,
|
||
args, scheduler, criterion, CLIP_Text,
|
||
Text_Encoder, CLIP_Image, Image_Encoder,cbp_layer)
|
||
if new_epoch_flag:
|
||
if (current_epoch + 1) % args.test_freq == 0 or current_epoch == 0:
|
||
if current_epoch >= args.valepoch:
|
||
prec,best_epoch = validate(best_epoch,classnames, templates, loader, model, adapter, current_epoch, args, criterion,
|
||
best_prec,
|
||
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder,cbp_layer)
|
||
is_best = prec > best_prec
|
||
if prec > args.valacc:
|
||
if is_best:
|
||
save_model(current_epoch, Text_Encoder, Image_Encoder, adapter,args, prec)
|
||
best_prec = max(prec, best_prec)
|
||
# 更新日志
|
||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||
logging.info(
|
||
f"Current Time: {current_time},Epoch: {current_epoch}, Accuracy: {prec}, Best: {best_prec}")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|