refactor: 迁移布局检测模型从 PicoDet 到 DocLayout-YOLO
- 核心变更: - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024) - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备 - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式 - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls] - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table) - 新增文件: - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行) - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例) - 配置更新: - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID) - app/config.py: Settings 类对应字段 - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等) - 删除旧文件: - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本 - 文档更新: - README.md: 更新环境变量说明 - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py) 此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
This commit is contained in:
+7
-2
@@ -60,8 +60,13 @@ class Settings(BaseSettings):
|
||||
EMBED_DIMENSIONS: int = 0
|
||||
|
||||
# 布局检测
|
||||
LAYOUT_MODEL_PATH: str = "data/models/picodet_layout_3cls.onnx"
|
||||
LAYOUT_THRESHOLD: float = 0.5
|
||||
LAYOUT_MODEL_PATH: str = "data/models/doclayout_yolo_docstructbench_imgsz1024.onnx"
|
||||
LAYOUT_IMGSZ: int = 1024
|
||||
LAYOUT_THRESHOLD: float = 0.2
|
||||
# 推理设备:auto|cpu|cuda|directml|openvino|cann|tensorrt|qnn
|
||||
# auto = 按优先级 [CUDA, DirectML, OpenVINO, CPU] 自动探测
|
||||
LAYOUT_DEVICE: str = "auto"
|
||||
LAYOUT_DEVICE_ID: int = 0
|
||||
|
||||
model_config = {
|
||||
"env_file": str(BASE_DIR / ".env"),
|
||||
|
||||
+8
-1
@@ -10,7 +10,14 @@ from fastapi.staticfiles import StaticFiles
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
|
||||
from app.config import settings
|
||||
from app.exceptions import AppError, ConflictError, ExternalAPIError, NotFoundError, PdfProcessError, ValidationError
|
||||
from app.exceptions import (
|
||||
AppError,
|
||||
ConflictError,
|
||||
ExternalAPIError,
|
||||
NotFoundError,
|
||||
PdfProcessError,
|
||||
ValidationError,
|
||||
)
|
||||
from app.database import engine, init_db
|
||||
from app.routes.admin import router as admin_router
|
||||
from app.routes.compare import router as compare_router
|
||||
|
||||
+9
-1
@@ -6,7 +6,15 @@ import hashlib
|
||||
import hmac
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Form, HTTPException, Query, Request
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
BackgroundTasks,
|
||||
Depends,
|
||||
Form,
|
||||
HTTPException,
|
||||
Query,
|
||||
Request,
|
||||
)
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import date, timedelta
|
||||
|
||||
|
||||
@@ -141,7 +141,9 @@ async def run_job(db: Session, job_id: int) -> dict:
|
||||
status=JobEventStatus.SUCCESS,
|
||||
payload=result if isinstance(result, dict) else {"result": result},
|
||||
)
|
||||
return result if isinstance(result, dict) else {"status": "success", "result": result}
|
||||
return (
|
||||
result if isinstance(result, dict) else {"status": "success", "result": result}
|
||||
)
|
||||
|
||||
|
||||
async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
|
||||
|
||||
+255
-86
@@ -1,19 +1,27 @@
|
||||
"""PicoDet-S_layout_3cls 布局检测 — 纯 ONNX Runtime 推理.
|
||||
"""DocLayout-YOLO 布局检测 — ONNX Runtime 推理,支持 CPU/GPU/NPU 多设备.
|
||||
|
||||
用 onnxruntime 加载导出好的 ONNX 模型,检测 PDF 页面中的 figure / table 区域。
|
||||
模型自带 NMS + GFL decode,输出即为后处理完毕的检测框。
|
||||
用 onnxruntime 加载 DocLayout-YOLO(DocStructBench, imgsz=1024)ONNX 模型,
|
||||
检测 PDF 页面中的 figure / table 区域。
|
||||
|
||||
预处理:letterbox(保比例缩放 + 灰边 padding 到 imgsz×imgsz),RGB,仅 /255 归一化
|
||||
(不做 ImageNet mean/std)。缩放由 pymupdf Matrix 完成,不依赖 OpenCV。
|
||||
后处理:YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls](已内置 NMS)。
|
||||
坐标还原:(model_coord - padding) / ratio —— 渲染缩放与 letterbox 缩放在 pymupdf
|
||||
渲染阶段合二为一,故只需一次除法。
|
||||
|
||||
设备:resolve_providers() 按 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表;
|
||||
_init_session() 逐个 try,首个不可用则降级,CPU 永远兜底。
|
||||
|
||||
输入:
|
||||
image: (1, 3, 480, 480) float32 — ImageNet 标准化后的图片
|
||||
scale_factor: (1, 2) float32 — [y_scale, x_scale],用于坐标还原
|
||||
images: (1, 3, imgsz, imgsz) float32 —— letterbox + /255 后的图
|
||||
|
||||
输出:
|
||||
fetch_name_0: (N, 6) float32 — [xmin, ymin, xmax, ymax, score, class_id]
|
||||
fetch_name_1: (1,) int32 — 有效框数量 N
|
||||
output0: (1, N, 6) float32 —— [x1, y1, x2, y2, conf, cls],已 NMS
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -26,30 +34,190 @@ from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 模型输入尺寸
|
||||
_MODEL_SIZE = 480
|
||||
# ImageNet normalize
|
||||
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||||
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||
# PicoDet label → 内部 boxclass
|
||||
_LABEL_MAP: dict[int, str] = {
|
||||
0: "picture", # PicoDet "image" → "picture"
|
||||
1: "table",
|
||||
# 2: seal — 忽略
|
||||
# DocLayout-YOLO DocStructBench 标准 10 类(ONNX metadata 读不到时的兜底,以实际为准)
|
||||
_FALLBACK_NAMES: dict[int, str] = {
|
||||
0: "title",
|
||||
1: "plain text",
|
||||
2: "abandon",
|
||||
3: "figure",
|
||||
4: "figure_caption",
|
||||
5: "table",
|
||||
6: "table_caption",
|
||||
7: "table_footnote",
|
||||
8: "isolate_formula",
|
||||
9: "formula_caption",
|
||||
}
|
||||
# 下游只需 picture/table —— 按 class name 字符串动态匹配(不依赖 class index,
|
||||
# 规避 DocStructBench 不同发布的类别顺序差异)
|
||||
_PICTURE_NAMES = {"figure", "figure_group"}
|
||||
_TABLE_NAMES = {"table", "table_group"}
|
||||
# letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降)
|
||||
_PAD_VALUE = 114
|
||||
# 最小 bbox 尺寸(PDF 点)
|
||||
_MIN_BOX_SIZE = 20
|
||||
|
||||
# device → ExecutionProvider 映射
|
||||
_PROVIDER_MAP: dict[str, str] = {
|
||||
"cpu": "CPUExecutionProvider",
|
||||
"cuda": "CUDAExecutionProvider",
|
||||
"directml": "DmlExecutionProvider",
|
||||
"openvino": "OpenVINOExecutionProvider",
|
||||
"cann": "CannExecutionProvider",
|
||||
"tensorrt": "TensorrtExecutionProvider",
|
||||
"qnn": "QNNExecutionProvider",
|
||||
}
|
||||
# auto 探测优先级(不含 cpu,cpu 永远兜底)
|
||||
_AUTO_PRIORITY = ["cuda", "directml", "openvino", "cann", "tensorrt", "qnn"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayoutBox:
|
||||
"""检测到的布局区域,兼容现有 _process_page 代码。"""
|
||||
"""检测到的布局区域,坐标为 PDF 点,boxclass ∈ {"picture", "table"}。"""
|
||||
|
||||
x0: float
|
||||
y0: float
|
||||
x1: float
|
||||
y1: float
|
||||
boxclass: str # "picture" | "table"
|
||||
boxclass: str
|
||||
|
||||
|
||||
# ── 设备选择 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def resolve_providers(device: str, device_id: int) -> list[tuple[str, dict]]:
|
||||
"""根据 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表(首选在前,均带 CPU 兜底)。
|
||||
|
||||
返回 list[tuple[ep_name, provider_options]],供 _init_session() 逐个 try。
|
||||
onnxruntime 创建 session 时若指定 EP 在本机变体里未注册会直接抛错,
|
||||
故降级逻辑由 _init_session() 完成,这里只产出候选。
|
||||
"""
|
||||
if device == "cpu":
|
||||
return [("CPUExecutionProvider", {})]
|
||||
|
||||
opts = {"device_id": str(device_id)}
|
||||
|
||||
if device == "auto":
|
||||
available = set(ort.get_available_providers())
|
||||
for dev in _AUTO_PRIORITY:
|
||||
ep = _PROVIDER_MAP[dev]
|
||||
if ep in available:
|
||||
logger.info("auto: selected provider %s", ep)
|
||||
return [(ep, opts), ("CPUExecutionProvider", {})]
|
||||
logger.info("auto: no GPU/NPU provider available, using CPU")
|
||||
return [("CPUExecutionProvider", {})]
|
||||
|
||||
ep = _PROVIDER_MAP.get(device)
|
||||
if ep is None:
|
||||
logger.warning("Unknown LAYOUT_DEVICE=%r, falling back to CPU", device)
|
||||
return [("CPUExecutionProvider", {})]
|
||||
return [(ep, opts), ("CPUExecutionProvider", {})]
|
||||
|
||||
|
||||
# ── 预处理:渲染几何与 letterbox ────────────────────────────────────────
|
||||
|
||||
|
||||
def _compute_render_geometry(page_w: float, page_h: float, imgsz: int) -> float:
|
||||
"""letterbox 渲染缩放 ratio = min(imgsz/page_w, imgsz/page_h)。
|
||||
|
||||
pymupdf 以 Matrix(ratio, ratio) 渲染,长边贴到 imgsz,短边留灰边。
|
||||
"""
|
||||
return min(imgsz / page_w, imgsz / page_h)
|
||||
|
||||
|
||||
def _letterbox_padding(
|
||||
content_w: float, content_h: float, imgsz: int
|
||||
) -> tuple[float, float]:
|
||||
"""居中 padding:(imgsz - content) / 2。content 为实际 pixmap 尺寸(已取整)。"""
|
||||
return (imgsz - content_w) / 2.0, (imgsz - content_h) / 2.0
|
||||
|
||||
|
||||
def _padded_nchw_from_pixmap(
|
||||
pix: pymupdf.Pixmap, imgsz: int, dw: float, dh: float
|
||||
) -> np.ndarray:
|
||||
"""pixmap → letterbox padded (1, 3, imgsz, imgsz) float32,灰边=114,/255 归一化。"""
|
||||
arr = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
|
||||
pix.height, pix.width, pix.n
|
||||
)
|
||||
if arr.shape[2] == 4: # 去 alpha(csRGB alpha=False 一般不会,防御性)
|
||||
arr = arr[:, :, :3]
|
||||
|
||||
canvas = np.full((imgsz, imgsz, 3), _PAD_VALUE, dtype=np.uint8)
|
||||
top = int(round(dh))
|
||||
left = int(round(dw))
|
||||
canvas[top : top + pix.height, left : left + pix.width] = arr
|
||||
|
||||
out = canvas.astype(np.float32) / 255.0
|
||||
return out.transpose(2, 0, 1)[np.newaxis] # (1, 3, imgsz, imgsz)
|
||||
|
||||
|
||||
def _model_to_pdf(
|
||||
model_x: float, model_y: float, dw: float, dh: float, ratio: float
|
||||
) -> tuple[float, float]:
|
||||
"""模型 imgsz 空间坐标 → PDF 点:(model - padding) / ratio。"""
|
||||
return (model_x - dw) / ratio, (model_y - dh) / ratio
|
||||
|
||||
|
||||
# ── 后处理 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _postprocess_output(
|
||||
output: np.ndarray, threshold: float, names: dict[int, str]
|
||||
) -> list[tuple[int, float, float, float, float]]:
|
||||
"""解析 YOLOv10 end-to-end 输出,过滤 conf < threshold。
|
||||
|
||||
Args:
|
||||
output: session.run 返回的第一个输出,shape [1, N, 6]
|
||||
threshold: 置信度阈值
|
||||
names: class id → name(仅用于日志,过滤不依赖)
|
||||
|
||||
Returns:
|
||||
[(cls_id, x1, y1, x2, y2), ...],坐标为模型 imgsz padded 空间。
|
||||
"""
|
||||
out = output[0] # 去 batch 维
|
||||
if out.ndim != 2 or out.shape[1] != 6:
|
||||
logger.warning(
|
||||
"Unexpected DocLayout-YOLO output shape %s (expected [N,6]); skip page",
|
||||
tuple(out.shape),
|
||||
)
|
||||
return []
|
||||
|
||||
results: list[tuple[int, float, float, float, float]] = []
|
||||
for row in out:
|
||||
x1, y1, x2, y2, conf, cls = row.tolist()
|
||||
if conf < threshold:
|
||||
continue
|
||||
results.append((int(cls), x1, y1, x2, y2))
|
||||
return results
|
||||
|
||||
|
||||
def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None:
|
||||
"""按 class name 匹配 figure→picture / table→table,其余返回 None。"""
|
||||
name = names.get(cls_id, "")
|
||||
n = name.strip().lower()
|
||||
if n in _PICTURE_NAMES:
|
||||
return "picture"
|
||||
if n in _TABLE_NAMES:
|
||||
return "table"
|
||||
return None
|
||||
|
||||
|
||||
def _parse_names_from_meta(session: ort.InferenceSession) -> dict[int, str]:
|
||||
"""从 ONNX metadata 读 names(ultralytics 导出写入的 JSON),读不到用兜底。"""
|
||||
raw = None
|
||||
try:
|
||||
raw = session.get_modelmeta().custom_metadata_map.get("names")
|
||||
except Exception:
|
||||
raw = None
|
||||
if raw:
|
||||
try:
|
||||
d = json.loads(raw)
|
||||
return {int(k): str(v) for k, v in d.items()}
|
||||
except Exception:
|
||||
logger.warning("Failed to parse ONNX names metadata; using fallback")
|
||||
return dict(_FALLBACK_NAMES)
|
||||
|
||||
|
||||
# ── 检测器单例 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _LayoutDetector:
|
||||
@@ -57,6 +225,9 @@ class _LayoutDetector:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._session: ort.InferenceSession | None = None
|
||||
self._names: dict[int, str] = {}
|
||||
self._input_name: str = ""
|
||||
self._imgsz: int = settings.LAYOUT_IMGSZ
|
||||
|
||||
def _init_session(self) -> ort.InferenceSession:
|
||||
if self._session is not None:
|
||||
@@ -66,97 +237,95 @@ class _LayoutDetector:
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Layout model not found: {model_path}. "
|
||||
"Run scripts/export_picodet_onnx.py first."
|
||||
"Run scripts/export_doclayout_yolo_onnx.py first."
|
||||
)
|
||||
|
||||
logger.info("Loading ONNX layout model: %s", model_path)
|
||||
self._session = ort.InferenceSession(
|
||||
str(model_path), providers=["CPUExecutionProvider"]
|
||||
eps = resolve_providers(settings.LAYOUT_DEVICE, settings.LAYOUT_DEVICE_ID)
|
||||
logger.info(
|
||||
"Loading layout model %s, candidate providers: %s",
|
||||
model_path,
|
||||
[ep[0] for ep in eps],
|
||||
)
|
||||
logger.info("ONNX layout model loaded")
|
||||
|
||||
# 逐个 EP 尝试,首个不可用则降级
|
||||
last_err: Exception | None = None
|
||||
for idx, (ep_name, ep_opts) in enumerate(eps):
|
||||
try:
|
||||
self._session = ort.InferenceSession(
|
||||
str(model_path), providers=[(ep_name, ep_opts)]
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
if idx < len(eps) - 1:
|
||||
logger.warning(
|
||||
"Provider %s unavailable (%s); falling back to %s",
|
||||
ep_name,
|
||||
e,
|
||||
eps[idx + 1][0],
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Failed to create layout session: {last_err}")
|
||||
|
||||
logger.info(
|
||||
"Layout session active providers: %s", self._session.get_providers()
|
||||
)
|
||||
self._input_name = self._session.get_inputs()[0].name
|
||||
self._names = _parse_names_from_meta(self._session)
|
||||
self._imgsz = settings.LAYOUT_IMGSZ
|
||||
return self._session
|
||||
|
||||
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||||
"""检测单页 PDF 的 figure / table 区域。
|
||||
|
||||
流程:
|
||||
1. pymupdf 以 480×480 渲染页面
|
||||
2. ImageNet normalize → NCHW
|
||||
3. ONNX 推理 → 得到已解码+NMS 的检测框
|
||||
4. 像素坐标 → PDF 点坐标
|
||||
5. 过滤 seal 类和低置信度框
|
||||
|
||||
Args:
|
||||
page: pymupdf Page 对象
|
||||
1. letterbox 渲染:保比例缩放到长边=imgsz,短边留灰边
|
||||
2. /255 + NCHW → ONNX 推理
|
||||
3. YOLOv10 end-to-end 后处理(已 NMS)
|
||||
4. 模型坐标 → PDF 点
|
||||
5. 过滤非 figure/table 类、极小框、越界 clip
|
||||
|
||||
Returns:
|
||||
LayoutBox 列表,坐标为 PDF 点
|
||||
LayoutBox 列表,坐标为 PDF 点。
|
||||
"""
|
||||
session = self._init_session()
|
||||
|
||||
page_w = page.rect.width
|
||||
page_h = page.rect.height
|
||||
ratio = _compute_render_geometry(page_w, page_h, self._imgsz)
|
||||
|
||||
# 1. 渲染页面到 _MODEL_SIZE × _MODEL_SIZE
|
||||
zoom_x = _MODEL_SIZE / page_w
|
||||
zoom_y = _MODEL_SIZE / page_h
|
||||
mat = pymupdf.Matrix(zoom_x, zoom_y)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
# 2. 预处理
|
||||
img = (
|
||||
np.frombuffer(pix.samples, dtype=np.uint8)
|
||||
.reshape(pix.height, pix.width, pix.n)
|
||||
.astype(np.float32)
|
||||
/ 255.0
|
||||
# 1. 保比例渲染(长边贴 imgsz)
|
||||
pix = page.get_pixmap(
|
||||
matrix=pymupdf.Matrix(ratio, ratio),
|
||||
colorspace=pymupdf.csRGB,
|
||||
alpha=False,
|
||||
)
|
||||
# 去掉 alpha 通道(如有)
|
||||
if img.shape[2] == 4:
|
||||
img = img[:, :, :3]
|
||||
img = (img - _MEAN) / _STD
|
||||
img = img.transpose(2, 0, 1)[np.newaxis] # (1, 3, H, W)
|
||||
# 用 pixmap 实际尺寸(已取整)算 padding,消除取整导致的坐标偏移
|
||||
dw, dh = _letterbox_padding(pix.width, pix.height, self._imgsz)
|
||||
tensor = _padded_nchw_from_pixmap(pix, self._imgsz, dw, dh)
|
||||
|
||||
# scale_factor 用于坐标还原(模型内部可能用)
|
||||
scale_factor = np.array([[1.0, 1.0]], dtype=np.float32)
|
||||
|
||||
# 3. 推理
|
||||
input_names = [i.name for i in session.get_inputs()]
|
||||
feed = {input_names[0]: img}
|
||||
if len(input_names) > 1:
|
||||
feed[input_names[1]] = scale_factor
|
||||
|
||||
outputs = session.run(None, feed)
|
||||
boxes_raw = outputs[0] # (N, 6): [class_id, score, xmin, ymin, xmax, ymax]
|
||||
num_boxes = int(outputs[1][0]) # 有效框数
|
||||
|
||||
if num_boxes == 0:
|
||||
return []
|
||||
|
||||
# 4. 像素 → PDF 点坐标
|
||||
sx = page_w / _MODEL_SIZE
|
||||
sy = page_h / _MODEL_SIZE
|
||||
# 2. 推理
|
||||
outputs = session.run(None, {self._input_name: tensor})
|
||||
detections = _postprocess_output(
|
||||
outputs[0], settings.LAYOUT_THRESHOLD, self._names
|
||||
)
|
||||
|
||||
# 3. 坐标还原 + 过滤
|
||||
result: list[LayoutBox] = []
|
||||
for i in range(min(num_boxes, len(boxes_raw))):
|
||||
cls_id, score, xmin, ymin, xmax, ymax = boxes_raw[i]
|
||||
cls_id = int(cls_id)
|
||||
|
||||
# 跳过 seal 类和低置信度
|
||||
if cls_id not in _LABEL_MAP:
|
||||
for cls_id, x1m, y1m, x2m, y2m in detections:
|
||||
boxclass = _map_class_to_boxclass(cls_id, self._names)
|
||||
if boxclass is None:
|
||||
continue
|
||||
if score < settings.LAYOUT_THRESHOLD:
|
||||
continue
|
||||
|
||||
x0, y0 = xmin * sx, ymin * sy
|
||||
x1, y1 = xmax * sx, ymax * sy
|
||||
|
||||
# 跳过极小区域
|
||||
x0, y0 = _model_to_pdf(x1m, y1m, dw, dh, ratio)
|
||||
x1, y1 = _model_to_pdf(x2m, y2m, dw, dh, ratio)
|
||||
# clip 到页面范围
|
||||
x0 = max(0.0, min(x0, page_w))
|
||||
y0 = max(0.0, min(y0, page_h))
|
||||
x1 = max(0.0, min(x1, page_w))
|
||||
y1 = max(0.0, min(y1, page_h))
|
||||
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
|
||||
continue
|
||||
|
||||
result.append(
|
||||
LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=_LABEL_MAP[cls_id])
|
||||
)
|
||||
result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -71,7 +71,9 @@ async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
|
||||
|
||||
try:
|
||||
session = _get_session()
|
||||
resp = session.get(pdf_url, timeout=settings.PDF_DOWNLOAD_TIMEOUT, allow_redirects=True)
|
||||
resp = session.get(
|
||||
pdf_url, timeout=settings.PDF_DOWNLOAD_TIMEOUT, allow_redirects=True
|
||||
)
|
||||
resp.raise_for_status()
|
||||
dest.write_bytes(resp.content)
|
||||
except Exception as exc:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""PDF 图片与表格提取 — 两阶段流水线。
|
||||
|
||||
Phase 1: PicoDet-S_layout_3cls 检测 figure/table 区域 → 渲染为 JPEG(通用标签)
|
||||
Phase 1: DocLayout-YOLO 检测 figure/table 区域 → 渲染为 JPEG(通用标签)
|
||||
Phase 2: 用 LLM summary 的 figures[].id 在 PDF 中搜索定位 → 匹配到 box → 重命名
|
||||
|
||||
相比旧方案(正则匹配 caption):
|
||||
|
||||
@@ -80,13 +80,17 @@ async def call_pi(
|
||||
actual_mode = "search"
|
||||
logger.info(
|
||||
"Auto mode: %s text=%d chars > %dk → search",
|
||||
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
|
||||
arxiv_id,
|
||||
txt_size,
|
||||
_PDF_MAX_CHARS // 1000,
|
||||
)
|
||||
else:
|
||||
actual_mode = "inject"
|
||||
logger.info(
|
||||
"Auto mode: %s text=%d chars ≤ %dk → inject",
|
||||
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
|
||||
arxiv_id,
|
||||
txt_size,
|
||||
_PDF_MAX_CHARS // 1000,
|
||||
)
|
||||
|
||||
# inject 模式需要截断过长的文本(避免撑爆 context)
|
||||
|
||||
@@ -225,7 +225,7 @@ def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
|
||||
"""从 PDF 提取图片和表格(失败不影响总结)。
|
||||
|
||||
两阶段流水线:
|
||||
1. PicoDet 检测 + 渲染截图(通用标签)
|
||||
1. DocLayout-YOLO 检测 + 渲染截图(通用标签)
|
||||
2. 用 summary 的 figures ID 在 PDF 中搜索定位 → 重命名
|
||||
"""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user