scripts and template
This commit is contained in:
119
extract_acc.py
Normal file
119
extract_acc.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
from glob import glob
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
def extract_accuracy(log_path):
|
||||||
|
"""从日志文件中提取accuracy"""
|
||||||
|
try:
|
||||||
|
with open(log_path, 'r') as f:
|
||||||
|
content = f.read()
|
||||||
|
match = re.search(r'\* accuracy: (\d+\.\d+)%', content)
|
||||||
|
if match:
|
||||||
|
return float(match.group(1))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
def collect_model_results(root_dir, target_model):
|
||||||
|
"""收集指定模型在所有数据集上的结果,z按seed分组"""
|
||||||
|
results = {
|
||||||
|
'base': defaultdict(list), # 使用列表存储多个seed的结果
|
||||||
|
'new': defaultdict(list),
|
||||||
|
'datasets': set()
|
||||||
|
}
|
||||||
|
# 查找所有base训练的log文件
|
||||||
|
base_logs = glob(os.path.join(root_dir, '**/train_base/**/log.txt'), recursive=True)
|
||||||
|
for log_path in base_logs:
|
||||||
|
parts = log_path.split(os.sep)
|
||||||
|
dataset = parts[-6]
|
||||||
|
model = parts[-4]
|
||||||
|
|
||||||
|
if model != target_model:
|
||||||
|
continue
|
||||||
|
|
||||||
|
accuracy = extract_accuracy(log_path)
|
||||||
|
if accuracy is not None:
|
||||||
|
results['base'][dataset].append(accuracy)
|
||||||
|
results['datasets'].add(dataset)
|
||||||
|
|
||||||
|
# 查找所有new测试的log文件
|
||||||
|
new_logs = glob(os.path.join(root_dir, '**/test_new/**/log.txt'), recursive=True)
|
||||||
|
for log_path in new_logs:
|
||||||
|
parts = log_path.split(os.sep)
|
||||||
|
dataset = parts[-6]
|
||||||
|
model = parts[-4]
|
||||||
|
|
||||||
|
if model != target_model:
|
||||||
|
continue
|
||||||
|
|
||||||
|
accuracy = extract_accuracy(log_path)
|
||||||
|
if accuracy is not None:
|
||||||
|
results['new'][dataset].append(accuracy)
|
||||||
|
results['datasets'].add(dataset)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def calculate_harmonic_mean(base, new):
|
||||||
|
"""计算调和平均数"""
|
||||||
|
if base == 0 or new == 0:
|
||||||
|
return 0
|
||||||
|
return 2 * base * new / (base + new)
|
||||||
|
|
||||||
|
def calculate_average(values):
|
||||||
|
"""计算平均值"""
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
return sum(values) / len(values)
|
||||||
|
|
||||||
|
def print_model_results(results, model_name):
|
||||||
|
"""打印指定模型在所有数据集上的结果(平均所有seed)"""
|
||||||
|
datasets = sorted(results['datasets'])
|
||||||
|
|
||||||
|
# 准备数据用于计算总体平均值
|
||||||
|
base_sum = 0
|
||||||
|
new_sum = 0
|
||||||
|
valid_datasets = 0
|
||||||
|
|
||||||
|
print(f"\nResults for model: {model_name}")
|
||||||
|
print(f"{'Dataset':<15} {'Base':<10} {'New':<10} {'H':<10} {'Seeds':<10}")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
for dataset in datasets:
|
||||||
|
base_accs = results['base'].get(dataset, [])
|
||||||
|
new_accs = results['new'].get(dataset, [0.0, 0.0, 0.0])
|
||||||
|
|
||||||
|
if base_accs and new_accs:
|
||||||
|
avg_base = calculate_average(base_accs)
|
||||||
|
avg_new = calculate_average(new_accs)
|
||||||
|
h = calculate_harmonic_mean(avg_base, avg_new)
|
||||||
|
|
||||||
|
# 获取seed数量(取base和new中较小的seed数)
|
||||||
|
num_seeds = min(len(base_accs), len(new_accs))
|
||||||
|
|
||||||
|
print(f"{dataset:<15} {avg_base:.2f}{'':<6} {avg_new:.2f}{'':<6} {h:.2f}{'':<6} {num_seeds}")
|
||||||
|
|
||||||
|
base_sum += avg_base
|
||||||
|
new_sum += avg_new
|
||||||
|
valid_datasets += 1
|
||||||
|
|
||||||
|
# 计算并打印总体平均值
|
||||||
|
if valid_datasets > 0:
|
||||||
|
avg_base = base_sum / valid_datasets
|
||||||
|
avg_new = new_sum / valid_datasets
|
||||||
|
avg_h = calculate_harmonic_mean(avg_base, avg_new)
|
||||||
|
print("-" * 60)
|
||||||
|
print(f"{'Average':<15} {avg_base:.2f}{'':<6} {avg_new:.2f}{'':<6} {avg_h:.2f}")
|
||||||
|
else:
|
||||||
|
print("No complete dataset results found for this model.")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
root_dir = 'output' # 修改为你的output目录路径
|
||||||
|
target_model = 'PromptSRC' # 指定要分析的模型
|
||||||
|
|
||||||
|
results = collect_model_results(root_dir, target_model)
|
||||||
|
print_model_results(results, target_model)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
||||||
@@ -1,16 +1,16 @@
|
|||||||
seeds=(1 2 3)
|
seeds=(1 2 3)
|
||||||
datasets=(
|
datasets=(
|
||||||
"ucf101"
|
# "ucf101"
|
||||||
"eurosat"
|
# "eurosat"
|
||||||
"oxford_pets"
|
# "oxford_pets"
|
||||||
"food101"
|
# "food101"
|
||||||
"oxford_flowers"
|
# "oxford_flowers"
|
||||||
"dtd"
|
# "dtd"
|
||||||
"caltech101"
|
# "caltech101"
|
||||||
"fgvc_aircraft"
|
# "fgvc_aircraft"
|
||||||
"stanford_cars"
|
# "stanford_cars"
|
||||||
# "sun397"
|
# "sun397"
|
||||||
# "imagenet"
|
"imagenet"
|
||||||
)
|
)
|
||||||
|
|
||||||
for dataset in "${datasets[@]}"; do
|
for dataset in "${datasets[@]}"; do
|
||||||
|
|||||||
@@ -18,6 +18,24 @@ _tokenizer = _Tokenizer()
|
|||||||
DESC_LLM = "gpt-4.1"
|
DESC_LLM = "gpt-4.1"
|
||||||
DESC_TOPK = 4
|
DESC_TOPK = 4
|
||||||
|
|
||||||
|
CUSTOM_TEMPLATES = {
|
||||||
|
"OxfordPets": "a photo of a {}, a type of pet.",
|
||||||
|
"OxfordFlowers": "a photo of a {}, a type of flower.",
|
||||||
|
"FGVCAircraft": "a photo of a {}, a type of aircraft.",
|
||||||
|
"DescribableTextures": "a photo of a {}, a type of texture.",
|
||||||
|
"EuroSAT": "a centered satellite photo of {}.",
|
||||||
|
"StanfordCars": "a photo of a {}.",
|
||||||
|
"Food101": "a photo of {}, a type of food.",
|
||||||
|
"SUN397": "a photo of a {}.",
|
||||||
|
"Caltech101": "a photo of a {}.",
|
||||||
|
"UCF101": "a photo of a person doing {}.",
|
||||||
|
"ImageNet": "a photo of a {}.",
|
||||||
|
"ImageNetSketch": "a photo of a {}.",
|
||||||
|
"ImageNetV2": "a photo of a {}.",
|
||||||
|
"ImageNetA": "a photo of a {}.",
|
||||||
|
"ImageNetR": "a photo of a {}.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_clip_to_cpu(cfg, zero_shot_model=False):
|
def load_clip_to_cpu(cfg, zero_shot_model=False):
|
||||||
backbone_name = cfg.MODEL.BACKBONE.NAME
|
backbone_name = cfg.MODEL.BACKBONE.NAME
|
||||||
@@ -125,8 +143,9 @@ class VLPromptLearner(nn.Module):
|
|||||||
with open(desc_file, "r") as f:
|
with open(desc_file, "r") as f:
|
||||||
all_desc = json.load(f)
|
all_desc = json.load(f)
|
||||||
|
|
||||||
|
template = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
|
||||||
for cls in classnames:
|
for cls in classnames:
|
||||||
cls_descs = [f"a photo of {cls}, {desc}" for desc in all_desc[cls]]
|
cls_descs = [template.format(cls)[:-1] + f", {desc}" for desc in all_desc[cls]]
|
||||||
cls_token = torch.cat([clip.tokenize(cls_desc) for cls_desc in cls_descs]).cuda()
|
cls_token = torch.cat([clip.tokenize(cls_desc) for cls_desc in cls_descs]).cuda()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
cls_feature = clip_model_temp.encode_text(cls_token)
|
cls_feature = clip_model_temp.encode_text(cls_token)
|
||||||
|
|||||||
Reference in New Issue
Block a user