Files
daily-paper/tests/test_layout_detector.py
T

560 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""layout_detector 测试 — 坐标还原数学、设备探测、类别映射、端到端 detect_page.
纯函数测试不依赖真实模型;TestDetectPage 用 MagicMock mock ort.InferenceSession
与 pymupdf.Page,参考 test_summary_utils.py 的 mock 模式。
"""
from __future__ import annotations
import json
import threading
import time
from unittest.mock import MagicMock
import numpy as np
import onnxruntime as ort
import pytest
from app.config import settings
from app.services import layout_detector as mod
from app.services.layout_detector import (
LayoutBox,
_compute_render_geometry,
_FALLBACK_NAMES,
_letterbox_padding,
_map_class_to_boxclass,
_model_to_pdf,
_parse_names_from_meta,
_postprocess_output,
detect_page_layout,
resolve_providers,
)
IMGSZ = 1024
# ═══════════════════════════════════════════════════════════════════════
# 渲染几何与 letterbox padding
# ═══════════════════════════════════════════════════════════════════════
class TestComputeRenderGeometry:
def test_a4_portrait_short_edge_pads(self):
# A4 595×842,高度贴边,宽度方向留灰边
ratio = _compute_render_geometry(595, 842, IMGSZ)
assert ratio == pytest.approx(min(IMGSZ / 595, IMGSZ / 842))
assert ratio == pytest.approx(IMGSZ / 842) # 高度方向贴边
def test_wide_page_width_pads(self):
# 1600×900 横向,宽度贴边
ratio = _compute_render_geometry(1600, 900, IMGSZ)
assert ratio == pytest.approx(IMGSZ / 1600)
def test_square_no_letterbox(self):
ratio = _compute_render_geometry(100, 100, IMGSZ)
assert ratio == pytest.approx(10.24)
class TestLetterboxPadding:
def test_centered_padding(self):
# pixmap 723×1024 贴满高度,宽度两侧各 (1024-723)/2
dw, dh = _letterbox_padding(723, 1024, IMGSZ)
assert dw == pytest.approx((IMGSZ - 723) / 2)
assert dh == pytest.approx(0.0)
def test_square_no_padding(self):
dw, dh = _letterbox_padding(IMGSZ, IMGSZ, IMGSZ)
assert dw == 0.0
assert dh == 0.0
# ═══════════════════════════════════════════════════════════════════════
# 坐标还原(核心)—— pdf = (model - padding) / ratio
# ═══════════════════════════════════════════════════════════════════════
class TestModelToPdf:
def test_padding_corner_maps_to_origin(self):
# 模型空间左上角 (dw, dh) → PDF (0, 0)
dw, dh, ratio = 150.5, 0.0, 1.2157
x, y = _model_to_pdf(dw, dh, dw, dh, ratio)
assert x == pytest.approx(0.0, abs=0.01)
assert y == pytest.approx(0.0, abs=0.01)
def test_round_trip(self):
dw, dh, ratio = 150.5, 0.0, 1.2157
# PDF (100, 200) → 模型空间 → 再还原回 PDF
mx, my = 100 * ratio + dw, 200 * ratio + dh
px, py = _model_to_pdf(mx, my, dw, dh, ratio)
assert px == pytest.approx(100, abs=0.01)
assert py == pytest.approx(200, abs=0.01)
def test_full_a4_page_box(self):
# 整页框在模型空间为 (dw,dh)-(dw+pix_w, dh+pix_h),还原回页面尺寸
ratio = _compute_render_geometry(595, 842, IMGSZ)
pix_w, pix_h = round(595 * ratio), round(842 * ratio)
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
x0, y0 = _model_to_pdf(dw, dh, dw, dh, ratio)
x1, y1 = _model_to_pdf(dw + pix_w, dh + pix_h, dw, dh, ratio)
assert x0 == pytest.approx(0.0, abs=1.0)
assert y0 == pytest.approx(0.0, abs=1.0)
assert x1 == pytest.approx(595, abs=1.0)
assert y1 == pytest.approx(842, abs=1.0)
# ═══════════════════════════════════════════════════════════════════════
# 设备探测
# ═══════════════════════════════════════════════════════════════════════
class TestResolveProviders:
def test_cpu(self):
assert resolve_providers("cpu", 0) == [("CPUExecutionProvider", {})]
def test_cuda_with_cpu_fallback(self):
eps = resolve_providers("cuda", 0)
assert eps[0] == ("CUDAExecutionProvider", {"device_id": "0"})
assert eps[1] == ("CPUExecutionProvider", {})
def test_directml_device_id(self):
eps = resolve_providers("directml", 2)
assert eps[0] == ("DmlExecutionProvider", {"device_id": "2"})
def test_auto_picks_cuda_if_available(self, monkeypatch):
monkeypatch.setattr(
ort,
"get_available_providers",
lambda: ["CUDAExecutionProvider", "CPUExecutionProvider"],
)
eps = resolve_providers("auto", 0)
assert eps[0][0] == "CUDAExecutionProvider"
assert eps[-1] == ("CPUExecutionProvider", {})
def test_auto_falls_back_to_cpu(self, monkeypatch):
monkeypatch.setattr(
ort, "get_available_providers", lambda: ["CPUExecutionProvider"]
)
assert resolve_providers("auto", 0) == [("CPUExecutionProvider", {})]
def test_auto_prefers_cuda_over_directml(self, monkeypatch):
monkeypatch.setattr(
ort,
"get_available_providers",
lambda: [
"DmlExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
],
)
eps = resolve_providers("auto", 0)
assert eps[0][0] == "CUDAExecutionProvider"
def test_unknown_device_falls_back(self):
assert resolve_providers("tpu", 0) == [("CPUExecutionProvider", {})]
# ═══════════════════════════════════════════════════════════════════════
# 类别映射与 names 解析
# ═══════════════════════════════════════════════════════════════════════
class TestClassMapping:
def test_figure_to_picture(self):
assert _map_class_to_boxclass(3, {3: "figure"}) == "picture"
def test_figure_group_to_picture(self):
assert _map_class_to_boxclass(0, {0: "figure_group"}) == "picture"
def test_table(self):
assert _map_class_to_boxclass(5, {5: "table"}) == "table"
def test_caption_classes(self):
names = {4: "figure_caption", 6: "table_caption"}
assert _map_class_to_boxclass(4, names) == "figure_caption"
assert _map_class_to_boxclass(6, names) == "table_caption"
def test_other_classes_ignored(self):
names = {0: "title", 1: "plain text", 2: "abandon", 8: "isolate_formula"}
for k in names:
assert _map_class_to_boxclass(k, names) is None
def test_case_insensitive(self):
assert _map_class_to_boxclass(0, {0: "Figure"}) == "picture"
assert _map_class_to_boxclass(0, {0: "TABLE"}) == "table"
def test_unknown_class_id(self):
assert _map_class_to_boxclass(99, {0: "figure"}) is None
class TestParseNamesFromMeta:
def test_reads_json_metadata(self):
sess = MagicMock()
meta = MagicMock()
meta.custom_metadata_map = {
"names": '{"0": "title", "3": "figure", "5": "table"}'
}
sess.get_modelmeta.return_value = meta
assert _parse_names_from_meta(sess) == {0: "title", 3: "figure", 5: "table"}
def test_fallback_when_missing(self):
sess = MagicMock()
meta = MagicMock()
meta.custom_metadata_map = {}
sess.get_modelmeta.return_value = meta
assert _parse_names_from_meta(sess) == _FALLBACK_NAMES
def test_fallback_on_garbage(self):
sess = MagicMock()
meta = MagicMock()
meta.custom_metadata_map = {"names": "not json"}
sess.get_modelmeta.return_value = meta
assert _parse_names_from_meta(sess) == _FALLBACK_NAMES
# ═══════════════════════════════════════════════════════════════════════
# 后处理
# ═══════════════════════════════════════════════════════════════════════
class TestPostprocessOutput:
def test_parses_end_to_end_filters_by_conf(self):
out = np.array(
[[[10, 20, 30, 40, 0.9, 3], [50, 60, 70, 80, 0.1, 5]]],
dtype=np.float32,
)
res = _postprocess_output(out, 0.2, {3: "figure", 5: "table"})
assert res == [(3, 10.0, 20.0, 30.0, 40.0)]
def test_empty_output(self):
out = np.zeros((1, 0, 6), dtype=np.float32)
assert _postprocess_output(out, 0.2, {}) == []
def test_unexpected_shape_returns_empty(self):
out = np.zeros((1, 84, 8400), dtype=np.float32)
assert _postprocess_output(out, 0.2, {}) == []
# ═══════════════════════════════════════════════════════════════════════
# detect_page 端到端(mock ort.InferenceSession + pymupdf.Page
# ═══════════════════════════════════════════════════════════════════════
class TestDetectPage:
@pytest.fixture(autouse=True)
def _reset_detector(self):
"""每个测试前重建单例(带新锁 + 空 session),避免复用上个测试的 mock session。"""
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector()
yield
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector()
@staticmethod
def _build_mock_session(page_w, page_h, boxes, names):
"""构造 mock InferenceSession。
boxes: list of (cls_id, pdf_x0, pdf_y0, pdf_x1, pdf_y1, conf)
坐标为 PDF 点,内部转成模型空间坐标塞进 output。
names: dict[int, str] —— 写入 metadata 供 _parse_names_from_meta 读取。
"""
ratio = _compute_render_geometry(page_w, page_h, IMGSZ)
pix_w, pix_h = round(page_w * ratio), round(page_h * ratio)
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
rows = []
for cls_id, x0, y0, x1, y1, conf in boxes:
rows.append(
[
x0 * ratio + dw,
y0 * ratio + dh,
x1 * ratio + dw,
y1 * ratio + dh,
conf,
cls_id,
]
)
fake_output = (
np.array([rows], dtype=np.float32)
if rows
else np.zeros((1, 0, 6), dtype=np.float32)
)
sess = MagicMock()
inp = MagicMock()
inp.name = "images"
sess.get_inputs.return_value = [inp]
sess.run.return_value = [fake_output]
sess.get_providers.return_value = ["CPUExecutionProvider"]
meta = MagicMock()
meta.custom_metadata_map = {
"names": json.dumps({str(k): v for k, v in names.items()})
}
sess.get_modelmeta.return_value = meta
return sess, (pix_w, pix_h)
@staticmethod
def _make_mock_page(page_w, page_h, pix_w, pix_h):
pix = MagicMock()
pix.width = pix_w
pix.height = pix_h
pix.n = 3
pix.samples = bytes([128] * (pix_w * pix_h * 3))
page = MagicMock()
page.rect.width = page_w
page.rect.height = page_h
page.get_pixmap.return_value = pix
return page
def _setup(self, monkeypatch, tmp_path, sess):
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
(tmp_path / "m.onnx").write_bytes(b"x")
monkeypatch.setattr(ort, "InferenceSession", lambda *a, **kw: sess)
def test_returns_picture_box(self, monkeypatch, tmp_path):
names = {3: "figure", 5: "table"}
sess, (pw, ph) = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.9)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
b = boxes[0]
assert isinstance(b, LayoutBox)
assert b.boxclass == "picture"
assert b.x0 == pytest.approx(100, abs=1.0)
assert b.y0 == pytest.approx(100, abs=1.0)
assert b.x1 == pytest.approx(300, abs=1.0)
assert b.y1 == pytest.approx(400, abs=1.0)
def test_returns_table_box(self, monkeypatch, tmp_path):
names = {3: "figure", 5: "table"}
sess, (pw, ph) = self._build_mock_session(
595, 842, [(5, 50, 50, 400, 300, 0.85)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
assert boxes[0].boxclass == "table"
def test_returns_caption_box_with_small_height(self, monkeypatch, tmp_path):
names = {4: "figure_caption"}
sess, (pw, ph) = self._build_mock_session(
595, 842, [(4, 100, 405, 300, 417, 0.9)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
assert boxes[0].boxclass == "figure_caption"
assert boxes[0].y1 - boxes[0].y0 == pytest.approx(12, abs=1.0)
def test_filters_low_confidence(self, monkeypatch, tmp_path):
names = {3: "figure"}
# conf=0.1 < LAYOUT_THRESHOLD(0.2) → 过滤
sess, (pw, ph) = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.1)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
assert detect_page_layout(page) == []
def test_filters_small_box(self, monkeypatch, tmp_path):
names = {3: "figure"}
# 还原后 5×5 pt < _MIN_BOX_SIZE(20) → 过滤
sess, (pw, ph) = self._build_mock_session(
595, 842, [(3, 100, 100, 105, 105, 0.9)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
assert detect_page_layout(page) == []
def test_mixed_picture_and_table(self, monkeypatch, tmp_path):
names = {3: "figure", 5: "table"}
sess, (pw, ph) = self._build_mock_session(
595,
842,
[
(3, 100, 100, 300, 400, 0.9),
(5, 50, 500, 400, 700, 0.8),
],
names,
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
classes = sorted(b.boxclass for b in boxes)
assert classes == ["picture", "table"]
def test_empty_output(self, monkeypatch, tmp_path):
names = {3: "figure"}
sess, (pw, ph) = self._build_mock_session(595, 842, [], names)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
assert detect_page_layout(page) == []
def test_ignored_class_skipped(self, monkeypatch, tmp_path):
# title 类(cls_id=0)不应产出 LayoutBox
names = {0: "title", 3: "figure"}
sess, (pw, ph) = self._build_mock_session(
595,
842,
[(0, 100, 100, 400, 150, 0.9), (3, 100, 200, 300, 400, 0.9)],
names,
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
assert boxes[0].boxclass == "picture"
# ═══════════════════════════════════════════════════════════════════════
# 并发安全:锁串行化推理 + 单例 session 只初始化一次
# ═══════════════════════════════════════════════════════════════════════
class TestDetectPageConcurrency:
"""锁包裹整段 detect_page 后,并发调用的安全性。"""
@pytest.fixture(autouse=True)
def _reset_detector(self):
"""重建单例(带新锁),避免跨测试锁状态污染。"""
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector()
yield
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector()
@staticmethod
def _build_mock_session(page_w, page_h, boxes, names):
"""同 TestDetectPage._build_mock_session,额外返回 fake_output 供 side_effect。"""
ratio = _compute_render_geometry(page_w, page_h, IMGSZ)
pix_w, pix_h = round(page_w * ratio), round(page_h * ratio)
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
rows = []
for cls_id, x0, y0, x1, y1, conf in boxes:
rows.append(
[
x0 * ratio + dw,
y0 * ratio + dh,
x1 * ratio + dw,
y1 * ratio + dh,
conf,
cls_id,
]
)
fake_output = (
np.array([rows], dtype=np.float32)
if rows
else np.zeros((1, 0, 6), dtype=np.float32)
)
sess = MagicMock()
inp = MagicMock()
inp.name = "images"
sess.get_inputs.return_value = [inp]
sess.run.return_value = [fake_output]
sess.get_providers.return_value = ["CPUExecutionProvider"]
meta = MagicMock()
meta.custom_metadata_map = {
"names": json.dumps({str(k): v for k, v in names.items()})
}
sess.get_modelmeta.return_value = meta
return sess, (pix_w, pix_h), fake_output
@staticmethod
def _make_mock_page(page_w, page_h, pix_w, pix_h):
pix = MagicMock()
pix.width = pix_w
pix.height = pix_h
pix.n = 3
pix.samples = bytes([128] * (pix_w * pix_h * 3))
page = MagicMock()
page.rect.width = page_w
page.rect.height = page_h
page.get_pixmap.return_value = pix
return page
def _setup(self, monkeypatch, tmp_path, sess):
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
(tmp_path / "m.onnx").write_bytes(b"x")
monkeypatch.setattr(ort, "InferenceSession", lambda *a, **kw: sess)
def test_detect_page_serializes_concurrent_calls(self, monkeypatch, tmp_path):
"""多线程并发调 detect_page_layoutsession.run 临界区同时只有一个。"""
sess, (pw, ph), fake_output = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
)
in_critical = 0
max_concurrent = 0
counter_lock = threading.Lock()
def counting_run(*args, **kwargs):
nonlocal in_critical, max_concurrent
with counter_lock:
in_critical += 1
max_concurrent = max(max_concurrent, in_critical)
time.sleep(0.02) # 放大竞争窗口,让并发线程有机会重叠
try:
return [fake_output]
finally:
with counter_lock:
in_critical -= 1
sess.run.side_effect = counting_run
self._setup(monkeypatch, tmp_path, sess)
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(8)]
threads = [
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
]
for t in threads:
t.start()
for t in threads:
t.join()
# 锁生效 → 临界区同时只有一个;不加锁时此值会 > 1(回归保护)
assert max_concurrent == 1
def test_session_created_once_under_concurrency(self, monkeypatch, tmp_path):
"""多线程并发首次调用,InferenceSession 只创建一次(锁间接保护 _init_session)。"""
sess, (pw, ph), _fake_output = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
)
create_count = 0
create_lock = threading.Lock()
def counting_init(*args, **kwargs):
nonlocal create_count
with create_lock:
create_count += 1
time.sleep(0.02) # 放大窗口,让并发首调都来抢
return sess
monkeypatch.setattr(ort, "InferenceSession", counting_init)
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
(tmp_path / "m.onnx").write_bytes(b"x")
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(6)]
threads = [
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
]
for t in threads:
t.start()
for t in threads:
t.join()
assert create_count == 1