This commit is contained in:
2024-05-21 19:41:56 +08:00
commit ca67205608
217 changed files with 201004 additions and 0 deletions

207
main_tip_adapter.py Normal file
View File

@@ -0,0 +1,207 @@
import os
import random
import argparse
import yaml
from tqdm import tqdm
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision.transforms as transforms
from datasets import build_dataset
from datasets.utils import build_data_loader
import clip
from utils import *
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--config', dest='config', help='settings of Tip-Adapter in yaml format')
args = parser.parse_args()
return args
def run_tip_adapter(cfg, cache_keys, cache_values, val_features, val_labels, test_features, test_labels, clip_weights):
print("\n-------- Searching hyperparameters on the val set. --------")
# Zero-shot CLIP
clip_logits = 100. * val_features @ clip_weights
acc = cls_acc(clip_logits, val_labels)
print("\n**** Zero-shot CLIP's val accuracy: {:.2f}. ****\n".format(acc))
# Tip-Adapter
beta, alpha = cfg['init_beta'], cfg['init_alpha']
affinity = val_features @ cache_keys
cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values
tip_logits = clip_logits + cache_logits * alpha
acc = cls_acc(tip_logits, val_labels)
print("**** Tip-Adapter's val accuracy: {:.2f}. ****\n".format(acc))
# Search Hyperparameters
best_beta, best_alpha = search_hp(cfg, cache_keys, cache_values, val_features, val_labels, clip_weights)
print("\n-------- Evaluating on the test set. --------")
# Zero-shot CLIP
clip_logits = 100. * test_features @ clip_weights
acc = cls_acc(clip_logits, test_labels)
print("\n**** Zero-shot CLIP's test accuracy: {:.2f}. ****\n".format(acc))
# Tip-Adapter
affinity = test_features @ cache_keys
cache_logits = ((-1) * (best_beta - best_beta * affinity)).exp() @ cache_values
tip_logits = clip_logits + cache_logits * best_alpha
acc = cls_acc(tip_logits, test_labels)
print("**** Tip-Adapter's test accuracy: {:.2f}. ****\n".format(acc))
def run_tip_adapter_F(cfg, cache_keys, cache_values, val_features, val_labels, test_features, test_labels, clip_weights, clip_model, train_loader_F):
# Enable the cached keys to be learnable
adapter = nn.Linear(cache_keys.shape[0], cache_keys.shape[1], bias=False).to(clip_model.dtype).cuda()
adapter.weight = nn.Parameter(cache_keys.t())
optimizer = torch.optim.AdamW(adapter.parameters(), lr=cfg['lr'], eps=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg['train_epoch'] * len(train_loader_F))
beta, alpha = cfg['init_beta'], cfg['init_alpha']
best_acc, best_epoch = 0.0, 0
for train_idx in range(cfg['train_epoch']):
# Train
adapter.train()
correct_samples, all_samples = 0, 0
loss_list = []
print('Train Epoch: {:} / {:}'.format(train_idx, cfg['train_epoch']))
for i, (images, target) in enumerate(tqdm(train_loader_F)):
images, target = images.cuda(), target.cuda()
with torch.no_grad():
image_features = clip_model.encode_image(images)
image_features /= image_features.norm(dim=-1, keepdim=True)
affinity = adapter(image_features)
cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values
clip_logits = 100. * image_features @ clip_weights
tip_logits = clip_logits + cache_logits * alpha
loss = F.cross_entropy(tip_logits, target)
acc = cls_acc(tip_logits, target)
correct_samples += acc / 100 * len(tip_logits)
all_samples += len(tip_logits)
loss_list.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
current_lr = scheduler.get_last_lr()[0]
print('LR: {:.6f}, Acc: {:.4f} ({:}/{:}), Loss: {:.4f}'.format(current_lr, correct_samples / all_samples, correct_samples, all_samples, sum(loss_list)/len(loss_list)))
# Eval
adapter.eval()
affinity = adapter(test_features)
cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values
clip_logits = 100. * test_features @ clip_weights
tip_logits = clip_logits + cache_logits * alpha
acc = cls_acc(tip_logits, test_labels)
print("**** Tip-Adapter-F's test accuracy: {:.2f}. ****\n".format(acc))
if acc > best_acc:
best_acc = acc
best_epoch = train_idx
torch.save(adapter.weight, cfg['cache_dir'] + "/best_F_" + str(cfg['shots']) + "shots.pt")
adapter.weight = torch.load(cfg['cache_dir'] + "/best_F_" + str(cfg['shots']) + "shots.pt")
print(f"**** After fine-tuning, Tip-Adapter-F's best test accuracy: {best_acc:.2f}, at epoch: {best_epoch}. ****\n")
print("\n-------- Searching hyperparameters on the val set. --------")
# Search Hyperparameters
best_beta, best_alpha = search_hp(cfg, cache_keys, cache_values, val_features, val_labels, clip_weights, adapter=adapter)
print("\n-------- Evaluating on the test set. --------")
affinity = adapter(test_features)
cache_logits = ((-1) * (best_beta - best_beta * affinity)).exp() @ cache_values
tip_logits = clip_logits + cache_logits * best_alpha
acc = cls_acc(tip_logits, test_labels)
print("**** Tip-Adapter-F's test accuracy: {:.2f}. ****\n".format(max(best_acc, acc)))
def main():
# Load config file
args = get_arguments()
assert (os.path.exists(args.config))
cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
cache_dir = os.path.join('./caches', cfg['dataset'])
os.makedirs(cache_dir, exist_ok=True)
cfg['cache_dir'] = cache_dir
print("\nRunning configs.")
print(cfg, "\n")
# CLIP
clip_model, preprocess = clip.load(cfg['backbone'])
clip_model.eval()
# Prepare dataset
random.seed(1)
torch.manual_seed(1)
print("Preparing dataset.")
dataset = build_dataset(cfg['dataset'], cfg['root_path'], cfg['shots'])
val_loader = build_data_loader(data_source=dataset.val, batch_size=64, is_train=False, tfm=preprocess, shuffle=False)
test_loader = build_data_loader(data_source=dataset.test, batch_size=64, is_train=False, tfm=preprocess, shuffle=False)
train_tranform = transforms.Compose([
transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
])
train_loader_cache = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=False)
train_loader_F = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=True)
# Textual features
print("\nGetting textual features as CLIP's classifier.")
clip_weights = clip_classifier(dataset.classnames, dataset.template, clip_model)
# Construct the cache model by few-shot training set
print("\nConstructing cache model by few-shot visual features and labels.")
cache_keys, cache_values = build_cache_model(cfg, clip_model, train_loader_cache)
# Pre-load val features
print("\nLoading visual features and labels from val set.")
val_features, val_labels = pre_load_features(cfg, "val", clip_model, val_loader)
# Pre-load test features
print("\nLoading visual features and labels from test set.")
test_features, test_labels = pre_load_features(cfg, "test", clip_model, test_loader)
# ------------------------------------------ Tip-Adapter ------------------------------------------
run_tip_adapter(cfg, cache_keys, cache_values, val_features, val_labels, test_features, test_labels, clip_weights)
# ------------------------------------------ Tip-Adapter-F ------------------------------------------
run_tip_adapter_F(cfg, cache_keys, cache_values, val_features, val_labels, test_features, test_labels, clip_weights, clip_model, train_loader_F)
if __name__ == '__main__':
main()