Files
daily-paper/app/services/pi_client.py
T
Rain-Bus 90fe705e8f refactor: 迁移布局检测模型从 PicoDet 到 DocLayout-YOLO
- 核心变更:
  - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024)
  - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备
  - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式
  - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls]
  - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table)

- 新增文件:
  - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行)
  - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例)

- 配置更新:
  - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID)
  - app/config.py: Settings 类对应字段
  - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等)

- 删除旧文件:
  - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本

- 文档更新:
  - README.md: 更新环境变量说明
  - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py)

此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
2026-06-14 10:41:44 +08:00

180 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""pi CLI 后端 — 调用 pi 子进程生成总结。
通用工具函数(prompt 构建、PDF 提取、JSON 提取、meta.json)已移至 summary_utils.py。
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from pathlib import Path
from app.config import settings
from app.utils import truncate_error
from app.services.summary_utils import (
JsonNotFoundError,
build_prompt,
extract_json,
extract_pdf_text,
write_meta_json,
)
logger = logging.getLogger(__name__)
# PDF 全文注入模式的字符上限 — 超过此阈值自动切换到 search 模式
_PDF_MAX_CHARS = 80_000
# 重新导出,保持向后兼容
__all__ = [
"PiTimeoutError",
"PiProcessError",
"JsonNotFoundError",
"call_pi",
"write_meta_json",
"extract_pdf_text",
"build_prompt",
"extract_json",
]
# ── 自定义异常 ──────────────────────────────────────────────────────────
class PiTimeoutError(Exception):
pass
class PiProcessError(Exception):
def __init__(self, returncode: int, stderr: str):
self.returncode = returncode
self.stderr = stderr
super().__init__(f"pi exited with code {returncode}: {truncate_error(stderr)}")
# ── pi CLI 调用 ────────────────────────────────────────────────────────
async def call_pi(
meta_path: Path,
pdf_path: Path,
fix_errors: list[str] | None = None,
session_id: str | None = None,
pdf_mode: str = "inject",
) -> tuple[str, str]:
"""调用 pi CLI 非交互模式,返回 (stdout 文本, session_id)。
fix_errors: 如果非空,表示上一次验证失败的错误列表,pi 需要修正这些问题。
session_id: 如果非空,用 --continue 延续该 session;否则创建新 session。
pdf_mode: "inject" = 全量注入 prompt@file),"search" = pi 自主 read 文件。
"""
arxiv_id = meta_path.parent.name
# 提取 PDF 全文(不截断),根据实际大小自动选择模式
txt_path = extract_pdf_text(pdf_path, max_chars=None)
txt_size = len(txt_path.read_text(encoding="utf-8"))
actual_mode = pdf_mode
if pdf_mode == "auto":
if txt_size > _PDF_MAX_CHARS:
actual_mode = "search"
logger.info(
"Auto mode: %s text=%d chars > %dk → search",
arxiv_id,
txt_size,
_PDF_MAX_CHARS // 1000,
)
else:
actual_mode = "inject"
logger.info(
"Auto mode: %s text=%d chars ≤ %dk → inject",
arxiv_id,
txt_size,
_PDF_MAX_CHARS // 1000,
)
# inject 模式需要截断过长的文本(避免撑爆 context)
if actual_mode == "inject" and txt_size > _PDF_MAX_CHARS:
body = txt_path.read_text(encoding="utf-8")
trimmed = body[:_PDF_MAX_CHARS].rstrip()
txt_path.write_text(trimmed, encoding="utf-8")
logger.info(
"Truncated %s for inject: %d%d chars", arxiv_id, txt_size, len(trimmed)
)
prompt_text = build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
# 构建 session ID(每篇论文一个独立 session)
if session_id is None:
session_id = f"summary-{arxiv_id}-{uuid.uuid4().hex[:8]}"
# 工具列表:search 模式需要 read 工具
tools = "bash,write_file" if actual_mode != "search" else "bash,write_file,read"
cmd = [
settings.PI_BIN,
"-p",
"--tools",
tools,
]
if fix_errors:
cmd += ["--session", session_id, "--continue"]
else:
cmd += ["--session-id", session_id]
cmd += [
"--skill",
settings.SUMMARY_SKILL,
prompt_text,
]
if not fix_errors and actual_mode != "search":
# inject 模式:首次调用传 @filesearch 模式 pi 自己 read,不注入
cmd += [f"@{meta_path}", f"@{txt_path}"]
logger.info(
"Calling pi for %s (fix=%s, session=%s, mode=%s)",
arxiv_id,
bool(fix_errors),
session_id,
actual_mode,
)
import time as _time
_t_sub_start = _time.monotonic()
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(),
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise PiTimeoutError(f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s")
_t_sub_end = _time.monotonic()
# 检查 summary.json 是否由 pi 子进程写入
_summary_file = pdf_path.parent / "summary.json"
_file_info = ""
if _summary_file.exists():
_file_mtime = _summary_file.stat().st_mtime
_file_size = _summary_file.stat().st_size
_file_info = f" summary.json={_file_size}B"
logger.info(
"pi subprocess for %s: %.2fs%s",
arxiv_id,
_t_sub_end - _t_sub_start,
_file_info,
)
if proc.returncode != 0:
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
return stdout.decode("utf-8", errors="replace"), session_id