refactor: 迁移布局检测模型从 PicoDet 到 DocLayout-YOLO

- 核心变更:
  - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024)
  - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备
  - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式
  - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls]
  - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table)

- 新增文件:
  - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行)
  - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例)

- 配置更新:
  - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID)
  - app/config.py: Settings 类对应字段
  - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等)

- 删除旧文件:
  - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本

- 文档更新:
  - README.md: 更新环境变量说明
  - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py)

此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
This commit is contained in:
2026-06-14 10:41:44 +08:00
parent 743d69efd0
commit 90fe705e8f
22 changed files with 2220 additions and 356 deletions
+15 -3
View File
@@ -48,6 +48,18 @@ EMBED_MODEL=Qwen/Qwen3-Embedding-4B
EMBED_DIMENSIONS=2560 EMBED_DIMENSIONS=2560
# ─── 布局检测 ───────────────────────────── # ─── 布局检测 ─────────────────────────────
# ONNX 模型路径(首次运行前执行 scripts/export_picodet_onnx.py 导出) # DocLayout-YOLO ONNX 模型(首次运行前执行 scripts/export_doclayout_yolo_onnx.py 导出)
# LAYOUT_MODEL_PATH=data/models/picodet_layout_3cls.onnx # LAYOUT_MODEL_PATH=data/models/doclayout_yolo_docstructbench_imgsz1024.onnx
# LAYOUT_THRESHOLD=0.5 # 模型输入尺寸(DocLayout-YOLO 推荐 1024
# LAYOUT_IMGSZ=1024
# 检测置信度阈值(DocLayout-YOLO 推荐 0.2
# LAYOUT_THRESHOLD=0.2
# 推理设备:auto|cpu|cuda|directml|openvino|cann|tensorrt|qnn
# auto = 按优先级 [CUDA, DirectML, OpenVINO, CPU] 自动探测,失败降级 CPU
# LAYOUT_DEVICE=auto
# 设备 IDGPU 序号)
# LAYOUT_DEVICE_ID=0
#
# GPU 用户:onnxruntime 与 onnxruntime-gpu/-directml 同环境冲突,需手动二选一:
# pip uninstall onnxruntime && pip install onnxruntime-gpu # NVIDIA CUDA
# pip uninstall onnxruntime && pip install onnxruntime-directml # Windows 任意 GPU
+5 -3
View File
@@ -125,7 +125,7 @@ paper/
├── scripts/ ├── scripts/
│ ├── init_db.py # 数据库初始化 │ ├── init_db.py # 数据库初始化
│ ├── manual_crawl.py # 手动抓取脚本 │ ├── manual_crawl.py # 手动抓取脚本
│ ├── export_picodet_onnx.py # 导出布局检测 ONNX 模型 │ ├── export_doclayout_yolo_onnx.py # 导出布局检测 ONNX 模型
│ ├── reextract_images.py # 批量重新提取图片 │ ├── reextract_images.py # 批量重新提取图片
│ └── validate_summary.py # 校验总结 JSON 结构 │ └── validate_summary.py # 校验总结 JSON 结构
@@ -198,8 +198,10 @@ SECRET_KEY=your_random_secret_key
| `EMBED_API_KEY` | — | Embedding API Key | | `EMBED_API_KEY` | — | Embedding API Key |
| `EMBED_MODEL` | — | Embedding 模型名 | | `EMBED_MODEL` | — | Embedding 模型名 |
| `EMBED_DIMENSIONS` | `0` | 向量维度 | | `EMBED_DIMENSIONS` | `0` | 向量维度 |
| `LAYOUT_MODEL_PATH` | `data/models/picodet_layout_3cls.onnx` | ONNX 布局检测模型路径(可选) | | `LAYOUT_MODEL_PATH` | `data/models/doclayout_yolo_docstructbench_imgsz1024.onnx` | DocLayout-YOLO ONNX 模型路径(可选) |
| `LAYOUT_THRESHOLD` | `0.5` | 布局检测置信度阈值(可选) | | `LAYOUT_IMGSZ` | `1024` | 模型输入尺寸 |
| `LAYOUT_THRESHOLD` | `0.2` | 布局检测置信度阈值(可选) |
| `LAYOUT_DEVICE` | `auto` | 推理设备:auto/cpu/cuda/directml/openvino/...(可选) |
### 4. 初始化数据库 ### 4. 初始化数据库
+7 -2
View File
@@ -60,8 +60,13 @@ class Settings(BaseSettings):
EMBED_DIMENSIONS: int = 0 EMBED_DIMENSIONS: int = 0
# 布局检测 # 布局检测
LAYOUT_MODEL_PATH: str = "data/models/picodet_layout_3cls.onnx" LAYOUT_MODEL_PATH: str = "data/models/doclayout_yolo_docstructbench_imgsz1024.onnx"
LAYOUT_THRESHOLD: float = 0.5 LAYOUT_IMGSZ: int = 1024
LAYOUT_THRESHOLD: float = 0.2
# 推理设备:auto|cpu|cuda|directml|openvino|cann|tensorrt|qnn
# auto = 按优先级 [CUDA, DirectML, OpenVINO, CPU] 自动探测
LAYOUT_DEVICE: str = "auto"
LAYOUT_DEVICE_ID: int = 0
model_config = { model_config = {
"env_file": str(BASE_DIR / ".env"), "env_file": str(BASE_DIR / ".env"),
+8 -1
View File
@@ -10,7 +10,14 @@ from fastapi.staticfiles import StaticFiles
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from app.config import settings from app.config import settings
from app.exceptions import AppError, ConflictError, ExternalAPIError, NotFoundError, PdfProcessError, ValidationError from app.exceptions import (
AppError,
ConflictError,
ExternalAPIError,
NotFoundError,
PdfProcessError,
ValidationError,
)
from app.database import engine, init_db from app.database import engine, init_db
from app.routes.admin import router as admin_router from app.routes.admin import router as admin_router
from app.routes.compare import router as compare_router from app.routes.compare import router as compare_router
+9 -1
View File
@@ -6,7 +6,15 @@ import hashlib
import hmac import hmac
from datetime import date from datetime import date
from fastapi import APIRouter, BackgroundTasks, Depends, Form, HTTPException, Query, Request from fastapi import (
APIRouter,
BackgroundTasks,
Depends,
Form,
HTTPException,
Query,
Request,
)
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
-1
View File
@@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
from datetime import date, timedelta from datetime import date, timedelta
+3 -1
View File
@@ -141,7 +141,9 @@ async def run_job(db: Session, job_id: int) -> dict:
status=JobEventStatus.SUCCESS, status=JobEventStatus.SUCCESS,
payload=result if isinstance(result, dict) else {"result": result}, payload=result if isinstance(result, dict) else {"result": result},
) )
return result if isinstance(result, dict) else {"status": "success", "result": result} return (
result if isinstance(result, dict) else {"status": "success", "result": result}
)
async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict: async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
+255 -86
View File
@@ -1,19 +1,27 @@
"""PicoDet-S_layout_3cls 布局检测 — ONNX Runtime 推理. """DocLayout-YOLO 布局检测 — ONNX Runtime 推理,支持 CPU/GPU/NPU 多设备.
用 onnxruntime 加载导出好的 ONNX 模型,检测 PDF 页面中的 figure / table 区域。 用 onnxruntime 加载 DocLayout-YOLODocStructBench, imgsz=1024ONNX 模型,
模型自带 NMS + GFL decode,输出即为后处理完毕的检测框 检测 PDF 页面中的 figure / table 区域
预处理:letterbox(保比例缩放 + 灰边 padding 到 imgsz×imgsz),RGB,仅 /255 归一化
(不做 ImageNet mean/std)。缩放由 pymupdf Matrix 完成,不依赖 OpenCV。
后处理:YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls](已内置 NMS)。
坐标还原:(model_coord - padding) / ratio —— 渲染缩放与 letterbox 缩放在 pymupdf
渲染阶段合二为一,故只需一次除法。
设备:resolve_providers() 按 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表;
_init_session() 逐个 try,首个不可用则降级,CPU 永远兜底。
输入: 输入:
image: (1, 3, 480, 480) float32 — ImageNet 标准化后的图 images: (1, 3, imgsz, imgsz) float32 — letterbox + /255 后的图
scale_factor: (1, 2) float32 — [y_scale, x_scale],用于坐标还原
输出: 输出:
fetch_name_0: (N, 6) float32 — [xmin, ymin, xmax, ymax, score, class_id] output0: (1, N, 6) float32 — [x1, y1, x2, y2, conf, cls],已 NMS
fetch_name_1: (1,) int32 — 有效框数量 N
""" """
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
@@ -26,30 +34,190 @@ from app.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 模型输入尺寸 # DocLayout-YOLO DocStructBench 标准 10 类(ONNX metadata 读不到时的兜底,以实际为准)
_MODEL_SIZE = 480 _FALLBACK_NAMES: dict[int, str] = {
# ImageNet normalize 0: "title",
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) 1: "plain text",
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) 2: "abandon",
# PicoDet label → 内部 boxclass 3: "figure",
_LABEL_MAP: dict[int, str] = { 4: "figure_caption",
0: "picture", # PicoDet "image" → "picture" 5: "table",
1: "table", 6: "table_caption",
# 2: seal — 忽略 7: "table_footnote",
8: "isolate_formula",
9: "formula_caption",
} }
# 下游只需 picture/table —— 按 class name 字符串动态匹配(不依赖 class index
# 规避 DocStructBench 不同发布的类别顺序差异)
_PICTURE_NAMES = {"figure", "figure_group"}
_TABLE_NAMES = {"table", "table_group"}
# letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降)
_PAD_VALUE = 114
# 最小 bbox 尺寸(PDF 点) # 最小 bbox 尺寸(PDF 点)
_MIN_BOX_SIZE = 20 _MIN_BOX_SIZE = 20
# device → ExecutionProvider 映射
_PROVIDER_MAP: dict[str, str] = {
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
"directml": "DmlExecutionProvider",
"openvino": "OpenVINOExecutionProvider",
"cann": "CannExecutionProvider",
"tensorrt": "TensorrtExecutionProvider",
"qnn": "QNNExecutionProvider",
}
# auto 探测优先级(不含 cpu,cpu 永远兜底)
_AUTO_PRIORITY = ["cuda", "directml", "openvino", "cann", "tensorrt", "qnn"]
@dataclass @dataclass
class LayoutBox: class LayoutBox:
"""检测到的布局区域,兼容现有 _process_page 代码""" """检测到的布局区域,坐标为 PDF 点,boxclass ∈ {"picture", "table"}"""
x0: float x0: float
y0: float y0: float
x1: float x1: float
y1: float y1: float
boxclass: str # "picture" | "table" boxclass: str
# ── 设备选择 ────────────────────────────────────────────────────────────
def resolve_providers(device: str, device_id: int) -> list[tuple[str, dict]]:
"""根据 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表(首选在前,均带 CPU 兜底)。
返回 list[tuple[ep_name, provider_options]],供 _init_session() 逐个 try。
onnxruntime 创建 session 时若指定 EP 在本机变体里未注册会直接抛错,
故降级逻辑由 _init_session() 完成,这里只产出候选。
"""
if device == "cpu":
return [("CPUExecutionProvider", {})]
opts = {"device_id": str(device_id)}
if device == "auto":
available = set(ort.get_available_providers())
for dev in _AUTO_PRIORITY:
ep = _PROVIDER_MAP[dev]
if ep in available:
logger.info("auto: selected provider %s", ep)
return [(ep, opts), ("CPUExecutionProvider", {})]
logger.info("auto: no GPU/NPU provider available, using CPU")
return [("CPUExecutionProvider", {})]
ep = _PROVIDER_MAP.get(device)
if ep is None:
logger.warning("Unknown LAYOUT_DEVICE=%r, falling back to CPU", device)
return [("CPUExecutionProvider", {})]
return [(ep, opts), ("CPUExecutionProvider", {})]
# ── 预处理:渲染几何与 letterbox ────────────────────────────────────────
def _compute_render_geometry(page_w: float, page_h: float, imgsz: int) -> float:
"""letterbox 渲染缩放 ratio = min(imgsz/page_w, imgsz/page_h)。
pymupdf 以 Matrix(ratio, ratio) 渲染,长边贴到 imgsz,短边留灰边。
"""
return min(imgsz / page_w, imgsz / page_h)
def _letterbox_padding(
content_w: float, content_h: float, imgsz: int
) -> tuple[float, float]:
"""居中 padding(imgsz - content) / 2。content 为实际 pixmap 尺寸(已取整)。"""
return (imgsz - content_w) / 2.0, (imgsz - content_h) / 2.0
def _padded_nchw_from_pixmap(
pix: pymupdf.Pixmap, imgsz: int, dw: float, dh: float
) -> np.ndarray:
"""pixmap → letterbox padded (1, 3, imgsz, imgsz) float32,灰边=114/255 归一化。"""
arr = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
pix.height, pix.width, pix.n
)
if arr.shape[2] == 4: # 去 alphacsRGB alpha=False 一般不会,防御性)
arr = arr[:, :, :3]
canvas = np.full((imgsz, imgsz, 3), _PAD_VALUE, dtype=np.uint8)
top = int(round(dh))
left = int(round(dw))
canvas[top : top + pix.height, left : left + pix.width] = arr
out = canvas.astype(np.float32) / 255.0
return out.transpose(2, 0, 1)[np.newaxis] # (1, 3, imgsz, imgsz)
def _model_to_pdf(
model_x: float, model_y: float, dw: float, dh: float, ratio: float
) -> tuple[float, float]:
"""模型 imgsz 空间坐标 → PDF 点:(model - padding) / ratio。"""
return (model_x - dw) / ratio, (model_y - dh) / ratio
# ── 后处理 ──────────────────────────────────────────────────────────────
def _postprocess_output(
output: np.ndarray, threshold: float, names: dict[int, str]
) -> list[tuple[int, float, float, float, float]]:
"""解析 YOLOv10 end-to-end 输出,过滤 conf < threshold。
Args:
output: session.run 返回的第一个输出,shape [1, N, 6]
threshold: 置信度阈值
names: class id → name(仅用于日志,过滤不依赖)
Returns:
[(cls_id, x1, y1, x2, y2), ...],坐标为模型 imgsz padded 空间。
"""
out = output[0] # 去 batch 维
if out.ndim != 2 or out.shape[1] != 6:
logger.warning(
"Unexpected DocLayout-YOLO output shape %s (expected [N,6]); skip page",
tuple(out.shape),
)
return []
results: list[tuple[int, float, float, float, float]] = []
for row in out:
x1, y1, x2, y2, conf, cls = row.tolist()
if conf < threshold:
continue
results.append((int(cls), x1, y1, x2, y2))
return results
def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None:
"""按 class name 匹配 figure→picture / table→table,其余返回 None。"""
name = names.get(cls_id, "")
n = name.strip().lower()
if n in _PICTURE_NAMES:
return "picture"
if n in _TABLE_NAMES:
return "table"
return None
def _parse_names_from_meta(session: ort.InferenceSession) -> dict[int, str]:
"""从 ONNX metadata 读 namesultralytics 导出写入的 JSON),读不到用兜底。"""
raw = None
try:
raw = session.get_modelmeta().custom_metadata_map.get("names")
except Exception:
raw = None
if raw:
try:
d = json.loads(raw)
return {int(k): str(v) for k, v in d.items()}
except Exception:
logger.warning("Failed to parse ONNX names metadata; using fallback")
return dict(_FALLBACK_NAMES)
# ── 检测器单例 ──────────────────────────────────────────────────────────
class _LayoutDetector: class _LayoutDetector:
@@ -57,6 +225,9 @@ class _LayoutDetector:
def __init__(self) -> None: def __init__(self) -> None:
self._session: ort.InferenceSession | None = None self._session: ort.InferenceSession | None = None
self._names: dict[int, str] = {}
self._input_name: str = ""
self._imgsz: int = settings.LAYOUT_IMGSZ
def _init_session(self) -> ort.InferenceSession: def _init_session(self) -> ort.InferenceSession:
if self._session is not None: if self._session is not None:
@@ -66,97 +237,95 @@ class _LayoutDetector:
if not model_path.exists(): if not model_path.exists():
raise FileNotFoundError( raise FileNotFoundError(
f"Layout model not found: {model_path}. " f"Layout model not found: {model_path}. "
"Run scripts/export_picodet_onnx.py first." "Run scripts/export_doclayout_yolo_onnx.py first."
) )
logger.info("Loading ONNX layout model: %s", model_path) eps = resolve_providers(settings.LAYOUT_DEVICE, settings.LAYOUT_DEVICE_ID)
self._session = ort.InferenceSession( logger.info(
str(model_path), providers=["CPUExecutionProvider"] "Loading layout model %s, candidate providers: %s",
model_path,
[ep[0] for ep in eps],
) )
logger.info("ONNX layout model loaded")
# 逐个 EP 尝试,首个不可用则降级
last_err: Exception | None = None
for idx, (ep_name, ep_opts) in enumerate(eps):
try:
self._session = ort.InferenceSession(
str(model_path), providers=[(ep_name, ep_opts)]
)
break
except Exception as e:
last_err = e
if idx < len(eps) - 1:
logger.warning(
"Provider %s unavailable (%s); falling back to %s",
ep_name,
e,
eps[idx + 1][0],
)
else:
raise RuntimeError(f"Failed to create layout session: {last_err}")
logger.info(
"Layout session active providers: %s", self._session.get_providers()
)
self._input_name = self._session.get_inputs()[0].name
self._names = _parse_names_from_meta(self._session)
self._imgsz = settings.LAYOUT_IMGSZ
return self._session return self._session
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]: def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
"""检测单页 PDF 的 figure / table 区域。 """检测单页 PDF 的 figure / table 区域。
流程: 流程:
1. pymupdf 以 480×480 渲染页面 1. letterbox 渲染:保比例缩放到长边=imgsz,短边留灰边
2. ImageNet normalize → NCHW 2. /255 + NCHW → ONNX 推理
3. ONNX 推理 → 得到已解码+NMS 的检测框 3. YOLOv10 end-to-end 后处理(已 NMS
4. 像素坐标 → PDF 点坐标 4. 模型坐标 → PDF 点
5. 过滤 seal 类和低置信度框 5. 过滤非 figure/table 类、极小框、越界 clip
Args:
page: pymupdf Page 对象
Returns: Returns:
LayoutBox 列表,坐标为 PDF 点 LayoutBox 列表,坐标为 PDF 点
""" """
session = self._init_session() session = self._init_session()
page_w = page.rect.width page_w = page.rect.width
page_h = page.rect.height page_h = page.rect.height
ratio = _compute_render_geometry(page_w, page_h, self._imgsz)
# 1. 渲染页面到 _MODEL_SIZE × _MODEL_SIZE # 1. 保比例渲染(长边贴 imgsz
zoom_x = _MODEL_SIZE / page_w pix = page.get_pixmap(
zoom_y = _MODEL_SIZE / page_h matrix=pymupdf.Matrix(ratio, ratio),
mat = pymupdf.Matrix(zoom_x, zoom_y) colorspace=pymupdf.csRGB,
pix = page.get_pixmap(matrix=mat) alpha=False,
# 2. 预处理
img = (
np.frombuffer(pix.samples, dtype=np.uint8)
.reshape(pix.height, pix.width, pix.n)
.astype(np.float32)
/ 255.0
) )
# 去掉 alpha 通道(如有) # 用 pixmap 实际尺寸(已取整)算 padding,消除取整导致的坐标偏移
if img.shape[2] == 4: dw, dh = _letterbox_padding(pix.width, pix.height, self._imgsz)
img = img[:, :, :3] tensor = _padded_nchw_from_pixmap(pix, self._imgsz, dw, dh)
img = (img - _MEAN) / _STD
img = img.transpose(2, 0, 1)[np.newaxis] # (1, 3, H, W)
# scale_factor 用于坐标还原(模型内部可能用) # 2. 推理
scale_factor = np.array([[1.0, 1.0]], dtype=np.float32) outputs = session.run(None, {self._input_name: tensor})
detections = _postprocess_output(
# 3. 推理 outputs[0], settings.LAYOUT_THRESHOLD, self._names
input_names = [i.name for i in session.get_inputs()] )
feed = {input_names[0]: img}
if len(input_names) > 1:
feed[input_names[1]] = scale_factor
outputs = session.run(None, feed)
boxes_raw = outputs[0] # (N, 6): [class_id, score, xmin, ymin, xmax, ymax]
num_boxes = int(outputs[1][0]) # 有效框数
if num_boxes == 0:
return []
# 4. 像素 → PDF 点坐标
sx = page_w / _MODEL_SIZE
sy = page_h / _MODEL_SIZE
# 3. 坐标还原 + 过滤
result: list[LayoutBox] = [] result: list[LayoutBox] = []
for i in range(min(num_boxes, len(boxes_raw))): for cls_id, x1m, y1m, x2m, y2m in detections:
cls_id, score, xmin, ymin, xmax, ymax = boxes_raw[i] boxclass = _map_class_to_boxclass(cls_id, self._names)
cls_id = int(cls_id) if boxclass is None:
# 跳过 seal 类和低置信度
if cls_id not in _LABEL_MAP:
continue continue
if score < settings.LAYOUT_THRESHOLD: x0, y0 = _model_to_pdf(x1m, y1m, dw, dh, ratio)
continue x1, y1 = _model_to_pdf(x2m, y2m, dw, dh, ratio)
# clip 到页面范围
x0, y0 = xmin * sx, ymin * sy x0 = max(0.0, min(x0, page_w))
x1, y1 = xmax * sx, ymax * sy y0 = max(0.0, min(y0, page_h))
x1 = max(0.0, min(x1, page_w))
# 跳过极小区域 y1 = max(0.0, min(y1, page_h))
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE: if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
continue continue
result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass))
result.append(
LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=_LABEL_MAP[cls_id])
)
return result return result
+3 -1
View File
@@ -71,7 +71,9 @@ async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
try: try:
session = _get_session() session = _get_session()
resp = session.get(pdf_url, timeout=settings.PDF_DOWNLOAD_TIMEOUT, allow_redirects=True) resp = session.get(
pdf_url, timeout=settings.PDF_DOWNLOAD_TIMEOUT, allow_redirects=True
)
resp.raise_for_status() resp.raise_for_status()
dest.write_bytes(resp.content) dest.write_bytes(resp.content)
except Exception as exc: except Exception as exc:
+1 -1
View File
@@ -1,6 +1,6 @@
"""PDF 图片与表格提取 — 两阶段流水线。 """PDF 图片与表格提取 — 两阶段流水线。
Phase 1: PicoDet-S_layout_3cls 检测 figure/table 区域 → 渲染为 JPEG(通用标签) Phase 1: DocLayout-YOLO 检测 figure/table 区域 → 渲染为 JPEG(通用标签)
Phase 2: 用 LLM summary 的 figures[].id 在 PDF 中搜索定位 → 匹配到 box → 重命名 Phase 2: 用 LLM summary 的 figures[].id 在 PDF 中搜索定位 → 匹配到 box → 重命名
相比旧方案(正则匹配 caption): 相比旧方案(正则匹配 caption):
+6 -2
View File
@@ -80,13 +80,17 @@ async def call_pi(
actual_mode = "search" actual_mode = "search"
logger.info( logger.info(
"Auto mode: %s text=%d chars > %dk → search", "Auto mode: %s text=%d chars > %dk → search",
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000, arxiv_id,
txt_size,
_PDF_MAX_CHARS // 1000,
) )
else: else:
actual_mode = "inject" actual_mode = "inject"
logger.info( logger.info(
"Auto mode: %s text=%d chars ≤ %dk → inject", "Auto mode: %s text=%d chars ≤ %dk → inject",
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000, arxiv_id,
txt_size,
_PDF_MAX_CHARS // 1000,
) )
# inject 模式需要截断过长的文本(避免撑爆 context) # inject 模式需要截断过长的文本(避免撑爆 context)
+1 -1
View File
@@ -225,7 +225,7 @@ def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
"""从 PDF 提取图片和表格(失败不影响总结)。 """从 PDF 提取图片和表格(失败不影响总结)。
两阶段流水线: 两阶段流水线:
1. PicoDet 检测 + 渲染截图(通用标签) 1. DocLayout-YOLO 检测 + 渲染截图(通用标签)
2. 用 summary 的 figures ID 在 PDF 中搜索定位 → 重命名 2. 用 summary 的 figures ID 在 PDF 中搜索定位 → 重命名
""" """
try: try:
+24 -19
View File
@@ -4,37 +4,42 @@ data = {
"arxiv_id": "2602.21760", "arxiv_id": "2602.21760",
"title_zh": "基于条件引导调度的混合数据-流水线并行加速扩散模型", "title_zh": "基于条件引导调度的混合数据-流水线并行加速扩散模型",
"one_line": "提出混合并行框架,通过条件划分与自适应流水线切换加速扩散推理,实现2.31倍提速。", "one_line": "提出混合并行框架,通过条件划分与自适应流水线切换加速扩散推理,实现2.31倍提速。",
"tags": ["Diffusion Models", "Distributed Inference", "Parallel Computing", "Image Generation"], "tags": [
"Diffusion Models",
"Distributed Inference",
"Parallel Computing",
"Image Generation",
],
"difficulty": "进阶", "difficulty": "进阶",
"prerequisites": { "prerequisites": {
"concepts": [ "concepts": [
{ {
"term": "Diffusion Models", "term": "Diffusion Models",
"explanation": "扩散模型是一类基于去噪过程的生成模型。在正向过程中,它逐渐向数据添加高斯噪声直到变成纯噪声;在反向过程中,模型学习逐步去噪以恢复原始数据。这种迭代特性虽然能生成高质量的样本,但也导致了高昂的推理计算成本。", "explanation": "扩散模型是一类基于去噪过程的生成模型。在正向过程中,它逐渐向数据添加高斯噪声直到变成纯噪声;在反向过程中,模型学习逐步去噪以恢复原始数据。这种迭代特性虽然能生成高质量的样本,但也导致了高昂的推理计算成本。",
"why_matters": "理解扩散模型的迭代去噪机制是理解本文如何通过并行化减少推理延迟的基础。" "why_matters": "理解扩散模型的迭代去噪机制是理解本文如何通过并行化减少推理延迟的基础。",
}, },
{ {
"term": "Classifier-Free Guidance (CFG)", "term": "Classifier-Free Guidance (CFG)",
"explanation": "无分类器引导是一种在推理时提升生成样本与文本条件一致性的技术。模型同时预测有条件噪声(给定文本提示)和无条件噪声(不给定提示),最终通过加权组合两者来获得最终预测。公式为 $\epsilon_{cfg} = \epsilon_\theta(x_t, c, t) + w (\epsilon_\theta(x_t, c, t) - \epsilon_\theta(x_t, t))$,其中 $w$ 是引导强度。", "explanation": "无分类器引导是一种在推理时提升生成样本与文本条件一致性的技术。模型同时预测有条件噪声(给定文本提示)和无条件噪声(不给定提示),最终通过加权组合两者来获得最终预测。公式为 $\epsilon_{cfg} = \epsilon_\theta(x_t, c, t) + w (\epsilon_\theta(x_t, c, t) - \epsilon_\theta(x_t, t))$,其中 $w$ 是引导强度。",
"why_matters": "本文的核心创新点在于利用CFG中存在的有条件和无条件双路径作为数据划分的基础。" "why_matters": "本文的核心创新点在于利用CFG中存在的有条件和无条件双路径作为数据划分的基础。",
}, },
{ {
"term": "Distributed Inference", "term": "Distributed Inference",
"explanation": "分布式推理利用多个GPU并行处理计算任务以减少延迟。主要分为数据并行(如将图像切片处理)和流水线并行(如将模型层切分)。然而,现有的分布式方法在扩散模型中往往面临通信开销大或生成图像出现拼接伪影的问题。", "explanation": "分布式推理利用多个GPU并行处理计算任务以减少延迟。主要分为数据并行(如将图像切片处理)和流水线并行(如将模型层切分)。然而,现有的分布式方法在扩散模型中往往面临通信开销大或生成图像出现拼接伪影的问题。",
"why_matters": "本文提出的混合并行框架正是为了解决现有分布式推理方法中的这些痛点。" "why_matters": "本文提出的混合并行框架正是为了解决现有分布式推理方法中的这些痛点。",
} },
] ]
}, },
"motivation": { "motivation": {
"problem": "现有的扩散模型加速方法,无论是单卡优化(如减少采样步数、模型剪枝)还是多卡分布式并行(如DistriFusion和AsyncDiff),都存在明显的局限性。单卡优化受限于硬件算力上限,而现有多卡并行方法通常只能实现次线性的加速比。例如,DistriFusion将图像切片并行处理,容易在拼接处产生明显的伪影;AsyncDiff采用异步流水线,虽然加速了但会引入估计误差,且通信开销巨大(在SDXL上高达9.83GB)。", "problem": "现有的扩散模型加速方法,无论是单卡优化(如减少采样步数、模型剪枝)还是多卡分布式并行(如DistriFusion和AsyncDiff),都存在明显的局限性。单卡优化受限于硬件算力上限,而现有多卡并行方法通常只能实现次线性的加速比。例如,DistriFusion将图像切片并行处理,容易在拼接处产生明显的伪影;AsyncDiff采用异步流水线,虽然加速了但会引入估计误差,且通信开销巨大(在SDXL上高达9.83GB)。",
"goal": "本文旨在提出一种新颖的混合并行框架,在仅使用两张GPU的情况下,不仅能实现超过线性的加速比(即 $>2\times$),还要严格保持甚至提升生成图像的质量,同时将通信开销降到最低。", "goal": "本文旨在提出一种新颖的混合并行框架,在仅使用两张GPU的情况下,不仅能实现超过线性的加速比(即 $>2\times$),还要严格保持甚至提升生成图像的质量,同时将通信开销降到最低。",
"gap": "与以往将图像空间切片(Patch-based)的思路不同,本文独辟蹊径,利用无分类器引导(CFG)中天然存在的“有条件”和“无条件”两条路径作为新的数据划分维度(Condition-based Partitioning)。同时,作者发现这两条路径的预测误差差异在整个去噪过程中呈现出先大后小再变大的U型曲线,因此引入了自适应的并行切换策略,只在误差差异最小时才进行并行流水线处理。" "gap": "与以往将图像空间切片(Patch-based)的思路不同,本文独辟蹊径,利用无分类器引导(CFG)中天然存在的“有条件”和“无条件”两条路径作为新的数据划分维度(Condition-based Partitioning)。同时,作者发现这两条路径的预测误差差异在整个去噪过程中呈现出先大后小再变大的U型曲线,因此引入了自适应的并行切换策略,只在误差差异最小时才进行并行流水线处理。",
}, },
"method": { "method": {
"overview": "该框架的核心思想是将扩散推理过程划分为三个阶段:预热阶段(Warm-Up)、并行阶段(Parallelism)和完全连接阶段(Fully-Connecting)。在预热和完全连接阶段,使用“基于条件的划分”策略,即一张GPU处理有条件的预测,另一张处理无条件的预测。而在中间的并行阶段,由于两个预测结果非常接近,框架切换到“自适应流水线并行”,利用两张GPU交替执行推理步骤,从而大幅压缩时间。", "overview": "该框架的核心思想是将扩散推理过程划分为三个阶段:预热阶段(Warm-Up)、并行阶段(Parallelism)和完全连接阶段(Fully-Connecting)。在预热和完全连接阶段,使用“基于条件的划分”策略,即一张GPU处理有条件的预测,另一张处理无条件的预测。而在中间的并行阶段,由于两个预测结果非常接近,框架切换到“自适应流水线并行”,利用两张GPU交替执行推理步骤,从而大幅压缩时间。",
"key_idea": "核心创新在于不再将图片在空间上切片,而是沿“条件”维度切分数据。这保证了每个GPU都能看到整张图片的全局信息,从而避免了拼接伪影。此外,引入了“去噪差异度”(Denoising Discrepancy,即 rel-MAE)这一指标来动态评估两条路径的相似性,并以此自动决定何时开启和关闭流水线并行,实现了最优的加速-质量平衡。", "key_idea": "核心创新在于不再将图片在空间上切片,而是沿“条件”维度切分数据。这保证了每个GPU都能看到整张图片的全局信息,从而避免了拼接伪影。此外,引入了“去噪差异度”(Denoising Discrepancy,即 rel-MAE)这一指标来动态评估两条路径的相似性,并以此自动决定何时开启和关闭流水线并行,实现了最优的加速-质量平衡。",
"steps": "1. 数据划分:输入潜变量同时送入GPU 1(有条件预测 $\epsilon_\theta(x_t, c, t)$)和GPU 2(无条件预测 $\epsilon_\theta(x_t, t)$)。2. 阶段判断:根据实时计算的“去噪差异度” $G_t$ 与阈值 $g_{slope}$ 的关系,确定切换点 $\tau_1$ 和 $\tau_2$。3. 混合执行:在 $[T, \tau_1]$ 阶段同步运行;在 $[\tau_1, \tau_2]$ 阶段启用流水线并行(如GPU 1处理 $t-1$ 步时GPU 2处理 $t$ 步);在 $[\tau_2, 0]$ 阶段重新恢复同步以精细调整细节。", "steps": "1. 数据划分:输入潜变量同时送入GPU 1(有条件预测 $\epsilon_\theta(x_t, c, t)$)和GPU 2(无条件预测 $\epsilon_\theta(x_t, t)$)。2. 阶段判断:根据实时计算的“去噪差异度” $G_t$ 与阈值 $g_{slope}$ 的关系,确定切换点 $\tau_1$ 和 $\tau_2$。3. 混合执行:在 $[T, \tau_1]$ 阶段同步运行;在 $[\tau_1, \tau_2]$ 阶段启用流水线并行(如GPU 1处理 $t-1$ 步时GPU 2处理 $t$ 步);在 $[\tau_2, 0]$ 阶段重新恢复同步以精细调整细节。",
"novelty": "该方法的另一大新颖之处在于其“安全性”设计:通过设置 $\tau_{cap}$ 作为安全上限,确保即使自动算法失效,也不会在错误的时间点引入并行,从而保证了算法的鲁棒性。此外,该框架对U-Net(如SDXL)和DiT(如SD3)架构均具有良好的泛化性。" "novelty": "该方法的另一大新颖之处在于其“安全性”设计:通过设置 $\tau_{cap}$ 作为安全上限,确保即使自动算法失效,也不会在错误的时间点引入并行,从而保证了算法的鲁棒性。此外,该框架对U-Net(如SDXL)和DiT(如SD3)架构均具有良好的泛化性。",
}, },
"results": { "results": {
"main_findings": "实验在SDXL和SD3模型上进行,使用MS-COCO 2014验证集。结果显示,在SDXL上,该方法实现了2.31倍加速,延迟从16.49秒降至7.12秒,且FID指标与原始单卡模型持平(甚至略优)。相比此前最强的DistriFusion1.22倍)和AsyncDiff(1.31倍),提速效果显著。在通信开销方面,本方法仅为0.516GB,比AsyncDiff的9.83GB降低了19.6倍。在SD3模型上,同样实现了2.07倍的加速。", "main_findings": "实验在SDXL和SD3模型上进行,使用MS-COCO 2014验证集。结果显示,在SDXL上,该方法实现了2.31倍加速,延迟从16.49秒降至7.12秒,且FID指标与原始单卡模型持平(甚至略优)。相比此前最强的DistriFusion1.22倍)和AsyncDiff(1.31倍),提速效果显著。在通信开销方面,本方法仅为0.516GB,比AsyncDiff的9.83GB降低了19.6倍。在SD3模型上,同样实现了2.07倍的加速。",
@@ -44,29 +49,29 @@ data = {
"metric": "Speed-Up", "metric": "Speed-Up",
"this_work": "2.31x", "this_work": "2.31x",
"baseline": "1.31x (AsyncDiff)", "baseline": "1.31x (AsyncDiff)",
"improvement": "1.0x (Extra speed)" "improvement": "1.0x (Extra speed)",
}, },
{ {
"task": "Text-to-Image (SDXL)", "task": "Text-to-Image (SDXL)",
"metric": "Comm. (GB)", "metric": "Comm. (GB)",
"this_work": "0.516", "this_work": "0.516",
"baseline": "9.830 (AsyncDiff)", "baseline": "9.830 (AsyncDiff)",
"improvement": "Reduced by 19.6x" "improvement": "Reduced by 19.6x",
}, },
{ {
"task": "Text-to-Image (SD3)", "task": "Text-to-Image (SD3)",
"metric": "Speed-Up", "metric": "Speed-Up",
"this_work": "2.07x", "this_work": "2.07x",
"baseline": "1.97x (AsyncDiff)", "baseline": "1.97x (AsyncDiff)",
"improvement": "0.1x (Extra speed)" "improvement": "0.1x (Extra speed)",
} },
], ],
"limitations": "尽管该方法在通用性上表现出色,但在处理极高分辨率(如4K以上)时,加速比会随分辨率提升而有所下降(从2.72x降至1.62x)。此外,目前的实现仅针对两张GPU进行了深度优化,虽然文中提出了多卡扩展策略,但在单个样本推理场景下,如何高效地扩展到四卡或更多卡仍是一个挑战。最后,参数 $k$ 的选取目前仍需人工根据经验设定。" "limitations": "尽管该方法在通用性上表现出色,但在处理极高分辨率(如4K以上)时,加速比会随分辨率提升而有所下降(从2.72x降至1.62x)。此外,目前的实现仅针对两张GPU进行了深度优化,虽然文中提出了多卡扩展策略,但在单个样本推理场景下,如何高效地扩展到四卡或更多卡仍是一个挑战。最后,参数 $k$ 的选取目前仍需人工根据经验设定。",
}, },
"improvements": { "improvements": {
"weaknesses": "主要弱点在于自适应切换参数(如 $k$ 和 $\tau_{cap}$)的确定目前仍偏向经验性,缺乏完全自动化的端到端学习机制。此外,虽然避免了图像切片,但条件分支的“信息量”并不总是完全对等的,特别是在极早期的噪声阶段,可能导致其中一张GPU负载不均衡。改进方向可以是结合动态负载均衡算法,根据当前步骤的预测难度动态分配计算资源。", "weaknesses": "主要弱点在于自适应切换参数(如 $k$ 和 $\tau_{cap}$)的确定目前仍偏向经验性,缺乏完全自动化的端到端学习机制。此外,虽然避免了图像切片,但条件分支的“信息量”并不总是完全对等的,特别是在极早期的噪声阶段,可能导致其中一张GPU负载不均衡。改进方向可以是结合动态负载均衡算法,根据当前步骤的预测难度动态分配计算资源。",
"future_work": "未来的研究方向包括:1. 将该混合并行策略扩展到视频生成模型(Video Diffusion)中,利用时间轴上的相关性进行更细粒度的流水线调度。2. 结合模型量化(Quantization)和蒸馏技术,在多卡并行的基础上进一步压缩单步推理时间。3. 探索在“去噪差异度”指标指导下自动学习最优的 $k$ 值和切换点。", "future_work": "未来的研究方向包括:1. 将该混合并行策略扩展到视频生成模型(Video Diffusion)中,利用时间轴上的相关性进行更细粒度的流水线调度。2. 结合模型量化(Quantization)和蒸馏技术,在多卡并行的基础上进一步压缩单步推理时间。3. 探索在“去噪差异度”指标指导下自动学习最优的 $k$ 值和切换点。",
"reproducibility": "代码已在GitHub开源(https://github.com/kaist-dmlab/Hybridiff)。实验环境基于PyTorch,使用的GPU为NVIDIA GeForce 3090,硬件门槛相对较低。文中详细列出了关键超参数(如SDXL上的 $L=12, k=5, \tau_{cap}=15$),使得复现结果的难度较低。" "reproducibility": "代码已在GitHub开源(https://github.com/kaist-dmlab/Hybridiff)。实验环境基于PyTorch,使用的GPU为NVIDIA GeForce 3090,硬件门槛相对较低。文中详细列出了关键超参数(如SDXL上的 $L=12, k=5, \tau_{cap}=15$),使得复现结果的难度较低。",
}, },
"figures": [ "figures": [
{ {
@@ -74,30 +79,30 @@ data = {
"caption": "Summary of the proposed hybrid data-pipeline parallelism", "caption": "Summary of the proposed hybrid data-pipeline parallelism",
"description": "五维雷达图展示了该方法在速度、图像质量、通用性、高分辨率能力和通信开销五个方面均优于现有分布式框架。", "description": "五维雷达图展示了该方法在速度、图像质量、通用性、高分辨率能力和通信开销五个方面均优于现有分布式框架。",
"reason": "直观概括了本文的核心优势,即全方位的性能提升。", "reason": "直观概括了本文的核心优势,即全方位的性能提升。",
"section": "results" "section": "results",
}, },
{ {
"id": "Figure 2", "id": "Figure 2",
"caption": "Comparison of parallel strategies", "caption": "Comparison of parallel strategies",
"description": "对比了三种并行策略:(a)基于切片的数据并行容易产生伪影,(b)流水线并行通信开销大,(c)本文提出的混合并行既保留全局一致性又实现了高效并行。", "description": "对比了三种并行策略:(a)基于切片的数据并行容易产生伪影,(b)流水线并行通信开销大,(c)本文提出的混合并行既保留全局一致性又实现了高效并行。",
"reason": "通过对比展示了本文方法设计的合理性和必要性。", "reason": "通过对比展示了本文方法设计的合理性和必要性。",
"section": "method" "section": "method",
}, },
{ {
"id": "Figure 3", "id": "Figure 3",
"caption": "Overview of the hybrid parallel framework", "caption": "Overview of the hybrid parallel framework",
"description": "详细展示了三个阶段(Warm-Up, Parallelism, Fully-Connecting)的数据流和通信模式,清晰地说明了自适应切换的动态过程。", "description": "详细展示了三个阶段(Warm-Up, Parallelism, Fully-Connecting)的数据流和通信模式,清晰地说明了自适应切换的动态过程。",
"reason": "这是理解整个算法执行流程的关键示意图。", "reason": "这是理解整个算法执行流程的关键示意图。",
"section": "method" "section": "method",
}, },
{ {
"id": "Table 1", "id": "Table 1",
"caption": "Quantitative comparison on SDXL and SD3", "caption": "Quantitative comparison on SDXL and SD3",
"description": "表格列出了该方法与基线方法在延迟、加速比、通信开销及生成质量指标(FID, LPIPS, PSNR)上的详细对比数据。", "description": "表格列出了该方法与基线方法在延迟、加速比、通信开销及生成质量指标(FID, LPIPS, PSNR)上的详细对比数据。",
"reason": "提供了最核心的定量证据,证明了该方法的有效性。", "reason": "提供了最核心的定量证据,证明了该方法的有效性。",
"section": "results" "section": "results",
} },
] ],
} }
with open("data/papers/2602.21760/summary.json", "w", encoding="utf-8") as f: with open("data/papers/2602.21760/summary.json", "w", encoding="utf-8") as f:
+12
View File
@@ -27,6 +27,18 @@ dev = [
"pytest>=8.0", "pytest>=8.0",
"pytest-asyncio>=0.24", "pytest-asyncio>=0.24",
] ]
# 导出 DocLayout-YOLO ONNX 用(一次性脚本 scripts/export_doclayout_yolo_onnx.py,独立 venv 运行)
# GPU 推理:onnxruntime 与 onnxruntime-gpu/-directml 同环境冲突,不在此声明,
# 需手动二选一(见 .env.example 布局检测段说明)
export = [
"torch>=2.0",
"torchvision>=0.15",
"doclayout-yolo",
"onnx>=1.14",
"onnxscript", # torch 2.12+ 的 onnx exporter 需要
"onnxsim",
"huggingface-hub>=0.20",
]
[build-system] [build-system]
requires = ["hatchling"] requires = ["hatchling"]
+161
View File
@@ -0,0 +1,161 @@
"""导出 DocLayout-YOLO (DocStructBench, imgsz=1024) 为 ONNX 格式.
一次性脚本,在独立 venv 中运行(不进运行时依赖):
python -m venv .venv-export && source .venv-export/bin/activate
pip install torch torchvision onnx onnxscript onnxsim onnxruntime huggingface-hub doclayout-yolo
两种权重来源:
# 1) 用本地已下载的 .pt(推荐,省下载)
.venv/bin/python scripts/export_doclayout_yolo_onnx.py \
--weights /path/to/doclayout_yolo_docstructbench_imgsz1024.pt
# 2) 从 HuggingFace 下载(不传 --weights
HF_ENDPOINT=https://hf-mirror.com .venv/bin/python scripts/export_doclayout_yolo_onnx.py
输出:
data/models/doclayout_yolo_docstructbench_imgsz1024.onnx
注意:model.export(simplify=True) 会清空 ONNX metadata,本脚本在导出后
用 onnx 包把 names 重新写回 metadata,供运行时 _parse_names_from_meta 读取。
"""
from __future__ import annotations
import argparse
import json
import os
import shutil
from pathlib import Path
# hf-mirror(国内加速,仅 --weights 未传、走 HF 下载时生效)
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
PROJECT_ROOT = Path(__file__).resolve().parent.parent
MODEL_DIR = PROJECT_ROOT / "data" / "models"
DEFAULT_OUTPUT = MODEL_DIR / "doclayout_yolo_docstructbench_imgsz1024.onnx"
REPO_ID = "juliozhao/DocLayout-YOLO-DocStructBench"
PT_FILENAME = "doclayout_yolo_docstructbench.pt"
IMGSZ = 1024
def resolve_weights(arg: str | None) -> Path:
"""返回 .pt 路径:传 --weights 用本地,否则从 HuggingFace 下载。"""
if arg:
p = Path(arg)
if not p.exists():
raise FileNotFoundError(f"--weights not found: {p}")
print(f"[1/5] Using local weights: {p}")
return p
print(f"[1/5] Downloading .pt from HuggingFace ({REPO_ID}) ...")
from huggingface_hub import hf_hub_download
pt_path = Path(hf_hub_download(repo_id=REPO_ID, filename=PT_FILENAME))
print(f"{pt_path}")
return pt_path
def export_onnx(pt_path: Path, output: Path) -> None:
print("\n[2/5] Loading model with doclayout_yolo ...")
from doclayout_yolo import YOLOv10
model = YOLOv10(str(pt_path))
names = model.names # dict[int, str],与 model.model.names 等价
print(f" ✓ Loaded. names = {names}")
print(f"\n[3/5] Exporting ONNX (imgsz={IMGSZ}, opset=12, simplify=True) ...")
try:
exported = model.export(
format="onnx",
imgsz=IMGSZ,
opset=12,
simplify=True, # 需要 onnxsim;失败则下面回退
dynamic=False, # 固定 batch=1 + 固定 1024,部署最稳
half=False, # FP32,保证 CPU 推理精度
)
except Exception as e:
print(f" ⚠ export with simplify failed ({e}); retrying without simplify")
exported = model.export(
format="onnx", imgsz=IMGSZ, opset=12, dynamic=False, half=False
)
exported_path = Path(exported)
output.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(exported_path), str(output))
print(f" ✓ Exported → {output} ({output.stat().st_size / 1024 / 1024:.1f} MB)")
print("\n[4/5] Re-writing names metadata (simplify may have dropped it) ...")
write_names_metadata(output, names)
def write_names_metadata(onnx_path: Path, names: dict) -> None:
"""把 names dict 写入 ONNX model.metadata_propssimplify 后通常丢失)。"""
import onnx
m = onnx.load(str(onnx_path))
keep = [p for p in m.metadata_props if p.key != "names"]
del m.metadata_props[:]
m.metadata_props.extend(keep)
names_json = json.dumps({str(k): v for k, v in names.items()}, ensure_ascii=False)
m.metadata_props.append(onnx.StringStringEntryProto(key="names", value=names_json))
onnx.save(m, str(onnx_path))
print(f" ✓ names metadata written: {names_json}")
def inspect_onnx(onnx_path: Path) -> None:
"""用 onnxruntime 加载模型,打印输入输出 + names metadata + 试推理。"""
print("\n[5/5] Verifying with onnxruntime ...")
import numpy as np
import onnxruntime as ort
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
print(" Inputs:")
for inp in session.get_inputs():
print(f" {inp.name}: shape={inp.shape}, dtype={inp.type}")
print(" Outputs:")
for out in session.get_outputs():
print(f" {out.name}: shape={out.shape}, dtype={out.type}")
meta = session.get_modelmeta()
print(f" metadata keys: {list(meta.custom_metadata_map.keys())}")
print(f" names: {meta.custom_metadata_map.get('names')}")
# dummy 推理
input_info = session.get_inputs()[0]
h = input_info.shape[2] if isinstance(input_info.shape[2], int) else IMGSZ
w = input_info.shape[3] if isinstance(input_info.shape[3], int) else IMGSZ
dummy = np.random.rand(1, 3, h, w).astype(np.float32)
outputs = session.run(None, {input_info.name: dummy})
print(f" Inference test: output[0] shape = {outputs[0].shape}")
out_shape = outputs[0].shape
if len(out_shape) == 3 and out_shape[2] == 6:
print(" ✓ output is [1, N, 6] (YOLOv10 end-to-end, NMS applied)")
else:
print(
f" ⚠️ output shape {out_shape} ≠ [1, N, 6]; "
"layout_detector._postprocess_output will warn and skip pages — "
"adjust export (e.g. end2end/nms) or postprocess.",
)
def main() -> None:
ap = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
ap.add_argument("--weights", help="本地 .pt 路径(不传则从 HuggingFace 下载)")
ap.add_argument(
"--output",
default=str(DEFAULT_OUTPUT),
help=f"输出 ONNX 路径(默认 {DEFAULT_OUTPUT}",
)
args = ap.parse_args()
output = Path(args.output)
pt_path = resolve_weights(args.weights)
export_onnx(pt_path, output)
inspect_onnx(output)
print(f"\n✓ Done! ONNX model saved to {output}")
if __name__ == "__main__":
main()
-172
View File
@@ -1,172 +0,0 @@
"""导出 PicoDet-S_layout_3cls 为 ONNX 格式.
一次性脚本,在独立 venv 中运行:
python -m venv .venv-export && source .venv-export/bin/activate
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple paddlepaddle paddleocr paddle2onnx onnxruntime opencv-python-headless
HF_ENDPOINT=https://hf-mirror.com python scripts/export_picodet_onnx.py
输出:
data/models/picodet_layout_3cls.onnx
"""
from __future__ import annotations
import os
import shutil
import subprocess
import sys
from pathlib import Path
# hf-mirror
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
PROJECT_ROOT = Path(__file__).resolve().parent.parent
MODEL_DIR = PROJECT_ROOT / "data" / "models"
OUTPUT_PATH = MODEL_DIR / "picodet_layout_3cls.onnx"
MODEL_NAME = "PicoDet-S_layout_3cls"
def main() -> None:
MODEL_DIR.mkdir(parents=True, exist_ok=True)
# ── Step 1: 用 PaddleOCR paddle_static 引擎加载模型,触发下载 ──
print(f"[1/4] Loading model '{MODEL_NAME}' (paddle_static engine, triggers download) ...")
from paddleocr import LayoutDetection
model = LayoutDetection(
model_name=MODEL_NAME,
engine="paddle_static",
device="cpu",
)
print(" ✓ Model loaded and cached")
# ── Step 2: 找到 PaddleX 缓存的 Paddle 模型文件 ────────────────
paddlex_cache = Path.home() / ".paddlex"
print(f"\n[2/4] Searching Paddle model cache in {paddlex_cache} ...")
# 搜索 layout 相关的缓存目录
candidates = []
for d in paddlex_cache.rglob("*"):
if d.is_dir() and (d / "inference.pdiparams").exists():
# 检查是否是 layout 模型
marker = d.name
parent_name = d.parent.name
if "layout" in marker.lower() or "layout" in parent_name.lower() or "picodet" in marker.lower():
candidates.append(d)
elif "PicoDet" in str(d):
candidates.append(d)
if not candidates:
# 如果没找到明确的 layout 目录,列出所有含 inference.pdiparams 的目录
all_model_dirs = [d for d in paddlex_cache.rglob("*") if d.is_dir() and (d / "inference.pdiparams").exists()]
print(" No layout-specific dir found. All model dirs with inference.pdiparams:")
for d in all_model_dirs:
files = [f.name for f in d.iterdir()]
print(f" {d} ({', '.join(files)})")
if all_model_dirs:
# 取最新的(刚下载的)
candidates = sorted(all_model_dirs, key=lambda d: (d / "inference.pdiparams").stat().st_mtime, reverse=True)[:1]
if not candidates:
print(" ✗ No cached model found")
sys.exit(1)
model_cache_dir = candidates[0]
files_in_dir = list(model_cache_dir.iterdir())
print(f" Using: {model_cache_dir}")
for f in files_in_dir:
print(f" {f.name} ({f.stat().st_size / 1024:.1f} KB)")
# ── Step 3: 用 paddle2onnx 转换 ─────────────────────────────────
print("\n[3/4] Converting to ONNX with paddle2onnx ...")
tmp_onnx = OUTPUT_PATH.with_suffix(".tmp.onnx")
# 确定 model_filename
pdmodel = model_cache_dir / "inference.pdmodel"
has_pdmodel = pdmodel.exists()
cmd = [
sys.executable, "-m", "paddle2onnx",
"--model_dir", str(model_cache_dir),
"--save_file", str(tmp_onnx),
"--opset_version", "11",
"--enable_onnx_checker", "True",
]
if has_pdmodel:
cmd.extend(["--model_filename", "inference.pdmodel"])
cmd.extend(["--params_filename", "inference.pdiparams"])
print(f" Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.stdout:
print(f" stdout: {result.stdout[:500]}")
if result.returncode != 0:
print(f" ✗ paddle2onnx failed (exit {result.returncode})")
print(f" stderr: {result.stderr[:500]}")
# 尝试不带 model_filenamecombined format
if has_pdmodel:
print(" Retrying without explicit model_filename ...")
cmd2 = [
sys.executable, "-m", "paddle2onnx",
"--model_dir", str(model_cache_dir),
"--params_filename", "inference.pdiparams",
"--save_file", str(tmp_onnx),
"--opset_version", "11",
]
result2 = subprocess.run(cmd2, capture_output=True, text=True)
if result2.returncode != 0:
print(f" ✗ Retry also failed: {result2.stderr[:500]}")
sys.exit(1)
if not tmp_onnx.exists() or tmp_onnx.stat().st_size < 1000:
print(" ✗ ONNX file not created or too small")
sys.exit(1)
shutil.move(str(tmp_onnx), str(OUTPUT_PATH))
print(f" ✓ ONNX saved ({OUTPUT_PATH.stat().st_size / 1024 / 1024:.2f} MB)")
# ── Step 4: 用 onnxruntime 验证 ─────────────────────────────────
print("\n[4/4] Verifying with onnxruntime ...")
_inspect_onnx(OUTPUT_PATH)
print(f"\n✓ Done! ONNX model saved to {OUTPUT_PATH}")
def _inspect_onnx(onnx_path: Path) -> None:
"""用 onnxruntime 加载模型,打印输入输出信息."""
import numpy as np
import onnxruntime as ort
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
print(" Inputs:")
for inp in session.get_inputs():
print(f" {inp.name}: shape={inp.shape}, dtype={inp.type}")
print(" Outputs:")
for out in session.get_outputs():
print(f" {out.name}: shape={out.shape}, dtype={out.type}")
# 试推理
input_info = session.get_inputs()[0]
input_name = input_info.name
batch_size = input_info.shape[0] if isinstance(input_info.shape[0], int) else 1
channels = input_info.shape[1] if isinstance(input_info.shape[1], int) else 3
height = input_info.shape[2] if isinstance(input_info.shape[2], int) else 480
width = input_info.shape[3] if isinstance(input_info.shape[3], int) else 480
dummy_input = np.random.rand(batch_size, channels, height, width).astype(np.float32)
outputs = session.run(None, {input_name: dummy_input})
print(" Inference test outputs:")
for i, (out_info, out_val) in enumerate(zip(session.get_outputs(), outputs)):
print(f" output[{i}] '{out_info.name}': shape={out_val.shape}, dtype={out_val.dtype}")
if out_val.size <= 20:
print(f" values: {out_val}")
print(" ✓ Inference OK")
if __name__ == "__main__":
main()
+1 -1
View File
@@ -1,4 +1,4 @@
"""批量重新提取所有论文的图片 — 下载 PDF + PicoDet 检测 + caption 匹配. """批量重新提取所有论文的图片 — 下载 PDF + DocLayout-YOLO 检测 + caption 匹配.
用法: 用法:
PROXY_SERVER=http://... uv run python scripts/reextract_images.py PROXY_SERVER=http://... uv run python scripts/reextract_images.py
+68 -38
View File
@@ -3,8 +3,19 @@ import sys
schema = { schema = {
"type": "object", "type": "object",
"required": ["arxiv_id", "title_zh", "one_line", "tags", "difficulty", "required": [
"prerequisites", "motivation", "method", "results", "improvements", "figures"], "arxiv_id",
"title_zh",
"one_line",
"tags",
"difficulty",
"prerequisites",
"motivation",
"method",
"results",
"improvements",
"figures",
],
"properties": { "properties": {
"arxiv_id": {"type": "string"}, "arxiv_id": {"type": "string"},
"title_zh": {"type": "string"}, "title_zh": {"type": "string"},
@@ -15,16 +26,19 @@ schema = {
"type": "object", "type": "object",
"required": ["concepts"], "required": ["concepts"],
"properties": { "properties": {
"concepts": {"type": "array", "items": { "concepts": {
"type": "object", "type": "array",
"required": ["term", "explanation", "why_matters"], "items": {
"properties": { "type": "object",
"term": {"type": "string"}, "required": ["term", "explanation", "why_matters"],
"explanation": {"type": "string"}, "properties": {
"why_matters": {"type": "string"} "term": {"type": "string"},
} "explanation": {"type": "string"},
}} "why_matters": {"type": "string"},
} },
},
}
},
}, },
"motivation": { "motivation": {
"type": "object", "type": "object",
@@ -32,8 +46,8 @@ schema = {
"properties": { "properties": {
"problem": {"type": "string"}, "problem": {"type": "string"},
"goal": {"type": "string"}, "goal": {"type": "string"},
"gap": {"type": "string"} "gap": {"type": "string"},
} },
}, },
"method": { "method": {
"type": "object", "type": "object",
@@ -42,27 +56,36 @@ schema = {
"overview": {"type": "string"}, "overview": {"type": "string"},
"key_idea": {"type": "string"}, "key_idea": {"type": "string"},
"steps": {"type": "string"}, "steps": {"type": "string"},
"novelty": {"type": "string"} "novelty": {"type": "string"},
} },
}, },
"results": { "results": {
"type": "object", "type": "object",
"required": ["main_findings", "benchmarks", "limitations"], "required": ["main_findings", "benchmarks", "limitations"],
"properties": { "properties": {
"main_findings": {"type": "string"}, "main_findings": {"type": "string"},
"benchmarks": {"type": "array", "items": { "benchmarks": {
"type": "object", "type": "array",
"required": ["task", "metric", "this_work", "baseline", "improvement"], "items": {
"properties": { "type": "object",
"task": {"type": "string"}, "required": [
"metric": {"type": "string"}, "task",
"this_work": {"type": "string"}, "metric",
"baseline": {"type": "string"}, "this_work",
"improvement": {"type": "string"} "baseline",
} "improvement",
}}, ],
"limitations": {"type": "string"} "properties": {
} "task": {"type": "string"},
"metric": {"type": "string"},
"this_work": {"type": "string"},
"baseline": {"type": "string"},
"improvement": {"type": "string"},
},
},
},
"limitations": {"type": "string"},
},
}, },
"improvements": { "improvements": {
"type": "object", "type": "object",
@@ -70,8 +93,8 @@ schema = {
"properties": { "properties": {
"weaknesses": {"type": "string"}, "weaknesses": {"type": "string"},
"future_work": {"type": "string"}, "future_work": {"type": "string"},
"reproducibility": {"type": "string"} "reproducibility": {"type": "string"},
} },
}, },
"figures": { "figures": {
"type": "array", "type": "array",
@@ -83,16 +106,20 @@ schema = {
"caption": {"type": "string"}, "caption": {"type": "string"},
"description": {"type": "string"}, "description": {"type": "string"},
"reason": {"type": "string"}, "reason": {"type": "string"},
"section": {"type": "string", "enum": ["motivation", "method", "results", "limitations"]} "section": {
} "type": "string",
} "enum": ["motivation", "method", "results", "limitations"],
} },
} },
},
},
},
} }
def validate_file(filepath): def validate_file(filepath):
try: try:
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
# Check required fields # Check required fields
@@ -139,6 +166,9 @@ def validate_file(filepath):
print(f"❌ Validation error: {e}") print(f"❌ Validation error: {e}")
return False return False
if __name__ == "__main__": if __name__ == "__main__":
filepath = sys.argv[1] if len(sys.argv) > 1 else "data/papers/2601.10592/summary.json" filepath = (
sys.argv[1] if len(sys.argv) > 1 else "data/papers/2601.10592/summary.json"
)
validate_file(filepath) validate_file(filepath)
+1 -4
View File
@@ -301,10 +301,7 @@ class TestAdminPapers:
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["count"] == 2 assert resp.json()["count"] == 2
remaining = db_session.execute( remaining = db_session.execute(
text( text("SELECT rowid FROM papers_fts WHERE rowid IN (:id1, :id2)"),
"SELECT rowid FROM papers_fts "
"WHERE rowid IN (:id1, :id2)"
),
{"id1": target_ids[0], "id2": target_ids[1]}, {"id1": target_ids[0], "id2": target_ids[1]},
).fetchall() ).fetchall()
assert remaining == [] assert remaining == []
+3 -1
View File
@@ -54,7 +54,9 @@ class TestReindexChroma:
def test_reindex_chroma_indexes_only_summarized_papers( def test_reindex_chroma_indexes_only_summarized_papers(
self, db_session, sample_papers_with_summary self, db_session, sample_papers_with_summary
): ):
with patch("app.services.embedder.index_paper", return_value=True) as mock_index: with patch(
"app.services.embedder.index_paper", return_value=True
) as mock_index:
result = reindex_chroma(db_session) result = reindex_chroma(db_session)
assert result["status"] == "success" assert result["status"] == "success"
+403
View File
@@ -0,0 +1,403 @@
"""layout_detector 测试 — 坐标还原数学、设备探测、类别映射、端到端 detect_page.
纯函数测试不依赖真实模型;TestDetectPage 用 MagicMock mock ort.InferenceSession
与 pymupdf.Page,参考 test_summary_utils.py 的 mock 模式。
"""
from __future__ import annotations
import json
from unittest.mock import MagicMock
import numpy as np
import onnxruntime as ort
import pytest
from app.config import settings
from app.services import layout_detector as mod
from app.services.layout_detector import (
LayoutBox,
_compute_render_geometry,
_FALLBACK_NAMES,
_letterbox_padding,
_map_class_to_boxclass,
_model_to_pdf,
_parse_names_from_meta,
_postprocess_output,
detect_page_layout,
resolve_providers,
)
IMGSZ = 1024
# ═══════════════════════════════════════════════════════════════════════
# 渲染几何与 letterbox padding
# ═══════════════════════════════════════════════════════════════════════
class TestComputeRenderGeometry:
def test_a4_portrait_short_edge_pads(self):
# A4 595×842,高度贴边,宽度方向留灰边
ratio = _compute_render_geometry(595, 842, IMGSZ)
assert ratio == pytest.approx(min(IMGSZ / 595, IMGSZ / 842))
assert ratio == pytest.approx(IMGSZ / 842) # 高度方向贴边
def test_wide_page_width_pads(self):
# 1600×900 横向,宽度贴边
ratio = _compute_render_geometry(1600, 900, IMGSZ)
assert ratio == pytest.approx(IMGSZ / 1600)
def test_square_no_letterbox(self):
ratio = _compute_render_geometry(100, 100, IMGSZ)
assert ratio == pytest.approx(10.24)
class TestLetterboxPadding:
def test_centered_padding(self):
# pixmap 723×1024 贴满高度,宽度两侧各 (1024-723)/2
dw, dh = _letterbox_padding(723, 1024, IMGSZ)
assert dw == pytest.approx((IMGSZ - 723) / 2)
assert dh == pytest.approx(0.0)
def test_square_no_padding(self):
dw, dh = _letterbox_padding(IMGSZ, IMGSZ, IMGSZ)
assert dw == 0.0
assert dh == 0.0
# ═══════════════════════════════════════════════════════════════════════
# 坐标还原(核心)—— pdf = (model - padding) / ratio
# ═══════════════════════════════════════════════════════════════════════
class TestModelToPdf:
def test_padding_corner_maps_to_origin(self):
# 模型空间左上角 (dw, dh) → PDF (0, 0)
dw, dh, ratio = 150.5, 0.0, 1.2157
x, y = _model_to_pdf(dw, dh, dw, dh, ratio)
assert x == pytest.approx(0.0, abs=0.01)
assert y == pytest.approx(0.0, abs=0.01)
def test_round_trip(self):
dw, dh, ratio = 150.5, 0.0, 1.2157
# PDF (100, 200) → 模型空间 → 再还原回 PDF
mx, my = 100 * ratio + dw, 200 * ratio + dh
px, py = _model_to_pdf(mx, my, dw, dh, ratio)
assert px == pytest.approx(100, abs=0.01)
assert py == pytest.approx(200, abs=0.01)
def test_full_a4_page_box(self):
# 整页框在模型空间为 (dw,dh)-(dw+pix_w, dh+pix_h),还原回页面尺寸
ratio = _compute_render_geometry(595, 842, IMGSZ)
pix_w, pix_h = round(595 * ratio), round(842 * ratio)
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
x0, y0 = _model_to_pdf(dw, dh, dw, dh, ratio)
x1, y1 = _model_to_pdf(dw + pix_w, dh + pix_h, dw, dh, ratio)
assert x0 == pytest.approx(0.0, abs=1.0)
assert y0 == pytest.approx(0.0, abs=1.0)
assert x1 == pytest.approx(595, abs=1.0)
assert y1 == pytest.approx(842, abs=1.0)
# ═══════════════════════════════════════════════════════════════════════
# 设备探测
# ═══════════════════════════════════════════════════════════════════════
class TestResolveProviders:
def test_cpu(self):
assert resolve_providers("cpu", 0) == [("CPUExecutionProvider", {})]
def test_cuda_with_cpu_fallback(self):
eps = resolve_providers("cuda", 0)
assert eps[0] == ("CUDAExecutionProvider", {"device_id": "0"})
assert eps[1] == ("CPUExecutionProvider", {})
def test_directml_device_id(self):
eps = resolve_providers("directml", 2)
assert eps[0] == ("DmlExecutionProvider", {"device_id": "2"})
def test_auto_picks_cuda_if_available(self, monkeypatch):
monkeypatch.setattr(
ort,
"get_available_providers",
lambda: ["CUDAExecutionProvider", "CPUExecutionProvider"],
)
eps = resolve_providers("auto", 0)
assert eps[0][0] == "CUDAExecutionProvider"
assert eps[-1] == ("CPUExecutionProvider", {})
def test_auto_falls_back_to_cpu(self, monkeypatch):
monkeypatch.setattr(
ort, "get_available_providers", lambda: ["CPUExecutionProvider"]
)
assert resolve_providers("auto", 0) == [("CPUExecutionProvider", {})]
def test_auto_prefers_cuda_over_directml(self, monkeypatch):
monkeypatch.setattr(
ort,
"get_available_providers",
lambda: [
"DmlExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
],
)
eps = resolve_providers("auto", 0)
assert eps[0][0] == "CUDAExecutionProvider"
def test_unknown_device_falls_back(self):
assert resolve_providers("tpu", 0) == [("CPUExecutionProvider", {})]
# ═══════════════════════════════════════════════════════════════════════
# 类别映射与 names 解析
# ═══════════════════════════════════════════════════════════════════════
class TestClassMapping:
def test_figure_to_picture(self):
assert _map_class_to_boxclass(3, {3: "figure"}) == "picture"
def test_figure_group_to_picture(self):
assert _map_class_to_boxclass(0, {0: "figure_group"}) == "picture"
def test_table(self):
assert _map_class_to_boxclass(5, {5: "table"}) == "table"
def test_caption_ignored(self):
names = {4: "figure_caption", 6: "table_caption"}
assert _map_class_to_boxclass(4, names) is None
assert _map_class_to_boxclass(6, names) is None
def test_other_classes_ignored(self):
names = {0: "title", 1: "plain text", 2: "abandon", 8: "isolate_formula"}
for k in names:
assert _map_class_to_boxclass(k, names) is None
def test_case_insensitive(self):
assert _map_class_to_boxclass(0, {0: "Figure"}) == "picture"
assert _map_class_to_boxclass(0, {0: "TABLE"}) == "table"
def test_unknown_class_id(self):
assert _map_class_to_boxclass(99, {0: "figure"}) is None
class TestParseNamesFromMeta:
def test_reads_json_metadata(self):
sess = MagicMock()
meta = MagicMock()
meta.custom_metadata_map = {
"names": '{"0": "title", "3": "figure", "5": "table"}'
}
sess.get_modelmeta.return_value = meta
assert _parse_names_from_meta(sess) == {0: "title", 3: "figure", 5: "table"}
def test_fallback_when_missing(self):
sess = MagicMock()
meta = MagicMock()
meta.custom_metadata_map = {}
sess.get_modelmeta.return_value = meta
assert _parse_names_from_meta(sess) == _FALLBACK_NAMES
def test_fallback_on_garbage(self):
sess = MagicMock()
meta = MagicMock()
meta.custom_metadata_map = {"names": "not json"}
sess.get_modelmeta.return_value = meta
assert _parse_names_from_meta(sess) == _FALLBACK_NAMES
# ═══════════════════════════════════════════════════════════════════════
# 后处理
# ═══════════════════════════════════════════════════════════════════════
class TestPostprocessOutput:
def test_parses_end_to_end_filters_by_conf(self):
out = np.array(
[[[10, 20, 30, 40, 0.9, 3], [50, 60, 70, 80, 0.1, 5]]],
dtype=np.float32,
)
res = _postprocess_output(out, 0.2, {3: "figure", 5: "table"})
assert res == [(3, 10.0, 20.0, 30.0, 40.0)]
def test_empty_output(self):
out = np.zeros((1, 0, 6), dtype=np.float32)
assert _postprocess_output(out, 0.2, {}) == []
def test_unexpected_shape_returns_empty(self):
out = np.zeros((1, 84, 8400), dtype=np.float32)
assert _postprocess_output(out, 0.2, {}) == []
# ═══════════════════════════════════════════════════════════════════════
# detect_page 端到端(mock ort.InferenceSession + pymupdf.Page
# ═══════════════════════════════════════════════════════════════════════
class TestDetectPage:
@pytest.fixture(autouse=True)
def _reset_detector(self):
"""每个测试前重置模块级单例,避免复用上个测试的 mock session。"""
mod._detector = mod._LayoutDetector()
yield
mod._detector = mod._LayoutDetector()
@staticmethod
def _build_mock_session(page_w, page_h, boxes, names):
"""构造 mock InferenceSession。
boxes: list of (cls_id, pdf_x0, pdf_y0, pdf_x1, pdf_y1, conf)
坐标为 PDF 点,内部转成模型空间坐标塞进 output。
names: dict[int, str] —— 写入 metadata 供 _parse_names_from_meta 读取。
"""
ratio = _compute_render_geometry(page_w, page_h, IMGSZ)
pix_w, pix_h = round(page_w * ratio), round(page_h * ratio)
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
rows = []
for cls_id, x0, y0, x1, y1, conf in boxes:
rows.append(
[
x0 * ratio + dw,
y0 * ratio + dh,
x1 * ratio + dw,
y1 * ratio + dh,
conf,
cls_id,
]
)
fake_output = (
np.array([rows], dtype=np.float32)
if rows
else np.zeros((1, 0, 6), dtype=np.float32)
)
sess = MagicMock()
inp = MagicMock()
inp.name = "images"
sess.get_inputs.return_value = [inp]
sess.run.return_value = [fake_output]
sess.get_providers.return_value = ["CPUExecutionProvider"]
meta = MagicMock()
meta.custom_metadata_map = {
"names": json.dumps({str(k): v for k, v in names.items()})
}
sess.get_modelmeta.return_value = meta
return sess, (pix_w, pix_h)
@staticmethod
def _make_mock_page(page_w, page_h, pix_w, pix_h):
pix = MagicMock()
pix.width = pix_w
pix.height = pix_h
pix.n = 3
pix.samples = bytes([128] * (pix_w * pix_h * 3))
page = MagicMock()
page.rect.width = page_w
page.rect.height = page_h
page.get_pixmap.return_value = pix
return page
def _setup(self, monkeypatch, tmp_path, sess):
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
(tmp_path / "m.onnx").write_bytes(b"x")
monkeypatch.setattr(ort, "InferenceSession", lambda *a, **kw: sess)
def test_returns_picture_box(self, monkeypatch, tmp_path):
names = {3: "figure", 5: "table"}
sess, (pw, ph) = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.9)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
b = boxes[0]
assert isinstance(b, LayoutBox)
assert b.boxclass == "picture"
assert b.x0 == pytest.approx(100, abs=1.0)
assert b.y0 == pytest.approx(100, abs=1.0)
assert b.x1 == pytest.approx(300, abs=1.0)
assert b.y1 == pytest.approx(400, abs=1.0)
def test_returns_table_box(self, monkeypatch, tmp_path):
names = {3: "figure", 5: "table"}
sess, (pw, ph) = self._build_mock_session(
595, 842, [(5, 50, 50, 400, 300, 0.85)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
assert boxes[0].boxclass == "table"
def test_filters_low_confidence(self, monkeypatch, tmp_path):
names = {3: "figure"}
# conf=0.1 < LAYOUT_THRESHOLD(0.2) → 过滤
sess, (pw, ph) = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.1)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
assert detect_page_layout(page) == []
def test_filters_small_box(self, monkeypatch, tmp_path):
names = {3: "figure"}
# 还原后 5×5 pt < _MIN_BOX_SIZE(20) → 过滤
sess, (pw, ph) = self._build_mock_session(
595, 842, [(3, 100, 100, 105, 105, 0.9)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
assert detect_page_layout(page) == []
def test_mixed_picture_and_table(self, monkeypatch, tmp_path):
names = {3: "figure", 5: "table"}
sess, (pw, ph) = self._build_mock_session(
595,
842,
[
(3, 100, 100, 300, 400, 0.9),
(5, 50, 500, 400, 700, 0.8),
],
names,
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
classes = sorted(b.boxclass for b in boxes)
assert classes == ["picture", "table"]
def test_empty_output(self, monkeypatch, tmp_path):
names = {3: "figure"}
sess, (pw, ph) = self._build_mock_session(595, 842, [], names)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
assert detect_page_layout(page) == []
def test_ignored_class_skipped(self, monkeypatch, tmp_path):
# title 类(cls_id=0)不应产出 LayoutBox
names = {0: "title", 3: "figure"}
sess, (pw, ph) = self._build_mock_session(
595,
842,
[(0, 100, 100, 400, 150, 0.9), (3, 100, 200, 300, 400, 0.9)],
names,
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
assert boxes[0].boxclass == "picture"
Generated
+1229 -13
View File
File diff suppressed because it is too large Load Diff