Files
daily-paper/scripts/export_doclayout_yolo_onnx.py
T
Rain-Bus 90fe705e8f refactor: 迁移布局检测模型从 PicoDet 到 DocLayout-YOLO
- 核心变更:
  - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024)
  - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备
  - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式
  - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls]
  - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table)

- 新增文件:
  - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行)
  - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例)

- 配置更新:
  - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID)
  - app/config.py: Settings 类对应字段
  - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等)

- 删除旧文件:
  - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本

- 文档更新:
  - README.md: 更新环境变量说明
  - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py)

此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
2026-06-14 10:41:44 +08:00

162 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""导出 DocLayout-YOLO (DocStructBench, imgsz=1024) 为 ONNX 格式.
一次性脚本,在独立 venv 中运行(不进运行时依赖):
python -m venv .venv-export && source .venv-export/bin/activate
pip install torch torchvision onnx onnxscript onnxsim onnxruntime huggingface-hub doclayout-yolo
两种权重来源:
# 1) 用本地已下载的 .pt(推荐,省下载)
.venv/bin/python scripts/export_doclayout_yolo_onnx.py \
--weights /path/to/doclayout_yolo_docstructbench_imgsz1024.pt
# 2) 从 HuggingFace 下载(不传 --weights
HF_ENDPOINT=https://hf-mirror.com .venv/bin/python scripts/export_doclayout_yolo_onnx.py
输出:
data/models/doclayout_yolo_docstructbench_imgsz1024.onnx
注意:model.export(simplify=True) 会清空 ONNX metadata,本脚本在导出后
用 onnx 包把 names 重新写回 metadata,供运行时 _parse_names_from_meta 读取。
"""
from __future__ import annotations
import argparse
import json
import os
import shutil
from pathlib import Path
# hf-mirror(国内加速,仅 --weights 未传、走 HF 下载时生效)
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
PROJECT_ROOT = Path(__file__).resolve().parent.parent
MODEL_DIR = PROJECT_ROOT / "data" / "models"
DEFAULT_OUTPUT = MODEL_DIR / "doclayout_yolo_docstructbench_imgsz1024.onnx"
REPO_ID = "juliozhao/DocLayout-YOLO-DocStructBench"
PT_FILENAME = "doclayout_yolo_docstructbench.pt"
IMGSZ = 1024
def resolve_weights(arg: str | None) -> Path:
"""返回 .pt 路径:传 --weights 用本地,否则从 HuggingFace 下载。"""
if arg:
p = Path(arg)
if not p.exists():
raise FileNotFoundError(f"--weights not found: {p}")
print(f"[1/5] Using local weights: {p}")
return p
print(f"[1/5] Downloading .pt from HuggingFace ({REPO_ID}) ...")
from huggingface_hub import hf_hub_download
pt_path = Path(hf_hub_download(repo_id=REPO_ID, filename=PT_FILENAME))
print(f"{pt_path}")
return pt_path
def export_onnx(pt_path: Path, output: Path) -> None:
print("\n[2/5] Loading model with doclayout_yolo ...")
from doclayout_yolo import YOLOv10
model = YOLOv10(str(pt_path))
names = model.names # dict[int, str],与 model.model.names 等价
print(f" ✓ Loaded. names = {names}")
print(f"\n[3/5] Exporting ONNX (imgsz={IMGSZ}, opset=12, simplify=True) ...")
try:
exported = model.export(
format="onnx",
imgsz=IMGSZ,
opset=12,
simplify=True, # 需要 onnxsim;失败则下面回退
dynamic=False, # 固定 batch=1 + 固定 1024,部署最稳
half=False, # FP32,保证 CPU 推理精度
)
except Exception as e:
print(f" ⚠ export with simplify failed ({e}); retrying without simplify")
exported = model.export(
format="onnx", imgsz=IMGSZ, opset=12, dynamic=False, half=False
)
exported_path = Path(exported)
output.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(exported_path), str(output))
print(f" ✓ Exported → {output} ({output.stat().st_size / 1024 / 1024:.1f} MB)")
print("\n[4/5] Re-writing names metadata (simplify may have dropped it) ...")
write_names_metadata(output, names)
def write_names_metadata(onnx_path: Path, names: dict) -> None:
"""把 names dict 写入 ONNX model.metadata_propssimplify 后通常丢失)。"""
import onnx
m = onnx.load(str(onnx_path))
keep = [p for p in m.metadata_props if p.key != "names"]
del m.metadata_props[:]
m.metadata_props.extend(keep)
names_json = json.dumps({str(k): v for k, v in names.items()}, ensure_ascii=False)
m.metadata_props.append(onnx.StringStringEntryProto(key="names", value=names_json))
onnx.save(m, str(onnx_path))
print(f" ✓ names metadata written: {names_json}")
def inspect_onnx(onnx_path: Path) -> None:
"""用 onnxruntime 加载模型,打印输入输出 + names metadata + 试推理。"""
print("\n[5/5] Verifying with onnxruntime ...")
import numpy as np
import onnxruntime as ort
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
print(" Inputs:")
for inp in session.get_inputs():
print(f" {inp.name}: shape={inp.shape}, dtype={inp.type}")
print(" Outputs:")
for out in session.get_outputs():
print(f" {out.name}: shape={out.shape}, dtype={out.type}")
meta = session.get_modelmeta()
print(f" metadata keys: {list(meta.custom_metadata_map.keys())}")
print(f" names: {meta.custom_metadata_map.get('names')}")
# dummy 推理
input_info = session.get_inputs()[0]
h = input_info.shape[2] if isinstance(input_info.shape[2], int) else IMGSZ
w = input_info.shape[3] if isinstance(input_info.shape[3], int) else IMGSZ
dummy = np.random.rand(1, 3, h, w).astype(np.float32)
outputs = session.run(None, {input_info.name: dummy})
print(f" Inference test: output[0] shape = {outputs[0].shape}")
out_shape = outputs[0].shape
if len(out_shape) == 3 and out_shape[2] == 6:
print(" ✓ output is [1, N, 6] (YOLOv10 end-to-end, NMS applied)")
else:
print(
f" ⚠️ output shape {out_shape} ≠ [1, N, 6]; "
"layout_detector._postprocess_output will warn and skip pages — "
"adjust export (e.g. end2end/nms) or postprocess.",
)
def main() -> None:
ap = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
ap.add_argument("--weights", help="本地 .pt 路径(不传则从 HuggingFace 下载)")
ap.add_argument(
"--output",
default=str(DEFAULT_OUTPUT),
help=f"输出 ONNX 路径(默认 {DEFAULT_OUTPUT}",
)
args = ap.parse_args()
output = Path(args.output)
pt_path = resolve_weights(args.weights)
export_onnx(pt_path, output)
inspect_onnx(output)
print(f"\n✓ Done! ONNX model saved to {output}")
if __name__ == "__main__":
main()