From ea5e9f17ba172d7c1c39eca14300430bc9caa4cb Mon Sep 17 00:00:00 2001 From: rain-bus Date: Wed, 4 Feb 2026 10:24:11 +0800 Subject: [PATCH] scripts and template --- extract_acc.py | 119 ++++++++++++++++++++++++++++++ scripts/promptsrc/base2new_all.sh | 20 ++--- trainers/promptsrc.py | 21 +++++- 3 files changed, 149 insertions(+), 11 deletions(-) create mode 100644 extract_acc.py diff --git a/extract_acc.py b/extract_acc.py new file mode 100644 index 0000000..dcd7e55 --- /dev/null +++ b/extract_acc.py @@ -0,0 +1,119 @@ +import os +import re +from glob import glob +from collections import defaultdict + +def extract_accuracy(log_path): + """从日志文件中提取accuracy""" + try: + with open(log_path, 'r') as f: + content = f.read() + match = re.search(r'\* accuracy: (\d+\.\d+)%', content) + if match: + return float(match.group(1)) + except: + pass + return None + +def collect_model_results(root_dir, target_model): + """收集指定模型在所有数据集上的结果,z按seed分组""" + results = { + 'base': defaultdict(list), # 使用列表存储多个seed的结果 + 'new': defaultdict(list), + 'datasets': set() + } + # 查找所有base训练的log文件 + base_logs = glob(os.path.join(root_dir, '**/train_base/**/log.txt'), recursive=True) + for log_path in base_logs: + parts = log_path.split(os.sep) + dataset = parts[-6] + model = parts[-4] + + if model != target_model: + continue + + accuracy = extract_accuracy(log_path) + if accuracy is not None: + results['base'][dataset].append(accuracy) + results['datasets'].add(dataset) + + # 查找所有new测试的log文件 + new_logs = glob(os.path.join(root_dir, '**/test_new/**/log.txt'), recursive=True) + for log_path in new_logs: + parts = log_path.split(os.sep) + dataset = parts[-6] + model = parts[-4] + + if model != target_model: + continue + + accuracy = extract_accuracy(log_path) + if accuracy is not None: + results['new'][dataset].append(accuracy) + results['datasets'].add(dataset) + + return results + +def calculate_harmonic_mean(base, new): + """计算调和平均数""" + if base == 0 or new == 0: + return 0 + return 2 * base * new / (base + new) + +def calculate_average(values): + """计算平均值""" + if not values: + return None + return sum(values) / len(values) + +def print_model_results(results, model_name): + """打印指定模型在所有数据集上的结果(平均所有seed)""" + datasets = sorted(results['datasets']) + + # 准备数据用于计算总体平均值 + base_sum = 0 + new_sum = 0 + valid_datasets = 0 + + print(f"\nResults for model: {model_name}") + print(f"{'Dataset':<15} {'Base':<10} {'New':<10} {'H':<10} {'Seeds':<10}") + print("-" * 60) + + for dataset in datasets: + base_accs = results['base'].get(dataset, []) + new_accs = results['new'].get(dataset, [0.0, 0.0, 0.0]) + + if base_accs and new_accs: + avg_base = calculate_average(base_accs) + avg_new = calculate_average(new_accs) + h = calculate_harmonic_mean(avg_base, avg_new) + + # 获取seed数量(取base和new中较小的seed数) + num_seeds = min(len(base_accs), len(new_accs)) + + print(f"{dataset:<15} {avg_base:.2f}{'':<6} {avg_new:.2f}{'':<6} {h:.2f}{'':<6} {num_seeds}") + + base_sum += avg_base + new_sum += avg_new + valid_datasets += 1 + + # 计算并打印总体平均值 + if valid_datasets > 0: + avg_base = base_sum / valid_datasets + avg_new = new_sum / valid_datasets + avg_h = calculate_harmonic_mean(avg_base, avg_new) + print("-" * 60) + print(f"{'Average':<15} {avg_base:.2f}{'':<6} {avg_new:.2f}{'':<6} {avg_h:.2f}") + else: + print("No complete dataset results found for this model.") + +def main(): + root_dir = 'output' # 修改为你的output目录路径 + target_model = 'PromptSRC' # 指定要分析的模型 + + results = collect_model_results(root_dir, target_model) + print_model_results(results, target_model) + +if __name__ == '__main__': + main() + diff --git a/scripts/promptsrc/base2new_all.sh b/scripts/promptsrc/base2new_all.sh index f5efcc8..65fdd42 100644 --- a/scripts/promptsrc/base2new_all.sh +++ b/scripts/promptsrc/base2new_all.sh @@ -1,16 +1,16 @@ seeds=(1 2 3) datasets=( - "ucf101" - "eurosat" - "oxford_pets" - "food101" - "oxford_flowers" - "dtd" - "caltech101" - "fgvc_aircraft" - "stanford_cars" + # "ucf101" + # "eurosat" + # "oxford_pets" + # "food101" + # "oxford_flowers" + # "dtd" + # "caltech101" + # "fgvc_aircraft" + # "stanford_cars" # "sun397" - # "imagenet" + "imagenet" ) for dataset in "${datasets[@]}"; do diff --git a/trainers/promptsrc.py b/trainers/promptsrc.py index 4d9712a..8c09fa2 100644 --- a/trainers/promptsrc.py +++ b/trainers/promptsrc.py @@ -18,6 +18,24 @@ _tokenizer = _Tokenizer() DESC_LLM = "gpt-4.1" DESC_TOPK = 4 +CUSTOM_TEMPLATES = { + "OxfordPets": "a photo of a {}, a type of pet.", + "OxfordFlowers": "a photo of a {}, a type of flower.", + "FGVCAircraft": "a photo of a {}, a type of aircraft.", + "DescribableTextures": "a photo of a {}, a type of texture.", + "EuroSAT": "a centered satellite photo of {}.", + "StanfordCars": "a photo of a {}.", + "Food101": "a photo of {}, a type of food.", + "SUN397": "a photo of a {}.", + "Caltech101": "a photo of a {}.", + "UCF101": "a photo of a person doing {}.", + "ImageNet": "a photo of a {}.", + "ImageNetSketch": "a photo of a {}.", + "ImageNetV2": "a photo of a {}.", + "ImageNetA": "a photo of a {}.", + "ImageNetR": "a photo of a {}.", +} + def load_clip_to_cpu(cfg, zero_shot_model=False): backbone_name = cfg.MODEL.BACKBONE.NAME @@ -125,8 +143,9 @@ class VLPromptLearner(nn.Module): with open(desc_file, "r") as f: all_desc = json.load(f) + template = CUSTOM_TEMPLATES[cfg.DATASET.NAME] for cls in classnames: - cls_descs = [f"a photo of {cls}, {desc}" for desc in all_desc[cls]] + cls_descs = [template.format(cls)[:-1] + f", {desc}" for desc in all_desc[cls]] cls_token = torch.cat([clip.tokenize(cls_desc) for cls_desc in cls_descs]).cuda() with torch.no_grad(): cls_feature = clip_model_temp.encode_text(cls_token)