refactor: 迁移布局检测模型从 PicoDet 到 DocLayout-YOLO

- 核心变更:
  - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024)
  - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备
  - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式
  - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls]
  - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table)

- 新增文件:
  - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行)
  - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例)

- 配置更新:
  - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID)
  - app/config.py: Settings 类对应字段
  - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等)

- 删除旧文件:
  - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本

- 文档更新:
  - README.md: 更新环境变量说明
  - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py)

此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
This commit is contained in:
2026-06-14 10:41:44 +08:00
parent 743d69efd0
commit 90fe705e8f
22 changed files with 2220 additions and 356 deletions
+7 -2
View File
@@ -60,8 +60,13 @@ class Settings(BaseSettings):
EMBED_DIMENSIONS: int = 0
# 布局检测
LAYOUT_MODEL_PATH: str = "data/models/picodet_layout_3cls.onnx"
LAYOUT_THRESHOLD: float = 0.5
LAYOUT_MODEL_PATH: str = "data/models/doclayout_yolo_docstructbench_imgsz1024.onnx"
LAYOUT_IMGSZ: int = 1024
LAYOUT_THRESHOLD: float = 0.2
# 推理设备:auto|cpu|cuda|directml|openvino|cann|tensorrt|qnn
# auto = 按优先级 [CUDA, DirectML, OpenVINO, CPU] 自动探测
LAYOUT_DEVICE: str = "auto"
LAYOUT_DEVICE_ID: int = 0
model_config = {
"env_file": str(BASE_DIR / ".env"),
+8 -1
View File
@@ -10,7 +10,14 @@ from fastapi.staticfiles import StaticFiles
from starlette.middleware.sessions import SessionMiddleware
from app.config import settings
from app.exceptions import AppError, ConflictError, ExternalAPIError, NotFoundError, PdfProcessError, ValidationError
from app.exceptions import (
AppError,
ConflictError,
ExternalAPIError,
NotFoundError,
PdfProcessError,
ValidationError,
)
from app.database import engine, init_db
from app.routes.admin import router as admin_router
from app.routes.compare import router as compare_router
+9 -1
View File
@@ -6,7 +6,15 @@ import hashlib
import hmac
from datetime import date
from fastapi import APIRouter, BackgroundTasks, Depends, Form, HTTPException, Query, Request
from fastapi import (
APIRouter,
BackgroundTasks,
Depends,
Form,
HTTPException,
Query,
Request,
)
from fastapi.responses import RedirectResponse
from pydantic import BaseModel, field_validator
from sqlalchemy.orm import Session
-1
View File
@@ -2,7 +2,6 @@
from __future__ import annotations
import json
import logging
from datetime import date, timedelta
+3 -1
View File
@@ -141,7 +141,9 @@ async def run_job(db: Session, job_id: int) -> dict:
status=JobEventStatus.SUCCESS,
payload=result if isinstance(result, dict) else {"result": result},
)
return result if isinstance(result, dict) else {"status": "success", "result": result}
return (
result if isinstance(result, dict) else {"status": "success", "result": result}
)
async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
+255 -86
View File
@@ -1,19 +1,27 @@
"""PicoDet-S_layout_3cls 布局检测 — ONNX Runtime 推理.
"""DocLayout-YOLO 布局检测 — ONNX Runtime 推理,支持 CPU/GPU/NPU 多设备.
用 onnxruntime 加载导出好的 ONNX 模型,检测 PDF 页面中的 figure / table 区域。
模型自带 NMS + GFL decode,输出即为后处理完毕的检测框
用 onnxruntime 加载 DocLayout-YOLODocStructBench, imgsz=1024ONNX 模型,
检测 PDF 页面中的 figure / table 区域
预处理:letterbox(保比例缩放 + 灰边 padding 到 imgsz×imgsz),RGB,仅 /255 归一化
(不做 ImageNet mean/std)。缩放由 pymupdf Matrix 完成,不依赖 OpenCV。
后处理:YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls](已内置 NMS)。
坐标还原:(model_coord - padding) / ratio —— 渲染缩放与 letterbox 缩放在 pymupdf
渲染阶段合二为一,故只需一次除法。
设备:resolve_providers() 按 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表;
_init_session() 逐个 try,首个不可用则降级,CPU 永远兜底。
输入:
image: (1, 3, 480, 480) float32 — ImageNet 标准化后的图
scale_factor: (1, 2) float32 — [y_scale, x_scale],用于坐标还原
images: (1, 3, imgsz, imgsz) float32 — letterbox + /255 后的图
输出:
fetch_name_0: (N, 6) float32 — [xmin, ymin, xmax, ymax, score, class_id]
fetch_name_1: (1,) int32 — 有效框数量 N
output0: (1, N, 6) float32 — [x1, y1, x2, y2, conf, cls],已 NMS
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from pathlib import Path
@@ -26,30 +34,190 @@ from app.config import settings
logger = logging.getLogger(__name__)
# 模型输入尺寸
_MODEL_SIZE = 480
# ImageNet normalize
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
# PicoDet label → 内部 boxclass
_LABEL_MAP: dict[int, str] = {
0: "picture", # PicoDet "image" → "picture"
1: "table",
# 2: seal — 忽略
# DocLayout-YOLO DocStructBench 标准 10 类(ONNX metadata 读不到时的兜底,以实际为准)
_FALLBACK_NAMES: dict[int, str] = {
0: "title",
1: "plain text",
2: "abandon",
3: "figure",
4: "figure_caption",
5: "table",
6: "table_caption",
7: "table_footnote",
8: "isolate_formula",
9: "formula_caption",
}
# 下游只需 picture/table —— 按 class name 字符串动态匹配(不依赖 class index
# 规避 DocStructBench 不同发布的类别顺序差异)
_PICTURE_NAMES = {"figure", "figure_group"}
_TABLE_NAMES = {"table", "table_group"}
# letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降)
_PAD_VALUE = 114
# 最小 bbox 尺寸(PDF 点)
_MIN_BOX_SIZE = 20
# device → ExecutionProvider 映射
_PROVIDER_MAP: dict[str, str] = {
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
"directml": "DmlExecutionProvider",
"openvino": "OpenVINOExecutionProvider",
"cann": "CannExecutionProvider",
"tensorrt": "TensorrtExecutionProvider",
"qnn": "QNNExecutionProvider",
}
# auto 探测优先级(不含 cpu,cpu 永远兜底)
_AUTO_PRIORITY = ["cuda", "directml", "openvino", "cann", "tensorrt", "qnn"]
@dataclass
class LayoutBox:
"""检测到的布局区域,兼容现有 _process_page 代码"""
"""检测到的布局区域,坐标为 PDF 点,boxclass ∈ {"picture", "table"}"""
x0: float
y0: float
x1: float
y1: float
boxclass: str # "picture" | "table"
boxclass: str
# ── 设备选择 ────────────────────────────────────────────────────────────
def resolve_providers(device: str, device_id: int) -> list[tuple[str, dict]]:
"""根据 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表(首选在前,均带 CPU 兜底)。
返回 list[tuple[ep_name, provider_options]],供 _init_session() 逐个 try。
onnxruntime 创建 session 时若指定 EP 在本机变体里未注册会直接抛错,
故降级逻辑由 _init_session() 完成,这里只产出候选。
"""
if device == "cpu":
return [("CPUExecutionProvider", {})]
opts = {"device_id": str(device_id)}
if device == "auto":
available = set(ort.get_available_providers())
for dev in _AUTO_PRIORITY:
ep = _PROVIDER_MAP[dev]
if ep in available:
logger.info("auto: selected provider %s", ep)
return [(ep, opts), ("CPUExecutionProvider", {})]
logger.info("auto: no GPU/NPU provider available, using CPU")
return [("CPUExecutionProvider", {})]
ep = _PROVIDER_MAP.get(device)
if ep is None:
logger.warning("Unknown LAYOUT_DEVICE=%r, falling back to CPU", device)
return [("CPUExecutionProvider", {})]
return [(ep, opts), ("CPUExecutionProvider", {})]
# ── 预处理:渲染几何与 letterbox ────────────────────────────────────────
def _compute_render_geometry(page_w: float, page_h: float, imgsz: int) -> float:
"""letterbox 渲染缩放 ratio = min(imgsz/page_w, imgsz/page_h)。
pymupdf 以 Matrix(ratio, ratio) 渲染,长边贴到 imgsz,短边留灰边。
"""
return min(imgsz / page_w, imgsz / page_h)
def _letterbox_padding(
content_w: float, content_h: float, imgsz: int
) -> tuple[float, float]:
"""居中 padding(imgsz - content) / 2。content 为实际 pixmap 尺寸(已取整)。"""
return (imgsz - content_w) / 2.0, (imgsz - content_h) / 2.0
def _padded_nchw_from_pixmap(
pix: pymupdf.Pixmap, imgsz: int, dw: float, dh: float
) -> np.ndarray:
"""pixmap → letterbox padded (1, 3, imgsz, imgsz) float32,灰边=114/255 归一化。"""
arr = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
pix.height, pix.width, pix.n
)
if arr.shape[2] == 4: # 去 alphacsRGB alpha=False 一般不会,防御性)
arr = arr[:, :, :3]
canvas = np.full((imgsz, imgsz, 3), _PAD_VALUE, dtype=np.uint8)
top = int(round(dh))
left = int(round(dw))
canvas[top : top + pix.height, left : left + pix.width] = arr
out = canvas.astype(np.float32) / 255.0
return out.transpose(2, 0, 1)[np.newaxis] # (1, 3, imgsz, imgsz)
def _model_to_pdf(
model_x: float, model_y: float, dw: float, dh: float, ratio: float
) -> tuple[float, float]:
"""模型 imgsz 空间坐标 → PDF 点:(model - padding) / ratio。"""
return (model_x - dw) / ratio, (model_y - dh) / ratio
# ── 后处理 ──────────────────────────────────────────────────────────────
def _postprocess_output(
output: np.ndarray, threshold: float, names: dict[int, str]
) -> list[tuple[int, float, float, float, float]]:
"""解析 YOLOv10 end-to-end 输出,过滤 conf < threshold。
Args:
output: session.run 返回的第一个输出,shape [1, N, 6]
threshold: 置信度阈值
names: class id → name(仅用于日志,过滤不依赖)
Returns:
[(cls_id, x1, y1, x2, y2), ...],坐标为模型 imgsz padded 空间。
"""
out = output[0] # 去 batch 维
if out.ndim != 2 or out.shape[1] != 6:
logger.warning(
"Unexpected DocLayout-YOLO output shape %s (expected [N,6]); skip page",
tuple(out.shape),
)
return []
results: list[tuple[int, float, float, float, float]] = []
for row in out:
x1, y1, x2, y2, conf, cls = row.tolist()
if conf < threshold:
continue
results.append((int(cls), x1, y1, x2, y2))
return results
def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None:
"""按 class name 匹配 figure→picture / table→table,其余返回 None。"""
name = names.get(cls_id, "")
n = name.strip().lower()
if n in _PICTURE_NAMES:
return "picture"
if n in _TABLE_NAMES:
return "table"
return None
def _parse_names_from_meta(session: ort.InferenceSession) -> dict[int, str]:
"""从 ONNX metadata 读 namesultralytics 导出写入的 JSON),读不到用兜底。"""
raw = None
try:
raw = session.get_modelmeta().custom_metadata_map.get("names")
except Exception:
raw = None
if raw:
try:
d = json.loads(raw)
return {int(k): str(v) for k, v in d.items()}
except Exception:
logger.warning("Failed to parse ONNX names metadata; using fallback")
return dict(_FALLBACK_NAMES)
# ── 检测器单例 ──────────────────────────────────────────────────────────
class _LayoutDetector:
@@ -57,6 +225,9 @@ class _LayoutDetector:
def __init__(self) -> None:
self._session: ort.InferenceSession | None = None
self._names: dict[int, str] = {}
self._input_name: str = ""
self._imgsz: int = settings.LAYOUT_IMGSZ
def _init_session(self) -> ort.InferenceSession:
if self._session is not None:
@@ -66,97 +237,95 @@ class _LayoutDetector:
if not model_path.exists():
raise FileNotFoundError(
f"Layout model not found: {model_path}. "
"Run scripts/export_picodet_onnx.py first."
"Run scripts/export_doclayout_yolo_onnx.py first."
)
logger.info("Loading ONNX layout model: %s", model_path)
self._session = ort.InferenceSession(
str(model_path), providers=["CPUExecutionProvider"]
eps = resolve_providers(settings.LAYOUT_DEVICE, settings.LAYOUT_DEVICE_ID)
logger.info(
"Loading layout model %s, candidate providers: %s",
model_path,
[ep[0] for ep in eps],
)
logger.info("ONNX layout model loaded")
# 逐个 EP 尝试,首个不可用则降级
last_err: Exception | None = None
for idx, (ep_name, ep_opts) in enumerate(eps):
try:
self._session = ort.InferenceSession(
str(model_path), providers=[(ep_name, ep_opts)]
)
break
except Exception as e:
last_err = e
if idx < len(eps) - 1:
logger.warning(
"Provider %s unavailable (%s); falling back to %s",
ep_name,
e,
eps[idx + 1][0],
)
else:
raise RuntimeError(f"Failed to create layout session: {last_err}")
logger.info(
"Layout session active providers: %s", self._session.get_providers()
)
self._input_name = self._session.get_inputs()[0].name
self._names = _parse_names_from_meta(self._session)
self._imgsz = settings.LAYOUT_IMGSZ
return self._session
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
"""检测单页 PDF 的 figure / table 区域。
流程:
1. pymupdf 以 480×480 渲染页面
2. ImageNet normalize → NCHW
3. ONNX 推理 → 得到已解码+NMS 的检测框
4. 像素坐标 → PDF 点坐标
5. 过滤 seal 类和低置信度框
Args:
page: pymupdf Page 对象
1. letterbox 渲染:保比例缩放到长边=imgsz,短边留灰边
2. /255 + NCHW → ONNX 推理
3. YOLOv10 end-to-end 后处理(已 NMS
4. 模型坐标 → PDF 点
5. 过滤非 figure/table 类、极小框、越界 clip
Returns:
LayoutBox 列表,坐标为 PDF 点
LayoutBox 列表,坐标为 PDF 点
"""
session = self._init_session()
page_w = page.rect.width
page_h = page.rect.height
ratio = _compute_render_geometry(page_w, page_h, self._imgsz)
# 1. 渲染页面到 _MODEL_SIZE × _MODEL_SIZE
zoom_x = _MODEL_SIZE / page_w
zoom_y = _MODEL_SIZE / page_h
mat = pymupdf.Matrix(zoom_x, zoom_y)
pix = page.get_pixmap(matrix=mat)
# 2. 预处理
img = (
np.frombuffer(pix.samples, dtype=np.uint8)
.reshape(pix.height, pix.width, pix.n)
.astype(np.float32)
/ 255.0
# 1. 保比例渲染(长边贴 imgsz
pix = page.get_pixmap(
matrix=pymupdf.Matrix(ratio, ratio),
colorspace=pymupdf.csRGB,
alpha=False,
)
# 去掉 alpha 通道(如有)
if img.shape[2] == 4:
img = img[:, :, :3]
img = (img - _MEAN) / _STD
img = img.transpose(2, 0, 1)[np.newaxis] # (1, 3, H, W)
# 用 pixmap 实际尺寸(已取整)算 padding,消除取整导致的坐标偏移
dw, dh = _letterbox_padding(pix.width, pix.height, self._imgsz)
tensor = _padded_nchw_from_pixmap(pix, self._imgsz, dw, dh)
# scale_factor 用于坐标还原(模型内部可能用)
scale_factor = np.array([[1.0, 1.0]], dtype=np.float32)
# 3. 推理
input_names = [i.name for i in session.get_inputs()]
feed = {input_names[0]: img}
if len(input_names) > 1:
feed[input_names[1]] = scale_factor
outputs = session.run(None, feed)
boxes_raw = outputs[0] # (N, 6): [class_id, score, xmin, ymin, xmax, ymax]
num_boxes = int(outputs[1][0]) # 有效框数
if num_boxes == 0:
return []
# 4. 像素 → PDF 点坐标
sx = page_w / _MODEL_SIZE
sy = page_h / _MODEL_SIZE
# 2. 推理
outputs = session.run(None, {self._input_name: tensor})
detections = _postprocess_output(
outputs[0], settings.LAYOUT_THRESHOLD, self._names
)
# 3. 坐标还原 + 过滤
result: list[LayoutBox] = []
for i in range(min(num_boxes, len(boxes_raw))):
cls_id, score, xmin, ymin, xmax, ymax = boxes_raw[i]
cls_id = int(cls_id)
# 跳过 seal 类和低置信度
if cls_id not in _LABEL_MAP:
for cls_id, x1m, y1m, x2m, y2m in detections:
boxclass = _map_class_to_boxclass(cls_id, self._names)
if boxclass is None:
continue
if score < settings.LAYOUT_THRESHOLD:
continue
x0, y0 = xmin * sx, ymin * sy
x1, y1 = xmax * sx, ymax * sy
# 跳过极小区域
x0, y0 = _model_to_pdf(x1m, y1m, dw, dh, ratio)
x1, y1 = _model_to_pdf(x2m, y2m, dw, dh, ratio)
# clip 到页面范围
x0 = max(0.0, min(x0, page_w))
y0 = max(0.0, min(y0, page_h))
x1 = max(0.0, min(x1, page_w))
y1 = max(0.0, min(y1, page_h))
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
continue
result.append(
LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=_LABEL_MAP[cls_id])
)
result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass))
return result
+3 -1
View File
@@ -71,7 +71,9 @@ async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
try:
session = _get_session()
resp = session.get(pdf_url, timeout=settings.PDF_DOWNLOAD_TIMEOUT, allow_redirects=True)
resp = session.get(
pdf_url, timeout=settings.PDF_DOWNLOAD_TIMEOUT, allow_redirects=True
)
resp.raise_for_status()
dest.write_bytes(resp.content)
except Exception as exc:
+1 -1
View File
@@ -1,6 +1,6 @@
"""PDF 图片与表格提取 — 两阶段流水线。
Phase 1: PicoDet-S_layout_3cls 检测 figure/table 区域 → 渲染为 JPEG(通用标签)
Phase 1: DocLayout-YOLO 检测 figure/table 区域 → 渲染为 JPEG(通用标签)
Phase 2: 用 LLM summary 的 figures[].id 在 PDF 中搜索定位 → 匹配到 box → 重命名
相比旧方案(正则匹配 caption):
+6 -2
View File
@@ -80,13 +80,17 @@ async def call_pi(
actual_mode = "search"
logger.info(
"Auto mode: %s text=%d chars > %dk → search",
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
arxiv_id,
txt_size,
_PDF_MAX_CHARS // 1000,
)
else:
actual_mode = "inject"
logger.info(
"Auto mode: %s text=%d chars ≤ %dk → inject",
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
arxiv_id,
txt_size,
_PDF_MAX_CHARS // 1000,
)
# inject 模式需要截断过长的文本(避免撑爆 context)
+1 -1
View File
@@ -225,7 +225,7 @@ def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
"""从 PDF 提取图片和表格(失败不影响总结)。
两阶段流水线:
1. PicoDet 检测 + 渲染截图(通用标签)
1. DocLayout-YOLO 检测 + 渲染截图(通用标签)
2. 用 summary 的 figures ID 在 PDF 中搜索定位 → 重命名
"""
try: