refactor: 迁移布局检测模型从 PicoDet 到 DocLayout-YOLO

- 核心变更:
  - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024)
  - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备
  - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式
  - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls]
  - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table)

- 新增文件:
  - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行)
  - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例)

- 配置更新:
  - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID)
  - app/config.py: Settings 类对应字段
  - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等)

- 删除旧文件:
  - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本

- 文档更新:
  - README.md: 更新环境变量说明
  - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py)

此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
This commit is contained in:
2026-06-14 10:41:44 +08:00
parent 743d69efd0
commit 90fe705e8f
22 changed files with 2220 additions and 356 deletions
+1 -4
View File
@@ -301,10 +301,7 @@ class TestAdminPapers:
assert resp.status_code == 200
assert resp.json()["count"] == 2
remaining = db_session.execute(
text(
"SELECT rowid FROM papers_fts "
"WHERE rowid IN (:id1, :id2)"
),
text("SELECT rowid FROM papers_fts WHERE rowid IN (:id1, :id2)"),
{"id1": target_ids[0], "id2": target_ids[1]},
).fetchall()
assert remaining == []
+3 -1
View File
@@ -54,7 +54,9 @@ class TestReindexChroma:
def test_reindex_chroma_indexes_only_summarized_papers(
self, db_session, sample_papers_with_summary
):
with patch("app.services.embedder.index_paper", return_value=True) as mock_index:
with patch(
"app.services.embedder.index_paper", return_value=True
) as mock_index:
result = reindex_chroma(db_session)
assert result["status"] == "success"
+403
View File
@@ -0,0 +1,403 @@
"""layout_detector 测试 — 坐标还原数学、设备探测、类别映射、端到端 detect_page.
纯函数测试不依赖真实模型;TestDetectPage 用 MagicMock mock ort.InferenceSession
与 pymupdf.Page,参考 test_summary_utils.py 的 mock 模式。
"""
from __future__ import annotations
import json
from unittest.mock import MagicMock
import numpy as np
import onnxruntime as ort
import pytest
from app.config import settings
from app.services import layout_detector as mod
from app.services.layout_detector import (
LayoutBox,
_compute_render_geometry,
_FALLBACK_NAMES,
_letterbox_padding,
_map_class_to_boxclass,
_model_to_pdf,
_parse_names_from_meta,
_postprocess_output,
detect_page_layout,
resolve_providers,
)
IMGSZ = 1024
# ═══════════════════════════════════════════════════════════════════════
# 渲染几何与 letterbox padding
# ═══════════════════════════════════════════════════════════════════════
class TestComputeRenderGeometry:
def test_a4_portrait_short_edge_pads(self):
# A4 595×842,高度贴边,宽度方向留灰边
ratio = _compute_render_geometry(595, 842, IMGSZ)
assert ratio == pytest.approx(min(IMGSZ / 595, IMGSZ / 842))
assert ratio == pytest.approx(IMGSZ / 842) # 高度方向贴边
def test_wide_page_width_pads(self):
# 1600×900 横向,宽度贴边
ratio = _compute_render_geometry(1600, 900, IMGSZ)
assert ratio == pytest.approx(IMGSZ / 1600)
def test_square_no_letterbox(self):
ratio = _compute_render_geometry(100, 100, IMGSZ)
assert ratio == pytest.approx(10.24)
class TestLetterboxPadding:
def test_centered_padding(self):
# pixmap 723×1024 贴满高度,宽度两侧各 (1024-723)/2
dw, dh = _letterbox_padding(723, 1024, IMGSZ)
assert dw == pytest.approx((IMGSZ - 723) / 2)
assert dh == pytest.approx(0.0)
def test_square_no_padding(self):
dw, dh = _letterbox_padding(IMGSZ, IMGSZ, IMGSZ)
assert dw == 0.0
assert dh == 0.0
# ═══════════════════════════════════════════════════════════════════════
# 坐标还原(核心)—— pdf = (model - padding) / ratio
# ═══════════════════════════════════════════════════════════════════════
class TestModelToPdf:
def test_padding_corner_maps_to_origin(self):
# 模型空间左上角 (dw, dh) → PDF (0, 0)
dw, dh, ratio = 150.5, 0.0, 1.2157
x, y = _model_to_pdf(dw, dh, dw, dh, ratio)
assert x == pytest.approx(0.0, abs=0.01)
assert y == pytest.approx(0.0, abs=0.01)
def test_round_trip(self):
dw, dh, ratio = 150.5, 0.0, 1.2157
# PDF (100, 200) → 模型空间 → 再还原回 PDF
mx, my = 100 * ratio + dw, 200 * ratio + dh
px, py = _model_to_pdf(mx, my, dw, dh, ratio)
assert px == pytest.approx(100, abs=0.01)
assert py == pytest.approx(200, abs=0.01)
def test_full_a4_page_box(self):
# 整页框在模型空间为 (dw,dh)-(dw+pix_w, dh+pix_h),还原回页面尺寸
ratio = _compute_render_geometry(595, 842, IMGSZ)
pix_w, pix_h = round(595 * ratio), round(842 * ratio)
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
x0, y0 = _model_to_pdf(dw, dh, dw, dh, ratio)
x1, y1 = _model_to_pdf(dw + pix_w, dh + pix_h, dw, dh, ratio)
assert x0 == pytest.approx(0.0, abs=1.0)
assert y0 == pytest.approx(0.0, abs=1.0)
assert x1 == pytest.approx(595, abs=1.0)
assert y1 == pytest.approx(842, abs=1.0)
# ═══════════════════════════════════════════════════════════════════════
# 设备探测
# ═══════════════════════════════════════════════════════════════════════
class TestResolveProviders:
def test_cpu(self):
assert resolve_providers("cpu", 0) == [("CPUExecutionProvider", {})]
def test_cuda_with_cpu_fallback(self):
eps = resolve_providers("cuda", 0)
assert eps[0] == ("CUDAExecutionProvider", {"device_id": "0"})
assert eps[1] == ("CPUExecutionProvider", {})
def test_directml_device_id(self):
eps = resolve_providers("directml", 2)
assert eps[0] == ("DmlExecutionProvider", {"device_id": "2"})
def test_auto_picks_cuda_if_available(self, monkeypatch):
monkeypatch.setattr(
ort,
"get_available_providers",
lambda: ["CUDAExecutionProvider", "CPUExecutionProvider"],
)
eps = resolve_providers("auto", 0)
assert eps[0][0] == "CUDAExecutionProvider"
assert eps[-1] == ("CPUExecutionProvider", {})
def test_auto_falls_back_to_cpu(self, monkeypatch):
monkeypatch.setattr(
ort, "get_available_providers", lambda: ["CPUExecutionProvider"]
)
assert resolve_providers("auto", 0) == [("CPUExecutionProvider", {})]
def test_auto_prefers_cuda_over_directml(self, monkeypatch):
monkeypatch.setattr(
ort,
"get_available_providers",
lambda: [
"DmlExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
],
)
eps = resolve_providers("auto", 0)
assert eps[0][0] == "CUDAExecutionProvider"
def test_unknown_device_falls_back(self):
assert resolve_providers("tpu", 0) == [("CPUExecutionProvider", {})]
# ═══════════════════════════════════════════════════════════════════════
# 类别映射与 names 解析
# ═══════════════════════════════════════════════════════════════════════
class TestClassMapping:
def test_figure_to_picture(self):
assert _map_class_to_boxclass(3, {3: "figure"}) == "picture"
def test_figure_group_to_picture(self):
assert _map_class_to_boxclass(0, {0: "figure_group"}) == "picture"
def test_table(self):
assert _map_class_to_boxclass(5, {5: "table"}) == "table"
def test_caption_ignored(self):
names = {4: "figure_caption", 6: "table_caption"}
assert _map_class_to_boxclass(4, names) is None
assert _map_class_to_boxclass(6, names) is None
def test_other_classes_ignored(self):
names = {0: "title", 1: "plain text", 2: "abandon", 8: "isolate_formula"}
for k in names:
assert _map_class_to_boxclass(k, names) is None
def test_case_insensitive(self):
assert _map_class_to_boxclass(0, {0: "Figure"}) == "picture"
assert _map_class_to_boxclass(0, {0: "TABLE"}) == "table"
def test_unknown_class_id(self):
assert _map_class_to_boxclass(99, {0: "figure"}) is None
class TestParseNamesFromMeta:
def test_reads_json_metadata(self):
sess = MagicMock()
meta = MagicMock()
meta.custom_metadata_map = {
"names": '{"0": "title", "3": "figure", "5": "table"}'
}
sess.get_modelmeta.return_value = meta
assert _parse_names_from_meta(sess) == {0: "title", 3: "figure", 5: "table"}
def test_fallback_when_missing(self):
sess = MagicMock()
meta = MagicMock()
meta.custom_metadata_map = {}
sess.get_modelmeta.return_value = meta
assert _parse_names_from_meta(sess) == _FALLBACK_NAMES
def test_fallback_on_garbage(self):
sess = MagicMock()
meta = MagicMock()
meta.custom_metadata_map = {"names": "not json"}
sess.get_modelmeta.return_value = meta
assert _parse_names_from_meta(sess) == _FALLBACK_NAMES
# ═══════════════════════════════════════════════════════════════════════
# 后处理
# ═══════════════════════════════════════════════════════════════════════
class TestPostprocessOutput:
def test_parses_end_to_end_filters_by_conf(self):
out = np.array(
[[[10, 20, 30, 40, 0.9, 3], [50, 60, 70, 80, 0.1, 5]]],
dtype=np.float32,
)
res = _postprocess_output(out, 0.2, {3: "figure", 5: "table"})
assert res == [(3, 10.0, 20.0, 30.0, 40.0)]
def test_empty_output(self):
out = np.zeros((1, 0, 6), dtype=np.float32)
assert _postprocess_output(out, 0.2, {}) == []
def test_unexpected_shape_returns_empty(self):
out = np.zeros((1, 84, 8400), dtype=np.float32)
assert _postprocess_output(out, 0.2, {}) == []
# ═══════════════════════════════════════════════════════════════════════
# detect_page 端到端(mock ort.InferenceSession + pymupdf.Page
# ═══════════════════════════════════════════════════════════════════════
class TestDetectPage:
@pytest.fixture(autouse=True)
def _reset_detector(self):
"""每个测试前重置模块级单例,避免复用上个测试的 mock session。"""
mod._detector = mod._LayoutDetector()
yield
mod._detector = mod._LayoutDetector()
@staticmethod
def _build_mock_session(page_w, page_h, boxes, names):
"""构造 mock InferenceSession。
boxes: list of (cls_id, pdf_x0, pdf_y0, pdf_x1, pdf_y1, conf)
坐标为 PDF 点,内部转成模型空间坐标塞进 output。
names: dict[int, str] —— 写入 metadata 供 _parse_names_from_meta 读取。
"""
ratio = _compute_render_geometry(page_w, page_h, IMGSZ)
pix_w, pix_h = round(page_w * ratio), round(page_h * ratio)
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
rows = []
for cls_id, x0, y0, x1, y1, conf in boxes:
rows.append(
[
x0 * ratio + dw,
y0 * ratio + dh,
x1 * ratio + dw,
y1 * ratio + dh,
conf,
cls_id,
]
)
fake_output = (
np.array([rows], dtype=np.float32)
if rows
else np.zeros((1, 0, 6), dtype=np.float32)
)
sess = MagicMock()
inp = MagicMock()
inp.name = "images"
sess.get_inputs.return_value = [inp]
sess.run.return_value = [fake_output]
sess.get_providers.return_value = ["CPUExecutionProvider"]
meta = MagicMock()
meta.custom_metadata_map = {
"names": json.dumps({str(k): v for k, v in names.items()})
}
sess.get_modelmeta.return_value = meta
return sess, (pix_w, pix_h)
@staticmethod
def _make_mock_page(page_w, page_h, pix_w, pix_h):
pix = MagicMock()
pix.width = pix_w
pix.height = pix_h
pix.n = 3
pix.samples = bytes([128] * (pix_w * pix_h * 3))
page = MagicMock()
page.rect.width = page_w
page.rect.height = page_h
page.get_pixmap.return_value = pix
return page
def _setup(self, monkeypatch, tmp_path, sess):
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
(tmp_path / "m.onnx").write_bytes(b"x")
monkeypatch.setattr(ort, "InferenceSession", lambda *a, **kw: sess)
def test_returns_picture_box(self, monkeypatch, tmp_path):
names = {3: "figure", 5: "table"}
sess, (pw, ph) = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.9)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
b = boxes[0]
assert isinstance(b, LayoutBox)
assert b.boxclass == "picture"
assert b.x0 == pytest.approx(100, abs=1.0)
assert b.y0 == pytest.approx(100, abs=1.0)
assert b.x1 == pytest.approx(300, abs=1.0)
assert b.y1 == pytest.approx(400, abs=1.0)
def test_returns_table_box(self, monkeypatch, tmp_path):
names = {3: "figure", 5: "table"}
sess, (pw, ph) = self._build_mock_session(
595, 842, [(5, 50, 50, 400, 300, 0.85)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
assert boxes[0].boxclass == "table"
def test_filters_low_confidence(self, monkeypatch, tmp_path):
names = {3: "figure"}
# conf=0.1 < LAYOUT_THRESHOLD(0.2) → 过滤
sess, (pw, ph) = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.1)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
assert detect_page_layout(page) == []
def test_filters_small_box(self, monkeypatch, tmp_path):
names = {3: "figure"}
# 还原后 5×5 pt < _MIN_BOX_SIZE(20) → 过滤
sess, (pw, ph) = self._build_mock_session(
595, 842, [(3, 100, 100, 105, 105, 0.9)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
assert detect_page_layout(page) == []
def test_mixed_picture_and_table(self, monkeypatch, tmp_path):
names = {3: "figure", 5: "table"}
sess, (pw, ph) = self._build_mock_session(
595,
842,
[
(3, 100, 100, 300, 400, 0.9),
(5, 50, 500, 400, 700, 0.8),
],
names,
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
classes = sorted(b.boxclass for b in boxes)
assert classes == ["picture", "table"]
def test_empty_output(self, monkeypatch, tmp_path):
names = {3: "figure"}
sess, (pw, ph) = self._build_mock_session(595, 842, [], names)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
assert detect_page_layout(page) == []
def test_ignored_class_skipped(self, monkeypatch, tmp_path):
# title 类(cls_id=0)不应产出 LayoutBox
names = {0: "title", 3: "figure"}
sess, (pw, ph) = self._build_mock_session(
595,
842,
[(0, 100, 100, 400, 150, 0.9), (3, 100, 200, 300, 400, 0.9)],
names,
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
assert boxes[0].boxclass == "picture"