Files
daily-paper/app/services/layout_detector.py
T

406 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""DocLayout-YOLO 布局检测 — ONNX Runtime 推理,支持 CPU/GPU/NPU 多设备.
用 onnxruntime 加载 DocLayout-YOLODocStructBench, imgsz=1024ONNX 模型,
检测 PDF 页面中的 figure / table 区域。
预处理:letterbox(保比例缩放 + 灰边 padding 到 imgsz×imgsz),RGB,仅 /255 归一化
(不做 ImageNet mean/std)。缩放由 pymupdf Matrix 完成,不依赖 OpenCV。
后处理:YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls](已内置 NMS)。
坐标还原:(model_coord - padding) / ratio —— 渲染缩放与 letterbox 缩放在 pymupdf
渲染阶段合二为一,故只需一次除法。
设备:resolve_providers() 按 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表;
_init_session() 逐个 try,首个不可用则降级,CPU 永远兜底。
输入:
images: (1, 3, imgsz, imgsz) float32 —— letterbox + /255 后的图
输出:
output0: (1, N, 6) float32 —— [x1, y1, x2, y2, conf, cls],已 NMS
"""
from __future__ import annotations
import json
import logging
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import onnxruntime as ort
import pymupdf
from app.config import settings
logger = logging.getLogger(__name__)
# DocLayout-YOLO DocStructBench 标准 10 类(ONNX metadata 读不到时的兜底,以实际为准)
_FALLBACK_NAMES: dict[int, str] = {
0: "title",
1: "plain text",
2: "abandon",
3: "figure",
4: "figure_caption",
5: "table",
6: "table_caption",
7: "table_footnote",
8: "isolate_formula",
9: "formula_caption",
}
# 下游需要 picture/table 及其 caption —— 按 class name 字符串动态匹配(不依赖 class index
# 规避 DocStructBench 不同发布的类别顺序差异)
_PICTURE_NAMES = {"figure", "figure_group"}
_TABLE_NAMES = {"table", "table_group"}
_FIGURE_CAPTION_NAMES = {"figure_caption"}
_TABLE_CAPTION_NAMES = {"table_caption"}
# letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降)
_PAD_VALUE = 114
# 最小 bbox 尺寸(PDF 点)
_MIN_BOX_SIZE = 20
_MIN_CAPTION_BOX_WIDTH = 30
_MIN_CAPTION_BOX_HEIGHT = 6
# device → ExecutionProvider 映射
_PROVIDER_MAP: dict[str, str] = {
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
"directml": "DmlExecutionProvider",
"openvino": "OpenVINOExecutionProvider",
"cann": "CannExecutionProvider",
"tensorrt": "TensorrtExecutionProvider",
"qnn": "QNNExecutionProvider",
}
# auto 探测优先级(不含 cpu,cpu 永远兜底)
_AUTO_PRIORITY = ["cuda", "directml", "openvino", "cann", "tensorrt", "qnn"]
@dataclass
class LayoutBox:
"""检测到的布局区域,坐标为 PDF 点。"""
x0: float
y0: float
x1: float
y1: float
boxclass: str
# ── 设备选择 ────────────────────────────────────────────────────────────
def resolve_providers(device: str, device_id: int) -> list[tuple[str, dict]]:
"""根据 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表(首选在前,均带 CPU 兜底)。
返回 list[tuple[ep_name, provider_options]],供 _init_session() 逐个 try。
onnxruntime 创建 session 时若指定 EP 在本机变体里未注册会直接抛错,
故降级逻辑由 _init_session() 完成,这里只产出候选。
"""
if device == "cpu":
return [("CPUExecutionProvider", {})]
opts = {"device_id": str(device_id)}
if device == "auto":
available = set(ort.get_available_providers())
for dev in _AUTO_PRIORITY:
ep = _PROVIDER_MAP[dev]
if ep in available:
logger.info("auto: selected provider %s", ep)
return [(ep, opts), ("CPUExecutionProvider", {})]
logger.info("auto: no GPU/NPU provider available, using CPU")
return [("CPUExecutionProvider", {})]
ep = _PROVIDER_MAP.get(device)
if ep is None:
logger.warning("Unknown LAYOUT_DEVICE=%r, falling back to CPU", device)
return [("CPUExecutionProvider", {})]
return [(ep, opts), ("CPUExecutionProvider", {})]
# ── 预处理:渲染几何与 letterbox ────────────────────────────────────────
def _compute_render_geometry(page_w: float, page_h: float, imgsz: int) -> float:
"""letterbox 渲染缩放 ratio = min(imgsz/page_w, imgsz/page_h)。
pymupdf 以 Matrix(ratio, ratio) 渲染,长边贴到 imgsz,短边留灰边。
"""
return min(imgsz / page_w, imgsz / page_h)
def _letterbox_padding(
content_w: float, content_h: float, imgsz: int
) -> tuple[float, float]:
"""居中 padding(imgsz - content) / 2。content 为实际 pixmap 尺寸(已取整)。"""
return (imgsz - content_w) / 2.0, (imgsz - content_h) / 2.0
def _padded_nchw_from_pixmap(
pix: pymupdf.Pixmap, imgsz: int, dw: float, dh: float
) -> np.ndarray:
"""pixmap → letterbox padded (1, 3, imgsz, imgsz) float32,灰边=114/255 归一化。"""
arr = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
pix.height, pix.width, pix.n
)
if arr.shape[2] == 4: # 去 alphacsRGB alpha=False 一般不会,防御性)
arr = arr[:, :, :3]
canvas = np.full((imgsz, imgsz, 3), _PAD_VALUE, dtype=np.uint8)
top = int(round(dh))
left = int(round(dw))
canvas[top : top + pix.height, left : left + pix.width] = arr
out = canvas.astype(np.float32) / 255.0
return out.transpose(2, 0, 1)[np.newaxis] # (1, 3, imgsz, imgsz)
def _model_to_pdf(
model_x: float, model_y: float, dw: float, dh: float, ratio: float
) -> tuple[float, float]:
"""模型 imgsz 空间坐标 → PDF 点:(model - padding) / ratio。"""
return (model_x - dw) / ratio, (model_y - dh) / ratio
# ── 后处理 ──────────────────────────────────────────────────────────────
def _postprocess_output(
output: np.ndarray, threshold: float, names: dict[int, str]
) -> list[tuple[int, float, float, float, float]]:
"""解析 YOLOv10 end-to-end 输出,过滤 conf < threshold。
Args:
output: session.run 返回的第一个输出,shape [1, N, 6]
threshold: 置信度阈值
names: class id → name(仅用于日志,过滤不依赖)
Returns:
[(cls_id, x1, y1, x2, y2), ...],坐标为模型 imgsz padded 空间。
"""
out = output[0] # 去 batch 维
if out.ndim != 2 or out.shape[1] != 6:
logger.warning(
"Unexpected DocLayout-YOLO output shape %s (expected [N,6]); skip page",
tuple(out.shape),
)
return []
results: list[tuple[int, float, float, float, float]] = []
for row in out:
x1, y1, x2, y2, conf, cls = row.tolist()
if conf < threshold:
continue
results.append((int(cls), x1, y1, x2, y2))
return results
def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None:
"""按 class name 匹配下游关心的布局类别,其余返回 None。"""
name = names.get(cls_id, "")
n = name.strip().lower()
if n in _PICTURE_NAMES:
return "picture"
if n in _TABLE_NAMES:
return "table"
if n in _FIGURE_CAPTION_NAMES:
return "figure_caption"
if n in _TABLE_CAPTION_NAMES:
return "table_caption"
return None
def _parse_names_from_meta(session: ort.InferenceSession) -> dict[int, str]:
"""从 ONNX metadata 读 namesultralytics 导出写入的 JSON),读不到用兜底。"""
raw = None
try:
raw = session.get_modelmeta().custom_metadata_map.get("names")
except Exception:
raw = None
if raw:
try:
d = json.loads(raw)
return {int(k): str(v) for k, v in d.items()}
except Exception:
logger.warning("Failed to parse ONNX names metadata; using fallback")
return dict(_FALLBACK_NAMES)
# ── 检测器单例 ──────────────────────────────────────────────────────────
class _Singleton(type):
"""元类单例:``cls()`` 永远返回同一实例;``reset_instance()`` 清缓存以便重建。
生产代码只应在模块级 ``_detector = _LayoutDetector()`` 创建一次。任何第二处
``_LayoutDetector()`` 都会拿到同一实例(含同一 ONNX session + 同一锁),杜绝
并发推理时各建一份 session 导致内存峰值翻倍(8GB 机器崩溃根因)。双检锁保证
首次实例化线程安全。
"""
_instances: dict[type, Any] = {}
_lock = threading.Lock()
def __call__(cls, *args, **kwargs):
if cls in _Singleton._instances:
return _Singleton._instances[cls]
with _Singleton._lock:
if cls not in _Singleton._instances:
_Singleton._instances[cls] = super().__call__(*args, **kwargs)
return _Singleton._instances[cls]
class _LayoutDetector(metaclass=_Singleton):
"""强约束单例:管理 ONNX InferenceSession 生命周期。
由 ``_Singleton`` 元类保证全进程唯一实例 —— 重复 ``_LayoutDetector()`` 只会返回
已有实例(含已加载的 session 和锁),不会新建。``reset_instance()`` 清缓存,仅供
测试隔离用。
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._session: ort.InferenceSession | None = None
self._names: dict[int, str] = {}
self._input_name: str = ""
self._imgsz: int = settings.LAYOUT_IMGSZ
@classmethod
def reset_instance(cls) -> None:
"""清空单例缓存,下次 ``_LayoutDetector()`` 重建新实例(含新锁 + 空 session)。
仅用于测试隔离 —— 生产代码永远不该调用(否则会丢掉已加载的模型 session)。
"""
_Singleton._instances.pop(cls, None)
def _init_session(self) -> ort.InferenceSession:
if self._session is not None:
return self._session
model_path = Path(settings.LAYOUT_MODEL_PATH)
if not model_path.exists():
raise FileNotFoundError(
f"Layout model not found: {model_path}. "
"Run scripts/export_doclayout_yolo_onnx.py first."
)
eps = resolve_providers(settings.LAYOUT_DEVICE, settings.LAYOUT_DEVICE_ID)
logger.info(
"Loading layout model %s, candidate providers: %s",
model_path,
[ep[0] for ep in eps],
)
# 逐个 EP 尝试,首个不可用则降级
last_err: Exception | None = None
for idx, (ep_name, ep_opts) in enumerate(eps):
try:
self._session = ort.InferenceSession(
str(model_path), providers=[(ep_name, ep_opts)]
)
break
except Exception as e:
last_err = e
if idx < len(eps) - 1:
logger.warning(
"Provider %s unavailable (%s); falling back to %s",
ep_name,
e,
eps[idx + 1][0],
)
else:
raise RuntimeError(f"Failed to create layout session: {last_err}")
logger.info(
"Layout session active providers: %s", self._session.get_providers()
)
self._input_name = self._session.get_inputs()[0].name
self._names = _parse_names_from_meta(self._session)
self._imgsz = settings.LAYOUT_IMGSZ
return self._session
def _detect_page_impl(self, page: pymupdf.Page) -> list[LayoutBox]:
"""检测单页 PDF 的 figure / table 区域。
流程:
1. letterbox 渲染:保比例缩放到长边=imgsz,短边留灰边
2. /255 + NCHW → ONNX 推理
3. YOLOv10 end-to-end 后处理(已 NMS
4. 模型坐标 → PDF 点
5. 过滤非 figure/table 类、极小框、越界 clip
Returns:
LayoutBox 列表,坐标为 PDF 点。
"""
session = self._init_session()
page_w = page.rect.width
page_h = page.rect.height
ratio = _compute_render_geometry(page_w, page_h, self._imgsz)
# 1. 保比例渲染(长边贴 imgsz)
pix = page.get_pixmap(
matrix=pymupdf.Matrix(ratio, ratio),
colorspace=pymupdf.csRGB,
alpha=False,
)
# 用 pixmap 实际尺寸(已取整)算 padding,消除取整导致的坐标偏移
dw, dh = _letterbox_padding(pix.width, pix.height, self._imgsz)
tensor = _padded_nchw_from_pixmap(pix, self._imgsz, dw, dh)
# 2. 推理
outputs = session.run(None, {self._input_name: tensor})
detections = _postprocess_output(
outputs[0], settings.LAYOUT_THRESHOLD, self._names
)
# 3. 坐标还原 + 过滤
result: list[LayoutBox] = []
for cls_id, x1m, y1m, x2m, y2m in detections:
boxclass = _map_class_to_boxclass(cls_id, self._names)
if boxclass is None:
continue
x0, y0 = _model_to_pdf(x1m, y1m, dw, dh, ratio)
x1, y1 = _model_to_pdf(x2m, y2m, dw, dh, ratio)
# clip 到页面范围
x0 = max(0.0, min(x0, page_w))
y0 = max(0.0, min(y0, page_h))
x1 = max(0.0, min(x1, page_w))
y1 = max(0.0, min(y1, page_h))
if boxclass in ("figure_caption", "table_caption"):
if (x1 - x0) < _MIN_CAPTION_BOX_WIDTH or (
y1 - y0
) < _MIN_CAPTION_BOX_HEIGHT:
continue
else:
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
continue
result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass))
return result
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
"""公共入口:加锁串行化推理。
包裹整段 _detect_page_impl(含 pixmap 渲染 + tensor 构造 + session.run),
保证同一时刻只有一个推理在跑——避免 SUMMARY_CONCURRENCY>1 时多个 to_thread
线程并发推理导致内存峰值翻倍(8GB 机器崩溃根因)。锁由 _detect_page_impl
间接保护 _init_session,首次加载也串行,杜绝并发各建一份 session。
"""
with self._lock:
return self._detect_page_impl(page)
# 模块级单例 —— 生产代码唯一的实例化点(_Singleton 元类保证不会再有第二个)
_detector = _LayoutDetector()
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
"""检测 PDF 页面中的 figure / table / caption 区域。
Returns:
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table 及其 caption。
"""
return _detector.detect_page(page)