Files
daily-paper/app/services/pi_client.py
T

165 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""pi CLI 后端 — 调用 pi 子进程生成总结。
通用工具函数(prompt 构建、PDF 提取、JSON 提取、meta.json)已移至 summary_utils.py。
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from pathlib import Path
from app.config import settings
from app.utils import truncate_error
from app.services.summary_utils import (
build_prompt,
extract_pdf_text,
)
logger = logging.getLogger(__name__)
# PDF 全文注入模式的字符上限 — 超过此阈值自动切换到 search 模式
_PDF_MAX_CHARS = 80_000
# ── 自定义异常 ──────────────────────────────────────────────────────────
class PiTimeoutError(Exception):
pass
class PiProcessError(Exception):
def __init__(self, returncode: int, stderr: str):
self.returncode = returncode
self.stderr = stderr
super().__init__(f"pi exited with code {returncode}: {truncate_error(stderr)}")
# ── pi CLI 调用 ────────────────────────────────────────────────────────
async def call_pi(
meta_path: Path,
pdf_path: Path,
fix_errors: list[str] | None = None,
session_id: str | None = None,
pdf_mode: str = "inject",
) -> tuple[str, str]:
"""调用 pi CLI 非交互模式,返回 (stdout 文本, session_id)。
fix_errors: 如果非空,表示上一次验证失败的错误列表,pi 需要修正这些问题。
session_id: 如果非空,用 --continue 延续该 session;否则创建新 session。
pdf_mode: "inject" = 全量注入 prompt@file),"search" = pi 自主 read 文件。
"""
arxiv_id = meta_path.parent.name
# 提取 PDF 全文(不截断),根据实际大小自动选择模式
txt_path = extract_pdf_text(pdf_path, max_chars=None)
txt_size = len(txt_path.read_text(encoding="utf-8"))
actual_mode = pdf_mode
if pdf_mode == "auto":
if txt_size > _PDF_MAX_CHARS:
actual_mode = "search"
logger.info(
"Auto mode: %s text=%d chars > %dk → search",
arxiv_id,
txt_size,
_PDF_MAX_CHARS // 1000,
)
else:
actual_mode = "inject"
logger.info(
"Auto mode: %s text=%d chars ≤ %dk → inject",
arxiv_id,
txt_size,
_PDF_MAX_CHARS // 1000,
)
# inject 模式需要截断过长的文本(避免撑爆 context)
if actual_mode == "inject" and txt_size > _PDF_MAX_CHARS:
body = txt_path.read_text(encoding="utf-8")
trimmed = body[:_PDF_MAX_CHARS].rstrip()
txt_path.write_text(trimmed, encoding="utf-8")
logger.info(
"Truncated %s for inject: %d%d chars", arxiv_id, txt_size, len(trimmed)
)
prompt_text = build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
# 构建 session ID(每篇论文一个独立 session)
if session_id is None:
session_id = f"summary-{arxiv_id}-{uuid.uuid4().hex[:8]}"
# 工具列表:search 模式需要 read 工具
tools = "bash,write_file" if actual_mode != "search" else "bash,write_file,read"
cmd = [
settings.PI_BIN,
"-p",
"--tools",
tools,
]
if fix_errors:
cmd += ["--session", session_id, "--continue"]
else:
cmd += ["--session-id", session_id]
cmd += [
"--skill",
settings.SUMMARY_SKILL,
prompt_text,
]
if not fix_errors and actual_mode != "search":
# inject 模式:首次调用传 @filesearch 模式 pi 自己 read,不注入
cmd += [f"@{meta_path}", f"@{txt_path}"]
logger.info(
"Calling pi for %s (fix=%s, session=%s, mode=%s)",
arxiv_id,
bool(fix_errors),
session_id,
actual_mode,
)
import time as _time
_t_sub_start = _time.monotonic()
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(),
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise PiTimeoutError(f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s")
_t_sub_end = _time.monotonic()
# 检查 summary.json 是否由 pi 子进程写入
_summary_file = pdf_path.parent / "summary.json"
_file_info = ""
if _summary_file.exists():
_file_mtime = _summary_file.stat().st_mtime
_file_size = _summary_file.stat().st_size
_file_info = f" summary.json={_file_size}B"
logger.info(
"pi subprocess for %s: %.2fs%s",
arxiv_id,
_t_sub_end - _t_sub_start,
_file_info,
)
if proc.returncode != 0:
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
return stdout.decode("utf-8", errors="replace"), session_id