scripts and template

This commit is contained in:
2026-02-04 10:24:11 +08:00
parent f9beacf476
commit ea5e9f17ba
3 changed files with 149 additions and 11 deletions

119
extract_acc.py Normal file
View File

@@ -0,0 +1,119 @@
import os
import re
from glob import glob
from collections import defaultdict
def extract_accuracy(log_path):
"""从日志文件中提取accuracy"""
try:
with open(log_path, 'r') as f:
content = f.read()
match = re.search(r'\* accuracy: (\d+\.\d+)%', content)
if match:
return float(match.group(1))
except:
pass
return None
def collect_model_results(root_dir, target_model):
"""收集指定模型在所有数据集上的结果z按seed分组"""
results = {
'base': defaultdict(list), # 使用列表存储多个seed的结果
'new': defaultdict(list),
'datasets': set()
}
# 查找所有base训练的log文件
base_logs = glob(os.path.join(root_dir, '**/train_base/**/log.txt'), recursive=True)
for log_path in base_logs:
parts = log_path.split(os.sep)
dataset = parts[-6]
model = parts[-4]
if model != target_model:
continue
accuracy = extract_accuracy(log_path)
if accuracy is not None:
results['base'][dataset].append(accuracy)
results['datasets'].add(dataset)
# 查找所有new测试的log文件
new_logs = glob(os.path.join(root_dir, '**/test_new/**/log.txt'), recursive=True)
for log_path in new_logs:
parts = log_path.split(os.sep)
dataset = parts[-6]
model = parts[-4]
if model != target_model:
continue
accuracy = extract_accuracy(log_path)
if accuracy is not None:
results['new'][dataset].append(accuracy)
results['datasets'].add(dataset)
return results
def calculate_harmonic_mean(base, new):
"""计算调和平均数"""
if base == 0 or new == 0:
return 0
return 2 * base * new / (base + new)
def calculate_average(values):
"""计算平均值"""
if not values:
return None
return sum(values) / len(values)
def print_model_results(results, model_name):
"""打印指定模型在所有数据集上的结果平均所有seed"""
datasets = sorted(results['datasets'])
# 准备数据用于计算总体平均值
base_sum = 0
new_sum = 0
valid_datasets = 0
print(f"\nResults for model: {model_name}")
print(f"{'Dataset':<15} {'Base':<10} {'New':<10} {'H':<10} {'Seeds':<10}")
print("-" * 60)
for dataset in datasets:
base_accs = results['base'].get(dataset, [])
new_accs = results['new'].get(dataset, [0.0, 0.0, 0.0])
if base_accs and new_accs:
avg_base = calculate_average(base_accs)
avg_new = calculate_average(new_accs)
h = calculate_harmonic_mean(avg_base, avg_new)
# 获取seed数量取base和new中较小的seed数
num_seeds = min(len(base_accs), len(new_accs))
print(f"{dataset:<15} {avg_base:.2f}{'':<6} {avg_new:.2f}{'':<6} {h:.2f}{'':<6} {num_seeds}")
base_sum += avg_base
new_sum += avg_new
valid_datasets += 1
# 计算并打印总体平均值
if valid_datasets > 0:
avg_base = base_sum / valid_datasets
avg_new = new_sum / valid_datasets
avg_h = calculate_harmonic_mean(avg_base, avg_new)
print("-" * 60)
print(f"{'Average':<15} {avg_base:.2f}{'':<6} {avg_new:.2f}{'':<6} {avg_h:.2f}")
else:
print("No complete dataset results found for this model.")
def main():
root_dir = 'output' # 修改为你的output目录路径
target_model = 'PromptSRC' # 指定要分析的模型
results = collect_model_results(root_dir, target_model)
print_model_results(results, target_model)
if __name__ == '__main__':
main()