161 lines
5.2 KiB
Python
161 lines
5.2 KiB
Python
"""pi CLI 调用与 JSON 提取 — 调用 pi 生成总结,从输出中提取结构化 JSON。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import re
|
|
from pathlib import Path
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ── 自定义异常 ──────────────────────────────────────────────────────────
|
|
|
|
|
|
class PiTimeoutError(Exception):
|
|
pass
|
|
|
|
|
|
class PiProcessError(Exception):
|
|
def __init__(self, returncode: int, stderr: str):
|
|
self.returncode = returncode
|
|
self.stderr = stderr
|
|
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
|
|
|
|
|
|
class JsonNotFoundError(Exception):
|
|
pass
|
|
|
|
|
|
# ── meta.json ───────────────────────────────────────────────────────────
|
|
|
|
|
|
def write_meta_json(paper) -> Path:
|
|
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
|
|
from app.services.pdf_downloader import paper_dir
|
|
|
|
d = paper_dir(paper.arxiv_id)
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
meta_path = d / "meta.json"
|
|
|
|
authors = [a.name for a in paper.authors]
|
|
tags = [t.tag for t in paper.tags]
|
|
meta = {
|
|
"arxiv_id": paper.arxiv_id,
|
|
"title_en": paper.title_en,
|
|
"abstract": paper.abstract or "",
|
|
"published_at": paper.published_at.isoformat() if paper.published_at else None,
|
|
"authors": authors,
|
|
"tags": tags,
|
|
"upvotes": paper.upvotes,
|
|
}
|
|
meta_path.write_text(
|
|
json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8"
|
|
)
|
|
return meta_path
|
|
|
|
|
|
# ── pi CLI 调用 ────────────────────────────────────────────────────────
|
|
|
|
|
|
async def call_pi(meta_path: Path, pdf_path: Path) -> str:
|
|
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
|
|
arxiv_id = meta_path.parent.name
|
|
cmd = [
|
|
settings.PI_BIN,
|
|
"-p",
|
|
"--no-tools",
|
|
"--skill",
|
|
settings.SUMMARY_SKILL,
|
|
"请深度解读以下论文,并按指定 JSON schema 输出:",
|
|
f"@{meta_path}",
|
|
f"@{pdf_path}",
|
|
]
|
|
logger.info("Calling pi for %s", arxiv_id)
|
|
|
|
proc = await asyncio.create_subprocess_exec(
|
|
*cmd,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
try:
|
|
stdout, stderr = await asyncio.wait_for(
|
|
proc.communicate(),
|
|
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
proc.kill()
|
|
await proc.wait()
|
|
raise PiTimeoutError(f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s")
|
|
|
|
if proc.returncode != 0:
|
|
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
|
|
|
|
return stdout.decode("utf-8", errors="replace")
|
|
|
|
|
|
# ── JSON 提取 ──────────────────────────────────────────────────────────
|
|
|
|
|
|
def extract_json(raw_output: str) -> dict:
|
|
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
|
|
# 策略 1:整体直接解析
|
|
stripped = raw_output.strip()
|
|
try:
|
|
result = json.loads(stripped)
|
|
if isinstance(result, dict) and "title_zh" in result:
|
|
return result
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# 策略 2:提取 ```json ... ``` 代码块
|
|
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
|
|
for match in fence_pattern.finditer(raw_output):
|
|
try:
|
|
result = json.loads(match.group(1).strip())
|
|
if isinstance(result, dict) and "title_zh" in result:
|
|
return result
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
# 策略 3:匹配包含 title_zh 的最大 {...} 块
|
|
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
|
|
for match in brace_pattern.finditer(raw_output):
|
|
try:
|
|
return json.loads(match.group(0))
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
# 更宽松:找到最大的 { ... } 平衡块
|
|
best = None
|
|
best_len = 0
|
|
for i, ch in enumerate(raw_output):
|
|
if ch != "{":
|
|
continue
|
|
depth = 0
|
|
for j in range(i, len(raw_output)):
|
|
if raw_output[j] == "{":
|
|
depth += 1
|
|
elif raw_output[j] == "}":
|
|
depth -= 1
|
|
if depth == 0:
|
|
candidate = raw_output[i : j + 1]
|
|
if len(candidate) > best_len:
|
|
try:
|
|
parsed = json.loads(candidate)
|
|
if isinstance(parsed, dict):
|
|
best = parsed
|
|
best_len = len(candidate)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
break
|
|
|
|
if best is not None:
|
|
return best
|
|
|
|
raise JsonNotFoundError("no JSON object found in pi output")
|