165 lines
5.2 KiB
Python
165 lines
5.2 KiB
Python
"""pi CLI 后端 — 调用 pi 子进程生成总结。
|
||
|
||
通用工具函数(prompt 构建、PDF 提取、JSON 提取、meta.json)已移至 summary_utils.py。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import uuid
|
||
from pathlib import Path
|
||
|
||
from app.config import settings
|
||
from app.utils import truncate_error
|
||
from app.services.summary_utils import (
|
||
build_prompt,
|
||
extract_pdf_text,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# PDF 全文注入模式的字符上限 — 超过此阈值自动切换到 search 模式
|
||
_PDF_MAX_CHARS = 80_000
|
||
|
||
|
||
# ── 自定义异常 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
class PiTimeoutError(Exception):
|
||
pass
|
||
|
||
|
||
class PiProcessError(Exception):
|
||
def __init__(self, returncode: int, stderr: str):
|
||
self.returncode = returncode
|
||
self.stderr = stderr
|
||
super().__init__(f"pi exited with code {returncode}: {truncate_error(stderr)}")
|
||
|
||
|
||
# ── pi CLI 调用 ────────────────────────────────────────────────────────
|
||
|
||
|
||
async def call_pi(
|
||
meta_path: Path,
|
||
pdf_path: Path,
|
||
fix_errors: list[str] | None = None,
|
||
session_id: str | None = None,
|
||
pdf_mode: str = "inject",
|
||
) -> tuple[str, str]:
|
||
"""调用 pi CLI 非交互模式,返回 (stdout 文本, session_id)。
|
||
|
||
fix_errors: 如果非空,表示上一次验证失败的错误列表,pi 需要修正这些问题。
|
||
session_id: 如果非空,用 --continue 延续该 session;否则创建新 session。
|
||
pdf_mode: "inject" = 全量注入 prompt(@file),"search" = pi 自主 read 文件。
|
||
"""
|
||
arxiv_id = meta_path.parent.name
|
||
|
||
# 提取 PDF 全文(不截断),根据实际大小自动选择模式
|
||
txt_path = extract_pdf_text(pdf_path, max_chars=None)
|
||
txt_size = len(txt_path.read_text(encoding="utf-8"))
|
||
|
||
actual_mode = pdf_mode
|
||
if pdf_mode == "auto":
|
||
if txt_size > _PDF_MAX_CHARS:
|
||
actual_mode = "search"
|
||
logger.info(
|
||
"Auto mode: %s text=%d chars > %dk → search",
|
||
arxiv_id,
|
||
txt_size,
|
||
_PDF_MAX_CHARS // 1000,
|
||
)
|
||
else:
|
||
actual_mode = "inject"
|
||
logger.info(
|
||
"Auto mode: %s text=%d chars ≤ %dk → inject",
|
||
arxiv_id,
|
||
txt_size,
|
||
_PDF_MAX_CHARS // 1000,
|
||
)
|
||
|
||
# inject 模式需要截断过长的文本(避免撑爆 context)
|
||
if actual_mode == "inject" and txt_size > _PDF_MAX_CHARS:
|
||
body = txt_path.read_text(encoding="utf-8")
|
||
trimmed = body[:_PDF_MAX_CHARS].rstrip()
|
||
txt_path.write_text(trimmed, encoding="utf-8")
|
||
logger.info(
|
||
"Truncated %s for inject: %d → %d chars", arxiv_id, txt_size, len(trimmed)
|
||
)
|
||
|
||
prompt_text = build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
|
||
|
||
# 构建 session ID(每篇论文一个独立 session)
|
||
if session_id is None:
|
||
session_id = f"summary-{arxiv_id}-{uuid.uuid4().hex[:8]}"
|
||
|
||
# 工具列表:search 模式需要 read 工具
|
||
tools = "bash,write_file" if actual_mode != "search" else "bash,write_file,read"
|
||
cmd = [
|
||
settings.PI_BIN,
|
||
"-p",
|
||
"--tools",
|
||
tools,
|
||
]
|
||
if fix_errors:
|
||
cmd += ["--session", session_id, "--continue"]
|
||
else:
|
||
cmd += ["--session-id", session_id]
|
||
cmd += [
|
||
"--skill",
|
||
settings.SUMMARY_SKILL,
|
||
prompt_text,
|
||
]
|
||
if not fix_errors and actual_mode != "search":
|
||
# inject 模式:首次调用传 @file;search 模式 pi 自己 read,不注入
|
||
cmd += [f"@{meta_path}", f"@{txt_path}"]
|
||
|
||
logger.info(
|
||
"Calling pi for %s (fix=%s, session=%s, mode=%s)",
|
||
arxiv_id,
|
||
bool(fix_errors),
|
||
session_id,
|
||
actual_mode,
|
||
)
|
||
|
||
import time as _time
|
||
|
||
_t_sub_start = _time.monotonic()
|
||
|
||
proc = await asyncio.create_subprocess_exec(
|
||
*cmd,
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
try:
|
||
stdout, stderr = await asyncio.wait_for(
|
||
proc.communicate(),
|
||
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
|
||
)
|
||
except asyncio.TimeoutError:
|
||
proc.kill()
|
||
await proc.wait()
|
||
raise PiTimeoutError(f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s")
|
||
|
||
_t_sub_end = _time.monotonic()
|
||
|
||
# 检查 summary.json 是否由 pi 子进程写入
|
||
_summary_file = pdf_path.parent / "summary.json"
|
||
_file_info = ""
|
||
if _summary_file.exists():
|
||
_file_mtime = _summary_file.stat().st_mtime
|
||
_file_size = _summary_file.stat().st_size
|
||
_file_info = f" summary.json={_file_size}B"
|
||
|
||
logger.info(
|
||
"pi subprocess for %s: %.2fs%s",
|
||
arxiv_id,
|
||
_t_sub_end - _t_sub_start,
|
||
_file_info,
|
||
)
|
||
|
||
if proc.returncode != 0:
|
||
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
|
||
|
||
return stdout.decode("utf-8", errors="replace"), session_id
|