Files
daily-paper/app/services/layout_detector.py
T
Rain-Bus 21f16e6756 feat: refactor summarizer and PDF extraction pipeline
- Split summarizer into summary_generator and summary_persister modules
- Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection
- Add layout_detector service for PicoDet-S_layout_3cls integration
- Add exceptions module with ConflictError and NotFoundError
- Improve admin dashboard with better statistics and task management
- Add design review document with system optimization suggestions
- Add new tests for crawler, pdf_downloader, pipeline, and summary_utils
- Update dependencies and configuration
- Clean up dead code and improve error handling
2026-06-13 13:16:47 +08:00

175 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""PicoDet-S_layout_3cls 布局检测 — 纯 ONNX Runtime 推理.
用 onnxruntime 加载导出好的 ONNX 模型,检测 PDF 页面中的 figure / table 区域。
模型自带 NMS + GFL decode,输出即为后处理完毕的检测框。
输入:
image: (1, 3, 480, 480) float32 — ImageNet 标准化后的图片
scale_factor: (1, 2) float32 — [y_scale, x_scale],用于坐标还原
输出:
fetch_name_0: (N, 6) float32 — [xmin, ymin, xmax, ymax, score, class_id]
fetch_name_1: (1,) int32 — 有效框数量 N
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import onnxruntime as ort
import pymupdf
from app.config import settings
logger = logging.getLogger(__name__)
# 模型输入尺寸
_MODEL_SIZE = 480
# ImageNet normalize
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
# PicoDet label → 内部 boxclass
_LABEL_MAP: dict[int, str] = {
0: "picture", # PicoDet "image" → "picture"
1: "table",
# 2: seal — 忽略
}
# 最小 bbox 尺寸(PDF 点)
_MIN_BOX_SIZE = 20
@dataclass
class LayoutBox:
"""检测到的布局区域,兼容现有 _process_page 代码。"""
x0: float
y0: float
x1: float
y1: float
boxclass: str # "picture" | "table"
class _LayoutDetector:
"""单例:管理 ONNX InferenceSession 生命周期。"""
def __init__(self) -> None:
self._session: ort.InferenceSession | None = None
def _init_session(self) -> ort.InferenceSession:
if self._session is not None:
return self._session
model_path = Path(settings.LAYOUT_MODEL_PATH)
if not model_path.exists():
raise FileNotFoundError(
f"Layout model not found: {model_path}. "
"Run scripts/export_picodet_onnx.py first."
)
logger.info("Loading ONNX layout model: %s", model_path)
self._session = ort.InferenceSession(
str(model_path), providers=["CPUExecutionProvider"]
)
logger.info("ONNX layout model loaded")
return self._session
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
"""检测单页 PDF 的 figure / table 区域。
流程:
1. pymupdf 以 480×480 渲染页面
2. ImageNet normalize → NCHW
3. ONNX 推理 → 得到已解码+NMS 的检测框
4. 像素坐标 → PDF 点坐标
5. 过滤 seal 类和低置信度框
Args:
page: pymupdf Page 对象
Returns:
LayoutBox 列表,坐标为 PDF 点
"""
session = self._init_session()
page_w = page.rect.width
page_h = page.rect.height
# 1. 渲染页面到 _MODEL_SIZE × _MODEL_SIZE
zoom_x = _MODEL_SIZE / page_w
zoom_y = _MODEL_SIZE / page_h
mat = pymupdf.Matrix(zoom_x, zoom_y)
pix = page.get_pixmap(matrix=mat)
# 2. 预处理
img = (
np.frombuffer(pix.samples, dtype=np.uint8)
.reshape(pix.height, pix.width, pix.n)
.astype(np.float32)
/ 255.0
)
# 去掉 alpha 通道(如有)
if img.shape[2] == 4:
img = img[:, :, :3]
img = (img - _MEAN) / _STD
img = img.transpose(2, 0, 1)[np.newaxis] # (1, 3, H, W)
# scale_factor 用于坐标还原(模型内部可能用)
scale_factor = np.array([[1.0, 1.0]], dtype=np.float32)
# 3. 推理
input_names = [i.name for i in session.get_inputs()]
feed = {input_names[0]: img}
if len(input_names) > 1:
feed[input_names[1]] = scale_factor
outputs = session.run(None, feed)
boxes_raw = outputs[0] # (N, 6): [class_id, score, xmin, ymin, xmax, ymax]
num_boxes = int(outputs[1][0]) # 有效框数
if num_boxes == 0:
return []
# 4. 像素 → PDF 点坐标
sx = page_w / _MODEL_SIZE
sy = page_h / _MODEL_SIZE
result: list[LayoutBox] = []
for i in range(min(num_boxes, len(boxes_raw))):
cls_id, score, xmin, ymin, xmax, ymax = boxes_raw[i]
cls_id = int(cls_id)
# 跳过 seal 类和低置信度
if cls_id not in _LABEL_MAP:
continue
if score < settings.LAYOUT_THRESHOLD:
continue
x0, y0 = xmin * sx, ymin * sy
x1, y1 = xmax * sx, ymax * sy
# 跳过极小区域
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
continue
result.append(
LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=_LABEL_MAP[cls_id])
)
return result
# 模块级单例
_detector = _LayoutDetector()
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
"""检测 PDF 页面中的 figure / table 区域。
Returns:
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table。
"""
return _detector.detect_page(page)