feat: add concurrency safety, caption detection, admin enhancements, and performance improvements
This commit is contained in:
@@ -7,6 +7,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
@@ -166,10 +168,10 @@ class TestClassMapping:
|
||||
def test_table(self):
|
||||
assert _map_class_to_boxclass(5, {5: "table"}) == "table"
|
||||
|
||||
def test_caption_ignored(self):
|
||||
def test_caption_classes(self):
|
||||
names = {4: "figure_caption", 6: "table_caption"}
|
||||
assert _map_class_to_boxclass(4, names) is None
|
||||
assert _map_class_to_boxclass(6, names) is None
|
||||
assert _map_class_to_boxclass(4, names) == "figure_caption"
|
||||
assert _map_class_to_boxclass(6, names) == "table_caption"
|
||||
|
||||
def test_other_classes_ignored(self):
|
||||
names = {0: "title", 1: "plain text", 2: "abandon", 8: "isolate_formula"}
|
||||
@@ -240,9 +242,11 @@ class TestPostprocessOutput:
|
||||
class TestDetectPage:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_detector(self):
|
||||
"""每个测试前重置模块级单例,避免复用上个测试的 mock session。"""
|
||||
"""每个测试前重建单例(带新锁 + 空 session),避免复用上个测试的 mock session。"""
|
||||
mod._LayoutDetector.reset_instance()
|
||||
mod._detector = mod._LayoutDetector()
|
||||
yield
|
||||
mod._LayoutDetector.reset_instance()
|
||||
mod._detector = mod._LayoutDetector()
|
||||
|
||||
@staticmethod
|
||||
@@ -338,6 +342,20 @@ class TestDetectPage:
|
||||
assert len(boxes) == 1
|
||||
assert boxes[0].boxclass == "table"
|
||||
|
||||
def test_returns_caption_box_with_small_height(self, monkeypatch, tmp_path):
|
||||
names = {4: "figure_caption"}
|
||||
sess, (pw, ph) = self._build_mock_session(
|
||||
595, 842, [(4, 100, 405, 300, 417, 0.9)], names
|
||||
)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
boxes = detect_page_layout(page)
|
||||
|
||||
assert len(boxes) == 1
|
||||
assert boxes[0].boxclass == "figure_caption"
|
||||
assert boxes[0].y1 - boxes[0].y0 == pytest.approx(12, abs=1.0)
|
||||
|
||||
def test_filters_low_confidence(self, monkeypatch, tmp_path):
|
||||
names = {3: "figure"}
|
||||
# conf=0.1 < LAYOUT_THRESHOLD(0.2) → 过滤
|
||||
@@ -401,3 +419,141 @@ class TestDetectPage:
|
||||
boxes = detect_page_layout(page)
|
||||
assert len(boxes) == 1
|
||||
assert boxes[0].boxclass == "picture"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 并发安全:锁串行化推理 + 单例 session 只初始化一次
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestDetectPageConcurrency:
|
||||
"""锁包裹整段 detect_page 后,并发调用的安全性。"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_detector(self):
|
||||
"""重建单例(带新锁),避免跨测试锁状态污染。"""
|
||||
mod._LayoutDetector.reset_instance()
|
||||
mod._detector = mod._LayoutDetector()
|
||||
yield
|
||||
mod._LayoutDetector.reset_instance()
|
||||
mod._detector = mod._LayoutDetector()
|
||||
|
||||
@staticmethod
|
||||
def _build_mock_session(page_w, page_h, boxes, names):
|
||||
"""同 TestDetectPage._build_mock_session,额外返回 fake_output 供 side_effect。"""
|
||||
ratio = _compute_render_geometry(page_w, page_h, IMGSZ)
|
||||
pix_w, pix_h = round(page_w * ratio), round(page_h * ratio)
|
||||
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
|
||||
rows = []
|
||||
for cls_id, x0, y0, x1, y1, conf in boxes:
|
||||
rows.append(
|
||||
[
|
||||
x0 * ratio + dw,
|
||||
y0 * ratio + dh,
|
||||
x1 * ratio + dw,
|
||||
y1 * ratio + dh,
|
||||
conf,
|
||||
cls_id,
|
||||
]
|
||||
)
|
||||
fake_output = (
|
||||
np.array([rows], dtype=np.float32)
|
||||
if rows
|
||||
else np.zeros((1, 0, 6), dtype=np.float32)
|
||||
)
|
||||
sess = MagicMock()
|
||||
inp = MagicMock()
|
||||
inp.name = "images"
|
||||
sess.get_inputs.return_value = [inp]
|
||||
sess.run.return_value = [fake_output]
|
||||
sess.get_providers.return_value = ["CPUExecutionProvider"]
|
||||
meta = MagicMock()
|
||||
meta.custom_metadata_map = {
|
||||
"names": json.dumps({str(k): v for k, v in names.items()})
|
||||
}
|
||||
sess.get_modelmeta.return_value = meta
|
||||
return sess, (pix_w, pix_h), fake_output
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_page(page_w, page_h, pix_w, pix_h):
|
||||
pix = MagicMock()
|
||||
pix.width = pix_w
|
||||
pix.height = pix_h
|
||||
pix.n = 3
|
||||
pix.samples = bytes([128] * (pix_w * pix_h * 3))
|
||||
page = MagicMock()
|
||||
page.rect.width = page_w
|
||||
page.rect.height = page_h
|
||||
page.get_pixmap.return_value = pix
|
||||
return page
|
||||
|
||||
def _setup(self, monkeypatch, tmp_path, sess):
|
||||
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
|
||||
(tmp_path / "m.onnx").write_bytes(b"x")
|
||||
monkeypatch.setattr(ort, "InferenceSession", lambda *a, **kw: sess)
|
||||
|
||||
def test_detect_page_serializes_concurrent_calls(self, monkeypatch, tmp_path):
|
||||
"""多线程并发调 detect_page_layout,session.run 临界区同时只有一个。"""
|
||||
sess, (pw, ph), fake_output = self._build_mock_session(
|
||||
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
|
||||
)
|
||||
in_critical = 0
|
||||
max_concurrent = 0
|
||||
counter_lock = threading.Lock()
|
||||
|
||||
def counting_run(*args, **kwargs):
|
||||
nonlocal in_critical, max_concurrent
|
||||
with counter_lock:
|
||||
in_critical += 1
|
||||
max_concurrent = max(max_concurrent, in_critical)
|
||||
time.sleep(0.02) # 放大竞争窗口,让并发线程有机会重叠
|
||||
try:
|
||||
return [fake_output]
|
||||
finally:
|
||||
with counter_lock:
|
||||
in_critical -= 1
|
||||
|
||||
sess.run.side_effect = counting_run
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
|
||||
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(8)]
|
||||
threads = [
|
||||
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# 锁生效 → 临界区同时只有一个;不加锁时此值会 > 1(回归保护)
|
||||
assert max_concurrent == 1
|
||||
|
||||
def test_session_created_once_under_concurrency(self, monkeypatch, tmp_path):
|
||||
"""多线程并发首次调用,InferenceSession 只创建一次(锁间接保护 _init_session)。"""
|
||||
sess, (pw, ph), _fake_output = self._build_mock_session(
|
||||
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
|
||||
)
|
||||
create_count = 0
|
||||
create_lock = threading.Lock()
|
||||
|
||||
def counting_init(*args, **kwargs):
|
||||
nonlocal create_count
|
||||
with create_lock:
|
||||
create_count += 1
|
||||
time.sleep(0.02) # 放大窗口,让并发首调都来抢
|
||||
return sess
|
||||
|
||||
monkeypatch.setattr(ort, "InferenceSession", counting_init)
|
||||
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
|
||||
(tmp_path / "m.onnx").write_bytes(b"x")
|
||||
|
||||
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(6)]
|
||||
threads = [
|
||||
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert create_count == 1
|
||||
|
||||
Reference in New Issue
Block a user