Files
daily-paper/tests/test_layout_detector.py
T
Rain-Bus 90fe705e8f refactor: 迁移布局检测模型从 PicoDet 到 DocLayout-YOLO
- 核心变更:
  - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024)
  - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备
  - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式
  - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls]
  - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table)

- 新增文件:
  - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行)
  - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例)

- 配置更新:
  - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID)
  - app/config.py: Settings 类对应字段
  - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等)

- 删除旧文件:
  - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本

- 文档更新:
  - README.md: 更新环境变量说明
  - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py)

此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
2026-06-14 10:41:44 +08:00

404 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""layout_detector 测试 — 坐标还原数学、设备探测、类别映射、端到端 detect_page.
纯函数测试不依赖真实模型;TestDetectPage 用 MagicMock mock ort.InferenceSession
与 pymupdf.Page,参考 test_summary_utils.py 的 mock 模式。
"""
from __future__ import annotations
import json
from unittest.mock import MagicMock
import numpy as np
import onnxruntime as ort
import pytest
from app.config import settings
from app.services import layout_detector as mod
from app.services.layout_detector import (
LayoutBox,
_compute_render_geometry,
_FALLBACK_NAMES,
_letterbox_padding,
_map_class_to_boxclass,
_model_to_pdf,
_parse_names_from_meta,
_postprocess_output,
detect_page_layout,
resolve_providers,
)
IMGSZ = 1024
# ═══════════════════════════════════════════════════════════════════════
# 渲染几何与 letterbox padding
# ═══════════════════════════════════════════════════════════════════════
class TestComputeRenderGeometry:
def test_a4_portrait_short_edge_pads(self):
# A4 595×842,高度贴边,宽度方向留灰边
ratio = _compute_render_geometry(595, 842, IMGSZ)
assert ratio == pytest.approx(min(IMGSZ / 595, IMGSZ / 842))
assert ratio == pytest.approx(IMGSZ / 842) # 高度方向贴边
def test_wide_page_width_pads(self):
# 1600×900 横向,宽度贴边
ratio = _compute_render_geometry(1600, 900, IMGSZ)
assert ratio == pytest.approx(IMGSZ / 1600)
def test_square_no_letterbox(self):
ratio = _compute_render_geometry(100, 100, IMGSZ)
assert ratio == pytest.approx(10.24)
class TestLetterboxPadding:
def test_centered_padding(self):
# pixmap 723×1024 贴满高度,宽度两侧各 (1024-723)/2
dw, dh = _letterbox_padding(723, 1024, IMGSZ)
assert dw == pytest.approx((IMGSZ - 723) / 2)
assert dh == pytest.approx(0.0)
def test_square_no_padding(self):
dw, dh = _letterbox_padding(IMGSZ, IMGSZ, IMGSZ)
assert dw == 0.0
assert dh == 0.0
# ═══════════════════════════════════════════════════════════════════════
# 坐标还原(核心)—— pdf = (model - padding) / ratio
# ═══════════════════════════════════════════════════════════════════════
class TestModelToPdf:
def test_padding_corner_maps_to_origin(self):
# 模型空间左上角 (dw, dh) → PDF (0, 0)
dw, dh, ratio = 150.5, 0.0, 1.2157
x, y = _model_to_pdf(dw, dh, dw, dh, ratio)
assert x == pytest.approx(0.0, abs=0.01)
assert y == pytest.approx(0.0, abs=0.01)
def test_round_trip(self):
dw, dh, ratio = 150.5, 0.0, 1.2157
# PDF (100, 200) → 模型空间 → 再还原回 PDF
mx, my = 100 * ratio + dw, 200 * ratio + dh
px, py = _model_to_pdf(mx, my, dw, dh, ratio)
assert px == pytest.approx(100, abs=0.01)
assert py == pytest.approx(200, abs=0.01)
def test_full_a4_page_box(self):
# 整页框在模型空间为 (dw,dh)-(dw+pix_w, dh+pix_h),还原回页面尺寸
ratio = _compute_render_geometry(595, 842, IMGSZ)
pix_w, pix_h = round(595 * ratio), round(842 * ratio)
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
x0, y0 = _model_to_pdf(dw, dh, dw, dh, ratio)
x1, y1 = _model_to_pdf(dw + pix_w, dh + pix_h, dw, dh, ratio)
assert x0 == pytest.approx(0.0, abs=1.0)
assert y0 == pytest.approx(0.0, abs=1.0)
assert x1 == pytest.approx(595, abs=1.0)
assert y1 == pytest.approx(842, abs=1.0)
# ═══════════════════════════════════════════════════════════════════════
# 设备探测
# ═══════════════════════════════════════════════════════════════════════
class TestResolveProviders:
def test_cpu(self):
assert resolve_providers("cpu", 0) == [("CPUExecutionProvider", {})]
def test_cuda_with_cpu_fallback(self):
eps = resolve_providers("cuda", 0)
assert eps[0] == ("CUDAExecutionProvider", {"device_id": "0"})
assert eps[1] == ("CPUExecutionProvider", {})
def test_directml_device_id(self):
eps = resolve_providers("directml", 2)
assert eps[0] == ("DmlExecutionProvider", {"device_id": "2"})
def test_auto_picks_cuda_if_available(self, monkeypatch):
monkeypatch.setattr(
ort,
"get_available_providers",
lambda: ["CUDAExecutionProvider", "CPUExecutionProvider"],
)
eps = resolve_providers("auto", 0)
assert eps[0][0] == "CUDAExecutionProvider"
assert eps[-1] == ("CPUExecutionProvider", {})
def test_auto_falls_back_to_cpu(self, monkeypatch):
monkeypatch.setattr(
ort, "get_available_providers", lambda: ["CPUExecutionProvider"]
)
assert resolve_providers("auto", 0) == [("CPUExecutionProvider", {})]
def test_auto_prefers_cuda_over_directml(self, monkeypatch):
monkeypatch.setattr(
ort,
"get_available_providers",
lambda: [
"DmlExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
],
)
eps = resolve_providers("auto", 0)
assert eps[0][0] == "CUDAExecutionProvider"
def test_unknown_device_falls_back(self):
assert resolve_providers("tpu", 0) == [("CPUExecutionProvider", {})]
# ═══════════════════════════════════════════════════════════════════════
# 类别映射与 names 解析
# ═══════════════════════════════════════════════════════════════════════
class TestClassMapping:
def test_figure_to_picture(self):
assert _map_class_to_boxclass(3, {3: "figure"}) == "picture"
def test_figure_group_to_picture(self):
assert _map_class_to_boxclass(0, {0: "figure_group"}) == "picture"
def test_table(self):
assert _map_class_to_boxclass(5, {5: "table"}) == "table"
def test_caption_ignored(self):
names = {4: "figure_caption", 6: "table_caption"}
assert _map_class_to_boxclass(4, names) is None
assert _map_class_to_boxclass(6, names) is None
def test_other_classes_ignored(self):
names = {0: "title", 1: "plain text", 2: "abandon", 8: "isolate_formula"}
for k in names:
assert _map_class_to_boxclass(k, names) is None
def test_case_insensitive(self):
assert _map_class_to_boxclass(0, {0: "Figure"}) == "picture"
assert _map_class_to_boxclass(0, {0: "TABLE"}) == "table"
def test_unknown_class_id(self):
assert _map_class_to_boxclass(99, {0: "figure"}) is None
class TestParseNamesFromMeta:
def test_reads_json_metadata(self):
sess = MagicMock()
meta = MagicMock()
meta.custom_metadata_map = {
"names": '{"0": "title", "3": "figure", "5": "table"}'
}
sess.get_modelmeta.return_value = meta
assert _parse_names_from_meta(sess) == {0: "title", 3: "figure", 5: "table"}
def test_fallback_when_missing(self):
sess = MagicMock()
meta = MagicMock()
meta.custom_metadata_map = {}
sess.get_modelmeta.return_value = meta
assert _parse_names_from_meta(sess) == _FALLBACK_NAMES
def test_fallback_on_garbage(self):
sess = MagicMock()
meta = MagicMock()
meta.custom_metadata_map = {"names": "not json"}
sess.get_modelmeta.return_value = meta
assert _parse_names_from_meta(sess) == _FALLBACK_NAMES
# ═══════════════════════════════════════════════════════════════════════
# 后处理
# ═══════════════════════════════════════════════════════════════════════
class TestPostprocessOutput:
def test_parses_end_to_end_filters_by_conf(self):
out = np.array(
[[[10, 20, 30, 40, 0.9, 3], [50, 60, 70, 80, 0.1, 5]]],
dtype=np.float32,
)
res = _postprocess_output(out, 0.2, {3: "figure", 5: "table"})
assert res == [(3, 10.0, 20.0, 30.0, 40.0)]
def test_empty_output(self):
out = np.zeros((1, 0, 6), dtype=np.float32)
assert _postprocess_output(out, 0.2, {}) == []
def test_unexpected_shape_returns_empty(self):
out = np.zeros((1, 84, 8400), dtype=np.float32)
assert _postprocess_output(out, 0.2, {}) == []
# ═══════════════════════════════════════════════════════════════════════
# detect_page 端到端(mock ort.InferenceSession + pymupdf.Page
# ═══════════════════════════════════════════════════════════════════════
class TestDetectPage:
@pytest.fixture(autouse=True)
def _reset_detector(self):
"""每个测试前重置模块级单例,避免复用上个测试的 mock session。"""
mod._detector = mod._LayoutDetector()
yield
mod._detector = mod._LayoutDetector()
@staticmethod
def _build_mock_session(page_w, page_h, boxes, names):
"""构造 mock InferenceSession。
boxes: list of (cls_id, pdf_x0, pdf_y0, pdf_x1, pdf_y1, conf)
坐标为 PDF 点,内部转成模型空间坐标塞进 output。
names: dict[int, str] —— 写入 metadata 供 _parse_names_from_meta 读取。
"""
ratio = _compute_render_geometry(page_w, page_h, IMGSZ)
pix_w, pix_h = round(page_w * ratio), round(page_h * ratio)
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
rows = []
for cls_id, x0, y0, x1, y1, conf in boxes:
rows.append(
[
x0 * ratio + dw,
y0 * ratio + dh,
x1 * ratio + dw,
y1 * ratio + dh,
conf,
cls_id,
]
)
fake_output = (
np.array([rows], dtype=np.float32)
if rows
else np.zeros((1, 0, 6), dtype=np.float32)
)
sess = MagicMock()
inp = MagicMock()
inp.name = "images"
sess.get_inputs.return_value = [inp]
sess.run.return_value = [fake_output]
sess.get_providers.return_value = ["CPUExecutionProvider"]
meta = MagicMock()
meta.custom_metadata_map = {
"names": json.dumps({str(k): v for k, v in names.items()})
}
sess.get_modelmeta.return_value = meta
return sess, (pix_w, pix_h)
@staticmethod
def _make_mock_page(page_w, page_h, pix_w, pix_h):
pix = MagicMock()
pix.width = pix_w
pix.height = pix_h
pix.n = 3
pix.samples = bytes([128] * (pix_w * pix_h * 3))
page = MagicMock()
page.rect.width = page_w
page.rect.height = page_h
page.get_pixmap.return_value = pix
return page
def _setup(self, monkeypatch, tmp_path, sess):
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
(tmp_path / "m.onnx").write_bytes(b"x")
monkeypatch.setattr(ort, "InferenceSession", lambda *a, **kw: sess)
def test_returns_picture_box(self, monkeypatch, tmp_path):
names = {3: "figure", 5: "table"}
sess, (pw, ph) = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.9)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
b = boxes[0]
assert isinstance(b, LayoutBox)
assert b.boxclass == "picture"
assert b.x0 == pytest.approx(100, abs=1.0)
assert b.y0 == pytest.approx(100, abs=1.0)
assert b.x1 == pytest.approx(300, abs=1.0)
assert b.y1 == pytest.approx(400, abs=1.0)
def test_returns_table_box(self, monkeypatch, tmp_path):
names = {3: "figure", 5: "table"}
sess, (pw, ph) = self._build_mock_session(
595, 842, [(5, 50, 50, 400, 300, 0.85)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
assert boxes[0].boxclass == "table"
def test_filters_low_confidence(self, monkeypatch, tmp_path):
names = {3: "figure"}
# conf=0.1 < LAYOUT_THRESHOLD(0.2) → 过滤
sess, (pw, ph) = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.1)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
assert detect_page_layout(page) == []
def test_filters_small_box(self, monkeypatch, tmp_path):
names = {3: "figure"}
# 还原后 5×5 pt < _MIN_BOX_SIZE(20) → 过滤
sess, (pw, ph) = self._build_mock_session(
595, 842, [(3, 100, 100, 105, 105, 0.9)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
assert detect_page_layout(page) == []
def test_mixed_picture_and_table(self, monkeypatch, tmp_path):
names = {3: "figure", 5: "table"}
sess, (pw, ph) = self._build_mock_session(
595,
842,
[
(3, 100, 100, 300, 400, 0.9),
(5, 50, 500, 400, 700, 0.8),
],
names,
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
classes = sorted(b.boxclass for b in boxes)
assert classes == ["picture", "table"]
def test_empty_output(self, monkeypatch, tmp_path):
names = {3: "figure"}
sess, (pw, ph) = self._build_mock_session(595, 842, [], names)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
assert detect_page_layout(page) == []
def test_ignored_class_skipped(self, monkeypatch, tmp_path):
# title 类(cls_id=0)不应产出 LayoutBox
names = {0: "title", 3: "figure"}
sess, (pw, ph) = self._build_mock_session(
595,
842,
[(0, 100, 100, 400, 150, 0.9), (3, 100, 200, 300, 400, 0.9)],
names,
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
assert boxes[0].boxclass == "picture"