90fe705e8f
- 核心变更: - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024) - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备 - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式 - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls] - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table) - 新增文件: - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行) - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例) - 配置更新: - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID) - app/config.py: Settings 类对应字段 - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等) - 删除旧文件: - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本 - 文档更新: - README.md: 更新环境变量说明 - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py) 此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
180 lines
5.5 KiB
Python
180 lines
5.5 KiB
Python
"""pi CLI 后端 — 调用 pi 子进程生成总结。
|
||
|
||
通用工具函数(prompt 构建、PDF 提取、JSON 提取、meta.json)已移至 summary_utils.py。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import uuid
|
||
from pathlib import Path
|
||
|
||
from app.config import settings
|
||
from app.utils import truncate_error
|
||
from app.services.summary_utils import (
|
||
JsonNotFoundError,
|
||
build_prompt,
|
||
extract_json,
|
||
extract_pdf_text,
|
||
write_meta_json,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# PDF 全文注入模式的字符上限 — 超过此阈值自动切换到 search 模式
|
||
_PDF_MAX_CHARS = 80_000
|
||
|
||
# 重新导出,保持向后兼容
|
||
__all__ = [
|
||
"PiTimeoutError",
|
||
"PiProcessError",
|
||
"JsonNotFoundError",
|
||
"call_pi",
|
||
"write_meta_json",
|
||
"extract_pdf_text",
|
||
"build_prompt",
|
||
"extract_json",
|
||
]
|
||
|
||
|
||
# ── 自定义异常 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
class PiTimeoutError(Exception):
|
||
pass
|
||
|
||
|
||
class PiProcessError(Exception):
|
||
def __init__(self, returncode: int, stderr: str):
|
||
self.returncode = returncode
|
||
self.stderr = stderr
|
||
super().__init__(f"pi exited with code {returncode}: {truncate_error(stderr)}")
|
||
|
||
|
||
# ── pi CLI 调用 ────────────────────────────────────────────────────────
|
||
|
||
|
||
async def call_pi(
|
||
meta_path: Path,
|
||
pdf_path: Path,
|
||
fix_errors: list[str] | None = None,
|
||
session_id: str | None = None,
|
||
pdf_mode: str = "inject",
|
||
) -> tuple[str, str]:
|
||
"""调用 pi CLI 非交互模式,返回 (stdout 文本, session_id)。
|
||
|
||
fix_errors: 如果非空,表示上一次验证失败的错误列表,pi 需要修正这些问题。
|
||
session_id: 如果非空,用 --continue 延续该 session;否则创建新 session。
|
||
pdf_mode: "inject" = 全量注入 prompt(@file),"search" = pi 自主 read 文件。
|
||
"""
|
||
arxiv_id = meta_path.parent.name
|
||
|
||
# 提取 PDF 全文(不截断),根据实际大小自动选择模式
|
||
txt_path = extract_pdf_text(pdf_path, max_chars=None)
|
||
txt_size = len(txt_path.read_text(encoding="utf-8"))
|
||
|
||
actual_mode = pdf_mode
|
||
if pdf_mode == "auto":
|
||
if txt_size > _PDF_MAX_CHARS:
|
||
actual_mode = "search"
|
||
logger.info(
|
||
"Auto mode: %s text=%d chars > %dk → search",
|
||
arxiv_id,
|
||
txt_size,
|
||
_PDF_MAX_CHARS // 1000,
|
||
)
|
||
else:
|
||
actual_mode = "inject"
|
||
logger.info(
|
||
"Auto mode: %s text=%d chars ≤ %dk → inject",
|
||
arxiv_id,
|
||
txt_size,
|
||
_PDF_MAX_CHARS // 1000,
|
||
)
|
||
|
||
# inject 模式需要截断过长的文本(避免撑爆 context)
|
||
if actual_mode == "inject" and txt_size > _PDF_MAX_CHARS:
|
||
body = txt_path.read_text(encoding="utf-8")
|
||
trimmed = body[:_PDF_MAX_CHARS].rstrip()
|
||
txt_path.write_text(trimmed, encoding="utf-8")
|
||
logger.info(
|
||
"Truncated %s for inject: %d → %d chars", arxiv_id, txt_size, len(trimmed)
|
||
)
|
||
|
||
prompt_text = build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
|
||
|
||
# 构建 session ID(每篇论文一个独立 session)
|
||
if session_id is None:
|
||
session_id = f"summary-{arxiv_id}-{uuid.uuid4().hex[:8]}"
|
||
|
||
# 工具列表:search 模式需要 read 工具
|
||
tools = "bash,write_file" if actual_mode != "search" else "bash,write_file,read"
|
||
cmd = [
|
||
settings.PI_BIN,
|
||
"-p",
|
||
"--tools",
|
||
tools,
|
||
]
|
||
if fix_errors:
|
||
cmd += ["--session", session_id, "--continue"]
|
||
else:
|
||
cmd += ["--session-id", session_id]
|
||
cmd += [
|
||
"--skill",
|
||
settings.SUMMARY_SKILL,
|
||
prompt_text,
|
||
]
|
||
if not fix_errors and actual_mode != "search":
|
||
# inject 模式:首次调用传 @file;search 模式 pi 自己 read,不注入
|
||
cmd += [f"@{meta_path}", f"@{txt_path}"]
|
||
|
||
logger.info(
|
||
"Calling pi for %s (fix=%s, session=%s, mode=%s)",
|
||
arxiv_id,
|
||
bool(fix_errors),
|
||
session_id,
|
||
actual_mode,
|
||
)
|
||
|
||
import time as _time
|
||
|
||
_t_sub_start = _time.monotonic()
|
||
|
||
proc = await asyncio.create_subprocess_exec(
|
||
*cmd,
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
try:
|
||
stdout, stderr = await asyncio.wait_for(
|
||
proc.communicate(),
|
||
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
|
||
)
|
||
except asyncio.TimeoutError:
|
||
proc.kill()
|
||
await proc.wait()
|
||
raise PiTimeoutError(f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s")
|
||
|
||
_t_sub_end = _time.monotonic()
|
||
|
||
# 检查 summary.json 是否由 pi 子进程写入
|
||
_summary_file = pdf_path.parent / "summary.json"
|
||
_file_info = ""
|
||
if _summary_file.exists():
|
||
_file_mtime = _summary_file.stat().st_mtime
|
||
_file_size = _summary_file.stat().st_size
|
||
_file_info = f" summary.json={_file_size}B"
|
||
|
||
logger.info(
|
||
"pi subprocess for %s: %.2fs%s",
|
||
arxiv_id,
|
||
_t_sub_end - _t_sub_start,
|
||
_file_info,
|
||
)
|
||
|
||
if proc.returncode != 0:
|
||
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
|
||
|
||
return stdout.decode("utf-8", errors="replace"), session_id
|