90fe705e8f
- 核心变更: - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024) - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备 - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式 - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls] - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table) - 新增文件: - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行) - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例) - 配置更新: - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID) - app/config.py: Settings 类对应字段 - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等) - 删除旧文件: - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本 - 文档更新: - README.md: 更新环境变量说明 - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py) 此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
162 lines
6.0 KiB
Python
162 lines
6.0 KiB
Python
"""导出 DocLayout-YOLO (DocStructBench, imgsz=1024) 为 ONNX 格式.
|
||
|
||
一次性脚本,在独立 venv 中运行(不进运行时依赖):
|
||
python -m venv .venv-export && source .venv-export/bin/activate
|
||
pip install torch torchvision onnx onnxscript onnxsim onnxruntime huggingface-hub doclayout-yolo
|
||
|
||
两种权重来源:
|
||
# 1) 用本地已下载的 .pt(推荐,省下载)
|
||
.venv/bin/python scripts/export_doclayout_yolo_onnx.py \
|
||
--weights /path/to/doclayout_yolo_docstructbench_imgsz1024.pt
|
||
|
||
# 2) 从 HuggingFace 下载(不传 --weights)
|
||
HF_ENDPOINT=https://hf-mirror.com .venv/bin/python scripts/export_doclayout_yolo_onnx.py
|
||
|
||
输出:
|
||
data/models/doclayout_yolo_docstructbench_imgsz1024.onnx
|
||
|
||
注意:model.export(simplify=True) 会清空 ONNX metadata,本脚本在导出后
|
||
用 onnx 包把 names 重新写回 metadata,供运行时 _parse_names_from_meta 读取。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import shutil
|
||
from pathlib import Path
|
||
|
||
# hf-mirror(国内加速,仅 --weights 未传、走 HF 下载时生效)
|
||
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
||
|
||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||
MODEL_DIR = PROJECT_ROOT / "data" / "models"
|
||
DEFAULT_OUTPUT = MODEL_DIR / "doclayout_yolo_docstructbench_imgsz1024.onnx"
|
||
REPO_ID = "juliozhao/DocLayout-YOLO-DocStructBench"
|
||
PT_FILENAME = "doclayout_yolo_docstructbench.pt"
|
||
IMGSZ = 1024
|
||
|
||
|
||
def resolve_weights(arg: str | None) -> Path:
|
||
"""返回 .pt 路径:传 --weights 用本地,否则从 HuggingFace 下载。"""
|
||
if arg:
|
||
p = Path(arg)
|
||
if not p.exists():
|
||
raise FileNotFoundError(f"--weights not found: {p}")
|
||
print(f"[1/5] Using local weights: {p}")
|
||
return p
|
||
|
||
print(f"[1/5] Downloading .pt from HuggingFace ({REPO_ID}) ...")
|
||
from huggingface_hub import hf_hub_download
|
||
|
||
pt_path = Path(hf_hub_download(repo_id=REPO_ID, filename=PT_FILENAME))
|
||
print(f" ✓ {pt_path}")
|
||
return pt_path
|
||
|
||
|
||
def export_onnx(pt_path: Path, output: Path) -> None:
|
||
print("\n[2/5] Loading model with doclayout_yolo ...")
|
||
from doclayout_yolo import YOLOv10
|
||
|
||
model = YOLOv10(str(pt_path))
|
||
names = model.names # dict[int, str],与 model.model.names 等价
|
||
print(f" ✓ Loaded. names = {names}")
|
||
|
||
print(f"\n[3/5] Exporting ONNX (imgsz={IMGSZ}, opset=12, simplify=True) ...")
|
||
try:
|
||
exported = model.export(
|
||
format="onnx",
|
||
imgsz=IMGSZ,
|
||
opset=12,
|
||
simplify=True, # 需要 onnxsim;失败则下面回退
|
||
dynamic=False, # 固定 batch=1 + 固定 1024,部署最稳
|
||
half=False, # FP32,保证 CPU 推理精度
|
||
)
|
||
except Exception as e:
|
||
print(f" ⚠ export with simplify failed ({e}); retrying without simplify")
|
||
exported = model.export(
|
||
format="onnx", imgsz=IMGSZ, opset=12, dynamic=False, half=False
|
||
)
|
||
exported_path = Path(exported)
|
||
output.parent.mkdir(parents=True, exist_ok=True)
|
||
shutil.copy(str(exported_path), str(output))
|
||
print(f" ✓ Exported → {output} ({output.stat().st_size / 1024 / 1024:.1f} MB)")
|
||
|
||
print("\n[4/5] Re-writing names metadata (simplify may have dropped it) ...")
|
||
write_names_metadata(output, names)
|
||
|
||
|
||
def write_names_metadata(onnx_path: Path, names: dict) -> None:
|
||
"""把 names dict 写入 ONNX model.metadata_props(simplify 后通常丢失)。"""
|
||
import onnx
|
||
|
||
m = onnx.load(str(onnx_path))
|
||
keep = [p for p in m.metadata_props if p.key != "names"]
|
||
del m.metadata_props[:]
|
||
m.metadata_props.extend(keep)
|
||
names_json = json.dumps({str(k): v for k, v in names.items()}, ensure_ascii=False)
|
||
m.metadata_props.append(onnx.StringStringEntryProto(key="names", value=names_json))
|
||
onnx.save(m, str(onnx_path))
|
||
print(f" ✓ names metadata written: {names_json}")
|
||
|
||
|
||
def inspect_onnx(onnx_path: Path) -> None:
|
||
"""用 onnxruntime 加载模型,打印输入输出 + names metadata + 试推理。"""
|
||
print("\n[5/5] Verifying with onnxruntime ...")
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
|
||
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
|
||
print(" Inputs:")
|
||
for inp in session.get_inputs():
|
||
print(f" {inp.name}: shape={inp.shape}, dtype={inp.type}")
|
||
print(" Outputs:")
|
||
for out in session.get_outputs():
|
||
print(f" {out.name}: shape={out.shape}, dtype={out.type}")
|
||
|
||
meta = session.get_modelmeta()
|
||
print(f" metadata keys: {list(meta.custom_metadata_map.keys())}")
|
||
print(f" names: {meta.custom_metadata_map.get('names')}")
|
||
|
||
# dummy 推理
|
||
input_info = session.get_inputs()[0]
|
||
h = input_info.shape[2] if isinstance(input_info.shape[2], int) else IMGSZ
|
||
w = input_info.shape[3] if isinstance(input_info.shape[3], int) else IMGSZ
|
||
dummy = np.random.rand(1, 3, h, w).astype(np.float32)
|
||
outputs = session.run(None, {input_info.name: dummy})
|
||
print(f" Inference test: output[0] shape = {outputs[0].shape}")
|
||
|
||
out_shape = outputs[0].shape
|
||
if len(out_shape) == 3 and out_shape[2] == 6:
|
||
print(" ✓ output is [1, N, 6] (YOLOv10 end-to-end, NMS applied)")
|
||
else:
|
||
print(
|
||
f" ⚠️ output shape {out_shape} ≠ [1, N, 6]; "
|
||
"layout_detector._postprocess_output will warn and skip pages — "
|
||
"adjust export (e.g. end2end/nms) or postprocess.",
|
||
)
|
||
|
||
|
||
def main() -> None:
|
||
ap = argparse.ArgumentParser(
|
||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||
)
|
||
ap.add_argument("--weights", help="本地 .pt 路径(不传则从 HuggingFace 下载)")
|
||
ap.add_argument(
|
||
"--output",
|
||
default=str(DEFAULT_OUTPUT),
|
||
help=f"输出 ONNX 路径(默认 {DEFAULT_OUTPUT})",
|
||
)
|
||
args = ap.parse_args()
|
||
|
||
output = Path(args.output)
|
||
pt_path = resolve_weights(args.weights)
|
||
export_onnx(pt_path, output)
|
||
inspect_onnx(output)
|
||
print(f"\n✓ Done! ONNX model saved to {output}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|