90fe705e8f
- 核心变更: - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024) - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备 - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式 - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls] - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table) - 新增文件: - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行) - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例) - 配置更新: - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID) - app/config.py: Settings 类对应字段 - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等) - 删除旧文件: - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本 - 文档更新: - README.md: 更新环境变量说明 - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py) 此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
344 lines
12 KiB
Python
344 lines
12 KiB
Python
"""DocLayout-YOLO 布局检测 — ONNX Runtime 推理,支持 CPU/GPU/NPU 多设备.
|
||
|
||
用 onnxruntime 加载 DocLayout-YOLO(DocStructBench, imgsz=1024)ONNX 模型,
|
||
检测 PDF 页面中的 figure / table 区域。
|
||
|
||
预处理:letterbox(保比例缩放 + 灰边 padding 到 imgsz×imgsz),RGB,仅 /255 归一化
|
||
(不做 ImageNet mean/std)。缩放由 pymupdf Matrix 完成,不依赖 OpenCV。
|
||
后处理:YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls](已内置 NMS)。
|
||
坐标还原:(model_coord - padding) / ratio —— 渲染缩放与 letterbox 缩放在 pymupdf
|
||
渲染阶段合二为一,故只需一次除法。
|
||
|
||
设备:resolve_providers() 按 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表;
|
||
_init_session() 逐个 try,首个不可用则降级,CPU 永远兜底。
|
||
|
||
输入:
|
||
images: (1, 3, imgsz, imgsz) float32 —— letterbox + /255 后的图
|
||
|
||
输出:
|
||
output0: (1, N, 6) float32 —— [x1, y1, x2, y2, conf, cls],已 NMS
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
import pymupdf
|
||
|
||
from app.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# DocLayout-YOLO DocStructBench 标准 10 类(ONNX metadata 读不到时的兜底,以实际为准)
|
||
_FALLBACK_NAMES: dict[int, str] = {
|
||
0: "title",
|
||
1: "plain text",
|
||
2: "abandon",
|
||
3: "figure",
|
||
4: "figure_caption",
|
||
5: "table",
|
||
6: "table_caption",
|
||
7: "table_footnote",
|
||
8: "isolate_formula",
|
||
9: "formula_caption",
|
||
}
|
||
# 下游只需 picture/table —— 按 class name 字符串动态匹配(不依赖 class index,
|
||
# 规避 DocStructBench 不同发布的类别顺序差异)
|
||
_PICTURE_NAMES = {"figure", "figure_group"}
|
||
_TABLE_NAMES = {"table", "table_group"}
|
||
# letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降)
|
||
_PAD_VALUE = 114
|
||
# 最小 bbox 尺寸(PDF 点)
|
||
_MIN_BOX_SIZE = 20
|
||
|
||
# device → ExecutionProvider 映射
|
||
_PROVIDER_MAP: dict[str, str] = {
|
||
"cpu": "CPUExecutionProvider",
|
||
"cuda": "CUDAExecutionProvider",
|
||
"directml": "DmlExecutionProvider",
|
||
"openvino": "OpenVINOExecutionProvider",
|
||
"cann": "CannExecutionProvider",
|
||
"tensorrt": "TensorrtExecutionProvider",
|
||
"qnn": "QNNExecutionProvider",
|
||
}
|
||
# auto 探测优先级(不含 cpu,cpu 永远兜底)
|
||
_AUTO_PRIORITY = ["cuda", "directml", "openvino", "cann", "tensorrt", "qnn"]
|
||
|
||
|
||
@dataclass
|
||
class LayoutBox:
|
||
"""检测到的布局区域,坐标为 PDF 点,boxclass ∈ {"picture", "table"}。"""
|
||
|
||
x0: float
|
||
y0: float
|
||
x1: float
|
||
y1: float
|
||
boxclass: str
|
||
|
||
|
||
# ── 设备选择 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
def resolve_providers(device: str, device_id: int) -> list[tuple[str, dict]]:
|
||
"""根据 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表(首选在前,均带 CPU 兜底)。
|
||
|
||
返回 list[tuple[ep_name, provider_options]],供 _init_session() 逐个 try。
|
||
onnxruntime 创建 session 时若指定 EP 在本机变体里未注册会直接抛错,
|
||
故降级逻辑由 _init_session() 完成,这里只产出候选。
|
||
"""
|
||
if device == "cpu":
|
||
return [("CPUExecutionProvider", {})]
|
||
|
||
opts = {"device_id": str(device_id)}
|
||
|
||
if device == "auto":
|
||
available = set(ort.get_available_providers())
|
||
for dev in _AUTO_PRIORITY:
|
||
ep = _PROVIDER_MAP[dev]
|
||
if ep in available:
|
||
logger.info("auto: selected provider %s", ep)
|
||
return [(ep, opts), ("CPUExecutionProvider", {})]
|
||
logger.info("auto: no GPU/NPU provider available, using CPU")
|
||
return [("CPUExecutionProvider", {})]
|
||
|
||
ep = _PROVIDER_MAP.get(device)
|
||
if ep is None:
|
||
logger.warning("Unknown LAYOUT_DEVICE=%r, falling back to CPU", device)
|
||
return [("CPUExecutionProvider", {})]
|
||
return [(ep, opts), ("CPUExecutionProvider", {})]
|
||
|
||
|
||
# ── 预处理:渲染几何与 letterbox ────────────────────────────────────────
|
||
|
||
|
||
def _compute_render_geometry(page_w: float, page_h: float, imgsz: int) -> float:
|
||
"""letterbox 渲染缩放 ratio = min(imgsz/page_w, imgsz/page_h)。
|
||
|
||
pymupdf 以 Matrix(ratio, ratio) 渲染,长边贴到 imgsz,短边留灰边。
|
||
"""
|
||
return min(imgsz / page_w, imgsz / page_h)
|
||
|
||
|
||
def _letterbox_padding(
|
||
content_w: float, content_h: float, imgsz: int
|
||
) -> tuple[float, float]:
|
||
"""居中 padding:(imgsz - content) / 2。content 为实际 pixmap 尺寸(已取整)。"""
|
||
return (imgsz - content_w) / 2.0, (imgsz - content_h) / 2.0
|
||
|
||
|
||
def _padded_nchw_from_pixmap(
|
||
pix: pymupdf.Pixmap, imgsz: int, dw: float, dh: float
|
||
) -> np.ndarray:
|
||
"""pixmap → letterbox padded (1, 3, imgsz, imgsz) float32,灰边=114,/255 归一化。"""
|
||
arr = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
|
||
pix.height, pix.width, pix.n
|
||
)
|
||
if arr.shape[2] == 4: # 去 alpha(csRGB alpha=False 一般不会,防御性)
|
||
arr = arr[:, :, :3]
|
||
|
||
canvas = np.full((imgsz, imgsz, 3), _PAD_VALUE, dtype=np.uint8)
|
||
top = int(round(dh))
|
||
left = int(round(dw))
|
||
canvas[top : top + pix.height, left : left + pix.width] = arr
|
||
|
||
out = canvas.astype(np.float32) / 255.0
|
||
return out.transpose(2, 0, 1)[np.newaxis] # (1, 3, imgsz, imgsz)
|
||
|
||
|
||
def _model_to_pdf(
|
||
model_x: float, model_y: float, dw: float, dh: float, ratio: float
|
||
) -> tuple[float, float]:
|
||
"""模型 imgsz 空间坐标 → PDF 点:(model - padding) / ratio。"""
|
||
return (model_x - dw) / ratio, (model_y - dh) / ratio
|
||
|
||
|
||
# ── 后处理 ──────────────────────────────────────────────────────────────
|
||
|
||
|
||
def _postprocess_output(
|
||
output: np.ndarray, threshold: float, names: dict[int, str]
|
||
) -> list[tuple[int, float, float, float, float]]:
|
||
"""解析 YOLOv10 end-to-end 输出,过滤 conf < threshold。
|
||
|
||
Args:
|
||
output: session.run 返回的第一个输出,shape [1, N, 6]
|
||
threshold: 置信度阈值
|
||
names: class id → name(仅用于日志,过滤不依赖)
|
||
|
||
Returns:
|
||
[(cls_id, x1, y1, x2, y2), ...],坐标为模型 imgsz padded 空间。
|
||
"""
|
||
out = output[0] # 去 batch 维
|
||
if out.ndim != 2 or out.shape[1] != 6:
|
||
logger.warning(
|
||
"Unexpected DocLayout-YOLO output shape %s (expected [N,6]); skip page",
|
||
tuple(out.shape),
|
||
)
|
||
return []
|
||
|
||
results: list[tuple[int, float, float, float, float]] = []
|
||
for row in out:
|
||
x1, y1, x2, y2, conf, cls = row.tolist()
|
||
if conf < threshold:
|
||
continue
|
||
results.append((int(cls), x1, y1, x2, y2))
|
||
return results
|
||
|
||
|
||
def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None:
|
||
"""按 class name 匹配 figure→picture / table→table,其余返回 None。"""
|
||
name = names.get(cls_id, "")
|
||
n = name.strip().lower()
|
||
if n in _PICTURE_NAMES:
|
||
return "picture"
|
||
if n in _TABLE_NAMES:
|
||
return "table"
|
||
return None
|
||
|
||
|
||
def _parse_names_from_meta(session: ort.InferenceSession) -> dict[int, str]:
|
||
"""从 ONNX metadata 读 names(ultralytics 导出写入的 JSON),读不到用兜底。"""
|
||
raw = None
|
||
try:
|
||
raw = session.get_modelmeta().custom_metadata_map.get("names")
|
||
except Exception:
|
||
raw = None
|
||
if raw:
|
||
try:
|
||
d = json.loads(raw)
|
||
return {int(k): str(v) for k, v in d.items()}
|
||
except Exception:
|
||
logger.warning("Failed to parse ONNX names metadata; using fallback")
|
||
return dict(_FALLBACK_NAMES)
|
||
|
||
|
||
# ── 检测器单例 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
class _LayoutDetector:
|
||
"""单例:管理 ONNX InferenceSession 生命周期。"""
|
||
|
||
def __init__(self) -> None:
|
||
self._session: ort.InferenceSession | None = None
|
||
self._names: dict[int, str] = {}
|
||
self._input_name: str = ""
|
||
self._imgsz: int = settings.LAYOUT_IMGSZ
|
||
|
||
def _init_session(self) -> ort.InferenceSession:
|
||
if self._session is not None:
|
||
return self._session
|
||
|
||
model_path = Path(settings.LAYOUT_MODEL_PATH)
|
||
if not model_path.exists():
|
||
raise FileNotFoundError(
|
||
f"Layout model not found: {model_path}. "
|
||
"Run scripts/export_doclayout_yolo_onnx.py first."
|
||
)
|
||
|
||
eps = resolve_providers(settings.LAYOUT_DEVICE, settings.LAYOUT_DEVICE_ID)
|
||
logger.info(
|
||
"Loading layout model %s, candidate providers: %s",
|
||
model_path,
|
||
[ep[0] for ep in eps],
|
||
)
|
||
|
||
# 逐个 EP 尝试,首个不可用则降级
|
||
last_err: Exception | None = None
|
||
for idx, (ep_name, ep_opts) in enumerate(eps):
|
||
try:
|
||
self._session = ort.InferenceSession(
|
||
str(model_path), providers=[(ep_name, ep_opts)]
|
||
)
|
||
break
|
||
except Exception as e:
|
||
last_err = e
|
||
if idx < len(eps) - 1:
|
||
logger.warning(
|
||
"Provider %s unavailable (%s); falling back to %s",
|
||
ep_name,
|
||
e,
|
||
eps[idx + 1][0],
|
||
)
|
||
else:
|
||
raise RuntimeError(f"Failed to create layout session: {last_err}")
|
||
|
||
logger.info(
|
||
"Layout session active providers: %s", self._session.get_providers()
|
||
)
|
||
self._input_name = self._session.get_inputs()[0].name
|
||
self._names = _parse_names_from_meta(self._session)
|
||
self._imgsz = settings.LAYOUT_IMGSZ
|
||
return self._session
|
||
|
||
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||
"""检测单页 PDF 的 figure / table 区域。
|
||
|
||
流程:
|
||
1. letterbox 渲染:保比例缩放到长边=imgsz,短边留灰边
|
||
2. /255 + NCHW → ONNX 推理
|
||
3. YOLOv10 end-to-end 后处理(已 NMS)
|
||
4. 模型坐标 → PDF 点
|
||
5. 过滤非 figure/table 类、极小框、越界 clip
|
||
|
||
Returns:
|
||
LayoutBox 列表,坐标为 PDF 点。
|
||
"""
|
||
session = self._init_session()
|
||
|
||
page_w = page.rect.width
|
||
page_h = page.rect.height
|
||
ratio = _compute_render_geometry(page_w, page_h, self._imgsz)
|
||
|
||
# 1. 保比例渲染(长边贴 imgsz)
|
||
pix = page.get_pixmap(
|
||
matrix=pymupdf.Matrix(ratio, ratio),
|
||
colorspace=pymupdf.csRGB,
|
||
alpha=False,
|
||
)
|
||
# 用 pixmap 实际尺寸(已取整)算 padding,消除取整导致的坐标偏移
|
||
dw, dh = _letterbox_padding(pix.width, pix.height, self._imgsz)
|
||
tensor = _padded_nchw_from_pixmap(pix, self._imgsz, dw, dh)
|
||
|
||
# 2. 推理
|
||
outputs = session.run(None, {self._input_name: tensor})
|
||
detections = _postprocess_output(
|
||
outputs[0], settings.LAYOUT_THRESHOLD, self._names
|
||
)
|
||
|
||
# 3. 坐标还原 + 过滤
|
||
result: list[LayoutBox] = []
|
||
for cls_id, x1m, y1m, x2m, y2m in detections:
|
||
boxclass = _map_class_to_boxclass(cls_id, self._names)
|
||
if boxclass is None:
|
||
continue
|
||
x0, y0 = _model_to_pdf(x1m, y1m, dw, dh, ratio)
|
||
x1, y1 = _model_to_pdf(x2m, y2m, dw, dh, ratio)
|
||
# clip 到页面范围
|
||
x0 = max(0.0, min(x0, page_w))
|
||
y0 = max(0.0, min(y0, page_h))
|
||
x1 = max(0.0, min(x1, page_w))
|
||
y1 = max(0.0, min(y1, page_h))
|
||
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
|
||
continue
|
||
result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass))
|
||
|
||
return result
|
||
|
||
|
||
# 模块级单例
|
||
_detector = _LayoutDetector()
|
||
|
||
|
||
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
|
||
"""检测 PDF 页面中的 figure / table 区域。
|
||
|
||
Returns:
|
||
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table。
|
||
"""
|
||
return _detector.detect_page(page)
|