21f16e6756
- Split summarizer into summary_generator and summary_persister modules - Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection - Add layout_detector service for PicoDet-S_layout_3cls integration - Add exceptions module with ConflictError and NotFoundError - Improve admin dashboard with better statistics and task management - Add design review document with system optimization suggestions - Add new tests for crawler, pdf_downloader, pipeline, and summary_utils - Update dependencies and configuration - Clean up dead code and improve error handling
173 lines
6.6 KiB
Python
173 lines
6.6 KiB
Python
"""导出 PicoDet-S_layout_3cls 为 ONNX 格式.
|
||
|
||
一次性脚本,在独立 venv 中运行:
|
||
python -m venv .venv-export && source .venv-export/bin/activate
|
||
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple paddlepaddle paddleocr paddle2onnx onnxruntime opencv-python-headless
|
||
HF_ENDPOINT=https://hf-mirror.com python scripts/export_picodet_onnx.py
|
||
|
||
输出:
|
||
data/models/picodet_layout_3cls.onnx
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
import shutil
|
||
import subprocess
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
# hf-mirror
|
||
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
||
|
||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||
MODEL_DIR = PROJECT_ROOT / "data" / "models"
|
||
OUTPUT_PATH = MODEL_DIR / "picodet_layout_3cls.onnx"
|
||
MODEL_NAME = "PicoDet-S_layout_3cls"
|
||
|
||
|
||
def main() -> None:
|
||
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
# ── Step 1: 用 PaddleOCR paddle_static 引擎加载模型,触发下载 ──
|
||
print(f"[1/4] Loading model '{MODEL_NAME}' (paddle_static engine, triggers download) ...")
|
||
from paddleocr import LayoutDetection
|
||
|
||
model = LayoutDetection(
|
||
model_name=MODEL_NAME,
|
||
engine="paddle_static",
|
||
device="cpu",
|
||
)
|
||
print(" ✓ Model loaded and cached")
|
||
|
||
# ── Step 2: 找到 PaddleX 缓存的 Paddle 模型文件 ────────────────
|
||
paddlex_cache = Path.home() / ".paddlex"
|
||
print(f"\n[2/4] Searching Paddle model cache in {paddlex_cache} ...")
|
||
|
||
# 搜索 layout 相关的缓存目录
|
||
candidates = []
|
||
for d in paddlex_cache.rglob("*"):
|
||
if d.is_dir() and (d / "inference.pdiparams").exists():
|
||
# 检查是否是 layout 模型
|
||
marker = d.name
|
||
parent_name = d.parent.name
|
||
if "layout" in marker.lower() or "layout" in parent_name.lower() or "picodet" in marker.lower():
|
||
candidates.append(d)
|
||
elif "PicoDet" in str(d):
|
||
candidates.append(d)
|
||
|
||
if not candidates:
|
||
# 如果没找到明确的 layout 目录,列出所有含 inference.pdiparams 的目录
|
||
all_model_dirs = [d for d in paddlex_cache.rglob("*") if d.is_dir() and (d / "inference.pdiparams").exists()]
|
||
print(" No layout-specific dir found. All model dirs with inference.pdiparams:")
|
||
for d in all_model_dirs:
|
||
files = [f.name for f in d.iterdir()]
|
||
print(f" {d} ({', '.join(files)})")
|
||
if all_model_dirs:
|
||
# 取最新的(刚下载的)
|
||
candidates = sorted(all_model_dirs, key=lambda d: (d / "inference.pdiparams").stat().st_mtime, reverse=True)[:1]
|
||
|
||
if not candidates:
|
||
print(" ✗ No cached model found")
|
||
sys.exit(1)
|
||
|
||
model_cache_dir = candidates[0]
|
||
files_in_dir = list(model_cache_dir.iterdir())
|
||
print(f" Using: {model_cache_dir}")
|
||
for f in files_in_dir:
|
||
print(f" {f.name} ({f.stat().st_size / 1024:.1f} KB)")
|
||
|
||
# ── Step 3: 用 paddle2onnx 转换 ─────────────────────────────────
|
||
print("\n[3/4] Converting to ONNX with paddle2onnx ...")
|
||
tmp_onnx = OUTPUT_PATH.with_suffix(".tmp.onnx")
|
||
|
||
# 确定 model_filename
|
||
pdmodel = model_cache_dir / "inference.pdmodel"
|
||
has_pdmodel = pdmodel.exists()
|
||
|
||
cmd = [
|
||
sys.executable, "-m", "paddle2onnx",
|
||
"--model_dir", str(model_cache_dir),
|
||
"--save_file", str(tmp_onnx),
|
||
"--opset_version", "11",
|
||
"--enable_onnx_checker", "True",
|
||
]
|
||
if has_pdmodel:
|
||
cmd.extend(["--model_filename", "inference.pdmodel"])
|
||
cmd.extend(["--params_filename", "inference.pdiparams"])
|
||
|
||
print(f" Running: {' '.join(cmd)}")
|
||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||
if result.stdout:
|
||
print(f" stdout: {result.stdout[:500]}")
|
||
if result.returncode != 0:
|
||
print(f" ✗ paddle2onnx failed (exit {result.returncode})")
|
||
print(f" stderr: {result.stderr[:500]}")
|
||
|
||
# 尝试不带 model_filename(combined format)
|
||
if has_pdmodel:
|
||
print(" Retrying without explicit model_filename ...")
|
||
cmd2 = [
|
||
sys.executable, "-m", "paddle2onnx",
|
||
"--model_dir", str(model_cache_dir),
|
||
"--params_filename", "inference.pdiparams",
|
||
"--save_file", str(tmp_onnx),
|
||
"--opset_version", "11",
|
||
]
|
||
result2 = subprocess.run(cmd2, capture_output=True, text=True)
|
||
if result2.returncode != 0:
|
||
print(f" ✗ Retry also failed: {result2.stderr[:500]}")
|
||
sys.exit(1)
|
||
|
||
if not tmp_onnx.exists() or tmp_onnx.stat().st_size < 1000:
|
||
print(" ✗ ONNX file not created or too small")
|
||
sys.exit(1)
|
||
|
||
shutil.move(str(tmp_onnx), str(OUTPUT_PATH))
|
||
print(f" ✓ ONNX saved ({OUTPUT_PATH.stat().st_size / 1024 / 1024:.2f} MB)")
|
||
|
||
# ── Step 4: 用 onnxruntime 验证 ─────────────────────────────────
|
||
print("\n[4/4] Verifying with onnxruntime ...")
|
||
_inspect_onnx(OUTPUT_PATH)
|
||
|
||
print(f"\n✓ Done! ONNX model saved to {OUTPUT_PATH}")
|
||
|
||
|
||
def _inspect_onnx(onnx_path: Path) -> None:
|
||
"""用 onnxruntime 加载模型,打印输入输出信息."""
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
|
||
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
|
||
|
||
print(" Inputs:")
|
||
for inp in session.get_inputs():
|
||
print(f" {inp.name}: shape={inp.shape}, dtype={inp.type}")
|
||
|
||
print(" Outputs:")
|
||
for out in session.get_outputs():
|
||
print(f" {out.name}: shape={out.shape}, dtype={out.type}")
|
||
|
||
# 试推理
|
||
input_info = session.get_inputs()[0]
|
||
input_name = input_info.name
|
||
batch_size = input_info.shape[0] if isinstance(input_info.shape[0], int) else 1
|
||
channels = input_info.shape[1] if isinstance(input_info.shape[1], int) else 3
|
||
height = input_info.shape[2] if isinstance(input_info.shape[2], int) else 480
|
||
width = input_info.shape[3] if isinstance(input_info.shape[3], int) else 480
|
||
|
||
dummy_input = np.random.rand(batch_size, channels, height, width).astype(np.float32)
|
||
outputs = session.run(None, {input_name: dummy_input})
|
||
|
||
print(" Inference test outputs:")
|
||
for i, (out_info, out_val) in enumerate(zip(session.get_outputs(), outputs)):
|
||
print(f" output[{i}] '{out_info.name}': shape={out_val.shape}, dtype={out_val.dtype}")
|
||
if out_val.size <= 20:
|
||
print(f" values: {out_val}")
|
||
|
||
print(" ✓ Inference OK")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|