feat: add concurrency safety, caption detection, admin enhancements, and performance improvements

This commit is contained in:
2026-06-14 22:20:02 +08:00
parent 8f13c31991
commit 29fb20828e
23 changed files with 1782 additions and 114 deletions
+160 -4
View File
@@ -7,6 +7,8 @@
from __future__ import annotations
import json
import threading
import time
from unittest.mock import MagicMock
import numpy as np
@@ -166,10 +168,10 @@ class TestClassMapping:
def test_table(self):
assert _map_class_to_boxclass(5, {5: "table"}) == "table"
def test_caption_ignored(self):
def test_caption_classes(self):
names = {4: "figure_caption", 6: "table_caption"}
assert _map_class_to_boxclass(4, names) is None
assert _map_class_to_boxclass(6, names) is None
assert _map_class_to_boxclass(4, names) == "figure_caption"
assert _map_class_to_boxclass(6, names) == "table_caption"
def test_other_classes_ignored(self):
names = {0: "title", 1: "plain text", 2: "abandon", 8: "isolate_formula"}
@@ -240,9 +242,11 @@ class TestPostprocessOutput:
class TestDetectPage:
@pytest.fixture(autouse=True)
def _reset_detector(self):
"""每个测试前重置模块级单例,避免复用上个测试的 mock session。"""
"""每个测试前重建单例(带新锁 + 空 session),避免复用上个测试的 mock session。"""
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector()
yield
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector()
@staticmethod
@@ -338,6 +342,20 @@ class TestDetectPage:
assert len(boxes) == 1
assert boxes[0].boxclass == "table"
def test_returns_caption_box_with_small_height(self, monkeypatch, tmp_path):
names = {4: "figure_caption"}
sess, (pw, ph) = self._build_mock_session(
595, 842, [(4, 100, 405, 300, 417, 0.9)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
assert boxes[0].boxclass == "figure_caption"
assert boxes[0].y1 - boxes[0].y0 == pytest.approx(12, abs=1.0)
def test_filters_low_confidence(self, monkeypatch, tmp_path):
names = {3: "figure"}
# conf=0.1 < LAYOUT_THRESHOLD(0.2) → 过滤
@@ -401,3 +419,141 @@ class TestDetectPage:
boxes = detect_page_layout(page)
assert len(boxes) == 1
assert boxes[0].boxclass == "picture"
# ═══════════════════════════════════════════════════════════════════════
# 并发安全:锁串行化推理 + 单例 session 只初始化一次
# ═══════════════════════════════════════════════════════════════════════
class TestDetectPageConcurrency:
"""锁包裹整段 detect_page 后,并发调用的安全性。"""
@pytest.fixture(autouse=True)
def _reset_detector(self):
"""重建单例(带新锁),避免跨测试锁状态污染。"""
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector()
yield
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector()
@staticmethod
def _build_mock_session(page_w, page_h, boxes, names):
"""同 TestDetectPage._build_mock_session,额外返回 fake_output 供 side_effect。"""
ratio = _compute_render_geometry(page_w, page_h, IMGSZ)
pix_w, pix_h = round(page_w * ratio), round(page_h * ratio)
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
rows = []
for cls_id, x0, y0, x1, y1, conf in boxes:
rows.append(
[
x0 * ratio + dw,
y0 * ratio + dh,
x1 * ratio + dw,
y1 * ratio + dh,
conf,
cls_id,
]
)
fake_output = (
np.array([rows], dtype=np.float32)
if rows
else np.zeros((1, 0, 6), dtype=np.float32)
)
sess = MagicMock()
inp = MagicMock()
inp.name = "images"
sess.get_inputs.return_value = [inp]
sess.run.return_value = [fake_output]
sess.get_providers.return_value = ["CPUExecutionProvider"]
meta = MagicMock()
meta.custom_metadata_map = {
"names": json.dumps({str(k): v for k, v in names.items()})
}
sess.get_modelmeta.return_value = meta
return sess, (pix_w, pix_h), fake_output
@staticmethod
def _make_mock_page(page_w, page_h, pix_w, pix_h):
pix = MagicMock()
pix.width = pix_w
pix.height = pix_h
pix.n = 3
pix.samples = bytes([128] * (pix_w * pix_h * 3))
page = MagicMock()
page.rect.width = page_w
page.rect.height = page_h
page.get_pixmap.return_value = pix
return page
def _setup(self, monkeypatch, tmp_path, sess):
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
(tmp_path / "m.onnx").write_bytes(b"x")
monkeypatch.setattr(ort, "InferenceSession", lambda *a, **kw: sess)
def test_detect_page_serializes_concurrent_calls(self, monkeypatch, tmp_path):
"""多线程并发调 detect_page_layoutsession.run 临界区同时只有一个。"""
sess, (pw, ph), fake_output = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
)
in_critical = 0
max_concurrent = 0
counter_lock = threading.Lock()
def counting_run(*args, **kwargs):
nonlocal in_critical, max_concurrent
with counter_lock:
in_critical += 1
max_concurrent = max(max_concurrent, in_critical)
time.sleep(0.02) # 放大竞争窗口,让并发线程有机会重叠
try:
return [fake_output]
finally:
with counter_lock:
in_critical -= 1
sess.run.side_effect = counting_run
self._setup(monkeypatch, tmp_path, sess)
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(8)]
threads = [
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
]
for t in threads:
t.start()
for t in threads:
t.join()
# 锁生效 → 临界区同时只有一个;不加锁时此值会 > 1(回归保护)
assert max_concurrent == 1
def test_session_created_once_under_concurrency(self, monkeypatch, tmp_path):
"""多线程并发首次调用,InferenceSession 只创建一次(锁间接保护 _init_session)。"""
sess, (pw, ph), _fake_output = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
)
create_count = 0
create_lock = threading.Lock()
def counting_init(*args, **kwargs):
nonlocal create_count
with create_lock:
create_count += 1
time.sleep(0.02) # 放大窗口,让并发首调都来抢
return sess
monkeypatch.setattr(ort, "InferenceSession", counting_init)
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
(tmp_path / "m.onnx").write_bytes(b"x")
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(6)]
threads = [
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
]
for t in threads:
t.start()
for t in threads:
t.join()
assert create_count == 1