Files
clip-symnets/main_cbp.py
2024-05-21 19:41:56 +08:00

295 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import time
from clip import clip
import torch.nn as nn
import torch.optim
from opts import opts # The options for the project
# from trainer import validate # For the validate (test) process
from models.DomainClassifierTarget import DClassifierForTarget
from models.DomainClassifierSource import DClassifierForSource
from utils.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss
from utils.utils import prepare_directories, set_seed, get_dataset_loader, configure_clip_encoders, save_model, \
set_adapter_weights_single, get_text_feature, AverageMeter, accuracy, calculate_zeroshot_weights
from Adapter import Adapter
import logging
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
class CompactBilinearPooling(nn.Module):
def __init__(self, input_dim, output_dim):
super(CompactBilinearPooling, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
# 随机生成哈希参数
self.hashcode = torch.randint(0, output_dim, (input_dim,), dtype=torch.long).cuda()
self.sign = (torch.randint(0, 2, (input_dim,)) * 2 - 1).cuda()
self.sign = self.sign.float().cuda()
def count_sketch(self, x, hashcode, sign):
batch_size, input_dim = x.shape
output = x.new_zeros(batch_size, self.output_dim)
for i in range(input_dim):
output[:, hashcode[i]] += sign[i] * x[:, i]
return output
def forward(self, x1, x2):
# 应用Count Sketch
sketch_x1 = self.count_sketch(x1, self.hashcode, self.sign).cuda()
sketch_x2 = self.count_sketch(x2, self.hashcode, self.sign).cuda()
# 应用FFT变换
fft_x1 = torch.fft.rfft(sketch_x1, n=self.output_dim, dim=1).cuda()
fft_x2 = torch.fft.rfft(sketch_x2, n=self.output_dim, dim=1).cuda()
# 计算FFT的点积并应用逆FFT变换
ifft_result = torch.fft.irfft(fft_x1 * fft_x2, n=self.output_dim, dim=1).cuda()
return ifft_result
def train(classnames, templates, source_train_loader, source_train_loader_batch, model,
adapter, optimizer,
epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder,cbp_layer):
batch_time = AverageMeter()
data_time = AverageMeter()
losses_classifier = AverageMeter()
losses_G = AverageMeter()
losses_T = AverageMeter()
top1_source = AverageMeter()
top1_target = AverageMeter()
logit_scale = model.logit_scale.exp()
model.eval()
CLIP_Text.eval()
CLIP_Image.eval()
Text_Encoder.train()
Image_Encoder.train()
cbp_layer.train()
adapter.train()
new_epoch_flag = False
end = time.time()
try:
(image, label, _) = source_train_loader_batch.__next__()[1]
except StopIteration:
epoch = epoch + 1
new_epoch_flag = True
source_train_loader_batch = enumerate(source_train_loader)
(image, label, _) = source_train_loader_batch.__next__()[1]
target_target = label.cuda()
label=label.cuda()
# target_source = label.cuda()
input_target = image.cuda()
input_source = calculate_zeroshot_weights(classnames, label, templates, CLIP_Text, Text_Encoder)
data_time.update(time.time() - end)
# clip图片编码器
with torch.no_grad():
input_target_temp = CLIP_Image(input_target)
input_target_add = Image_Encoder(input_target_temp)
# 使用CBP层融合图像和文本特征
fused_features = cbp_layer(input_target_add, input_source)
# 文本直接输入全连接层
output = adapter(fused_features) * logit_scale
# 有监督分类的交叉熵损失
loss= criterion(output, label)
prec, _ = accuracy(output, label, topk=(1, 5))
losses_T.update(loss.item(), input_source.size(0))
top1_source.update(prec[0], input_source.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
batch_time.update(time.time() - end)
if (epoch + 1) % args.print_freq == 0 or epoch == 0:
print('Train: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t'
'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t'
'Loss@T {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
'top1T {top1T.val:.3f} ({top1T.avg:.3f})\t'.format(
epoch, args.epochs, batch_time=batch_time,
data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, loss_t=losses_T, top1S=top1_source,
top1T=top1_target))
return source_train_loader_batch, epoch, new_epoch_flag
def validate(best_epoch,classnames, templates, val_loader, model, adapter, epoch, args, criterion, best_prec, CLIP_Text,
Text_Encoder, CLIP_Image,
Image_Encoder,cbp_layer):
batch_time = AverageMeter()
losses_source = AverageMeter()
losses_target = AverageMeter()
top1_source = AverageMeter()
top1_target = AverageMeter()
CLIP_Text.eval()
CLIP_Image.eval()
Text_Encoder.eval()
Image_Encoder.eval()
cbp_layer.eval()
model.eval()
adapter.eval()
end = time.time()
logit_scale = model.logit_scale.exp()
for i, (image, label, _) in enumerate(val_loader):
image = image.cuda()
label = label.cuda()
input_source = calculate_zeroshot_weights(classnames, label, templates, CLIP_Text, Text_Encoder)
input_target = image.cuda()
target_target = label.cuda()
target_source = label.cuda()
# clip图片编码器
with torch.no_grad():
input_target_temp = CLIP_Image(input_target)
input_target_add = Image_Encoder(input_target_temp)
fused_features=cbp_layer(input_target_add,input_target_add)
# output_source = adapter(input_source) * logit_scale
output = adapter(fused_features) * logit_scale
# 3
loss_source = criterion(output, target_target)
# measure accuracy and record loss
prec, _ = accuracy(output, target_target, topk=(1, 5))
losses_source.update(loss_source.item(), image.size(0))
top1_source.update(prec[0], image.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'LS {lossS.val:.4f} ({lossS.avg:.4f})\t'
'LT {lossT.val:.4f} ({lossT.avg:.4f})\t'
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
'top1T {top1T.val:.3f} ({top1T.avg:.3f})'.format(
epoch, i, len(val_loader), batch_time=batch_time, lossS=losses_source, lossT=losses_target,
top1S=top1_source, top1T=top1_target))
print(' * Top1@S {top1S.avg:.3f} Top1@T {top1T.avg:.3f}'
.format(top1S=top1_source, top1T=top1_target))
prec = max(top1_target.avg, top1_source.avg).item()
if prec > best_prec:
best_prec = max(top1_target.avg, top1_source.avg).item()
best_epoch = epoch
print('best_epoch', best_epoch, ' * Current_best_target@T:', best_prec)
return prec,best_epoch
def main():
args = opts()
set_seed(2023)
model, preprocess = clip.load(args.name)
model = model.cuda()
model.float( )
classnames, templates, loader, train_loader = get_dataset_loader(args, preprocess)
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder = configure_clip_encoders(args, model, 0, 1)
prepare_directories(args, CLIP_Text, CLIP_Image)
# 分类层
weights = set_adapter_weights_single(model, classnames, templates)
adapter = Adapter(args, classnames, weights).cuda()
#cbp
# 定义模型参数
input_dim = 1024
output_dim = 1024 # 或者其他根据需要设置的值
# 实例化CBP模块
cbp_layer = CompactBilinearPooling(input_dim, output_dim)
# 损失函数
criterion = nn.CrossEntropyLoss().cuda()
criterion_classifier_target = DClassifierForTarget(nClass=len(classnames)).cuda()
criterion_classifier_source = DClassifierForSource(nClass=len(classnames)).cuda()
# 为模型的每个部分定义学习率和权重衰减
lr_adapter = 0.0001
lr_image_encoder = 0.00001
lr_text_encoder = 0.00001
weight_decay = 0.00001
# ADAM_BETAS 是用于控制移动平均衰减率的元组
ADAM_BETAS = (0.9, 0.999)
# 创建 AdamW 优化器实例
optimizer = torch.optim.AdamW([
{'params': adapter.parameters(), 'lr': lr_adapter, 'weight_decay': weight_decay, 'betas': ADAM_BETAS},
{'params': Image_Encoder.parameters(), 'lr': lr_image_encoder, 'weight_decay': weight_decay,
'betas': ADAM_BETAS},
{'params': Text_Encoder.parameters(), 'lr': lr_text_encoder, 'weight_decay': weight_decay, 'betas': ADAM_BETAS}
], eps=1e-4)
# 设置CosineAnnealingLR学习率调度器
# T_max设置为epochs的数量表示在每个epoch后更新学习率
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(train_loader))
source_train_loader_batch = enumerate(train_loader)
current_epoch = 0
best_prec = 0
best_epoch=0
while (current_epoch < args.epochs):
source_train_loader_batch, current_epoch, new_epoch_flag = train(classnames, templates,
train_loader,
source_train_loader_batch,
model,
adapter,
optimizer,
current_epoch,
args, scheduler, criterion, CLIP_Text,
Text_Encoder, CLIP_Image, Image_Encoder,cbp_layer)
if new_epoch_flag:
if (current_epoch + 1) % args.test_freq == 0 or current_epoch == 0:
if current_epoch >= args.valepoch:
prec,best_epoch = validate(best_epoch,classnames, templates, loader, model, adapter, current_epoch, args, criterion,
best_prec,
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder,cbp_layer)
is_best = prec > best_prec
if prec > args.valacc:
if is_best:
save_model(current_epoch, Text_Encoder, Image_Encoder, adapter,args, prec)
best_prec = max(prec, best_prec)
# 更新日志
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
logging.info(
f"Current Time: {current_time},Epoch: {current_epoch}, Accuracy: {prec}, Best: {best_prec}")
if __name__ == '__main__':
main()