Files
daily-paper/app/services/layout_detector.py
T
Rain-Bus 90fe705e8f refactor: 迁移布局检测模型从 PicoDet 到 DocLayout-YOLO
- 核心变更:
  - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024)
  - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备
  - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式
  - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls]
  - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table)

- 新增文件:
  - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行)
  - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例)

- 配置更新:
  - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID)
  - app/config.py: Settings 类对应字段
  - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等)

- 删除旧文件:
  - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本

- 文档更新:
  - README.md: 更新环境变量说明
  - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py)

此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
2026-06-14 10:41:44 +08:00

344 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""DocLayout-YOLO 布局检测 — ONNX Runtime 推理,支持 CPU/GPU/NPU 多设备.
用 onnxruntime 加载 DocLayout-YOLODocStructBench, imgsz=1024ONNX 模型,
检测 PDF 页面中的 figure / table 区域。
预处理:letterbox(保比例缩放 + 灰边 padding 到 imgsz×imgsz),RGB,仅 /255 归一化
(不做 ImageNet mean/std)。缩放由 pymupdf Matrix 完成,不依赖 OpenCV。
后处理:YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls](已内置 NMS)。
坐标还原:(model_coord - padding) / ratio —— 渲染缩放与 letterbox 缩放在 pymupdf
渲染阶段合二为一,故只需一次除法。
设备:resolve_providers() 按 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表;
_init_session() 逐个 try,首个不可用则降级,CPU 永远兜底。
输入:
images: (1, 3, imgsz, imgsz) float32 —— letterbox + /255 后的图
输出:
output0: (1, N, 6) float32 —— [x1, y1, x2, y2, conf, cls],已 NMS
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import onnxruntime as ort
import pymupdf
from app.config import settings
logger = logging.getLogger(__name__)
# DocLayout-YOLO DocStructBench 标准 10 类(ONNX metadata 读不到时的兜底,以实际为准)
_FALLBACK_NAMES: dict[int, str] = {
0: "title",
1: "plain text",
2: "abandon",
3: "figure",
4: "figure_caption",
5: "table",
6: "table_caption",
7: "table_footnote",
8: "isolate_formula",
9: "formula_caption",
}
# 下游只需 picture/table —— 按 class name 字符串动态匹配(不依赖 class index
# 规避 DocStructBench 不同发布的类别顺序差异)
_PICTURE_NAMES = {"figure", "figure_group"}
_TABLE_NAMES = {"table", "table_group"}
# letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降)
_PAD_VALUE = 114
# 最小 bbox 尺寸(PDF 点)
_MIN_BOX_SIZE = 20
# device → ExecutionProvider 映射
_PROVIDER_MAP: dict[str, str] = {
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
"directml": "DmlExecutionProvider",
"openvino": "OpenVINOExecutionProvider",
"cann": "CannExecutionProvider",
"tensorrt": "TensorrtExecutionProvider",
"qnn": "QNNExecutionProvider",
}
# auto 探测优先级(不含 cpu,cpu 永远兜底)
_AUTO_PRIORITY = ["cuda", "directml", "openvino", "cann", "tensorrt", "qnn"]
@dataclass
class LayoutBox:
"""检测到的布局区域,坐标为 PDF 点,boxclass ∈ {"picture", "table"}。"""
x0: float
y0: float
x1: float
y1: float
boxclass: str
# ── 设备选择 ────────────────────────────────────────────────────────────
def resolve_providers(device: str, device_id: int) -> list[tuple[str, dict]]:
"""根据 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表(首选在前,均带 CPU 兜底)。
返回 list[tuple[ep_name, provider_options]],供 _init_session() 逐个 try。
onnxruntime 创建 session 时若指定 EP 在本机变体里未注册会直接抛错,
故降级逻辑由 _init_session() 完成,这里只产出候选。
"""
if device == "cpu":
return [("CPUExecutionProvider", {})]
opts = {"device_id": str(device_id)}
if device == "auto":
available = set(ort.get_available_providers())
for dev in _AUTO_PRIORITY:
ep = _PROVIDER_MAP[dev]
if ep in available:
logger.info("auto: selected provider %s", ep)
return [(ep, opts), ("CPUExecutionProvider", {})]
logger.info("auto: no GPU/NPU provider available, using CPU")
return [("CPUExecutionProvider", {})]
ep = _PROVIDER_MAP.get(device)
if ep is None:
logger.warning("Unknown LAYOUT_DEVICE=%r, falling back to CPU", device)
return [("CPUExecutionProvider", {})]
return [(ep, opts), ("CPUExecutionProvider", {})]
# ── 预处理:渲染几何与 letterbox ────────────────────────────────────────
def _compute_render_geometry(page_w: float, page_h: float, imgsz: int) -> float:
"""letterbox 渲染缩放 ratio = min(imgsz/page_w, imgsz/page_h)。
pymupdf 以 Matrix(ratio, ratio) 渲染,长边贴到 imgsz,短边留灰边。
"""
return min(imgsz / page_w, imgsz / page_h)
def _letterbox_padding(
content_w: float, content_h: float, imgsz: int
) -> tuple[float, float]:
"""居中 padding(imgsz - content) / 2。content 为实际 pixmap 尺寸(已取整)。"""
return (imgsz - content_w) / 2.0, (imgsz - content_h) / 2.0
def _padded_nchw_from_pixmap(
pix: pymupdf.Pixmap, imgsz: int, dw: float, dh: float
) -> np.ndarray:
"""pixmap → letterbox padded (1, 3, imgsz, imgsz) float32,灰边=114/255 归一化。"""
arr = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
pix.height, pix.width, pix.n
)
if arr.shape[2] == 4: # 去 alphacsRGB alpha=False 一般不会,防御性)
arr = arr[:, :, :3]
canvas = np.full((imgsz, imgsz, 3), _PAD_VALUE, dtype=np.uint8)
top = int(round(dh))
left = int(round(dw))
canvas[top : top + pix.height, left : left + pix.width] = arr
out = canvas.astype(np.float32) / 255.0
return out.transpose(2, 0, 1)[np.newaxis] # (1, 3, imgsz, imgsz)
def _model_to_pdf(
model_x: float, model_y: float, dw: float, dh: float, ratio: float
) -> tuple[float, float]:
"""模型 imgsz 空间坐标 → PDF 点:(model - padding) / ratio。"""
return (model_x - dw) / ratio, (model_y - dh) / ratio
# ── 后处理 ──────────────────────────────────────────────────────────────
def _postprocess_output(
output: np.ndarray, threshold: float, names: dict[int, str]
) -> list[tuple[int, float, float, float, float]]:
"""解析 YOLOv10 end-to-end 输出,过滤 conf < threshold。
Args:
output: session.run 返回的第一个输出,shape [1, N, 6]
threshold: 置信度阈值
names: class id → name(仅用于日志,过滤不依赖)
Returns:
[(cls_id, x1, y1, x2, y2), ...],坐标为模型 imgsz padded 空间。
"""
out = output[0] # 去 batch 维
if out.ndim != 2 or out.shape[1] != 6:
logger.warning(
"Unexpected DocLayout-YOLO output shape %s (expected [N,6]); skip page",
tuple(out.shape),
)
return []
results: list[tuple[int, float, float, float, float]] = []
for row in out:
x1, y1, x2, y2, conf, cls = row.tolist()
if conf < threshold:
continue
results.append((int(cls), x1, y1, x2, y2))
return results
def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None:
"""按 class name 匹配 figure→picture / table→table,其余返回 None。"""
name = names.get(cls_id, "")
n = name.strip().lower()
if n in _PICTURE_NAMES:
return "picture"
if n in _TABLE_NAMES:
return "table"
return None
def _parse_names_from_meta(session: ort.InferenceSession) -> dict[int, str]:
"""从 ONNX metadata 读 namesultralytics 导出写入的 JSON),读不到用兜底。"""
raw = None
try:
raw = session.get_modelmeta().custom_metadata_map.get("names")
except Exception:
raw = None
if raw:
try:
d = json.loads(raw)
return {int(k): str(v) for k, v in d.items()}
except Exception:
logger.warning("Failed to parse ONNX names metadata; using fallback")
return dict(_FALLBACK_NAMES)
# ── 检测器单例 ──────────────────────────────────────────────────────────
class _LayoutDetector:
"""单例:管理 ONNX InferenceSession 生命周期。"""
def __init__(self) -> None:
self._session: ort.InferenceSession | None = None
self._names: dict[int, str] = {}
self._input_name: str = ""
self._imgsz: int = settings.LAYOUT_IMGSZ
def _init_session(self) -> ort.InferenceSession:
if self._session is not None:
return self._session
model_path = Path(settings.LAYOUT_MODEL_PATH)
if not model_path.exists():
raise FileNotFoundError(
f"Layout model not found: {model_path}. "
"Run scripts/export_doclayout_yolo_onnx.py first."
)
eps = resolve_providers(settings.LAYOUT_DEVICE, settings.LAYOUT_DEVICE_ID)
logger.info(
"Loading layout model %s, candidate providers: %s",
model_path,
[ep[0] for ep in eps],
)
# 逐个 EP 尝试,首个不可用则降级
last_err: Exception | None = None
for idx, (ep_name, ep_opts) in enumerate(eps):
try:
self._session = ort.InferenceSession(
str(model_path), providers=[(ep_name, ep_opts)]
)
break
except Exception as e:
last_err = e
if idx < len(eps) - 1:
logger.warning(
"Provider %s unavailable (%s); falling back to %s",
ep_name,
e,
eps[idx + 1][0],
)
else:
raise RuntimeError(f"Failed to create layout session: {last_err}")
logger.info(
"Layout session active providers: %s", self._session.get_providers()
)
self._input_name = self._session.get_inputs()[0].name
self._names = _parse_names_from_meta(self._session)
self._imgsz = settings.LAYOUT_IMGSZ
return self._session
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
"""检测单页 PDF 的 figure / table 区域。
流程:
1. letterbox 渲染:保比例缩放到长边=imgsz,短边留灰边
2. /255 + NCHW → ONNX 推理
3. YOLOv10 end-to-end 后处理(已 NMS
4. 模型坐标 → PDF 点
5. 过滤非 figure/table 类、极小框、越界 clip
Returns:
LayoutBox 列表,坐标为 PDF 点。
"""
session = self._init_session()
page_w = page.rect.width
page_h = page.rect.height
ratio = _compute_render_geometry(page_w, page_h, self._imgsz)
# 1. 保比例渲染(长边贴 imgsz)
pix = page.get_pixmap(
matrix=pymupdf.Matrix(ratio, ratio),
colorspace=pymupdf.csRGB,
alpha=False,
)
# 用 pixmap 实际尺寸(已取整)算 padding,消除取整导致的坐标偏移
dw, dh = _letterbox_padding(pix.width, pix.height, self._imgsz)
tensor = _padded_nchw_from_pixmap(pix, self._imgsz, dw, dh)
# 2. 推理
outputs = session.run(None, {self._input_name: tensor})
detections = _postprocess_output(
outputs[0], settings.LAYOUT_THRESHOLD, self._names
)
# 3. 坐标还原 + 过滤
result: list[LayoutBox] = []
for cls_id, x1m, y1m, x2m, y2m in detections:
boxclass = _map_class_to_boxclass(cls_id, self._names)
if boxclass is None:
continue
x0, y0 = _model_to_pdf(x1m, y1m, dw, dh, ratio)
x1, y1 = _model_to_pdf(x2m, y2m, dw, dh, ratio)
# clip 到页面范围
x0 = max(0.0, min(x0, page_w))
y0 = max(0.0, min(y0, page_h))
x1 = max(0.0, min(x1, page_w))
y1 = max(0.0, min(y1, page_h))
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
continue
result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass))
return result
# 模块级单例
_detector = _LayoutDetector()
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
"""检测 PDF 页面中的 figure / table 区域。
Returns:
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table。
"""
return _detector.detect_page(page)