"""导出 DocLayout-YOLO (DocStructBench, imgsz=1024) 为 ONNX 格式. 一次性脚本,在独立 venv 中运行(不进运行时依赖): python -m venv .venv-export && source .venv-export/bin/activate pip install torch torchvision onnx onnxscript onnxsim onnxruntime huggingface-hub doclayout-yolo 两种权重来源: # 1) 用本地已下载的 .pt(推荐,省下载) .venv/bin/python scripts/export_doclayout_yolo_onnx.py \ --weights /path/to/doclayout_yolo_docstructbench_imgsz1024.pt # 2) 从 HuggingFace 下载(不传 --weights) HF_ENDPOINT=https://hf-mirror.com .venv/bin/python scripts/export_doclayout_yolo_onnx.py 输出: data/models/doclayout_yolo_docstructbench_imgsz1024.onnx 注意:model.export(simplify=True) 会清空 ONNX metadata,本脚本在导出后 用 onnx 包把 names 重新写回 metadata,供运行时 _parse_names_from_meta 读取。 """ from __future__ import annotations import argparse import json import os import shutil from pathlib import Path # hf-mirror(国内加速,仅 --weights 未传、走 HF 下载时生效) os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com") PROJECT_ROOT = Path(__file__).resolve().parent.parent MODEL_DIR = PROJECT_ROOT / "data" / "models" DEFAULT_OUTPUT = MODEL_DIR / "doclayout_yolo_docstructbench_imgsz1024.onnx" REPO_ID = "juliozhao/DocLayout-YOLO-DocStructBench" PT_FILENAME = "doclayout_yolo_docstructbench.pt" IMGSZ = 1024 def resolve_weights(arg: str | None) -> Path: """返回 .pt 路径:传 --weights 用本地,否则从 HuggingFace 下载。""" if arg: p = Path(arg) if not p.exists(): raise FileNotFoundError(f"--weights not found: {p}") print(f"[1/5] Using local weights: {p}") return p print(f"[1/5] Downloading .pt from HuggingFace ({REPO_ID}) ...") from huggingface_hub import hf_hub_download pt_path = Path(hf_hub_download(repo_id=REPO_ID, filename=PT_FILENAME)) print(f" ✓ {pt_path}") return pt_path def export_onnx(pt_path: Path, output: Path) -> None: print("\n[2/5] Loading model with doclayout_yolo ...") from doclayout_yolo import YOLOv10 model = YOLOv10(str(pt_path)) names = model.names # dict[int, str],与 model.model.names 等价 print(f" ✓ Loaded. names = {names}") print(f"\n[3/5] Exporting ONNX (imgsz={IMGSZ}, opset=12, simplify=True) ...") try: exported = model.export( format="onnx", imgsz=IMGSZ, opset=12, simplify=True, # 需要 onnxsim;失败则下面回退 dynamic=False, # 固定 batch=1 + 固定 1024,部署最稳 half=False, # FP32,保证 CPU 推理精度 ) except Exception as e: print(f" ⚠ export with simplify failed ({e}); retrying without simplify") exported = model.export( format="onnx", imgsz=IMGSZ, opset=12, dynamic=False, half=False ) exported_path = Path(exported) output.parent.mkdir(parents=True, exist_ok=True) shutil.copy(str(exported_path), str(output)) print(f" ✓ Exported → {output} ({output.stat().st_size / 1024 / 1024:.1f} MB)") print("\n[4/5] Re-writing names metadata (simplify may have dropped it) ...") write_names_metadata(output, names) def write_names_metadata(onnx_path: Path, names: dict) -> None: """把 names dict 写入 ONNX model.metadata_props(simplify 后通常丢失)。""" import onnx m = onnx.load(str(onnx_path)) keep = [p for p in m.metadata_props if p.key != "names"] del m.metadata_props[:] m.metadata_props.extend(keep) names_json = json.dumps({str(k): v for k, v in names.items()}, ensure_ascii=False) m.metadata_props.append(onnx.StringStringEntryProto(key="names", value=names_json)) onnx.save(m, str(onnx_path)) print(f" ✓ names metadata written: {names_json}") def inspect_onnx(onnx_path: Path) -> None: """用 onnxruntime 加载模型,打印输入输出 + names metadata + 试推理。""" print("\n[5/5] Verifying with onnxruntime ...") import numpy as np import onnxruntime as ort session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) print(" Inputs:") for inp in session.get_inputs(): print(f" {inp.name}: shape={inp.shape}, dtype={inp.type}") print(" Outputs:") for out in session.get_outputs(): print(f" {out.name}: shape={out.shape}, dtype={out.type}") meta = session.get_modelmeta() print(f" metadata keys: {list(meta.custom_metadata_map.keys())}") print(f" names: {meta.custom_metadata_map.get('names')}") # dummy 推理 input_info = session.get_inputs()[0] h = input_info.shape[2] if isinstance(input_info.shape[2], int) else IMGSZ w = input_info.shape[3] if isinstance(input_info.shape[3], int) else IMGSZ dummy = np.random.rand(1, 3, h, w).astype(np.float32) outputs = session.run(None, {input_info.name: dummy}) print(f" Inference test: output[0] shape = {outputs[0].shape}") out_shape = outputs[0].shape if len(out_shape) == 3 and out_shape[2] == 6: print(" ✓ output is [1, N, 6] (YOLOv10 end-to-end, NMS applied)") else: print( f" ⚠️ output shape {out_shape} ≠ [1, N, 6]; " "layout_detector._postprocess_output will warn and skip pages — " "adjust export (e.g. end2end/nms) or postprocess.", ) def main() -> None: ap = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) ap.add_argument("--weights", help="本地 .pt 路径(不传则从 HuggingFace 下载)") ap.add_argument( "--output", default=str(DEFAULT_OUTPUT), help=f"输出 ONNX 路径(默认 {DEFAULT_OUTPUT})", ) args = ap.parse_args() output = Path(args.output) pt_path = resolve_weights(args.weights) export_onnx(pt_path, output) inspect_onnx(output) print(f"\n✓ Done! ONNX model saved to {output}") if __name__ == "__main__": main()