Files
daily-paper/app/services/pi_client.py
T

380 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""pi CLI 调用与 JSON 提取 — 调用 pi 生成总结,从输出中提取结构化 JSON。"""
from __future__ import annotations
import asyncio
import json
import logging
import re
from pathlib import Path
from app.config import settings
logger = logging.getLogger(__name__)
# ── 自定义异常 ──────────────────────────────────────────────────────────
class PiTimeoutError(Exception):
pass
class PiProcessError(Exception):
def __init__(self, returncode: int, stderr: str):
self.returncode = returncode
self.stderr = stderr
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
class JsonNotFoundError(Exception):
pass
# ── meta.json ───────────────────────────────────────────────────────────
def write_meta_json(paper) -> Path:
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
from app.services.pdf_downloader import paper_dir
d = paper_dir(paper.arxiv_id)
d.mkdir(parents=True, exist_ok=True)
meta_path = d / "meta.json"
authors = [a.name for a in paper.authors]
tags = [t.tag for t in paper.tags]
meta = {
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"abstract": paper.abstract or "",
"published_at": paper.published_at.isoformat() if paper.published_at else None,
"authors": authors,
"tags": tags,
"upvotes": paper.upvotes,
}
meta_path.write_text(
json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8"
)
return meta_path
# ── PDF 文本提取 ────────────────────────────────────────────────────────
def _trim_body(text: str, max_chars: int | None = None) -> str:
"""去除参考文献,保留正文+附录,超长时从末尾截断。
策略:
1. 去掉 References/Bibliography 段落(纯引用列表,对解读无用)
2. 正文 + 附录全部保留
3. 如果指定了 max_chars 且总长超过,从末尾截断(附录靠后,优先保留正文)
"""
import re
# 找 References 段落的位置(在 Appendix 之后的那个)
# 简单策略:找到 References 标题,如果后面没有 Appendix 就全删
# 如果后面还有 Appendix,只删 References 到 Appendix 之间的内容
ref_match = re.search(r"(?m)^(?:References|Bibliography|参考文献)\s*$", text)
if ref_match:
ref_start = ref_match.start()
# 看 References 之后有没有 Appendix
after_ref = text[ref_start:]
app_match = re.search(
r"(?m)^(?:A\s+(?:Appendix|Supplementary)|Appendix|附录)\s*$", after_ref
)
if app_match:
# References 之后有 Appendix:只删 References 段
ref_end = ref_start + app_match.start()
text = text[:ref_start] + text[ref_end:]
else:
# References 之后没有 Appendix:删掉从 References 到结尾
text = text[:ref_start].rstrip()
# 去掉 Acknowledgments(对解读无用)
ack_match = re.search(r"(?m)^(?:Acknowledgments?\s*|致谢\s*)$", text)
if ack_match:
# 只删 Acknowledgments 本身,不删后面的内容
next_section = re.search(r"(?m)^(?:A\s|Appendix|Supplementary|附录)\s*$", text[ack_match.start():])
if next_section:
text = text[:ack_match.start()] + text[ack_match.start() + next_section.start():]
else:
text = text[:ack_match.start()].rstrip()
# 最后:如果指定了上限且超长,从末尾截断(附录在后面,正文在前面,优先保留正文)
if max_chars is not None and len(text) > max_chars:
text = text[:max_chars].rstrip()
return text
def extract_pdf_text(pdf_path: Path, max_chars: int | None = None) -> Path:
"""用 pymupdf 提取 PDF 正文文本,保存为 .txt。
max_chars=None 时不截断,给 search/auto 模式保留完整内容。
"""
import pymupdf
txt_path = pdf_path.with_suffix(".txt")
if txt_path.exists():
# 缓存优先;如果需重新提取(不同 max_chars),先删旧文件
return txt_path
doc = pymupdf.open(str(pdf_path))
raw_text = "\n\n".join(page.get_text() for page in doc)
doc.close()
body = _trim_body(raw_text, max_chars=max_chars)
txt_path.write_text(body, encoding="utf-8")
logger.info(
"Extracted PDF text: %s (%d -> %d chars, -%d%%)",
txt_path,
len(raw_text),
len(body),
(1 - len(body) / len(raw_text)) * 100 if raw_text else 0,
)
return txt_path
# ── Prompt 构建 ─────────────────────────────────────────────────────────
def _build_prompt(
arxiv_id: str,
meta_path: Path,
txt_path: Path,
pdf_mode: str,
fix_errors: list[str] | None = None,
) -> str:
"""根据模式构建 pi prompt。
inject: 全量注入,prompt 末尾包含论文全文内容
search: pi 自主 read 文件,prompt 只包含工作流指令
"""
json_schema = (
"## 必须包含以下字段(不要自创字段名):\n"
'{"arxiv_id": "...", '
'"title_zh": "中文标题", '
'"one_line": "一句话概括(≤50字)", '
'"tags": ["标签1","标签2"], '
'"difficulty": "入门/进阶/前沿", '
'"prerequisites": {"concepts": [{"term":"术语","explanation":"详细解释这个概念是什么、怎么工作的(50-150字)","why_matters":"为什么读懂本文需要它"}]}, '
'"motivation": {"problem": "详细段落:现有方法的具体问题(包含具体场景和数据)", '
'"goal": "详细段落:本文的具体目标", '
'"gap": "详细段落:本文的独特切入角度"}, '
'"method": {"overview": "详细段落:方法整体思路(先直觉再技术路线)", '
'"key_idea": "详细段落:核心创新点(和已有方法的本质区别)", '
'"steps": "详细段落:方法步骤的完整描述(每步的输入输出和具体操作)", '
'"novelty": "详细段落:技术新颖性分析"}, '
'"results": {"main_findings": "详细段落:核心发现(带具体数字和指标,逐一分析每个实验)", '
'"benchmarks": [{"task":"任务","metric":"指标","this_work":"本文结果","baseline":"基线","improvement":"提升"}], '
'"limitations": "详细段落:局限性分析(作者承认的+你自己的观察")}, '
'"improvements": {"weaknesses": "详细段落:独立分析的弱点(具体场景,每个弱点给改进方向)", '
'"future_work": "详细段落:未来研究方向(作者提出的+基于成果可延伸的)", '
'"reproducibility": "详细段落:复现评估(开源情况、数据、算力、难度")}, '
'"figures": [{"id":"Figure 1","caption":"原图标题","description":"文字描述图展示了什么","reason":"为什么这张图对理解论文重要"},'
'{"id":"Table 1","caption":"表格标题","description":"文字描述表格包含的数据和结论","reason":"为什么这个表格对理解论文重要"}]'
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Tableid 严格使用 \"Figure N\"\"Table N\" 格式。"
"}"
)
writing_requirements = (
"## 写作要求\n"
"- 每个字符串字段必须写成详细段落(200-500字),不要用列表或数组\n"
"- 必须包含论文中的具体数据、数字、实验指标\n"
"- 像资深同事给同事讲论文一样,专业但易懂\n"
"- 数学公式、符号、变量必须使用 LaTeX 格式:行内公式用 $...$,独立公式用 $$...$$\n"
" 例如:损失函数 $\\mathcal{L} = -\\sum_{i} \\log p(y_i | x_i)$,学习率 $\\eta$\n"
)
if fix_errors:
error_list = "\n".join(f"- {e}" for e in fix_errors)
return (
"你之前生成的 JSON 存在以下问题,请修正后重新用 write_file 保存到 "
f"data/papers/{arxiv_id}/summary.json\n\n"
f"{error_list}\n\n"
"注意:所有字符串字段必须是详细段落(≥50字),不能是数组或列表。"
"修正后请用 bash 运行 python scripts/validate_summary.py 验证。"
)
if pdf_mode == "search":
return (
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。\n\n"
"## 工作流程\n"
f"1. 先用 read 工具读取 {meta_path} 了解论文元信息(标题、作者、摘要)\n"
f"2. 再用 read 工具阅读 {txt_path}(论文正文全文),可以多次读取定位关键段落\n"
f"3. 充分理解后,用 write_file 将结果保存到 data/papers/{arxiv_id}/summary.json\n\n"
+ writing_requirements
+ "\n"
+ json_schema
)
else:
return (
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。\n\n"
"## 工作流程\n"
"论文元信息和正文全文已在上文提供,请仔细阅读。\n"
f"1. 充分理解论文后,用 write_file 将结果保存到 data/papers/{arxiv_id}/summary.json\n"
"2. 用 bash 运行 python scripts/validate_summary.py 验证\n\n"
+ writing_requirements
+ "\n"
+ json_schema
)
# ── pi CLI 调用 ────────────────────────────────────────────────────────
async def call_pi(
meta_path: Path,
pdf_path: Path,
fix_errors: list[str] | None = None,
session_id: str | None = None,
pdf_mode: str = "inject",
) -> tuple[str, str]:
"""调用 pi CLI 非交互模式,返回 (stdout 文本, session_id)。
fix_errors: 如果非空,表示上一次验证失败的错误列表,pi 需要修正这些问题。
session_id: 如果非空,用 --continue 延续该 session;否则创建新 session。
pdf_mode: "inject" = 全量注入 prompt@file),"search" = pi 自主 read 文件。
"""
arxiv_id = meta_path.parent.name
# 提取 PDF 全文(不截断),根据实际大小自动选择模式
txt_path = extract_pdf_text(pdf_path, max_chars=None)
txt_size = len(txt_path.read_text(encoding="utf-8"))
actual_mode = pdf_mode
if pdf_mode == "auto":
if txt_size > 80_000:
actual_mode = "search"
logger.info(
"Auto mode: %s text=%d chars > 80k → search", arxiv_id, txt_size
)
else:
actual_mode = "inject"
logger.info(
"Auto mode: %s text=%d chars ≤ 80k → inject", arxiv_id, txt_size
)
# inject 模式需要截断过长的文本(避免撑爆 context)
if actual_mode == "inject" and txt_size > 80_000:
body = txt_path.read_text(encoding="utf-8")
trimmed = body[:80_000].rstrip()
txt_path.write_text(trimmed, encoding="utf-8")
logger.info("Truncated %s for inject: %d%d chars", arxiv_id, txt_size, len(trimmed))
prompt_text = _build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
# 构建 session ID(每篇论文一个独立 session)
if session_id is None:
import uuid
session_id = f"summary-{arxiv_id}-{uuid.uuid4().hex[:8]}"
# 工具列表:search 模式需要 read 工具
tools = "bash,write_file" if actual_mode != "search" else "bash,write_file,read"
cmd = [
settings.PI_BIN,
"-p",
"--tools", tools,
]
if fix_errors:
cmd += ["--session", session_id, "--continue"]
else:
cmd += ["--session-id", session_id]
cmd += [
"--skill",
settings.SUMMARY_SKILL,
prompt_text,
]
if not fix_errors and actual_mode != "search":
# inject 模式:首次调用传 @filesearch 模式 pi 自己 read,不注入
cmd += [f"@{meta_path}", f"@{txt_path}"]
logger.info(
"Calling pi for %s (fix=%s, session=%s, mode=%s)",
arxiv_id, bool(fix_errors), session_id, actual_mode,
)
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(),
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise PiTimeoutError(f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s")
if proc.returncode != 0:
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
return stdout.decode("utf-8", errors="replace"), session_id
# ── JSON 提取 ──────────────────────────────────────────────────────────
def extract_json(raw_output: str) -> dict:
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
# 策略 1:整体直接解析
stripped = raw_output.strip()
try:
result = json.loads(stripped)
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
pass
# 策略 2:提取 ```json ... ``` 代码块
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
for match in fence_pattern.finditer(raw_output):
try:
result = json.loads(match.group(1).strip())
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
continue
# 策略 3:匹配包含 title_zh 的最大 {...} 块
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
for match in brace_pattern.finditer(raw_output):
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
continue
# 更宽松:找到最大的 { ... } 平衡块
best = None
best_len = 0
for i, ch in enumerate(raw_output):
if ch != "{":
continue
depth = 0
for j in range(i, len(raw_output)):
if raw_output[j] == "{":
depth += 1
elif raw_output[j] == "}":
depth -= 1
if depth == 0:
candidate = raw_output[i : j + 1]
if len(candidate) > best_len:
try:
parsed = json.loads(candidate)
if isinstance(parsed, dict):
best = parsed
best_len = len(candidate)
except json.JSONDecodeError:
pass
break
if best is not None:
return best
raise JsonNotFoundError("no JSON object found in pi output")