feat: add concurrency safety, caption detection, admin enhancements, and performance improvements

This commit is contained in:
2026-06-14 22:20:02 +08:00
parent 8f13c31991
commit 29fb20828e
23 changed files with 1782 additions and 114 deletions
+73 -11
View File
@@ -23,8 +23,10 @@ from __future__ import annotations
import json
import logging
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import onnxruntime as ort
@@ -47,14 +49,18 @@ _FALLBACK_NAMES: dict[int, str] = {
8: "isolate_formula",
9: "formula_caption",
}
# 下游需 picture/table —— 按 class name 字符串动态匹配(不依赖 class index
# 下游需 picture/table 及其 caption —— 按 class name 字符串动态匹配(不依赖 class index
# 规避 DocStructBench 不同发布的类别顺序差异)
_PICTURE_NAMES = {"figure", "figure_group"}
_TABLE_NAMES = {"table", "table_group"}
_FIGURE_CAPTION_NAMES = {"figure_caption"}
_TABLE_CAPTION_NAMES = {"table_caption"}
# letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降)
_PAD_VALUE = 114
# 最小 bbox 尺寸(PDF 点)
_MIN_BOX_SIZE = 20
_MIN_CAPTION_BOX_WIDTH = 30
_MIN_CAPTION_BOX_HEIGHT = 6
# device → ExecutionProvider 映射
_PROVIDER_MAP: dict[str, str] = {
@@ -72,7 +78,7 @@ _AUTO_PRIORITY = ["cuda", "directml", "openvino", "cann", "tensorrt", "qnn"]
@dataclass
class LayoutBox:
"""检测到的布局区域,坐标为 PDF 点boxclass ∈ {"picture", "table"}"""
"""检测到的布局区域,坐标为 PDF 点。"""
x0: float
y0: float
@@ -191,13 +197,17 @@ def _postprocess_output(
def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None:
"""按 class name 匹配 figure→picture / table→table,其余返回 None。"""
"""按 class name 匹配下游关心的布局类别,其余返回 None。"""
name = names.get(cls_id, "")
n = name.strip().lower()
if n in _PICTURE_NAMES:
return "picture"
if n in _TABLE_NAMES:
return "table"
if n in _FIGURE_CAPTION_NAMES:
return "figure_caption"
if n in _TABLE_CAPTION_NAMES:
return "table_caption"
return None
@@ -220,15 +230,50 @@ def _parse_names_from_meta(session: ort.InferenceSession) -> dict[int, str]:
# ── 检测器单例 ──────────────────────────────────────────────────────────
class _LayoutDetector:
"""单例:管理 ONNX InferenceSession 生命周期。"""
class _Singleton(type):
"""元类单例:``cls()`` 永远返回同一实例;``reset_instance()`` 清缓存以便重建。
生产代码只应在模块级 ``_detector = _LayoutDetector()`` 创建一次。任何第二处
``_LayoutDetector()`` 都会拿到同一实例(含同一 ONNX session + 同一锁),杜绝
并发推理时各建一份 session 导致内存峰值翻倍(8GB 机器崩溃根因)。双检锁保证
首次实例化线程安全。
"""
_instances: dict[type, Any] = {}
_lock = threading.Lock()
def __call__(cls, *args, **kwargs):
if cls in _Singleton._instances:
return _Singleton._instances[cls]
with _Singleton._lock:
if cls not in _Singleton._instances:
_Singleton._instances[cls] = super().__call__(*args, **kwargs)
return _Singleton._instances[cls]
class _LayoutDetector(metaclass=_Singleton):
"""强约束单例:管理 ONNX InferenceSession 生命周期。
由 ``_Singleton`` 元类保证全进程唯一实例 —— 重复 ``_LayoutDetector()`` 只会返回
已有实例(含已加载的 session 和锁),不会新建。``reset_instance()`` 清缓存,仅供
测试隔离用。
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._session: ort.InferenceSession | None = None
self._names: dict[int, str] = {}
self._input_name: str = ""
self._imgsz: int = settings.LAYOUT_IMGSZ
@classmethod
def reset_instance(cls) -> None:
"""清空单例缓存,下次 ``_LayoutDetector()`` 重建新实例(含新锁 + 空 session)。
仅用于测试隔离 —— 生产代码永远不该调用(否则会丢掉已加载的模型 session)。
"""
_Singleton._instances.pop(cls, None)
def _init_session(self) -> ort.InferenceSession:
if self._session is not None:
return self._session
@@ -275,7 +320,7 @@ class _LayoutDetector:
self._imgsz = settings.LAYOUT_IMGSZ
return self._session
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
def _detect_page_impl(self, page: pymupdf.Page) -> list[LayoutBox]:
"""检测单页 PDF 的 figure / table 区域。
流程:
@@ -323,21 +368,38 @@ class _LayoutDetector:
y0 = max(0.0, min(y0, page_h))
x1 = max(0.0, min(x1, page_w))
y1 = max(0.0, min(y1, page_h))
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
continue
if boxclass in ("figure_caption", "table_caption"):
if (x1 - x0) < _MIN_CAPTION_BOX_WIDTH or (
y1 - y0
) < _MIN_CAPTION_BOX_HEIGHT:
continue
else:
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
continue
result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass))
return result
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
"""公共入口:加锁串行化推理。
# 模块级单例
包裹整段 _detect_page_impl(含 pixmap 渲染 + tensor 构造 + session.run),
保证同一时刻只有一个推理在跑——避免 SUMMARY_CONCURRENCY>1 时多个 to_thread
线程并发推理导致内存峰值翻倍(8GB 机器崩溃根因)。锁由 _detect_page_impl
间接保护 _init_session,首次加载也串行,杜绝并发各建一份 session。
"""
with self._lock:
return self._detect_page_impl(page)
# 模块级单例 —— 生产代码唯一的实例化点(_Singleton 元类保证不会再有第二个)
_detector = _LayoutDetector()
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
"""检测 PDF 页面中的 figure / table 区域。
"""检测 PDF 页面中的 figure / table / caption 区域。
Returns:
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table。
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table 及其 caption
"""
return _detector.detect_page(page)