refactor: 迁移布局检测模型从 PicoDet 到 DocLayout-YOLO
- 核心变更: - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024) - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备 - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式 - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls] - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table) - 新增文件: - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行) - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例) - 配置更新: - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID) - app/config.py: Settings 类对应字段 - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等) - 删除旧文件: - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本 - 文档更新: - README.md: 更新环境变量说明 - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py) 此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
This commit is contained in:
@@ -0,0 +1,161 @@
|
||||
"""导出 DocLayout-YOLO (DocStructBench, imgsz=1024) 为 ONNX 格式.
|
||||
|
||||
一次性脚本,在独立 venv 中运行(不进运行时依赖):
|
||||
python -m venv .venv-export && source .venv-export/bin/activate
|
||||
pip install torch torchvision onnx onnxscript onnxsim onnxruntime huggingface-hub doclayout-yolo
|
||||
|
||||
两种权重来源:
|
||||
# 1) 用本地已下载的 .pt(推荐,省下载)
|
||||
.venv/bin/python scripts/export_doclayout_yolo_onnx.py \
|
||||
--weights /path/to/doclayout_yolo_docstructbench_imgsz1024.pt
|
||||
|
||||
# 2) 从 HuggingFace 下载(不传 --weights)
|
||||
HF_ENDPOINT=https://hf-mirror.com .venv/bin/python scripts/export_doclayout_yolo_onnx.py
|
||||
|
||||
输出:
|
||||
data/models/doclayout_yolo_docstructbench_imgsz1024.onnx
|
||||
|
||||
注意:model.export(simplify=True) 会清空 ONNX metadata,本脚本在导出后
|
||||
用 onnx 包把 names 重新写回 metadata,供运行时 _parse_names_from_meta 读取。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# hf-mirror(国内加速,仅 --weights 未传、走 HF 下载时生效)
|
||||
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
MODEL_DIR = PROJECT_ROOT / "data" / "models"
|
||||
DEFAULT_OUTPUT = MODEL_DIR / "doclayout_yolo_docstructbench_imgsz1024.onnx"
|
||||
REPO_ID = "juliozhao/DocLayout-YOLO-DocStructBench"
|
||||
PT_FILENAME = "doclayout_yolo_docstructbench.pt"
|
||||
IMGSZ = 1024
|
||||
|
||||
|
||||
def resolve_weights(arg: str | None) -> Path:
|
||||
"""返回 .pt 路径:传 --weights 用本地,否则从 HuggingFace 下载。"""
|
||||
if arg:
|
||||
p = Path(arg)
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"--weights not found: {p}")
|
||||
print(f"[1/5] Using local weights: {p}")
|
||||
return p
|
||||
|
||||
print(f"[1/5] Downloading .pt from HuggingFace ({REPO_ID}) ...")
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
pt_path = Path(hf_hub_download(repo_id=REPO_ID, filename=PT_FILENAME))
|
||||
print(f" ✓ {pt_path}")
|
||||
return pt_path
|
||||
|
||||
|
||||
def export_onnx(pt_path: Path, output: Path) -> None:
|
||||
print("\n[2/5] Loading model with doclayout_yolo ...")
|
||||
from doclayout_yolo import YOLOv10
|
||||
|
||||
model = YOLOv10(str(pt_path))
|
||||
names = model.names # dict[int, str],与 model.model.names 等价
|
||||
print(f" ✓ Loaded. names = {names}")
|
||||
|
||||
print(f"\n[3/5] Exporting ONNX (imgsz={IMGSZ}, opset=12, simplify=True) ...")
|
||||
try:
|
||||
exported = model.export(
|
||||
format="onnx",
|
||||
imgsz=IMGSZ,
|
||||
opset=12,
|
||||
simplify=True, # 需要 onnxsim;失败则下面回退
|
||||
dynamic=False, # 固定 batch=1 + 固定 1024,部署最稳
|
||||
half=False, # FP32,保证 CPU 推理精度
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" ⚠ export with simplify failed ({e}); retrying without simplify")
|
||||
exported = model.export(
|
||||
format="onnx", imgsz=IMGSZ, opset=12, dynamic=False, half=False
|
||||
)
|
||||
exported_path = Path(exported)
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(exported_path), str(output))
|
||||
print(f" ✓ Exported → {output} ({output.stat().st_size / 1024 / 1024:.1f} MB)")
|
||||
|
||||
print("\n[4/5] Re-writing names metadata (simplify may have dropped it) ...")
|
||||
write_names_metadata(output, names)
|
||||
|
||||
|
||||
def write_names_metadata(onnx_path: Path, names: dict) -> None:
|
||||
"""把 names dict 写入 ONNX model.metadata_props(simplify 后通常丢失)。"""
|
||||
import onnx
|
||||
|
||||
m = onnx.load(str(onnx_path))
|
||||
keep = [p for p in m.metadata_props if p.key != "names"]
|
||||
del m.metadata_props[:]
|
||||
m.metadata_props.extend(keep)
|
||||
names_json = json.dumps({str(k): v for k, v in names.items()}, ensure_ascii=False)
|
||||
m.metadata_props.append(onnx.StringStringEntryProto(key="names", value=names_json))
|
||||
onnx.save(m, str(onnx_path))
|
||||
print(f" ✓ names metadata written: {names_json}")
|
||||
|
||||
|
||||
def inspect_onnx(onnx_path: Path) -> None:
|
||||
"""用 onnxruntime 加载模型,打印输入输出 + names metadata + 试推理。"""
|
||||
print("\n[5/5] Verifying with onnxruntime ...")
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
|
||||
print(" Inputs:")
|
||||
for inp in session.get_inputs():
|
||||
print(f" {inp.name}: shape={inp.shape}, dtype={inp.type}")
|
||||
print(" Outputs:")
|
||||
for out in session.get_outputs():
|
||||
print(f" {out.name}: shape={out.shape}, dtype={out.type}")
|
||||
|
||||
meta = session.get_modelmeta()
|
||||
print(f" metadata keys: {list(meta.custom_metadata_map.keys())}")
|
||||
print(f" names: {meta.custom_metadata_map.get('names')}")
|
||||
|
||||
# dummy 推理
|
||||
input_info = session.get_inputs()[0]
|
||||
h = input_info.shape[2] if isinstance(input_info.shape[2], int) else IMGSZ
|
||||
w = input_info.shape[3] if isinstance(input_info.shape[3], int) else IMGSZ
|
||||
dummy = np.random.rand(1, 3, h, w).astype(np.float32)
|
||||
outputs = session.run(None, {input_info.name: dummy})
|
||||
print(f" Inference test: output[0] shape = {outputs[0].shape}")
|
||||
|
||||
out_shape = outputs[0].shape
|
||||
if len(out_shape) == 3 and out_shape[2] == 6:
|
||||
print(" ✓ output is [1, N, 6] (YOLOv10 end-to-end, NMS applied)")
|
||||
else:
|
||||
print(
|
||||
f" ⚠️ output shape {out_shape} ≠ [1, N, 6]; "
|
||||
"layout_detector._postprocess_output will warn and skip pages — "
|
||||
"adjust export (e.g. end2end/nms) or postprocess.",
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
ap.add_argument("--weights", help="本地 .pt 路径(不传则从 HuggingFace 下载)")
|
||||
ap.add_argument(
|
||||
"--output",
|
||||
default=str(DEFAULT_OUTPUT),
|
||||
help=f"输出 ONNX 路径(默认 {DEFAULT_OUTPUT})",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
output = Path(args.output)
|
||||
pt_path = resolve_weights(args.weights)
|
||||
export_onnx(pt_path, output)
|
||||
inspect_onnx(output)
|
||||
print(f"\n✓ Done! ONNX model saved to {output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,172 +0,0 @@
|
||||
"""导出 PicoDet-S_layout_3cls 为 ONNX 格式.
|
||||
|
||||
一次性脚本,在独立 venv 中运行:
|
||||
python -m venv .venv-export && source .venv-export/bin/activate
|
||||
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple paddlepaddle paddleocr paddle2onnx onnxruntime opencv-python-headless
|
||||
HF_ENDPOINT=https://hf-mirror.com python scripts/export_picodet_onnx.py
|
||||
|
||||
输出:
|
||||
data/models/picodet_layout_3cls.onnx
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# hf-mirror
|
||||
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
MODEL_DIR = PROJECT_ROOT / "data" / "models"
|
||||
OUTPUT_PATH = MODEL_DIR / "picodet_layout_3cls.onnx"
|
||||
MODEL_NAME = "PicoDet-S_layout_3cls"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ── Step 1: 用 PaddleOCR paddle_static 引擎加载模型,触发下载 ──
|
||||
print(f"[1/4] Loading model '{MODEL_NAME}' (paddle_static engine, triggers download) ...")
|
||||
from paddleocr import LayoutDetection
|
||||
|
||||
model = LayoutDetection(
|
||||
model_name=MODEL_NAME,
|
||||
engine="paddle_static",
|
||||
device="cpu",
|
||||
)
|
||||
print(" ✓ Model loaded and cached")
|
||||
|
||||
# ── Step 2: 找到 PaddleX 缓存的 Paddle 模型文件 ────────────────
|
||||
paddlex_cache = Path.home() / ".paddlex"
|
||||
print(f"\n[2/4] Searching Paddle model cache in {paddlex_cache} ...")
|
||||
|
||||
# 搜索 layout 相关的缓存目录
|
||||
candidates = []
|
||||
for d in paddlex_cache.rglob("*"):
|
||||
if d.is_dir() and (d / "inference.pdiparams").exists():
|
||||
# 检查是否是 layout 模型
|
||||
marker = d.name
|
||||
parent_name = d.parent.name
|
||||
if "layout" in marker.lower() or "layout" in parent_name.lower() or "picodet" in marker.lower():
|
||||
candidates.append(d)
|
||||
elif "PicoDet" in str(d):
|
||||
candidates.append(d)
|
||||
|
||||
if not candidates:
|
||||
# 如果没找到明确的 layout 目录,列出所有含 inference.pdiparams 的目录
|
||||
all_model_dirs = [d for d in paddlex_cache.rglob("*") if d.is_dir() and (d / "inference.pdiparams").exists()]
|
||||
print(" No layout-specific dir found. All model dirs with inference.pdiparams:")
|
||||
for d in all_model_dirs:
|
||||
files = [f.name for f in d.iterdir()]
|
||||
print(f" {d} ({', '.join(files)})")
|
||||
if all_model_dirs:
|
||||
# 取最新的(刚下载的)
|
||||
candidates = sorted(all_model_dirs, key=lambda d: (d / "inference.pdiparams").stat().st_mtime, reverse=True)[:1]
|
||||
|
||||
if not candidates:
|
||||
print(" ✗ No cached model found")
|
||||
sys.exit(1)
|
||||
|
||||
model_cache_dir = candidates[0]
|
||||
files_in_dir = list(model_cache_dir.iterdir())
|
||||
print(f" Using: {model_cache_dir}")
|
||||
for f in files_in_dir:
|
||||
print(f" {f.name} ({f.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
# ── Step 3: 用 paddle2onnx 转换 ─────────────────────────────────
|
||||
print("\n[3/4] Converting to ONNX with paddle2onnx ...")
|
||||
tmp_onnx = OUTPUT_PATH.with_suffix(".tmp.onnx")
|
||||
|
||||
# 确定 model_filename
|
||||
pdmodel = model_cache_dir / "inference.pdmodel"
|
||||
has_pdmodel = pdmodel.exists()
|
||||
|
||||
cmd = [
|
||||
sys.executable, "-m", "paddle2onnx",
|
||||
"--model_dir", str(model_cache_dir),
|
||||
"--save_file", str(tmp_onnx),
|
||||
"--opset_version", "11",
|
||||
"--enable_onnx_checker", "True",
|
||||
]
|
||||
if has_pdmodel:
|
||||
cmd.extend(["--model_filename", "inference.pdmodel"])
|
||||
cmd.extend(["--params_filename", "inference.pdiparams"])
|
||||
|
||||
print(f" Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.stdout:
|
||||
print(f" stdout: {result.stdout[:500]}")
|
||||
if result.returncode != 0:
|
||||
print(f" ✗ paddle2onnx failed (exit {result.returncode})")
|
||||
print(f" stderr: {result.stderr[:500]}")
|
||||
|
||||
# 尝试不带 model_filename(combined format)
|
||||
if has_pdmodel:
|
||||
print(" Retrying without explicit model_filename ...")
|
||||
cmd2 = [
|
||||
sys.executable, "-m", "paddle2onnx",
|
||||
"--model_dir", str(model_cache_dir),
|
||||
"--params_filename", "inference.pdiparams",
|
||||
"--save_file", str(tmp_onnx),
|
||||
"--opset_version", "11",
|
||||
]
|
||||
result2 = subprocess.run(cmd2, capture_output=True, text=True)
|
||||
if result2.returncode != 0:
|
||||
print(f" ✗ Retry also failed: {result2.stderr[:500]}")
|
||||
sys.exit(1)
|
||||
|
||||
if not tmp_onnx.exists() or tmp_onnx.stat().st_size < 1000:
|
||||
print(" ✗ ONNX file not created or too small")
|
||||
sys.exit(1)
|
||||
|
||||
shutil.move(str(tmp_onnx), str(OUTPUT_PATH))
|
||||
print(f" ✓ ONNX saved ({OUTPUT_PATH.stat().st_size / 1024 / 1024:.2f} MB)")
|
||||
|
||||
# ── Step 4: 用 onnxruntime 验证 ─────────────────────────────────
|
||||
print("\n[4/4] Verifying with onnxruntime ...")
|
||||
_inspect_onnx(OUTPUT_PATH)
|
||||
|
||||
print(f"\n✓ Done! ONNX model saved to {OUTPUT_PATH}")
|
||||
|
||||
|
||||
def _inspect_onnx(onnx_path: Path) -> None:
|
||||
"""用 onnxruntime 加载模型,打印输入输出信息."""
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
|
||||
|
||||
print(" Inputs:")
|
||||
for inp in session.get_inputs():
|
||||
print(f" {inp.name}: shape={inp.shape}, dtype={inp.type}")
|
||||
|
||||
print(" Outputs:")
|
||||
for out in session.get_outputs():
|
||||
print(f" {out.name}: shape={out.shape}, dtype={out.type}")
|
||||
|
||||
# 试推理
|
||||
input_info = session.get_inputs()[0]
|
||||
input_name = input_info.name
|
||||
batch_size = input_info.shape[0] if isinstance(input_info.shape[0], int) else 1
|
||||
channels = input_info.shape[1] if isinstance(input_info.shape[1], int) else 3
|
||||
height = input_info.shape[2] if isinstance(input_info.shape[2], int) else 480
|
||||
width = input_info.shape[3] if isinstance(input_info.shape[3], int) else 480
|
||||
|
||||
dummy_input = np.random.rand(batch_size, channels, height, width).astype(np.float32)
|
||||
outputs = session.run(None, {input_name: dummy_input})
|
||||
|
||||
print(" Inference test outputs:")
|
||||
for i, (out_info, out_val) in enumerate(zip(session.get_outputs(), outputs)):
|
||||
print(f" output[{i}] '{out_info.name}': shape={out_val.shape}, dtype={out_val.dtype}")
|
||||
if out_val.size <= 20:
|
||||
print(f" values: {out_val}")
|
||||
|
||||
print(" ✓ Inference OK")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,4 +1,4 @@
|
||||
"""批量重新提取所有论文的图片 — 下载 PDF + PicoDet 检测 + caption 匹配.
|
||||
"""批量重新提取所有论文的图片 — 下载 PDF + DocLayout-YOLO 检测 + caption 匹配.
|
||||
|
||||
用法:
|
||||
PROXY_SERVER=http://... uv run python scripts/reextract_images.py
|
||||
|
||||
+73
-43
@@ -3,8 +3,19 @@ import sys
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"required": ["arxiv_id", "title_zh", "one_line", "tags", "difficulty",
|
||||
"prerequisites", "motivation", "method", "results", "improvements", "figures"],
|
||||
"required": [
|
||||
"arxiv_id",
|
||||
"title_zh",
|
||||
"one_line",
|
||||
"tags",
|
||||
"difficulty",
|
||||
"prerequisites",
|
||||
"motivation",
|
||||
"method",
|
||||
"results",
|
||||
"improvements",
|
||||
"figures",
|
||||
],
|
||||
"properties": {
|
||||
"arxiv_id": {"type": "string"},
|
||||
"title_zh": {"type": "string"},
|
||||
@@ -15,16 +26,19 @@ schema = {
|
||||
"type": "object",
|
||||
"required": ["concepts"],
|
||||
"properties": {
|
||||
"concepts": {"type": "array", "items": {
|
||||
"type": "object",
|
||||
"required": ["term", "explanation", "why_matters"],
|
||||
"properties": {
|
||||
"term": {"type": "string"},
|
||||
"explanation": {"type": "string"},
|
||||
"why_matters": {"type": "string"}
|
||||
}
|
||||
}}
|
||||
}
|
||||
"concepts": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["term", "explanation", "why_matters"],
|
||||
"properties": {
|
||||
"term": {"type": "string"},
|
||||
"explanation": {"type": "string"},
|
||||
"why_matters": {"type": "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
"motivation": {
|
||||
"type": "object",
|
||||
@@ -32,8 +46,8 @@ schema = {
|
||||
"properties": {
|
||||
"problem": {"type": "string"},
|
||||
"goal": {"type": "string"},
|
||||
"gap": {"type": "string"}
|
||||
}
|
||||
"gap": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"method": {
|
||||
"type": "object",
|
||||
@@ -42,27 +56,36 @@ schema = {
|
||||
"overview": {"type": "string"},
|
||||
"key_idea": {"type": "string"},
|
||||
"steps": {"type": "string"},
|
||||
"novelty": {"type": "string"}
|
||||
}
|
||||
"novelty": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"results": {
|
||||
"type": "object",
|
||||
"required": ["main_findings", "benchmarks", "limitations"],
|
||||
"properties": {
|
||||
"main_findings": {"type": "string"},
|
||||
"benchmarks": {"type": "array", "items": {
|
||||
"type": "object",
|
||||
"required": ["task", "metric", "this_work", "baseline", "improvement"],
|
||||
"properties": {
|
||||
"task": {"type": "string"},
|
||||
"metric": {"type": "string"},
|
||||
"this_work": {"type": "string"},
|
||||
"baseline": {"type": "string"},
|
||||
"improvement": {"type": "string"}
|
||||
}
|
||||
}},
|
||||
"limitations": {"type": "string"}
|
||||
}
|
||||
"benchmarks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"task",
|
||||
"metric",
|
||||
"this_work",
|
||||
"baseline",
|
||||
"improvement",
|
||||
],
|
||||
"properties": {
|
||||
"task": {"type": "string"},
|
||||
"metric": {"type": "string"},
|
||||
"this_work": {"type": "string"},
|
||||
"baseline": {"type": "string"},
|
||||
"improvement": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"limitations": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"improvements": {
|
||||
"type": "object",
|
||||
@@ -70,8 +93,8 @@ schema = {
|
||||
"properties": {
|
||||
"weaknesses": {"type": "string"},
|
||||
"future_work": {"type": "string"},
|
||||
"reproducibility": {"type": "string"}
|
||||
}
|
||||
"reproducibility": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"figures": {
|
||||
"type": "array",
|
||||
@@ -83,24 +106,28 @@ schema = {
|
||||
"caption": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"reason": {"type": "string"},
|
||||
"section": {"type": "string", "enum": ["motivation", "method", "results", "limitations"]}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"section": {
|
||||
"type": "string",
|
||||
"enum": ["motivation", "method", "results", "limitations"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def validate_file(filepath):
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
# Check required fields
|
||||
for field in schema["required"]:
|
||||
if field not in data:
|
||||
print(f"❌ Missing field: {field}")
|
||||
return False
|
||||
|
||||
|
||||
# Validate nested structure
|
||||
for field, spec in schema["properties"].items():
|
||||
if field in data:
|
||||
@@ -121,17 +148,17 @@ def validate_file(filepath):
|
||||
if subfield not in data[field]:
|
||||
print(f"❌ Missing subfield: {field}.{subfield}")
|
||||
return False
|
||||
|
||||
|
||||
# Validate section enum in figures
|
||||
valid_sections = ["motivation", "method", "results", "limitations"]
|
||||
for fig in data.get("figures", []):
|
||||
if fig["section"] not in valid_sections:
|
||||
print(f"❌ Invalid section in figure: {fig['section']}")
|
||||
return False
|
||||
|
||||
|
||||
print("✅ JSON validation passed!")
|
||||
return True
|
||||
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"❌ JSON decode error: {e}")
|
||||
return False
|
||||
@@ -139,6 +166,9 @@ def validate_file(filepath):
|
||||
print(f"❌ Validation error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
filepath = sys.argv[1] if len(sys.argv) > 1 else "data/papers/2601.10592/summary.json"
|
||||
filepath = (
|
||||
sys.argv[1] if len(sys.argv) > 1 else "data/papers/2601.10592/summary.json"
|
||||
)
|
||||
validate_file(filepath)
|
||||
|
||||
Reference in New Issue
Block a user