refactor: 迁移布局检测模型从 PicoDet 到 DocLayout-YOLO
- 核心变更: - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024) - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备 - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式 - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls] - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table) - 新增文件: - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行) - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例) - 配置更新: - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID) - app/config.py: Settings 类对应字段 - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等) - 删除旧文件: - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本 - 文档更新: - README.md: 更新环境变量说明 - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py) 此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
This commit is contained in:
+15
-3
@@ -48,6 +48,18 @@ EMBED_MODEL=Qwen/Qwen3-Embedding-4B
|
||||
EMBED_DIMENSIONS=2560
|
||||
|
||||
# ─── 布局检测 ─────────────────────────────
|
||||
# ONNX 模型路径(首次运行前执行 scripts/export_picodet_onnx.py 导出)
|
||||
# LAYOUT_MODEL_PATH=data/models/picodet_layout_3cls.onnx
|
||||
# LAYOUT_THRESHOLD=0.5
|
||||
# DocLayout-YOLO ONNX 模型(首次运行前执行 scripts/export_doclayout_yolo_onnx.py 导出)
|
||||
# LAYOUT_MODEL_PATH=data/models/doclayout_yolo_docstructbench_imgsz1024.onnx
|
||||
# 模型输入尺寸(DocLayout-YOLO 推荐 1024)
|
||||
# LAYOUT_IMGSZ=1024
|
||||
# 检测置信度阈值(DocLayout-YOLO 推荐 0.2)
|
||||
# LAYOUT_THRESHOLD=0.2
|
||||
# 推理设备:auto|cpu|cuda|directml|openvino|cann|tensorrt|qnn
|
||||
# auto = 按优先级 [CUDA, DirectML, OpenVINO, CPU] 自动探测,失败降级 CPU
|
||||
# LAYOUT_DEVICE=auto
|
||||
# 设备 ID(GPU 序号)
|
||||
# LAYOUT_DEVICE_ID=0
|
||||
#
|
||||
# GPU 用户:onnxruntime 与 onnxruntime-gpu/-directml 同环境冲突,需手动二选一:
|
||||
# pip uninstall onnxruntime && pip install onnxruntime-gpu # NVIDIA CUDA
|
||||
# pip uninstall onnxruntime && pip install onnxruntime-directml # Windows 任意 GPU
|
||||
|
||||
@@ -125,7 +125,7 @@ paper/
|
||||
├── scripts/
|
||||
│ ├── init_db.py # 数据库初始化
|
||||
│ ├── manual_crawl.py # 手动抓取脚本
|
||||
│ ├── export_picodet_onnx.py # 导出布局检测 ONNX 模型
|
||||
│ ├── export_doclayout_yolo_onnx.py # 导出布局检测 ONNX 模型
|
||||
│ ├── reextract_images.py # 批量重新提取图片
|
||||
│ └── validate_summary.py # 校验总结 JSON 结构
|
||||
│
|
||||
@@ -198,8 +198,10 @@ SECRET_KEY=your_random_secret_key
|
||||
| `EMBED_API_KEY` | — | Embedding API Key |
|
||||
| `EMBED_MODEL` | — | Embedding 模型名 |
|
||||
| `EMBED_DIMENSIONS` | `0` | 向量维度 |
|
||||
| `LAYOUT_MODEL_PATH` | `data/models/picodet_layout_3cls.onnx` | ONNX 布局检测模型路径(可选) |
|
||||
| `LAYOUT_THRESHOLD` | `0.5` | 布局检测置信度阈值(可选) |
|
||||
| `LAYOUT_MODEL_PATH` | `data/models/doclayout_yolo_docstructbench_imgsz1024.onnx` | DocLayout-YOLO ONNX 模型路径(可选) |
|
||||
| `LAYOUT_IMGSZ` | `1024` | 模型输入尺寸 |
|
||||
| `LAYOUT_THRESHOLD` | `0.2` | 布局检测置信度阈值(可选) |
|
||||
| `LAYOUT_DEVICE` | `auto` | 推理设备:auto/cpu/cuda/directml/openvino/...(可选) |
|
||||
|
||||
### 4. 初始化数据库
|
||||
|
||||
|
||||
+7
-2
@@ -60,8 +60,13 @@ class Settings(BaseSettings):
|
||||
EMBED_DIMENSIONS: int = 0
|
||||
|
||||
# 布局检测
|
||||
LAYOUT_MODEL_PATH: str = "data/models/picodet_layout_3cls.onnx"
|
||||
LAYOUT_THRESHOLD: float = 0.5
|
||||
LAYOUT_MODEL_PATH: str = "data/models/doclayout_yolo_docstructbench_imgsz1024.onnx"
|
||||
LAYOUT_IMGSZ: int = 1024
|
||||
LAYOUT_THRESHOLD: float = 0.2
|
||||
# 推理设备:auto|cpu|cuda|directml|openvino|cann|tensorrt|qnn
|
||||
# auto = 按优先级 [CUDA, DirectML, OpenVINO, CPU] 自动探测
|
||||
LAYOUT_DEVICE: str = "auto"
|
||||
LAYOUT_DEVICE_ID: int = 0
|
||||
|
||||
model_config = {
|
||||
"env_file": str(BASE_DIR / ".env"),
|
||||
|
||||
+8
-1
@@ -10,7 +10,14 @@ from fastapi.staticfiles import StaticFiles
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
|
||||
from app.config import settings
|
||||
from app.exceptions import AppError, ConflictError, ExternalAPIError, NotFoundError, PdfProcessError, ValidationError
|
||||
from app.exceptions import (
|
||||
AppError,
|
||||
ConflictError,
|
||||
ExternalAPIError,
|
||||
NotFoundError,
|
||||
PdfProcessError,
|
||||
ValidationError,
|
||||
)
|
||||
from app.database import engine, init_db
|
||||
from app.routes.admin import router as admin_router
|
||||
from app.routes.compare import router as compare_router
|
||||
|
||||
+9
-1
@@ -6,7 +6,15 @@ import hashlib
|
||||
import hmac
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Form, HTTPException, Query, Request
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
BackgroundTasks,
|
||||
Depends,
|
||||
Form,
|
||||
HTTPException,
|
||||
Query,
|
||||
Request,
|
||||
)
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import date, timedelta
|
||||
|
||||
|
||||
@@ -141,7 +141,9 @@ async def run_job(db: Session, job_id: int) -> dict:
|
||||
status=JobEventStatus.SUCCESS,
|
||||
payload=result if isinstance(result, dict) else {"result": result},
|
||||
)
|
||||
return result if isinstance(result, dict) else {"status": "success", "result": result}
|
||||
return (
|
||||
result if isinstance(result, dict) else {"status": "success", "result": result}
|
||||
)
|
||||
|
||||
|
||||
async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
|
||||
|
||||
+255
-86
@@ -1,19 +1,27 @@
|
||||
"""PicoDet-S_layout_3cls 布局检测 — 纯 ONNX Runtime 推理.
|
||||
"""DocLayout-YOLO 布局检测 — ONNX Runtime 推理,支持 CPU/GPU/NPU 多设备.
|
||||
|
||||
用 onnxruntime 加载导出好的 ONNX 模型,检测 PDF 页面中的 figure / table 区域。
|
||||
模型自带 NMS + GFL decode,输出即为后处理完毕的检测框。
|
||||
用 onnxruntime 加载 DocLayout-YOLO(DocStructBench, imgsz=1024)ONNX 模型,
|
||||
检测 PDF 页面中的 figure / table 区域。
|
||||
|
||||
预处理:letterbox(保比例缩放 + 灰边 padding 到 imgsz×imgsz),RGB,仅 /255 归一化
|
||||
(不做 ImageNet mean/std)。缩放由 pymupdf Matrix 完成,不依赖 OpenCV。
|
||||
后处理:YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls](已内置 NMS)。
|
||||
坐标还原:(model_coord - padding) / ratio —— 渲染缩放与 letterbox 缩放在 pymupdf
|
||||
渲染阶段合二为一,故只需一次除法。
|
||||
|
||||
设备:resolve_providers() 按 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表;
|
||||
_init_session() 逐个 try,首个不可用则降级,CPU 永远兜底。
|
||||
|
||||
输入:
|
||||
image: (1, 3, 480, 480) float32 — ImageNet 标准化后的图片
|
||||
scale_factor: (1, 2) float32 — [y_scale, x_scale],用于坐标还原
|
||||
images: (1, 3, imgsz, imgsz) float32 —— letterbox + /255 后的图
|
||||
|
||||
输出:
|
||||
fetch_name_0: (N, 6) float32 — [xmin, ymin, xmax, ymax, score, class_id]
|
||||
fetch_name_1: (1,) int32 — 有效框数量 N
|
||||
output0: (1, N, 6) float32 —— [x1, y1, x2, y2, conf, cls],已 NMS
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -26,30 +34,190 @@ from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 模型输入尺寸
|
||||
_MODEL_SIZE = 480
|
||||
# ImageNet normalize
|
||||
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||||
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||
# PicoDet label → 内部 boxclass
|
||||
_LABEL_MAP: dict[int, str] = {
|
||||
0: "picture", # PicoDet "image" → "picture"
|
||||
1: "table",
|
||||
# 2: seal — 忽略
|
||||
# DocLayout-YOLO DocStructBench 标准 10 类(ONNX metadata 读不到时的兜底,以实际为准)
|
||||
_FALLBACK_NAMES: dict[int, str] = {
|
||||
0: "title",
|
||||
1: "plain text",
|
||||
2: "abandon",
|
||||
3: "figure",
|
||||
4: "figure_caption",
|
||||
5: "table",
|
||||
6: "table_caption",
|
||||
7: "table_footnote",
|
||||
8: "isolate_formula",
|
||||
9: "formula_caption",
|
||||
}
|
||||
# 下游只需 picture/table —— 按 class name 字符串动态匹配(不依赖 class index,
|
||||
# 规避 DocStructBench 不同发布的类别顺序差异)
|
||||
_PICTURE_NAMES = {"figure", "figure_group"}
|
||||
_TABLE_NAMES = {"table", "table_group"}
|
||||
# letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降)
|
||||
_PAD_VALUE = 114
|
||||
# 最小 bbox 尺寸(PDF 点)
|
||||
_MIN_BOX_SIZE = 20
|
||||
|
||||
# device → ExecutionProvider 映射
|
||||
_PROVIDER_MAP: dict[str, str] = {
|
||||
"cpu": "CPUExecutionProvider",
|
||||
"cuda": "CUDAExecutionProvider",
|
||||
"directml": "DmlExecutionProvider",
|
||||
"openvino": "OpenVINOExecutionProvider",
|
||||
"cann": "CannExecutionProvider",
|
||||
"tensorrt": "TensorrtExecutionProvider",
|
||||
"qnn": "QNNExecutionProvider",
|
||||
}
|
||||
# auto 探测优先级(不含 cpu,cpu 永远兜底)
|
||||
_AUTO_PRIORITY = ["cuda", "directml", "openvino", "cann", "tensorrt", "qnn"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayoutBox:
|
||||
"""检测到的布局区域,兼容现有 _process_page 代码。"""
|
||||
"""检测到的布局区域,坐标为 PDF 点,boxclass ∈ {"picture", "table"}。"""
|
||||
|
||||
x0: float
|
||||
y0: float
|
||||
x1: float
|
||||
y1: float
|
||||
boxclass: str # "picture" | "table"
|
||||
boxclass: str
|
||||
|
||||
|
||||
# ── 设备选择 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def resolve_providers(device: str, device_id: int) -> list[tuple[str, dict]]:
|
||||
"""根据 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表(首选在前,均带 CPU 兜底)。
|
||||
|
||||
返回 list[tuple[ep_name, provider_options]],供 _init_session() 逐个 try。
|
||||
onnxruntime 创建 session 时若指定 EP 在本机变体里未注册会直接抛错,
|
||||
故降级逻辑由 _init_session() 完成,这里只产出候选。
|
||||
"""
|
||||
if device == "cpu":
|
||||
return [("CPUExecutionProvider", {})]
|
||||
|
||||
opts = {"device_id": str(device_id)}
|
||||
|
||||
if device == "auto":
|
||||
available = set(ort.get_available_providers())
|
||||
for dev in _AUTO_PRIORITY:
|
||||
ep = _PROVIDER_MAP[dev]
|
||||
if ep in available:
|
||||
logger.info("auto: selected provider %s", ep)
|
||||
return [(ep, opts), ("CPUExecutionProvider", {})]
|
||||
logger.info("auto: no GPU/NPU provider available, using CPU")
|
||||
return [("CPUExecutionProvider", {})]
|
||||
|
||||
ep = _PROVIDER_MAP.get(device)
|
||||
if ep is None:
|
||||
logger.warning("Unknown LAYOUT_DEVICE=%r, falling back to CPU", device)
|
||||
return [("CPUExecutionProvider", {})]
|
||||
return [(ep, opts), ("CPUExecutionProvider", {})]
|
||||
|
||||
|
||||
# ── 预处理:渲染几何与 letterbox ────────────────────────────────────────
|
||||
|
||||
|
||||
def _compute_render_geometry(page_w: float, page_h: float, imgsz: int) -> float:
|
||||
"""letterbox 渲染缩放 ratio = min(imgsz/page_w, imgsz/page_h)。
|
||||
|
||||
pymupdf 以 Matrix(ratio, ratio) 渲染,长边贴到 imgsz,短边留灰边。
|
||||
"""
|
||||
return min(imgsz / page_w, imgsz / page_h)
|
||||
|
||||
|
||||
def _letterbox_padding(
|
||||
content_w: float, content_h: float, imgsz: int
|
||||
) -> tuple[float, float]:
|
||||
"""居中 padding:(imgsz - content) / 2。content 为实际 pixmap 尺寸(已取整)。"""
|
||||
return (imgsz - content_w) / 2.0, (imgsz - content_h) / 2.0
|
||||
|
||||
|
||||
def _padded_nchw_from_pixmap(
|
||||
pix: pymupdf.Pixmap, imgsz: int, dw: float, dh: float
|
||||
) -> np.ndarray:
|
||||
"""pixmap → letterbox padded (1, 3, imgsz, imgsz) float32,灰边=114,/255 归一化。"""
|
||||
arr = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
|
||||
pix.height, pix.width, pix.n
|
||||
)
|
||||
if arr.shape[2] == 4: # 去 alpha(csRGB alpha=False 一般不会,防御性)
|
||||
arr = arr[:, :, :3]
|
||||
|
||||
canvas = np.full((imgsz, imgsz, 3), _PAD_VALUE, dtype=np.uint8)
|
||||
top = int(round(dh))
|
||||
left = int(round(dw))
|
||||
canvas[top : top + pix.height, left : left + pix.width] = arr
|
||||
|
||||
out = canvas.astype(np.float32) / 255.0
|
||||
return out.transpose(2, 0, 1)[np.newaxis] # (1, 3, imgsz, imgsz)
|
||||
|
||||
|
||||
def _model_to_pdf(
|
||||
model_x: float, model_y: float, dw: float, dh: float, ratio: float
|
||||
) -> tuple[float, float]:
|
||||
"""模型 imgsz 空间坐标 → PDF 点:(model - padding) / ratio。"""
|
||||
return (model_x - dw) / ratio, (model_y - dh) / ratio
|
||||
|
||||
|
||||
# ── 后处理 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _postprocess_output(
|
||||
output: np.ndarray, threshold: float, names: dict[int, str]
|
||||
) -> list[tuple[int, float, float, float, float]]:
|
||||
"""解析 YOLOv10 end-to-end 输出,过滤 conf < threshold。
|
||||
|
||||
Args:
|
||||
output: session.run 返回的第一个输出,shape [1, N, 6]
|
||||
threshold: 置信度阈值
|
||||
names: class id → name(仅用于日志,过滤不依赖)
|
||||
|
||||
Returns:
|
||||
[(cls_id, x1, y1, x2, y2), ...],坐标为模型 imgsz padded 空间。
|
||||
"""
|
||||
out = output[0] # 去 batch 维
|
||||
if out.ndim != 2 or out.shape[1] != 6:
|
||||
logger.warning(
|
||||
"Unexpected DocLayout-YOLO output shape %s (expected [N,6]); skip page",
|
||||
tuple(out.shape),
|
||||
)
|
||||
return []
|
||||
|
||||
results: list[tuple[int, float, float, float, float]] = []
|
||||
for row in out:
|
||||
x1, y1, x2, y2, conf, cls = row.tolist()
|
||||
if conf < threshold:
|
||||
continue
|
||||
results.append((int(cls), x1, y1, x2, y2))
|
||||
return results
|
||||
|
||||
|
||||
def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None:
|
||||
"""按 class name 匹配 figure→picture / table→table,其余返回 None。"""
|
||||
name = names.get(cls_id, "")
|
||||
n = name.strip().lower()
|
||||
if n in _PICTURE_NAMES:
|
||||
return "picture"
|
||||
if n in _TABLE_NAMES:
|
||||
return "table"
|
||||
return None
|
||||
|
||||
|
||||
def _parse_names_from_meta(session: ort.InferenceSession) -> dict[int, str]:
|
||||
"""从 ONNX metadata 读 names(ultralytics 导出写入的 JSON),读不到用兜底。"""
|
||||
raw = None
|
||||
try:
|
||||
raw = session.get_modelmeta().custom_metadata_map.get("names")
|
||||
except Exception:
|
||||
raw = None
|
||||
if raw:
|
||||
try:
|
||||
d = json.loads(raw)
|
||||
return {int(k): str(v) for k, v in d.items()}
|
||||
except Exception:
|
||||
logger.warning("Failed to parse ONNX names metadata; using fallback")
|
||||
return dict(_FALLBACK_NAMES)
|
||||
|
||||
|
||||
# ── 检测器单例 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _LayoutDetector:
|
||||
@@ -57,6 +225,9 @@ class _LayoutDetector:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._session: ort.InferenceSession | None = None
|
||||
self._names: dict[int, str] = {}
|
||||
self._input_name: str = ""
|
||||
self._imgsz: int = settings.LAYOUT_IMGSZ
|
||||
|
||||
def _init_session(self) -> ort.InferenceSession:
|
||||
if self._session is not None:
|
||||
@@ -66,97 +237,95 @@ class _LayoutDetector:
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Layout model not found: {model_path}. "
|
||||
"Run scripts/export_picodet_onnx.py first."
|
||||
"Run scripts/export_doclayout_yolo_onnx.py first."
|
||||
)
|
||||
|
||||
logger.info("Loading ONNX layout model: %s", model_path)
|
||||
self._session = ort.InferenceSession(
|
||||
str(model_path), providers=["CPUExecutionProvider"]
|
||||
eps = resolve_providers(settings.LAYOUT_DEVICE, settings.LAYOUT_DEVICE_ID)
|
||||
logger.info(
|
||||
"Loading layout model %s, candidate providers: %s",
|
||||
model_path,
|
||||
[ep[0] for ep in eps],
|
||||
)
|
||||
logger.info("ONNX layout model loaded")
|
||||
|
||||
# 逐个 EP 尝试,首个不可用则降级
|
||||
last_err: Exception | None = None
|
||||
for idx, (ep_name, ep_opts) in enumerate(eps):
|
||||
try:
|
||||
self._session = ort.InferenceSession(
|
||||
str(model_path), providers=[(ep_name, ep_opts)]
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
if idx < len(eps) - 1:
|
||||
logger.warning(
|
||||
"Provider %s unavailable (%s); falling back to %s",
|
||||
ep_name,
|
||||
e,
|
||||
eps[idx + 1][0],
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Failed to create layout session: {last_err}")
|
||||
|
||||
logger.info(
|
||||
"Layout session active providers: %s", self._session.get_providers()
|
||||
)
|
||||
self._input_name = self._session.get_inputs()[0].name
|
||||
self._names = _parse_names_from_meta(self._session)
|
||||
self._imgsz = settings.LAYOUT_IMGSZ
|
||||
return self._session
|
||||
|
||||
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||||
"""检测单页 PDF 的 figure / table 区域。
|
||||
|
||||
流程:
|
||||
1. pymupdf 以 480×480 渲染页面
|
||||
2. ImageNet normalize → NCHW
|
||||
3. ONNX 推理 → 得到已解码+NMS 的检测框
|
||||
4. 像素坐标 → PDF 点坐标
|
||||
5. 过滤 seal 类和低置信度框
|
||||
|
||||
Args:
|
||||
page: pymupdf Page 对象
|
||||
1. letterbox 渲染:保比例缩放到长边=imgsz,短边留灰边
|
||||
2. /255 + NCHW → ONNX 推理
|
||||
3. YOLOv10 end-to-end 后处理(已 NMS)
|
||||
4. 模型坐标 → PDF 点
|
||||
5. 过滤非 figure/table 类、极小框、越界 clip
|
||||
|
||||
Returns:
|
||||
LayoutBox 列表,坐标为 PDF 点
|
||||
LayoutBox 列表,坐标为 PDF 点。
|
||||
"""
|
||||
session = self._init_session()
|
||||
|
||||
page_w = page.rect.width
|
||||
page_h = page.rect.height
|
||||
ratio = _compute_render_geometry(page_w, page_h, self._imgsz)
|
||||
|
||||
# 1. 渲染页面到 _MODEL_SIZE × _MODEL_SIZE
|
||||
zoom_x = _MODEL_SIZE / page_w
|
||||
zoom_y = _MODEL_SIZE / page_h
|
||||
mat = pymupdf.Matrix(zoom_x, zoom_y)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
# 2. 预处理
|
||||
img = (
|
||||
np.frombuffer(pix.samples, dtype=np.uint8)
|
||||
.reshape(pix.height, pix.width, pix.n)
|
||||
.astype(np.float32)
|
||||
/ 255.0
|
||||
# 1. 保比例渲染(长边贴 imgsz)
|
||||
pix = page.get_pixmap(
|
||||
matrix=pymupdf.Matrix(ratio, ratio),
|
||||
colorspace=pymupdf.csRGB,
|
||||
alpha=False,
|
||||
)
|
||||
# 去掉 alpha 通道(如有)
|
||||
if img.shape[2] == 4:
|
||||
img = img[:, :, :3]
|
||||
img = (img - _MEAN) / _STD
|
||||
img = img.transpose(2, 0, 1)[np.newaxis] # (1, 3, H, W)
|
||||
# 用 pixmap 实际尺寸(已取整)算 padding,消除取整导致的坐标偏移
|
||||
dw, dh = _letterbox_padding(pix.width, pix.height, self._imgsz)
|
||||
tensor = _padded_nchw_from_pixmap(pix, self._imgsz, dw, dh)
|
||||
|
||||
# scale_factor 用于坐标还原(模型内部可能用)
|
||||
scale_factor = np.array([[1.0, 1.0]], dtype=np.float32)
|
||||
|
||||
# 3. 推理
|
||||
input_names = [i.name for i in session.get_inputs()]
|
||||
feed = {input_names[0]: img}
|
||||
if len(input_names) > 1:
|
||||
feed[input_names[1]] = scale_factor
|
||||
|
||||
outputs = session.run(None, feed)
|
||||
boxes_raw = outputs[0] # (N, 6): [class_id, score, xmin, ymin, xmax, ymax]
|
||||
num_boxes = int(outputs[1][0]) # 有效框数
|
||||
|
||||
if num_boxes == 0:
|
||||
return []
|
||||
|
||||
# 4. 像素 → PDF 点坐标
|
||||
sx = page_w / _MODEL_SIZE
|
||||
sy = page_h / _MODEL_SIZE
|
||||
# 2. 推理
|
||||
outputs = session.run(None, {self._input_name: tensor})
|
||||
detections = _postprocess_output(
|
||||
outputs[0], settings.LAYOUT_THRESHOLD, self._names
|
||||
)
|
||||
|
||||
# 3. 坐标还原 + 过滤
|
||||
result: list[LayoutBox] = []
|
||||
for i in range(min(num_boxes, len(boxes_raw))):
|
||||
cls_id, score, xmin, ymin, xmax, ymax = boxes_raw[i]
|
||||
cls_id = int(cls_id)
|
||||
|
||||
# 跳过 seal 类和低置信度
|
||||
if cls_id not in _LABEL_MAP:
|
||||
for cls_id, x1m, y1m, x2m, y2m in detections:
|
||||
boxclass = _map_class_to_boxclass(cls_id, self._names)
|
||||
if boxclass is None:
|
||||
continue
|
||||
if score < settings.LAYOUT_THRESHOLD:
|
||||
continue
|
||||
|
||||
x0, y0 = xmin * sx, ymin * sy
|
||||
x1, y1 = xmax * sx, ymax * sy
|
||||
|
||||
# 跳过极小区域
|
||||
x0, y0 = _model_to_pdf(x1m, y1m, dw, dh, ratio)
|
||||
x1, y1 = _model_to_pdf(x2m, y2m, dw, dh, ratio)
|
||||
# clip 到页面范围
|
||||
x0 = max(0.0, min(x0, page_w))
|
||||
y0 = max(0.0, min(y0, page_h))
|
||||
x1 = max(0.0, min(x1, page_w))
|
||||
y1 = max(0.0, min(y1, page_h))
|
||||
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
|
||||
continue
|
||||
|
||||
result.append(
|
||||
LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=_LABEL_MAP[cls_id])
|
||||
)
|
||||
result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -71,7 +71,9 @@ async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
|
||||
|
||||
try:
|
||||
session = _get_session()
|
||||
resp = session.get(pdf_url, timeout=settings.PDF_DOWNLOAD_TIMEOUT, allow_redirects=True)
|
||||
resp = session.get(
|
||||
pdf_url, timeout=settings.PDF_DOWNLOAD_TIMEOUT, allow_redirects=True
|
||||
)
|
||||
resp.raise_for_status()
|
||||
dest.write_bytes(resp.content)
|
||||
except Exception as exc:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""PDF 图片与表格提取 — 两阶段流水线。
|
||||
|
||||
Phase 1: PicoDet-S_layout_3cls 检测 figure/table 区域 → 渲染为 JPEG(通用标签)
|
||||
Phase 1: DocLayout-YOLO 检测 figure/table 区域 → 渲染为 JPEG(通用标签)
|
||||
Phase 2: 用 LLM summary 的 figures[].id 在 PDF 中搜索定位 → 匹配到 box → 重命名
|
||||
|
||||
相比旧方案(正则匹配 caption):
|
||||
|
||||
@@ -80,13 +80,17 @@ async def call_pi(
|
||||
actual_mode = "search"
|
||||
logger.info(
|
||||
"Auto mode: %s text=%d chars > %dk → search",
|
||||
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
|
||||
arxiv_id,
|
||||
txt_size,
|
||||
_PDF_MAX_CHARS // 1000,
|
||||
)
|
||||
else:
|
||||
actual_mode = "inject"
|
||||
logger.info(
|
||||
"Auto mode: %s text=%d chars ≤ %dk → inject",
|
||||
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
|
||||
arxiv_id,
|
||||
txt_size,
|
||||
_PDF_MAX_CHARS // 1000,
|
||||
)
|
||||
|
||||
# inject 模式需要截断过长的文本(避免撑爆 context)
|
||||
|
||||
@@ -225,7 +225,7 @@ def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
|
||||
"""从 PDF 提取图片和表格(失败不影响总结)。
|
||||
|
||||
两阶段流水线:
|
||||
1. PicoDet 检测 + 渲染截图(通用标签)
|
||||
1. DocLayout-YOLO 检测 + 渲染截图(通用标签)
|
||||
2. 用 summary 的 figures ID 在 PDF 中搜索定位 → 重命名
|
||||
"""
|
||||
try:
|
||||
|
||||
+24
-19
@@ -4,37 +4,42 @@ data = {
|
||||
"arxiv_id": "2602.21760",
|
||||
"title_zh": "基于条件引导调度的混合数据-流水线并行加速扩散模型",
|
||||
"one_line": "提出混合并行框架,通过条件划分与自适应流水线切换加速扩散推理,实现2.31倍提速。",
|
||||
"tags": ["Diffusion Models", "Distributed Inference", "Parallel Computing", "Image Generation"],
|
||||
"tags": [
|
||||
"Diffusion Models",
|
||||
"Distributed Inference",
|
||||
"Parallel Computing",
|
||||
"Image Generation",
|
||||
],
|
||||
"difficulty": "进阶",
|
||||
"prerequisites": {
|
||||
"concepts": [
|
||||
{
|
||||
"term": "Diffusion Models",
|
||||
"explanation": "扩散模型是一类基于去噪过程的生成模型。在正向过程中,它逐渐向数据添加高斯噪声直到变成纯噪声;在反向过程中,模型学习逐步去噪以恢复原始数据。这种迭代特性虽然能生成高质量的样本,但也导致了高昂的推理计算成本。",
|
||||
"why_matters": "理解扩散模型的迭代去噪机制是理解本文如何通过并行化减少推理延迟的基础。"
|
||||
"why_matters": "理解扩散模型的迭代去噪机制是理解本文如何通过并行化减少推理延迟的基础。",
|
||||
},
|
||||
{
|
||||
"term": "Classifier-Free Guidance (CFG)",
|
||||
"explanation": "无分类器引导是一种在推理时提升生成样本与文本条件一致性的技术。模型同时预测有条件噪声(给定文本提示)和无条件噪声(不给定提示),最终通过加权组合两者来获得最终预测。公式为 $\epsilon_{cfg} = \epsilon_\theta(x_t, c, t) + w (\epsilon_\theta(x_t, c, t) - \epsilon_\theta(x_t, t))$,其中 $w$ 是引导强度。",
|
||||
"why_matters": "本文的核心创新点在于利用CFG中存在的有条件和无条件双路径作为数据划分的基础。"
|
||||
"why_matters": "本文的核心创新点在于利用CFG中存在的有条件和无条件双路径作为数据划分的基础。",
|
||||
},
|
||||
{
|
||||
"term": "Distributed Inference",
|
||||
"explanation": "分布式推理利用多个GPU并行处理计算任务以减少延迟。主要分为数据并行(如将图像切片处理)和流水线并行(如将模型层切分)。然而,现有的分布式方法在扩散模型中往往面临通信开销大或生成图像出现拼接伪影的问题。",
|
||||
"why_matters": "本文提出的混合并行框架正是为了解决现有分布式推理方法中的这些痛点。"
|
||||
}
|
||||
"why_matters": "本文提出的混合并行框架正是为了解决现有分布式推理方法中的这些痛点。",
|
||||
},
|
||||
]
|
||||
},
|
||||
"motivation": {
|
||||
"problem": "现有的扩散模型加速方法,无论是单卡优化(如减少采样步数、模型剪枝)还是多卡分布式并行(如DistriFusion和AsyncDiff),都存在明显的局限性。单卡优化受限于硬件算力上限,而现有多卡并行方法通常只能实现次线性的加速比。例如,DistriFusion将图像切片并行处理,容易在拼接处产生明显的伪影;AsyncDiff采用异步流水线,虽然加速了但会引入估计误差,且通信开销巨大(在SDXL上高达9.83GB)。",
|
||||
"goal": "本文旨在提出一种新颖的混合并行框架,在仅使用两张GPU的情况下,不仅能实现超过线性的加速比(即 $>2\times$),还要严格保持甚至提升生成图像的质量,同时将通信开销降到最低。",
|
||||
"gap": "与以往将图像空间切片(Patch-based)的思路不同,本文独辟蹊径,利用无分类器引导(CFG)中天然存在的“有条件”和“无条件”两条路径作为新的数据划分维度(Condition-based Partitioning)。同时,作者发现这两条路径的预测误差差异在整个去噪过程中呈现出先大后小再变大的U型曲线,因此引入了自适应的并行切换策略,只在误差差异最小时才进行并行流水线处理。"
|
||||
"gap": "与以往将图像空间切片(Patch-based)的思路不同,本文独辟蹊径,利用无分类器引导(CFG)中天然存在的“有条件”和“无条件”两条路径作为新的数据划分维度(Condition-based Partitioning)。同时,作者发现这两条路径的预测误差差异在整个去噪过程中呈现出先大后小再变大的U型曲线,因此引入了自适应的并行切换策略,只在误差差异最小时才进行并行流水线处理。",
|
||||
},
|
||||
"method": {
|
||||
"overview": "该框架的核心思想是将扩散推理过程划分为三个阶段:预热阶段(Warm-Up)、并行阶段(Parallelism)和完全连接阶段(Fully-Connecting)。在预热和完全连接阶段,使用“基于条件的划分”策略,即一张GPU处理有条件的预测,另一张处理无条件的预测。而在中间的并行阶段,由于两个预测结果非常接近,框架切换到“自适应流水线并行”,利用两张GPU交替执行推理步骤,从而大幅压缩时间。",
|
||||
"key_idea": "核心创新在于不再将图片在空间上切片,而是沿“条件”维度切分数据。这保证了每个GPU都能看到整张图片的全局信息,从而避免了拼接伪影。此外,引入了“去噪差异度”(Denoising Discrepancy,即 rel-MAE)这一指标来动态评估两条路径的相似性,并以此自动决定何时开启和关闭流水线并行,实现了最优的加速-质量平衡。",
|
||||
"steps": "1. 数据划分:输入潜变量同时送入GPU 1(有条件预测 $\epsilon_\theta(x_t, c, t)$)和GPU 2(无条件预测 $\epsilon_\theta(x_t, t)$)。2. 阶段判断:根据实时计算的“去噪差异度” $G_t$ 与阈值 $g_{slope}$ 的关系,确定切换点 $\tau_1$ 和 $\tau_2$。3. 混合执行:在 $[T, \tau_1]$ 阶段同步运行;在 $[\tau_1, \tau_2]$ 阶段启用流水线并行(如GPU 1处理 $t-1$ 步时GPU 2处理 $t$ 步);在 $[\tau_2, 0]$ 阶段重新恢复同步以精细调整细节。",
|
||||
"novelty": "该方法的另一大新颖之处在于其“安全性”设计:通过设置 $\tau_{cap}$ 作为安全上限,确保即使自动算法失效,也不会在错误的时间点引入并行,从而保证了算法的鲁棒性。此外,该框架对U-Net(如SDXL)和DiT(如SD3)架构均具有良好的泛化性。"
|
||||
"novelty": "该方法的另一大新颖之处在于其“安全性”设计:通过设置 $\tau_{cap}$ 作为安全上限,确保即使自动算法失效,也不会在错误的时间点引入并行,从而保证了算法的鲁棒性。此外,该框架对U-Net(如SDXL)和DiT(如SD3)架构均具有良好的泛化性。",
|
||||
},
|
||||
"results": {
|
||||
"main_findings": "实验在SDXL和SD3模型上进行,使用MS-COCO 2014验证集。结果显示,在SDXL上,该方法实现了2.31倍加速,延迟从16.49秒降至7.12秒,且FID指标与原始单卡模型持平(甚至略优)。相比此前最强的DistriFusion(1.22倍)和AsyncDiff(1.31倍),提速效果显著。在通信开销方面,本方法仅为0.516GB,比AsyncDiff的9.83GB降低了19.6倍。在SD3模型上,同样实现了2.07倍的加速。",
|
||||
@@ -44,29 +49,29 @@ data = {
|
||||
"metric": "Speed-Up",
|
||||
"this_work": "2.31x",
|
||||
"baseline": "1.31x (AsyncDiff)",
|
||||
"improvement": "1.0x (Extra speed)"
|
||||
"improvement": "1.0x (Extra speed)",
|
||||
},
|
||||
{
|
||||
"task": "Text-to-Image (SDXL)",
|
||||
"metric": "Comm. (GB)",
|
||||
"this_work": "0.516",
|
||||
"baseline": "9.830 (AsyncDiff)",
|
||||
"improvement": "Reduced by 19.6x"
|
||||
"improvement": "Reduced by 19.6x",
|
||||
},
|
||||
{
|
||||
"task": "Text-to-Image (SD3)",
|
||||
"metric": "Speed-Up",
|
||||
"this_work": "2.07x",
|
||||
"baseline": "1.97x (AsyncDiff)",
|
||||
"improvement": "0.1x (Extra speed)"
|
||||
}
|
||||
"improvement": "0.1x (Extra speed)",
|
||||
},
|
||||
],
|
||||
"limitations": "尽管该方法在通用性上表现出色,但在处理极高分辨率(如4K以上)时,加速比会随分辨率提升而有所下降(从2.72x降至1.62x)。此外,目前的实现仅针对两张GPU进行了深度优化,虽然文中提出了多卡扩展策略,但在单个样本推理场景下,如何高效地扩展到四卡或更多卡仍是一个挑战。最后,参数 $k$ 的选取目前仍需人工根据经验设定。"
|
||||
"limitations": "尽管该方法在通用性上表现出色,但在处理极高分辨率(如4K以上)时,加速比会随分辨率提升而有所下降(从2.72x降至1.62x)。此外,目前的实现仅针对两张GPU进行了深度优化,虽然文中提出了多卡扩展策略,但在单个样本推理场景下,如何高效地扩展到四卡或更多卡仍是一个挑战。最后,参数 $k$ 的选取目前仍需人工根据经验设定。",
|
||||
},
|
||||
"improvements": {
|
||||
"weaknesses": "主要弱点在于自适应切换参数(如 $k$ 和 $\tau_{cap}$)的确定目前仍偏向经验性,缺乏完全自动化的端到端学习机制。此外,虽然避免了图像切片,但条件分支的“信息量”并不总是完全对等的,特别是在极早期的噪声阶段,可能导致其中一张GPU负载不均衡。改进方向可以是结合动态负载均衡算法,根据当前步骤的预测难度动态分配计算资源。",
|
||||
"future_work": "未来的研究方向包括:1. 将该混合并行策略扩展到视频生成模型(Video Diffusion)中,利用时间轴上的相关性进行更细粒度的流水线调度。2. 结合模型量化(Quantization)和蒸馏技术,在多卡并行的基础上进一步压缩单步推理时间。3. 探索在“去噪差异度”指标指导下自动学习最优的 $k$ 值和切换点。",
|
||||
"reproducibility": "代码已在GitHub开源(https://github.com/kaist-dmlab/Hybridiff)。实验环境基于PyTorch,使用的GPU为NVIDIA GeForce 3090,硬件门槛相对较低。文中详细列出了关键超参数(如SDXL上的 $L=12, k=5, \tau_{cap}=15$),使得复现结果的难度较低。"
|
||||
"reproducibility": "代码已在GitHub开源(https://github.com/kaist-dmlab/Hybridiff)。实验环境基于PyTorch,使用的GPU为NVIDIA GeForce 3090,硬件门槛相对较低。文中详细列出了关键超参数(如SDXL上的 $L=12, k=5, \tau_{cap}=15$),使得复现结果的难度较低。",
|
||||
},
|
||||
"figures": [
|
||||
{
|
||||
@@ -74,30 +79,30 @@ data = {
|
||||
"caption": "Summary of the proposed hybrid data-pipeline parallelism",
|
||||
"description": "五维雷达图展示了该方法在速度、图像质量、通用性、高分辨率能力和通信开销五个方面均优于现有分布式框架。",
|
||||
"reason": "直观概括了本文的核心优势,即全方位的性能提升。",
|
||||
"section": "results"
|
||||
"section": "results",
|
||||
},
|
||||
{
|
||||
"id": "Figure 2",
|
||||
"caption": "Comparison of parallel strategies",
|
||||
"description": "对比了三种并行策略:(a)基于切片的数据并行容易产生伪影,(b)流水线并行通信开销大,(c)本文提出的混合并行既保留全局一致性又实现了高效并行。",
|
||||
"reason": "通过对比展示了本文方法设计的合理性和必要性。",
|
||||
"section": "method"
|
||||
"section": "method",
|
||||
},
|
||||
{
|
||||
"id": "Figure 3",
|
||||
"caption": "Overview of the hybrid parallel framework",
|
||||
"description": "详细展示了三个阶段(Warm-Up, Parallelism, Fully-Connecting)的数据流和通信模式,清晰地说明了自适应切换的动态过程。",
|
||||
"reason": "这是理解整个算法执行流程的关键示意图。",
|
||||
"section": "method"
|
||||
"section": "method",
|
||||
},
|
||||
{
|
||||
"id": "Table 1",
|
||||
"caption": "Quantitative comparison on SDXL and SD3",
|
||||
"description": "表格列出了该方法与基线方法在延迟、加速比、通信开销及生成质量指标(FID, LPIPS, PSNR)上的详细对比数据。",
|
||||
"reason": "提供了最核心的定量证据,证明了该方法的有效性。",
|
||||
"section": "results"
|
||||
}
|
||||
]
|
||||
"section": "results",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
with open("data/papers/2602.21760/summary.json", "w", encoding="utf-8") as f:
|
||||
|
||||
@@ -27,6 +27,18 @@ dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.24",
|
||||
]
|
||||
# 导出 DocLayout-YOLO ONNX 用(一次性脚本 scripts/export_doclayout_yolo_onnx.py,独立 venv 运行)
|
||||
# GPU 推理:onnxruntime 与 onnxruntime-gpu/-directml 同环境冲突,不在此声明,
|
||||
# 需手动二选一(见 .env.example 布局检测段说明)
|
||||
export = [
|
||||
"torch>=2.0",
|
||||
"torchvision>=0.15",
|
||||
"doclayout-yolo",
|
||||
"onnx>=1.14",
|
||||
"onnxscript", # torch 2.12+ 的 onnx exporter 需要
|
||||
"onnxsim",
|
||||
"huggingface-hub>=0.20",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
"""导出 DocLayout-YOLO (DocStructBench, imgsz=1024) 为 ONNX 格式.
|
||||
|
||||
一次性脚本,在独立 venv 中运行(不进运行时依赖):
|
||||
python -m venv .venv-export && source .venv-export/bin/activate
|
||||
pip install torch torchvision onnx onnxscript onnxsim onnxruntime huggingface-hub doclayout-yolo
|
||||
|
||||
两种权重来源:
|
||||
# 1) 用本地已下载的 .pt(推荐,省下载)
|
||||
.venv/bin/python scripts/export_doclayout_yolo_onnx.py \
|
||||
--weights /path/to/doclayout_yolo_docstructbench_imgsz1024.pt
|
||||
|
||||
# 2) 从 HuggingFace 下载(不传 --weights)
|
||||
HF_ENDPOINT=https://hf-mirror.com .venv/bin/python scripts/export_doclayout_yolo_onnx.py
|
||||
|
||||
输出:
|
||||
data/models/doclayout_yolo_docstructbench_imgsz1024.onnx
|
||||
|
||||
注意:model.export(simplify=True) 会清空 ONNX metadata,本脚本在导出后
|
||||
用 onnx 包把 names 重新写回 metadata,供运行时 _parse_names_from_meta 读取。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# hf-mirror(国内加速,仅 --weights 未传、走 HF 下载时生效)
|
||||
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
MODEL_DIR = PROJECT_ROOT / "data" / "models"
|
||||
DEFAULT_OUTPUT = MODEL_DIR / "doclayout_yolo_docstructbench_imgsz1024.onnx"
|
||||
REPO_ID = "juliozhao/DocLayout-YOLO-DocStructBench"
|
||||
PT_FILENAME = "doclayout_yolo_docstructbench.pt"
|
||||
IMGSZ = 1024
|
||||
|
||||
|
||||
def resolve_weights(arg: str | None) -> Path:
|
||||
"""返回 .pt 路径:传 --weights 用本地,否则从 HuggingFace 下载。"""
|
||||
if arg:
|
||||
p = Path(arg)
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"--weights not found: {p}")
|
||||
print(f"[1/5] Using local weights: {p}")
|
||||
return p
|
||||
|
||||
print(f"[1/5] Downloading .pt from HuggingFace ({REPO_ID}) ...")
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
pt_path = Path(hf_hub_download(repo_id=REPO_ID, filename=PT_FILENAME))
|
||||
print(f" ✓ {pt_path}")
|
||||
return pt_path
|
||||
|
||||
|
||||
def export_onnx(pt_path: Path, output: Path) -> None:
|
||||
print("\n[2/5] Loading model with doclayout_yolo ...")
|
||||
from doclayout_yolo import YOLOv10
|
||||
|
||||
model = YOLOv10(str(pt_path))
|
||||
names = model.names # dict[int, str],与 model.model.names 等价
|
||||
print(f" ✓ Loaded. names = {names}")
|
||||
|
||||
print(f"\n[3/5] Exporting ONNX (imgsz={IMGSZ}, opset=12, simplify=True) ...")
|
||||
try:
|
||||
exported = model.export(
|
||||
format="onnx",
|
||||
imgsz=IMGSZ,
|
||||
opset=12,
|
||||
simplify=True, # 需要 onnxsim;失败则下面回退
|
||||
dynamic=False, # 固定 batch=1 + 固定 1024,部署最稳
|
||||
half=False, # FP32,保证 CPU 推理精度
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" ⚠ export with simplify failed ({e}); retrying without simplify")
|
||||
exported = model.export(
|
||||
format="onnx", imgsz=IMGSZ, opset=12, dynamic=False, half=False
|
||||
)
|
||||
exported_path = Path(exported)
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(exported_path), str(output))
|
||||
print(f" ✓ Exported → {output} ({output.stat().st_size / 1024 / 1024:.1f} MB)")
|
||||
|
||||
print("\n[4/5] Re-writing names metadata (simplify may have dropped it) ...")
|
||||
write_names_metadata(output, names)
|
||||
|
||||
|
||||
def write_names_metadata(onnx_path: Path, names: dict) -> None:
|
||||
"""把 names dict 写入 ONNX model.metadata_props(simplify 后通常丢失)。"""
|
||||
import onnx
|
||||
|
||||
m = onnx.load(str(onnx_path))
|
||||
keep = [p for p in m.metadata_props if p.key != "names"]
|
||||
del m.metadata_props[:]
|
||||
m.metadata_props.extend(keep)
|
||||
names_json = json.dumps({str(k): v for k, v in names.items()}, ensure_ascii=False)
|
||||
m.metadata_props.append(onnx.StringStringEntryProto(key="names", value=names_json))
|
||||
onnx.save(m, str(onnx_path))
|
||||
print(f" ✓ names metadata written: {names_json}")
|
||||
|
||||
|
||||
def inspect_onnx(onnx_path: Path) -> None:
|
||||
"""用 onnxruntime 加载模型,打印输入输出 + names metadata + 试推理。"""
|
||||
print("\n[5/5] Verifying with onnxruntime ...")
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
|
||||
print(" Inputs:")
|
||||
for inp in session.get_inputs():
|
||||
print(f" {inp.name}: shape={inp.shape}, dtype={inp.type}")
|
||||
print(" Outputs:")
|
||||
for out in session.get_outputs():
|
||||
print(f" {out.name}: shape={out.shape}, dtype={out.type}")
|
||||
|
||||
meta = session.get_modelmeta()
|
||||
print(f" metadata keys: {list(meta.custom_metadata_map.keys())}")
|
||||
print(f" names: {meta.custom_metadata_map.get('names')}")
|
||||
|
||||
# dummy 推理
|
||||
input_info = session.get_inputs()[0]
|
||||
h = input_info.shape[2] if isinstance(input_info.shape[2], int) else IMGSZ
|
||||
w = input_info.shape[3] if isinstance(input_info.shape[3], int) else IMGSZ
|
||||
dummy = np.random.rand(1, 3, h, w).astype(np.float32)
|
||||
outputs = session.run(None, {input_info.name: dummy})
|
||||
print(f" Inference test: output[0] shape = {outputs[0].shape}")
|
||||
|
||||
out_shape = outputs[0].shape
|
||||
if len(out_shape) == 3 and out_shape[2] == 6:
|
||||
print(" ✓ output is [1, N, 6] (YOLOv10 end-to-end, NMS applied)")
|
||||
else:
|
||||
print(
|
||||
f" ⚠️ output shape {out_shape} ≠ [1, N, 6]; "
|
||||
"layout_detector._postprocess_output will warn and skip pages — "
|
||||
"adjust export (e.g. end2end/nms) or postprocess.",
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
ap.add_argument("--weights", help="本地 .pt 路径(不传则从 HuggingFace 下载)")
|
||||
ap.add_argument(
|
||||
"--output",
|
||||
default=str(DEFAULT_OUTPUT),
|
||||
help=f"输出 ONNX 路径(默认 {DEFAULT_OUTPUT})",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
output = Path(args.output)
|
||||
pt_path = resolve_weights(args.weights)
|
||||
export_onnx(pt_path, output)
|
||||
inspect_onnx(output)
|
||||
print(f"\n✓ Done! ONNX model saved to {output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,172 +0,0 @@
|
||||
"""导出 PicoDet-S_layout_3cls 为 ONNX 格式.
|
||||
|
||||
一次性脚本,在独立 venv 中运行:
|
||||
python -m venv .venv-export && source .venv-export/bin/activate
|
||||
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple paddlepaddle paddleocr paddle2onnx onnxruntime opencv-python-headless
|
||||
HF_ENDPOINT=https://hf-mirror.com python scripts/export_picodet_onnx.py
|
||||
|
||||
输出:
|
||||
data/models/picodet_layout_3cls.onnx
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# hf-mirror
|
||||
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
MODEL_DIR = PROJECT_ROOT / "data" / "models"
|
||||
OUTPUT_PATH = MODEL_DIR / "picodet_layout_3cls.onnx"
|
||||
MODEL_NAME = "PicoDet-S_layout_3cls"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ── Step 1: 用 PaddleOCR paddle_static 引擎加载模型,触发下载 ──
|
||||
print(f"[1/4] Loading model '{MODEL_NAME}' (paddle_static engine, triggers download) ...")
|
||||
from paddleocr import LayoutDetection
|
||||
|
||||
model = LayoutDetection(
|
||||
model_name=MODEL_NAME,
|
||||
engine="paddle_static",
|
||||
device="cpu",
|
||||
)
|
||||
print(" ✓ Model loaded and cached")
|
||||
|
||||
# ── Step 2: 找到 PaddleX 缓存的 Paddle 模型文件 ────────────────
|
||||
paddlex_cache = Path.home() / ".paddlex"
|
||||
print(f"\n[2/4] Searching Paddle model cache in {paddlex_cache} ...")
|
||||
|
||||
# 搜索 layout 相关的缓存目录
|
||||
candidates = []
|
||||
for d in paddlex_cache.rglob("*"):
|
||||
if d.is_dir() and (d / "inference.pdiparams").exists():
|
||||
# 检查是否是 layout 模型
|
||||
marker = d.name
|
||||
parent_name = d.parent.name
|
||||
if "layout" in marker.lower() or "layout" in parent_name.lower() or "picodet" in marker.lower():
|
||||
candidates.append(d)
|
||||
elif "PicoDet" in str(d):
|
||||
candidates.append(d)
|
||||
|
||||
if not candidates:
|
||||
# 如果没找到明确的 layout 目录,列出所有含 inference.pdiparams 的目录
|
||||
all_model_dirs = [d for d in paddlex_cache.rglob("*") if d.is_dir() and (d / "inference.pdiparams").exists()]
|
||||
print(" No layout-specific dir found. All model dirs with inference.pdiparams:")
|
||||
for d in all_model_dirs:
|
||||
files = [f.name for f in d.iterdir()]
|
||||
print(f" {d} ({', '.join(files)})")
|
||||
if all_model_dirs:
|
||||
# 取最新的(刚下载的)
|
||||
candidates = sorted(all_model_dirs, key=lambda d: (d / "inference.pdiparams").stat().st_mtime, reverse=True)[:1]
|
||||
|
||||
if not candidates:
|
||||
print(" ✗ No cached model found")
|
||||
sys.exit(1)
|
||||
|
||||
model_cache_dir = candidates[0]
|
||||
files_in_dir = list(model_cache_dir.iterdir())
|
||||
print(f" Using: {model_cache_dir}")
|
||||
for f in files_in_dir:
|
||||
print(f" {f.name} ({f.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
# ── Step 3: 用 paddle2onnx 转换 ─────────────────────────────────
|
||||
print("\n[3/4] Converting to ONNX with paddle2onnx ...")
|
||||
tmp_onnx = OUTPUT_PATH.with_suffix(".tmp.onnx")
|
||||
|
||||
# 确定 model_filename
|
||||
pdmodel = model_cache_dir / "inference.pdmodel"
|
||||
has_pdmodel = pdmodel.exists()
|
||||
|
||||
cmd = [
|
||||
sys.executable, "-m", "paddle2onnx",
|
||||
"--model_dir", str(model_cache_dir),
|
||||
"--save_file", str(tmp_onnx),
|
||||
"--opset_version", "11",
|
||||
"--enable_onnx_checker", "True",
|
||||
]
|
||||
if has_pdmodel:
|
||||
cmd.extend(["--model_filename", "inference.pdmodel"])
|
||||
cmd.extend(["--params_filename", "inference.pdiparams"])
|
||||
|
||||
print(f" Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.stdout:
|
||||
print(f" stdout: {result.stdout[:500]}")
|
||||
if result.returncode != 0:
|
||||
print(f" ✗ paddle2onnx failed (exit {result.returncode})")
|
||||
print(f" stderr: {result.stderr[:500]}")
|
||||
|
||||
# 尝试不带 model_filename(combined format)
|
||||
if has_pdmodel:
|
||||
print(" Retrying without explicit model_filename ...")
|
||||
cmd2 = [
|
||||
sys.executable, "-m", "paddle2onnx",
|
||||
"--model_dir", str(model_cache_dir),
|
||||
"--params_filename", "inference.pdiparams",
|
||||
"--save_file", str(tmp_onnx),
|
||||
"--opset_version", "11",
|
||||
]
|
||||
result2 = subprocess.run(cmd2, capture_output=True, text=True)
|
||||
if result2.returncode != 0:
|
||||
print(f" ✗ Retry also failed: {result2.stderr[:500]}")
|
||||
sys.exit(1)
|
||||
|
||||
if not tmp_onnx.exists() or tmp_onnx.stat().st_size < 1000:
|
||||
print(" ✗ ONNX file not created or too small")
|
||||
sys.exit(1)
|
||||
|
||||
shutil.move(str(tmp_onnx), str(OUTPUT_PATH))
|
||||
print(f" ✓ ONNX saved ({OUTPUT_PATH.stat().st_size / 1024 / 1024:.2f} MB)")
|
||||
|
||||
# ── Step 4: 用 onnxruntime 验证 ─────────────────────────────────
|
||||
print("\n[4/4] Verifying with onnxruntime ...")
|
||||
_inspect_onnx(OUTPUT_PATH)
|
||||
|
||||
print(f"\n✓ Done! ONNX model saved to {OUTPUT_PATH}")
|
||||
|
||||
|
||||
def _inspect_onnx(onnx_path: Path) -> None:
|
||||
"""用 onnxruntime 加载模型,打印输入输出信息."""
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
|
||||
|
||||
print(" Inputs:")
|
||||
for inp in session.get_inputs():
|
||||
print(f" {inp.name}: shape={inp.shape}, dtype={inp.type}")
|
||||
|
||||
print(" Outputs:")
|
||||
for out in session.get_outputs():
|
||||
print(f" {out.name}: shape={out.shape}, dtype={out.type}")
|
||||
|
||||
# 试推理
|
||||
input_info = session.get_inputs()[0]
|
||||
input_name = input_info.name
|
||||
batch_size = input_info.shape[0] if isinstance(input_info.shape[0], int) else 1
|
||||
channels = input_info.shape[1] if isinstance(input_info.shape[1], int) else 3
|
||||
height = input_info.shape[2] if isinstance(input_info.shape[2], int) else 480
|
||||
width = input_info.shape[3] if isinstance(input_info.shape[3], int) else 480
|
||||
|
||||
dummy_input = np.random.rand(batch_size, channels, height, width).astype(np.float32)
|
||||
outputs = session.run(None, {input_name: dummy_input})
|
||||
|
||||
print(" Inference test outputs:")
|
||||
for i, (out_info, out_val) in enumerate(zip(session.get_outputs(), outputs)):
|
||||
print(f" output[{i}] '{out_info.name}': shape={out_val.shape}, dtype={out_val.dtype}")
|
||||
if out_val.size <= 20:
|
||||
print(f" values: {out_val}")
|
||||
|
||||
print(" ✓ Inference OK")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,4 +1,4 @@
|
||||
"""批量重新提取所有论文的图片 — 下载 PDF + PicoDet 检测 + caption 匹配.
|
||||
"""批量重新提取所有论文的图片 — 下载 PDF + DocLayout-YOLO 检测 + caption 匹配.
|
||||
|
||||
用法:
|
||||
PROXY_SERVER=http://... uv run python scripts/reextract_images.py
|
||||
|
||||
+68
-38
@@ -3,8 +3,19 @@ import sys
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"required": ["arxiv_id", "title_zh", "one_line", "tags", "difficulty",
|
||||
"prerequisites", "motivation", "method", "results", "improvements", "figures"],
|
||||
"required": [
|
||||
"arxiv_id",
|
||||
"title_zh",
|
||||
"one_line",
|
||||
"tags",
|
||||
"difficulty",
|
||||
"prerequisites",
|
||||
"motivation",
|
||||
"method",
|
||||
"results",
|
||||
"improvements",
|
||||
"figures",
|
||||
],
|
||||
"properties": {
|
||||
"arxiv_id": {"type": "string"},
|
||||
"title_zh": {"type": "string"},
|
||||
@@ -15,16 +26,19 @@ schema = {
|
||||
"type": "object",
|
||||
"required": ["concepts"],
|
||||
"properties": {
|
||||
"concepts": {"type": "array", "items": {
|
||||
"type": "object",
|
||||
"required": ["term", "explanation", "why_matters"],
|
||||
"properties": {
|
||||
"term": {"type": "string"},
|
||||
"explanation": {"type": "string"},
|
||||
"why_matters": {"type": "string"}
|
||||
}
|
||||
}}
|
||||
}
|
||||
"concepts": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["term", "explanation", "why_matters"],
|
||||
"properties": {
|
||||
"term": {"type": "string"},
|
||||
"explanation": {"type": "string"},
|
||||
"why_matters": {"type": "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
"motivation": {
|
||||
"type": "object",
|
||||
@@ -32,8 +46,8 @@ schema = {
|
||||
"properties": {
|
||||
"problem": {"type": "string"},
|
||||
"goal": {"type": "string"},
|
||||
"gap": {"type": "string"}
|
||||
}
|
||||
"gap": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"method": {
|
||||
"type": "object",
|
||||
@@ -42,27 +56,36 @@ schema = {
|
||||
"overview": {"type": "string"},
|
||||
"key_idea": {"type": "string"},
|
||||
"steps": {"type": "string"},
|
||||
"novelty": {"type": "string"}
|
||||
}
|
||||
"novelty": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"results": {
|
||||
"type": "object",
|
||||
"required": ["main_findings", "benchmarks", "limitations"],
|
||||
"properties": {
|
||||
"main_findings": {"type": "string"},
|
||||
"benchmarks": {"type": "array", "items": {
|
||||
"type": "object",
|
||||
"required": ["task", "metric", "this_work", "baseline", "improvement"],
|
||||
"properties": {
|
||||
"task": {"type": "string"},
|
||||
"metric": {"type": "string"},
|
||||
"this_work": {"type": "string"},
|
||||
"baseline": {"type": "string"},
|
||||
"improvement": {"type": "string"}
|
||||
}
|
||||
}},
|
||||
"limitations": {"type": "string"}
|
||||
}
|
||||
"benchmarks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"task",
|
||||
"metric",
|
||||
"this_work",
|
||||
"baseline",
|
||||
"improvement",
|
||||
],
|
||||
"properties": {
|
||||
"task": {"type": "string"},
|
||||
"metric": {"type": "string"},
|
||||
"this_work": {"type": "string"},
|
||||
"baseline": {"type": "string"},
|
||||
"improvement": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"limitations": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"improvements": {
|
||||
"type": "object",
|
||||
@@ -70,8 +93,8 @@ schema = {
|
||||
"properties": {
|
||||
"weaknesses": {"type": "string"},
|
||||
"future_work": {"type": "string"},
|
||||
"reproducibility": {"type": "string"}
|
||||
}
|
||||
"reproducibility": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"figures": {
|
||||
"type": "array",
|
||||
@@ -83,16 +106,20 @@ schema = {
|
||||
"caption": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"reason": {"type": "string"},
|
||||
"section": {"type": "string", "enum": ["motivation", "method", "results", "limitations"]}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"section": {
|
||||
"type": "string",
|
||||
"enum": ["motivation", "method", "results", "limitations"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def validate_file(filepath):
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Check required fields
|
||||
@@ -139,6 +166,9 @@ def validate_file(filepath):
|
||||
print(f"❌ Validation error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
filepath = sys.argv[1] if len(sys.argv) > 1 else "data/papers/2601.10592/summary.json"
|
||||
filepath = (
|
||||
sys.argv[1] if len(sys.argv) > 1 else "data/papers/2601.10592/summary.json"
|
||||
)
|
||||
validate_file(filepath)
|
||||
|
||||
+1
-4
@@ -301,10 +301,7 @@ class TestAdminPapers:
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["count"] == 2
|
||||
remaining = db_session.execute(
|
||||
text(
|
||||
"SELECT rowid FROM papers_fts "
|
||||
"WHERE rowid IN (:id1, :id2)"
|
||||
),
|
||||
text("SELECT rowid FROM papers_fts WHERE rowid IN (:id1, :id2)"),
|
||||
{"id1": target_ids[0], "id2": target_ids[1]},
|
||||
).fetchall()
|
||||
assert remaining == []
|
||||
|
||||
@@ -54,7 +54,9 @@ class TestReindexChroma:
|
||||
def test_reindex_chroma_indexes_only_summarized_papers(
|
||||
self, db_session, sample_papers_with_summary
|
||||
):
|
||||
with patch("app.services.embedder.index_paper", return_value=True) as mock_index:
|
||||
with patch(
|
||||
"app.services.embedder.index_paper", return_value=True
|
||||
) as mock_index:
|
||||
result = reindex_chroma(db_session)
|
||||
|
||||
assert result["status"] == "success"
|
||||
|
||||
@@ -0,0 +1,403 @@
|
||||
"""layout_detector 测试 — 坐标还原数学、设备探测、类别映射、端到端 detect_page.
|
||||
|
||||
纯函数测试不依赖真实模型;TestDetectPage 用 MagicMock mock ort.InferenceSession
|
||||
与 pymupdf.Page,参考 test_summary_utils.py 的 mock 模式。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import pytest
|
||||
|
||||
from app.config import settings
|
||||
from app.services import layout_detector as mod
|
||||
from app.services.layout_detector import (
|
||||
LayoutBox,
|
||||
_compute_render_geometry,
|
||||
_FALLBACK_NAMES,
|
||||
_letterbox_padding,
|
||||
_map_class_to_boxclass,
|
||||
_model_to_pdf,
|
||||
_parse_names_from_meta,
|
||||
_postprocess_output,
|
||||
detect_page_layout,
|
||||
resolve_providers,
|
||||
)
|
||||
|
||||
IMGSZ = 1024
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 渲染几何与 letterbox padding
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestComputeRenderGeometry:
|
||||
def test_a4_portrait_short_edge_pads(self):
|
||||
# A4 595×842,高度贴边,宽度方向留灰边
|
||||
ratio = _compute_render_geometry(595, 842, IMGSZ)
|
||||
assert ratio == pytest.approx(min(IMGSZ / 595, IMGSZ / 842))
|
||||
assert ratio == pytest.approx(IMGSZ / 842) # 高度方向贴边
|
||||
|
||||
def test_wide_page_width_pads(self):
|
||||
# 1600×900 横向,宽度贴边
|
||||
ratio = _compute_render_geometry(1600, 900, IMGSZ)
|
||||
assert ratio == pytest.approx(IMGSZ / 1600)
|
||||
|
||||
def test_square_no_letterbox(self):
|
||||
ratio = _compute_render_geometry(100, 100, IMGSZ)
|
||||
assert ratio == pytest.approx(10.24)
|
||||
|
||||
|
||||
class TestLetterboxPadding:
|
||||
def test_centered_padding(self):
|
||||
# pixmap 723×1024 贴满高度,宽度两侧各 (1024-723)/2
|
||||
dw, dh = _letterbox_padding(723, 1024, IMGSZ)
|
||||
assert dw == pytest.approx((IMGSZ - 723) / 2)
|
||||
assert dh == pytest.approx(0.0)
|
||||
|
||||
def test_square_no_padding(self):
|
||||
dw, dh = _letterbox_padding(IMGSZ, IMGSZ, IMGSZ)
|
||||
assert dw == 0.0
|
||||
assert dh == 0.0
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 坐标还原(核心)—— pdf = (model - padding) / ratio
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestModelToPdf:
|
||||
def test_padding_corner_maps_to_origin(self):
|
||||
# 模型空间左上角 (dw, dh) → PDF (0, 0)
|
||||
dw, dh, ratio = 150.5, 0.0, 1.2157
|
||||
x, y = _model_to_pdf(dw, dh, dw, dh, ratio)
|
||||
assert x == pytest.approx(0.0, abs=0.01)
|
||||
assert y == pytest.approx(0.0, abs=0.01)
|
||||
|
||||
def test_round_trip(self):
|
||||
dw, dh, ratio = 150.5, 0.0, 1.2157
|
||||
# PDF (100, 200) → 模型空间 → 再还原回 PDF
|
||||
mx, my = 100 * ratio + dw, 200 * ratio + dh
|
||||
px, py = _model_to_pdf(mx, my, dw, dh, ratio)
|
||||
assert px == pytest.approx(100, abs=0.01)
|
||||
assert py == pytest.approx(200, abs=0.01)
|
||||
|
||||
def test_full_a4_page_box(self):
|
||||
# 整页框在模型空间为 (dw,dh)-(dw+pix_w, dh+pix_h),还原回页面尺寸
|
||||
ratio = _compute_render_geometry(595, 842, IMGSZ)
|
||||
pix_w, pix_h = round(595 * ratio), round(842 * ratio)
|
||||
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
|
||||
x0, y0 = _model_to_pdf(dw, dh, dw, dh, ratio)
|
||||
x1, y1 = _model_to_pdf(dw + pix_w, dh + pix_h, dw, dh, ratio)
|
||||
assert x0 == pytest.approx(0.0, abs=1.0)
|
||||
assert y0 == pytest.approx(0.0, abs=1.0)
|
||||
assert x1 == pytest.approx(595, abs=1.0)
|
||||
assert y1 == pytest.approx(842, abs=1.0)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 设备探测
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestResolveProviders:
|
||||
def test_cpu(self):
|
||||
assert resolve_providers("cpu", 0) == [("CPUExecutionProvider", {})]
|
||||
|
||||
def test_cuda_with_cpu_fallback(self):
|
||||
eps = resolve_providers("cuda", 0)
|
||||
assert eps[0] == ("CUDAExecutionProvider", {"device_id": "0"})
|
||||
assert eps[1] == ("CPUExecutionProvider", {})
|
||||
|
||||
def test_directml_device_id(self):
|
||||
eps = resolve_providers("directml", 2)
|
||||
assert eps[0] == ("DmlExecutionProvider", {"device_id": "2"})
|
||||
|
||||
def test_auto_picks_cuda_if_available(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
ort,
|
||||
"get_available_providers",
|
||||
lambda: ["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
)
|
||||
eps = resolve_providers("auto", 0)
|
||||
assert eps[0][0] == "CUDAExecutionProvider"
|
||||
assert eps[-1] == ("CPUExecutionProvider", {})
|
||||
|
||||
def test_auto_falls_back_to_cpu(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
ort, "get_available_providers", lambda: ["CPUExecutionProvider"]
|
||||
)
|
||||
assert resolve_providers("auto", 0) == [("CPUExecutionProvider", {})]
|
||||
|
||||
def test_auto_prefers_cuda_over_directml(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
ort,
|
||||
"get_available_providers",
|
||||
lambda: [
|
||||
"DmlExecutionProvider",
|
||||
"CUDAExecutionProvider",
|
||||
"CPUExecutionProvider",
|
||||
],
|
||||
)
|
||||
eps = resolve_providers("auto", 0)
|
||||
assert eps[0][0] == "CUDAExecutionProvider"
|
||||
|
||||
def test_unknown_device_falls_back(self):
|
||||
assert resolve_providers("tpu", 0) == [("CPUExecutionProvider", {})]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 类别映射与 names 解析
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestClassMapping:
|
||||
def test_figure_to_picture(self):
|
||||
assert _map_class_to_boxclass(3, {3: "figure"}) == "picture"
|
||||
|
||||
def test_figure_group_to_picture(self):
|
||||
assert _map_class_to_boxclass(0, {0: "figure_group"}) == "picture"
|
||||
|
||||
def test_table(self):
|
||||
assert _map_class_to_boxclass(5, {5: "table"}) == "table"
|
||||
|
||||
def test_caption_ignored(self):
|
||||
names = {4: "figure_caption", 6: "table_caption"}
|
||||
assert _map_class_to_boxclass(4, names) is None
|
||||
assert _map_class_to_boxclass(6, names) is None
|
||||
|
||||
def test_other_classes_ignored(self):
|
||||
names = {0: "title", 1: "plain text", 2: "abandon", 8: "isolate_formula"}
|
||||
for k in names:
|
||||
assert _map_class_to_boxclass(k, names) is None
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _map_class_to_boxclass(0, {0: "Figure"}) == "picture"
|
||||
assert _map_class_to_boxclass(0, {0: "TABLE"}) == "table"
|
||||
|
||||
def test_unknown_class_id(self):
|
||||
assert _map_class_to_boxclass(99, {0: "figure"}) is None
|
||||
|
||||
|
||||
class TestParseNamesFromMeta:
|
||||
def test_reads_json_metadata(self):
|
||||
sess = MagicMock()
|
||||
meta = MagicMock()
|
||||
meta.custom_metadata_map = {
|
||||
"names": '{"0": "title", "3": "figure", "5": "table"}'
|
||||
}
|
||||
sess.get_modelmeta.return_value = meta
|
||||
assert _parse_names_from_meta(sess) == {0: "title", 3: "figure", 5: "table"}
|
||||
|
||||
def test_fallback_when_missing(self):
|
||||
sess = MagicMock()
|
||||
meta = MagicMock()
|
||||
meta.custom_metadata_map = {}
|
||||
sess.get_modelmeta.return_value = meta
|
||||
assert _parse_names_from_meta(sess) == _FALLBACK_NAMES
|
||||
|
||||
def test_fallback_on_garbage(self):
|
||||
sess = MagicMock()
|
||||
meta = MagicMock()
|
||||
meta.custom_metadata_map = {"names": "not json"}
|
||||
sess.get_modelmeta.return_value = meta
|
||||
assert _parse_names_from_meta(sess) == _FALLBACK_NAMES
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 后处理
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestPostprocessOutput:
|
||||
def test_parses_end_to_end_filters_by_conf(self):
|
||||
out = np.array(
|
||||
[[[10, 20, 30, 40, 0.9, 3], [50, 60, 70, 80, 0.1, 5]]],
|
||||
dtype=np.float32,
|
||||
)
|
||||
res = _postprocess_output(out, 0.2, {3: "figure", 5: "table"})
|
||||
assert res == [(3, 10.0, 20.0, 30.0, 40.0)]
|
||||
|
||||
def test_empty_output(self):
|
||||
out = np.zeros((1, 0, 6), dtype=np.float32)
|
||||
assert _postprocess_output(out, 0.2, {}) == []
|
||||
|
||||
def test_unexpected_shape_returns_empty(self):
|
||||
out = np.zeros((1, 84, 8400), dtype=np.float32)
|
||||
assert _postprocess_output(out, 0.2, {}) == []
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# detect_page 端到端(mock ort.InferenceSession + pymupdf.Page)
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestDetectPage:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_detector(self):
|
||||
"""每个测试前重置模块级单例,避免复用上个测试的 mock session。"""
|
||||
mod._detector = mod._LayoutDetector()
|
||||
yield
|
||||
mod._detector = mod._LayoutDetector()
|
||||
|
||||
@staticmethod
|
||||
def _build_mock_session(page_w, page_h, boxes, names):
|
||||
"""构造 mock InferenceSession。
|
||||
|
||||
boxes: list of (cls_id, pdf_x0, pdf_y0, pdf_x1, pdf_y1, conf)
|
||||
坐标为 PDF 点,内部转成模型空间坐标塞进 output。
|
||||
names: dict[int, str] —— 写入 metadata 供 _parse_names_from_meta 读取。
|
||||
"""
|
||||
ratio = _compute_render_geometry(page_w, page_h, IMGSZ)
|
||||
pix_w, pix_h = round(page_w * ratio), round(page_h * ratio)
|
||||
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
|
||||
|
||||
rows = []
|
||||
for cls_id, x0, y0, x1, y1, conf in boxes:
|
||||
rows.append(
|
||||
[
|
||||
x0 * ratio + dw,
|
||||
y0 * ratio + dh,
|
||||
x1 * ratio + dw,
|
||||
y1 * ratio + dh,
|
||||
conf,
|
||||
cls_id,
|
||||
]
|
||||
)
|
||||
fake_output = (
|
||||
np.array([rows], dtype=np.float32)
|
||||
if rows
|
||||
else np.zeros((1, 0, 6), dtype=np.float32)
|
||||
)
|
||||
|
||||
sess = MagicMock()
|
||||
inp = MagicMock()
|
||||
inp.name = "images"
|
||||
sess.get_inputs.return_value = [inp]
|
||||
sess.run.return_value = [fake_output]
|
||||
sess.get_providers.return_value = ["CPUExecutionProvider"]
|
||||
meta = MagicMock()
|
||||
meta.custom_metadata_map = {
|
||||
"names": json.dumps({str(k): v for k, v in names.items()})
|
||||
}
|
||||
sess.get_modelmeta.return_value = meta
|
||||
return sess, (pix_w, pix_h)
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_page(page_w, page_h, pix_w, pix_h):
|
||||
pix = MagicMock()
|
||||
pix.width = pix_w
|
||||
pix.height = pix_h
|
||||
pix.n = 3
|
||||
pix.samples = bytes([128] * (pix_w * pix_h * 3))
|
||||
page = MagicMock()
|
||||
page.rect.width = page_w
|
||||
page.rect.height = page_h
|
||||
page.get_pixmap.return_value = pix
|
||||
return page
|
||||
|
||||
def _setup(self, monkeypatch, tmp_path, sess):
|
||||
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
|
||||
(tmp_path / "m.onnx").write_bytes(b"x")
|
||||
monkeypatch.setattr(ort, "InferenceSession", lambda *a, **kw: sess)
|
||||
|
||||
def test_returns_picture_box(self, monkeypatch, tmp_path):
|
||||
names = {3: "figure", 5: "table"}
|
||||
sess, (pw, ph) = self._build_mock_session(
|
||||
595, 842, [(3, 100, 100, 300, 400, 0.9)], names
|
||||
)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
boxes = detect_page_layout(page)
|
||||
|
||||
assert len(boxes) == 1
|
||||
b = boxes[0]
|
||||
assert isinstance(b, LayoutBox)
|
||||
assert b.boxclass == "picture"
|
||||
assert b.x0 == pytest.approx(100, abs=1.0)
|
||||
assert b.y0 == pytest.approx(100, abs=1.0)
|
||||
assert b.x1 == pytest.approx(300, abs=1.0)
|
||||
assert b.y1 == pytest.approx(400, abs=1.0)
|
||||
|
||||
def test_returns_table_box(self, monkeypatch, tmp_path):
|
||||
names = {3: "figure", 5: "table"}
|
||||
sess, (pw, ph) = self._build_mock_session(
|
||||
595, 842, [(5, 50, 50, 400, 300, 0.85)], names
|
||||
)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
boxes = detect_page_layout(page)
|
||||
|
||||
assert len(boxes) == 1
|
||||
assert boxes[0].boxclass == "table"
|
||||
|
||||
def test_filters_low_confidence(self, monkeypatch, tmp_path):
|
||||
names = {3: "figure"}
|
||||
# conf=0.1 < LAYOUT_THRESHOLD(0.2) → 过滤
|
||||
sess, (pw, ph) = self._build_mock_session(
|
||||
595, 842, [(3, 100, 100, 300, 400, 0.1)], names
|
||||
)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
assert detect_page_layout(page) == []
|
||||
|
||||
def test_filters_small_box(self, monkeypatch, tmp_path):
|
||||
names = {3: "figure"}
|
||||
# 还原后 5×5 pt < _MIN_BOX_SIZE(20) → 过滤
|
||||
sess, (pw, ph) = self._build_mock_session(
|
||||
595, 842, [(3, 100, 100, 105, 105, 0.9)], names
|
||||
)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
assert detect_page_layout(page) == []
|
||||
|
||||
def test_mixed_picture_and_table(self, monkeypatch, tmp_path):
|
||||
names = {3: "figure", 5: "table"}
|
||||
sess, (pw, ph) = self._build_mock_session(
|
||||
595,
|
||||
842,
|
||||
[
|
||||
(3, 100, 100, 300, 400, 0.9),
|
||||
(5, 50, 500, 400, 700, 0.8),
|
||||
],
|
||||
names,
|
||||
)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
boxes = detect_page_layout(page)
|
||||
classes = sorted(b.boxclass for b in boxes)
|
||||
assert classes == ["picture", "table"]
|
||||
|
||||
def test_empty_output(self, monkeypatch, tmp_path):
|
||||
names = {3: "figure"}
|
||||
sess, (pw, ph) = self._build_mock_session(595, 842, [], names)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
assert detect_page_layout(page) == []
|
||||
|
||||
def test_ignored_class_skipped(self, monkeypatch, tmp_path):
|
||||
# title 类(cls_id=0)不应产出 LayoutBox
|
||||
names = {0: "title", 3: "figure"}
|
||||
sess, (pw, ph) = self._build_mock_session(
|
||||
595,
|
||||
842,
|
||||
[(0, 100, 100, 400, 150, 0.9), (3, 100, 200, 300, 400, 0.9)],
|
||||
names,
|
||||
)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
boxes = detect_page_layout(page)
|
||||
assert len(boxes) == 1
|
||||
assert boxes[0].boxclass == "picture"
|
||||
Reference in New Issue
Block a user