"""pi CLI 调用与 JSON 提取 — 调用 pi 生成总结,从输出中提取结构化 JSON。""" from __future__ import annotations import asyncio import json import logging import re from pathlib import Path from app.config import settings logger = logging.getLogger(__name__) # ── 自定义异常 ────────────────────────────────────────────────────────── class PiTimeoutError(Exception): pass class PiProcessError(Exception): def __init__(self, returncode: int, stderr: str): self.returncode = returncode self.stderr = stderr super().__init__(f"pi exited with code {returncode}: {stderr[:500]}") class JsonNotFoundError(Exception): pass # ── meta.json ─────────────────────────────────────────────────────────── def write_meta_json(paper) -> Path: """写入 data/papers/{arxiv_id}/meta.json,返回路径。""" from app.services.pdf_downloader import paper_dir d = paper_dir(paper.arxiv_id) d.mkdir(parents=True, exist_ok=True) meta_path = d / "meta.json" authors = [a.name for a in paper.authors] tags = [t.tag for t in paper.tags] meta = { "arxiv_id": paper.arxiv_id, "title_en": paper.title_en, "abstract": paper.abstract or "", "published_at": paper.published_at.isoformat() if paper.published_at else None, "authors": authors, "tags": tags, "upvotes": paper.upvotes, } meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8") return meta_path # ── pi CLI 调用 ──────────────────────────────────────────────────────── async def call_pi(meta_path: Path, pdf_path: Path) -> str: """调用 pi CLI 非交互模式,返回 stdout 文本。""" arxiv_id = meta_path.parent.name cmd = [ settings.PI_BIN, "-p", "--no-tools", "--skill", settings.SUMMARY_SKILL, "请深度解读以下论文,并按指定 JSON schema 输出:", f"@{meta_path}", f"@{pdf_path}", ] logger.info("Calling pi for %s", arxiv_id) proc = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) try: stdout, stderr = await asyncio.wait_for( proc.communicate(), timeout=settings.SUMMARY_TIMEOUT_SECONDS, ) except asyncio.TimeoutError: proc.kill() await proc.wait() raise PiTimeoutError( f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s" ) if proc.returncode != 0: raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace")) return stdout.decode("utf-8", errors="replace") # ── JSON 提取 ────────────────────────────────────────────────────────── def extract_json(raw_output: str) -> dict: """从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。""" # 策略 1:整体直接解析 stripped = raw_output.strip() try: result = json.loads(stripped) if isinstance(result, dict) and "title_zh" in result: return result except json.JSONDecodeError: pass # 策略 2:提取 ```json ... ``` 代码块 fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL) for match in fence_pattern.finditer(raw_output): try: result = json.loads(match.group(1).strip()) if isinstance(result, dict) and "title_zh" in result: return result except json.JSONDecodeError: continue # 策略 3:匹配包含 title_zh 的最大 {...} 块 brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL) for match in brace_pattern.finditer(raw_output): try: return json.loads(match.group(0)) except json.JSONDecodeError: continue # 更宽松:找到最大的 { ... } 平衡块 best = None best_len = 0 for i, ch in enumerate(raw_output): if ch != "{": continue depth = 0 for j in range(i, len(raw_output)): if raw_output[j] == "{": depth += 1 elif raw_output[j] == "}": depth -= 1 if depth == 0: candidate = raw_output[i : j + 1] if len(candidate) > best_len: try: parsed = json.loads(candidate) if isinstance(parsed, dict): best = parsed best_len = len(candidate) except json.JSONDecodeError: pass break if best is not None: return best raise JsonNotFoundError("no JSON object found in pi output")