406 lines
15 KiB
Python
406 lines
15 KiB
Python
"""DocLayout-YOLO 布局检测 — ONNX Runtime 推理,支持 CPU/GPU/NPU 多设备.
|
||
|
||
用 onnxruntime 加载 DocLayout-YOLO(DocStructBench, imgsz=1024)ONNX 模型,
|
||
检测 PDF 页面中的 figure / table 区域。
|
||
|
||
预处理:letterbox(保比例缩放 + 灰边 padding 到 imgsz×imgsz),RGB,仅 /255 归一化
|
||
(不做 ImageNet mean/std)。缩放由 pymupdf Matrix 完成,不依赖 OpenCV。
|
||
后处理:YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls](已内置 NMS)。
|
||
坐标还原:(model_coord - padding) / ratio —— 渲染缩放与 letterbox 缩放在 pymupdf
|
||
渲染阶段合二为一,故只需一次除法。
|
||
|
||
设备:resolve_providers() 按 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表;
|
||
_init_session() 逐个 try,首个不可用则降级,CPU 永远兜底。
|
||
|
||
输入:
|
||
images: (1, 3, imgsz, imgsz) float32 —— letterbox + /255 后的图
|
||
|
||
输出:
|
||
output0: (1, N, 6) float32 —— [x1, y1, x2, y2, conf, cls],已 NMS
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import threading
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
import pymupdf
|
||
|
||
from app.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# DocLayout-YOLO DocStructBench 标准 10 类(ONNX metadata 读不到时的兜底,以实际为准)
|
||
_FALLBACK_NAMES: dict[int, str] = {
|
||
0: "title",
|
||
1: "plain text",
|
||
2: "abandon",
|
||
3: "figure",
|
||
4: "figure_caption",
|
||
5: "table",
|
||
6: "table_caption",
|
||
7: "table_footnote",
|
||
8: "isolate_formula",
|
||
9: "formula_caption",
|
||
}
|
||
# 下游需要 picture/table 及其 caption —— 按 class name 字符串动态匹配(不依赖 class index,
|
||
# 规避 DocStructBench 不同发布的类别顺序差异)
|
||
_PICTURE_NAMES = {"figure", "figure_group"}
|
||
_TABLE_NAMES = {"table", "table_group"}
|
||
_FIGURE_CAPTION_NAMES = {"figure_caption"}
|
||
_TABLE_CAPTION_NAMES = {"table_caption"}
|
||
# letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降)
|
||
_PAD_VALUE = 114
|
||
# 最小 bbox 尺寸(PDF 点)
|
||
_MIN_BOX_SIZE = 20
|
||
_MIN_CAPTION_BOX_WIDTH = 30
|
||
_MIN_CAPTION_BOX_HEIGHT = 6
|
||
|
||
# device → ExecutionProvider 映射
|
||
_PROVIDER_MAP: dict[str, str] = {
|
||
"cpu": "CPUExecutionProvider",
|
||
"cuda": "CUDAExecutionProvider",
|
||
"directml": "DmlExecutionProvider",
|
||
"openvino": "OpenVINOExecutionProvider",
|
||
"cann": "CannExecutionProvider",
|
||
"tensorrt": "TensorrtExecutionProvider",
|
||
"qnn": "QNNExecutionProvider",
|
||
}
|
||
# auto 探测优先级(不含 cpu,cpu 永远兜底)
|
||
_AUTO_PRIORITY = ["cuda", "directml", "openvino", "cann", "tensorrt", "qnn"]
|
||
|
||
|
||
@dataclass
|
||
class LayoutBox:
|
||
"""检测到的布局区域,坐标为 PDF 点。"""
|
||
|
||
x0: float
|
||
y0: float
|
||
x1: float
|
||
y1: float
|
||
boxclass: str
|
||
|
||
|
||
# ── 设备选择 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
def resolve_providers(device: str, device_id: int) -> list[tuple[str, dict]]:
|
||
"""根据 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表(首选在前,均带 CPU 兜底)。
|
||
|
||
返回 list[tuple[ep_name, provider_options]],供 _init_session() 逐个 try。
|
||
onnxruntime 创建 session 时若指定 EP 在本机变体里未注册会直接抛错,
|
||
故降级逻辑由 _init_session() 完成,这里只产出候选。
|
||
"""
|
||
if device == "cpu":
|
||
return [("CPUExecutionProvider", {})]
|
||
|
||
opts = {"device_id": str(device_id)}
|
||
|
||
if device == "auto":
|
||
available = set(ort.get_available_providers())
|
||
for dev in _AUTO_PRIORITY:
|
||
ep = _PROVIDER_MAP[dev]
|
||
if ep in available:
|
||
logger.info("auto: selected provider %s", ep)
|
||
return [(ep, opts), ("CPUExecutionProvider", {})]
|
||
logger.info("auto: no GPU/NPU provider available, using CPU")
|
||
return [("CPUExecutionProvider", {})]
|
||
|
||
ep = _PROVIDER_MAP.get(device)
|
||
if ep is None:
|
||
logger.warning("Unknown LAYOUT_DEVICE=%r, falling back to CPU", device)
|
||
return [("CPUExecutionProvider", {})]
|
||
return [(ep, opts), ("CPUExecutionProvider", {})]
|
||
|
||
|
||
# ── 预处理:渲染几何与 letterbox ────────────────────────────────────────
|
||
|
||
|
||
def _compute_render_geometry(page_w: float, page_h: float, imgsz: int) -> float:
|
||
"""letterbox 渲染缩放 ratio = min(imgsz/page_w, imgsz/page_h)。
|
||
|
||
pymupdf 以 Matrix(ratio, ratio) 渲染,长边贴到 imgsz,短边留灰边。
|
||
"""
|
||
return min(imgsz / page_w, imgsz / page_h)
|
||
|
||
|
||
def _letterbox_padding(
|
||
content_w: float, content_h: float, imgsz: int
|
||
) -> tuple[float, float]:
|
||
"""居中 padding:(imgsz - content) / 2。content 为实际 pixmap 尺寸(已取整)。"""
|
||
return (imgsz - content_w) / 2.0, (imgsz - content_h) / 2.0
|
||
|
||
|
||
def _padded_nchw_from_pixmap(
|
||
pix: pymupdf.Pixmap, imgsz: int, dw: float, dh: float
|
||
) -> np.ndarray:
|
||
"""pixmap → letterbox padded (1, 3, imgsz, imgsz) float32,灰边=114,/255 归一化。"""
|
||
arr = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
|
||
pix.height, pix.width, pix.n
|
||
)
|
||
if arr.shape[2] == 4: # 去 alpha(csRGB alpha=False 一般不会,防御性)
|
||
arr = arr[:, :, :3]
|
||
|
||
canvas = np.full((imgsz, imgsz, 3), _PAD_VALUE, dtype=np.uint8)
|
||
top = int(round(dh))
|
||
left = int(round(dw))
|
||
canvas[top : top + pix.height, left : left + pix.width] = arr
|
||
|
||
out = canvas.astype(np.float32) / 255.0
|
||
return out.transpose(2, 0, 1)[np.newaxis] # (1, 3, imgsz, imgsz)
|
||
|
||
|
||
def _model_to_pdf(
|
||
model_x: float, model_y: float, dw: float, dh: float, ratio: float
|
||
) -> tuple[float, float]:
|
||
"""模型 imgsz 空间坐标 → PDF 点:(model - padding) / ratio。"""
|
||
return (model_x - dw) / ratio, (model_y - dh) / ratio
|
||
|
||
|
||
# ── 后处理 ──────────────────────────────────────────────────────────────
|
||
|
||
|
||
def _postprocess_output(
|
||
output: np.ndarray, threshold: float, names: dict[int, str]
|
||
) -> list[tuple[int, float, float, float, float]]:
|
||
"""解析 YOLOv10 end-to-end 输出,过滤 conf < threshold。
|
||
|
||
Args:
|
||
output: session.run 返回的第一个输出,shape [1, N, 6]
|
||
threshold: 置信度阈值
|
||
names: class id → name(仅用于日志,过滤不依赖)
|
||
|
||
Returns:
|
||
[(cls_id, x1, y1, x2, y2), ...],坐标为模型 imgsz padded 空间。
|
||
"""
|
||
out = output[0] # 去 batch 维
|
||
if out.ndim != 2 or out.shape[1] != 6:
|
||
logger.warning(
|
||
"Unexpected DocLayout-YOLO output shape %s (expected [N,6]); skip page",
|
||
tuple(out.shape),
|
||
)
|
||
return []
|
||
|
||
results: list[tuple[int, float, float, float, float]] = []
|
||
for row in out:
|
||
x1, y1, x2, y2, conf, cls = row.tolist()
|
||
if conf < threshold:
|
||
continue
|
||
results.append((int(cls), x1, y1, x2, y2))
|
||
return results
|
||
|
||
|
||
def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None:
|
||
"""按 class name 匹配下游关心的布局类别,其余返回 None。"""
|
||
name = names.get(cls_id, "")
|
||
n = name.strip().lower()
|
||
if n in _PICTURE_NAMES:
|
||
return "picture"
|
||
if n in _TABLE_NAMES:
|
||
return "table"
|
||
if n in _FIGURE_CAPTION_NAMES:
|
||
return "figure_caption"
|
||
if n in _TABLE_CAPTION_NAMES:
|
||
return "table_caption"
|
||
return None
|
||
|
||
|
||
def _parse_names_from_meta(session: ort.InferenceSession) -> dict[int, str]:
|
||
"""从 ONNX metadata 读 names(ultralytics 导出写入的 JSON),读不到用兜底。"""
|
||
raw = None
|
||
try:
|
||
raw = session.get_modelmeta().custom_metadata_map.get("names")
|
||
except Exception:
|
||
raw = None
|
||
if raw:
|
||
try:
|
||
d = json.loads(raw)
|
||
return {int(k): str(v) for k, v in d.items()}
|
||
except Exception:
|
||
logger.warning("Failed to parse ONNX names metadata; using fallback")
|
||
return dict(_FALLBACK_NAMES)
|
||
|
||
|
||
# ── 检测器单例 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
class _Singleton(type):
|
||
"""元类单例:``cls()`` 永远返回同一实例;``reset_instance()`` 清缓存以便重建。
|
||
|
||
生产代码只应在模块级 ``_detector = _LayoutDetector()`` 创建一次。任何第二处
|
||
``_LayoutDetector()`` 都会拿到同一实例(含同一 ONNX session + 同一锁),杜绝
|
||
并发推理时各建一份 session 导致内存峰值翻倍(8GB 机器崩溃根因)。双检锁保证
|
||
首次实例化线程安全。
|
||
"""
|
||
|
||
_instances: dict[type, Any] = {}
|
||
_lock = threading.Lock()
|
||
|
||
def __call__(cls, *args, **kwargs):
|
||
if cls in _Singleton._instances:
|
||
return _Singleton._instances[cls]
|
||
with _Singleton._lock:
|
||
if cls not in _Singleton._instances:
|
||
_Singleton._instances[cls] = super().__call__(*args, **kwargs)
|
||
return _Singleton._instances[cls]
|
||
|
||
|
||
class _LayoutDetector(metaclass=_Singleton):
|
||
"""强约束单例:管理 ONNX InferenceSession 生命周期。
|
||
|
||
由 ``_Singleton`` 元类保证全进程唯一实例 —— 重复 ``_LayoutDetector()`` 只会返回
|
||
已有实例(含已加载的 session 和锁),不会新建。``reset_instance()`` 清缓存,仅供
|
||
测试隔离用。
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
self._lock = threading.Lock()
|
||
self._session: ort.InferenceSession | None = None
|
||
self._names: dict[int, str] = {}
|
||
self._input_name: str = ""
|
||
self._imgsz: int = settings.LAYOUT_IMGSZ
|
||
|
||
@classmethod
|
||
def reset_instance(cls) -> None:
|
||
"""清空单例缓存,下次 ``_LayoutDetector()`` 重建新实例(含新锁 + 空 session)。
|
||
|
||
仅用于测试隔离 —— 生产代码永远不该调用(否则会丢掉已加载的模型 session)。
|
||
"""
|
||
_Singleton._instances.pop(cls, None)
|
||
|
||
def _init_session(self) -> ort.InferenceSession:
|
||
if self._session is not None:
|
||
return self._session
|
||
|
||
model_path = Path(settings.LAYOUT_MODEL_PATH)
|
||
if not model_path.exists():
|
||
raise FileNotFoundError(
|
||
f"Layout model not found: {model_path}. "
|
||
"Run scripts/export_doclayout_yolo_onnx.py first."
|
||
)
|
||
|
||
eps = resolve_providers(settings.LAYOUT_DEVICE, settings.LAYOUT_DEVICE_ID)
|
||
logger.info(
|
||
"Loading layout model %s, candidate providers: %s",
|
||
model_path,
|
||
[ep[0] for ep in eps],
|
||
)
|
||
|
||
# 逐个 EP 尝试,首个不可用则降级
|
||
last_err: Exception | None = None
|
||
for idx, (ep_name, ep_opts) in enumerate(eps):
|
||
try:
|
||
self._session = ort.InferenceSession(
|
||
str(model_path), providers=[(ep_name, ep_opts)]
|
||
)
|
||
break
|
||
except Exception as e:
|
||
last_err = e
|
||
if idx < len(eps) - 1:
|
||
logger.warning(
|
||
"Provider %s unavailable (%s); falling back to %s",
|
||
ep_name,
|
||
e,
|
||
eps[idx + 1][0],
|
||
)
|
||
else:
|
||
raise RuntimeError(f"Failed to create layout session: {last_err}")
|
||
|
||
logger.info(
|
||
"Layout session active providers: %s", self._session.get_providers()
|
||
)
|
||
self._input_name = self._session.get_inputs()[0].name
|
||
self._names = _parse_names_from_meta(self._session)
|
||
self._imgsz = settings.LAYOUT_IMGSZ
|
||
return self._session
|
||
|
||
def _detect_page_impl(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||
"""检测单页 PDF 的 figure / table 区域。
|
||
|
||
流程:
|
||
1. letterbox 渲染:保比例缩放到长边=imgsz,短边留灰边
|
||
2. /255 + NCHW → ONNX 推理
|
||
3. YOLOv10 end-to-end 后处理(已 NMS)
|
||
4. 模型坐标 → PDF 点
|
||
5. 过滤非 figure/table 类、极小框、越界 clip
|
||
|
||
Returns:
|
||
LayoutBox 列表,坐标为 PDF 点。
|
||
"""
|
||
session = self._init_session()
|
||
|
||
page_w = page.rect.width
|
||
page_h = page.rect.height
|
||
ratio = _compute_render_geometry(page_w, page_h, self._imgsz)
|
||
|
||
# 1. 保比例渲染(长边贴 imgsz)
|
||
pix = page.get_pixmap(
|
||
matrix=pymupdf.Matrix(ratio, ratio),
|
||
colorspace=pymupdf.csRGB,
|
||
alpha=False,
|
||
)
|
||
# 用 pixmap 实际尺寸(已取整)算 padding,消除取整导致的坐标偏移
|
||
dw, dh = _letterbox_padding(pix.width, pix.height, self._imgsz)
|
||
tensor = _padded_nchw_from_pixmap(pix, self._imgsz, dw, dh)
|
||
|
||
# 2. 推理
|
||
outputs = session.run(None, {self._input_name: tensor})
|
||
detections = _postprocess_output(
|
||
outputs[0], settings.LAYOUT_THRESHOLD, self._names
|
||
)
|
||
|
||
# 3. 坐标还原 + 过滤
|
||
result: list[LayoutBox] = []
|
||
for cls_id, x1m, y1m, x2m, y2m in detections:
|
||
boxclass = _map_class_to_boxclass(cls_id, self._names)
|
||
if boxclass is None:
|
||
continue
|
||
x0, y0 = _model_to_pdf(x1m, y1m, dw, dh, ratio)
|
||
x1, y1 = _model_to_pdf(x2m, y2m, dw, dh, ratio)
|
||
# clip 到页面范围
|
||
x0 = max(0.0, min(x0, page_w))
|
||
y0 = max(0.0, min(y0, page_h))
|
||
x1 = max(0.0, min(x1, page_w))
|
||
y1 = max(0.0, min(y1, page_h))
|
||
if boxclass in ("figure_caption", "table_caption"):
|
||
if (x1 - x0) < _MIN_CAPTION_BOX_WIDTH or (
|
||
y1 - y0
|
||
) < _MIN_CAPTION_BOX_HEIGHT:
|
||
continue
|
||
else:
|
||
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
|
||
continue
|
||
result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass))
|
||
|
||
return result
|
||
|
||
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||
"""公共入口:加锁串行化推理。
|
||
|
||
包裹整段 _detect_page_impl(含 pixmap 渲染 + tensor 构造 + session.run),
|
||
保证同一时刻只有一个推理在跑——避免 SUMMARY_CONCURRENCY>1 时多个 to_thread
|
||
线程并发推理导致内存峰值翻倍(8GB 机器崩溃根因)。锁由 _detect_page_impl
|
||
间接保护 _init_session,首次加载也串行,杜绝并发各建一份 session。
|
||
"""
|
||
with self._lock:
|
||
return self._detect_page_impl(page)
|
||
|
||
|
||
# 模块级单例 —— 生产代码唯一的实例化点(_Singleton 元类保证不会再有第二个)
|
||
_detector = _LayoutDetector()
|
||
|
||
|
||
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
|
||
"""检测 PDF 页面中的 figure / table / caption 区域。
|
||
|
||
Returns:
|
||
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table 及其 caption。
|
||
"""
|
||
return _detector.detect_page(page)
|