feat: add concurrency safety, caption detection, admin enhancements, and performance improvements
This commit is contained in:
@@ -23,8 +23,10 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
@@ -47,14 +49,18 @@ _FALLBACK_NAMES: dict[int, str] = {
|
||||
8: "isolate_formula",
|
||||
9: "formula_caption",
|
||||
}
|
||||
# 下游只需 picture/table —— 按 class name 字符串动态匹配(不依赖 class index,
|
||||
# 下游需要 picture/table 及其 caption —— 按 class name 字符串动态匹配(不依赖 class index,
|
||||
# 规避 DocStructBench 不同发布的类别顺序差异)
|
||||
_PICTURE_NAMES = {"figure", "figure_group"}
|
||||
_TABLE_NAMES = {"table", "table_group"}
|
||||
_FIGURE_CAPTION_NAMES = {"figure_caption"}
|
||||
_TABLE_CAPTION_NAMES = {"table_caption"}
|
||||
# letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降)
|
||||
_PAD_VALUE = 114
|
||||
# 最小 bbox 尺寸(PDF 点)
|
||||
_MIN_BOX_SIZE = 20
|
||||
_MIN_CAPTION_BOX_WIDTH = 30
|
||||
_MIN_CAPTION_BOX_HEIGHT = 6
|
||||
|
||||
# device → ExecutionProvider 映射
|
||||
_PROVIDER_MAP: dict[str, str] = {
|
||||
@@ -72,7 +78,7 @@ _AUTO_PRIORITY = ["cuda", "directml", "openvino", "cann", "tensorrt", "qnn"]
|
||||
|
||||
@dataclass
|
||||
class LayoutBox:
|
||||
"""检测到的布局区域,坐标为 PDF 点,boxclass ∈ {"picture", "table"}。"""
|
||||
"""检测到的布局区域,坐标为 PDF 点。"""
|
||||
|
||||
x0: float
|
||||
y0: float
|
||||
@@ -191,13 +197,17 @@ def _postprocess_output(
|
||||
|
||||
|
||||
def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None:
|
||||
"""按 class name 匹配 figure→picture / table→table,其余返回 None。"""
|
||||
"""按 class name 匹配下游关心的布局类别,其余返回 None。"""
|
||||
name = names.get(cls_id, "")
|
||||
n = name.strip().lower()
|
||||
if n in _PICTURE_NAMES:
|
||||
return "picture"
|
||||
if n in _TABLE_NAMES:
|
||||
return "table"
|
||||
if n in _FIGURE_CAPTION_NAMES:
|
||||
return "figure_caption"
|
||||
if n in _TABLE_CAPTION_NAMES:
|
||||
return "table_caption"
|
||||
return None
|
||||
|
||||
|
||||
@@ -220,15 +230,50 @@ def _parse_names_from_meta(session: ort.InferenceSession) -> dict[int, str]:
|
||||
# ── 检测器单例 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _LayoutDetector:
|
||||
"""单例:管理 ONNX InferenceSession 生命周期。"""
|
||||
class _Singleton(type):
|
||||
"""元类单例:``cls()`` 永远返回同一实例;``reset_instance()`` 清缓存以便重建。
|
||||
|
||||
生产代码只应在模块级 ``_detector = _LayoutDetector()`` 创建一次。任何第二处
|
||||
``_LayoutDetector()`` 都会拿到同一实例(含同一 ONNX session + 同一锁),杜绝
|
||||
并发推理时各建一份 session 导致内存峰值翻倍(8GB 机器崩溃根因)。双检锁保证
|
||||
首次实例化线程安全。
|
||||
"""
|
||||
|
||||
_instances: dict[type, Any] = {}
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls in _Singleton._instances:
|
||||
return _Singleton._instances[cls]
|
||||
with _Singleton._lock:
|
||||
if cls not in _Singleton._instances:
|
||||
_Singleton._instances[cls] = super().__call__(*args, **kwargs)
|
||||
return _Singleton._instances[cls]
|
||||
|
||||
|
||||
class _LayoutDetector(metaclass=_Singleton):
|
||||
"""强约束单例:管理 ONNX InferenceSession 生命周期。
|
||||
|
||||
由 ``_Singleton`` 元类保证全进程唯一实例 —— 重复 ``_LayoutDetector()`` 只会返回
|
||||
已有实例(含已加载的 session 和锁),不会新建。``reset_instance()`` 清缓存,仅供
|
||||
测试隔离用。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._session: ort.InferenceSession | None = None
|
||||
self._names: dict[int, str] = {}
|
||||
self._input_name: str = ""
|
||||
self._imgsz: int = settings.LAYOUT_IMGSZ
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""清空单例缓存,下次 ``_LayoutDetector()`` 重建新实例(含新锁 + 空 session)。
|
||||
|
||||
仅用于测试隔离 —— 生产代码永远不该调用(否则会丢掉已加载的模型 session)。
|
||||
"""
|
||||
_Singleton._instances.pop(cls, None)
|
||||
|
||||
def _init_session(self) -> ort.InferenceSession:
|
||||
if self._session is not None:
|
||||
return self._session
|
||||
@@ -275,7 +320,7 @@ class _LayoutDetector:
|
||||
self._imgsz = settings.LAYOUT_IMGSZ
|
||||
return self._session
|
||||
|
||||
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||||
def _detect_page_impl(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||||
"""检测单页 PDF 的 figure / table 区域。
|
||||
|
||||
流程:
|
||||
@@ -323,21 +368,38 @@ class _LayoutDetector:
|
||||
y0 = max(0.0, min(y0, page_h))
|
||||
x1 = max(0.0, min(x1, page_w))
|
||||
y1 = max(0.0, min(y1, page_h))
|
||||
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
|
||||
continue
|
||||
if boxclass in ("figure_caption", "table_caption"):
|
||||
if (x1 - x0) < _MIN_CAPTION_BOX_WIDTH or (
|
||||
y1 - y0
|
||||
) < _MIN_CAPTION_BOX_HEIGHT:
|
||||
continue
|
||||
else:
|
||||
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
|
||||
continue
|
||||
result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass))
|
||||
|
||||
return result
|
||||
|
||||
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||||
"""公共入口:加锁串行化推理。
|
||||
|
||||
# 模块级单例
|
||||
包裹整段 _detect_page_impl(含 pixmap 渲染 + tensor 构造 + session.run),
|
||||
保证同一时刻只有一个推理在跑——避免 SUMMARY_CONCURRENCY>1 时多个 to_thread
|
||||
线程并发推理导致内存峰值翻倍(8GB 机器崩溃根因)。锁由 _detect_page_impl
|
||||
间接保护 _init_session,首次加载也串行,杜绝并发各建一份 session。
|
||||
"""
|
||||
with self._lock:
|
||||
return self._detect_page_impl(page)
|
||||
|
||||
|
||||
# 模块级单例 —— 生产代码唯一的实例化点(_Singleton 元类保证不会再有第二个)
|
||||
_detector = _LayoutDetector()
|
||||
|
||||
|
||||
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
|
||||
"""检测 PDF 页面中的 figure / table 区域。
|
||||
"""检测 PDF 页面中的 figure / table / caption 区域。
|
||||
|
||||
Returns:
|
||||
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table。
|
||||
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table 及其 caption。
|
||||
"""
|
||||
return _detector.detect_page(page)
|
||||
|
||||
Reference in New Issue
Block a user