feat: add claude backend, refactor summary utilities, improve batch worker pattern, add pymupdf4llm

This commit is contained in:
2026-06-12 22:25:57 +08:00
parent b42e9149e5
commit e2f0e1a8be
13 changed files with 1350 additions and 1010 deletions
+43 -263
View File
@@ -1,17 +1,38 @@
"""pi CLI 调用与 JSON 提取 — 调用 pi 生成总结,从输出中提取结构化 JSON。"""
"""pi CLI 后端 — 调用 pi 子进程生成总结
通用工具函数(prompt 构建、PDF 提取、JSON 提取、meta.json)已移至 summary_utils.py。
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
import uuid
from pathlib import Path
from app.config import settings
from app.services.summary_utils import (
JsonNotFoundError,
build_prompt,
extract_json,
extract_pdf_text,
write_meta_json,
)
logger = logging.getLogger(__name__)
# 重新导出,保持向后兼容
__all__ = [
"PiTimeoutError",
"PiProcessError",
"JsonNotFoundError",
"call_pi",
"write_meta_json",
"extract_pdf_text",
"build_prompt",
"extract_json",
]
# ── 自定义异常 ──────────────────────────────────────────────────────────
@@ -27,201 +48,6 @@ class PiProcessError(Exception):
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
class JsonNotFoundError(Exception):
pass
# ── meta.json ───────────────────────────────────────────────────────────
def write_meta_json(paper) -> Path:
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
from app.services.pdf_downloader import paper_dir
d = paper_dir(paper.arxiv_id)
d.mkdir(parents=True, exist_ok=True)
meta_path = d / "meta.json"
authors = [a.name for a in paper.authors]
tags = [t.tag for t in paper.tags]
meta = {
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"abstract": paper.abstract or "",
"published_at": paper.published_at.isoformat() if paper.published_at else None,
"authors": authors,
"tags": tags,
"upvotes": paper.upvotes,
}
meta_path.write_text(
json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8"
)
return meta_path
# ── PDF 文本提取 ────────────────────────────────────────────────────────
def _trim_body(text: str, max_chars: int | None = None) -> str:
"""去除参考文献,保留正文+附录,超长时从末尾截断。
策略:
1. 去掉 References/Bibliography 段落(纯引用列表,对解读无用)
2. 正文 + 附录全部保留
3. 如果指定了 max_chars 且总长超过,从末尾截断(附录靠后,优先保留正文)
"""
import re
# 找 References 段落的位置(在 Appendix 之后的那个)
# 简单策略:找到 References 标题,如果后面没有 Appendix 就全删
# 如果后面还有 Appendix,只删 References 到 Appendix 之间的内容
ref_match = re.search(r"(?m)^(?:References|Bibliography|参考文献)\s*$", text)
if ref_match:
ref_start = ref_match.start()
# 看 References 之后有没有 Appendix
after_ref = text[ref_start:]
app_match = re.search(
r"(?m)^(?:A\s+(?:Appendix|Supplementary)|Appendix|附录)\s*$", after_ref
)
if app_match:
# References 之后有 Appendix:只删 References 段
ref_end = ref_start + app_match.start()
text = text[:ref_start] + text[ref_end:]
else:
# References 之后没有 Appendix:删掉从 References 到结尾
text = text[:ref_start].rstrip()
# 去掉 Acknowledgments(对解读无用)
ack_match = re.search(r"(?m)^(?:Acknowledgments?\s*|致谢\s*)$", text)
if ack_match:
# 只删 Acknowledgments 本身,不删后面的内容
next_section = re.search(r"(?m)^(?:A\s|Appendix|Supplementary|附录)\s*$", text[ack_match.start():])
if next_section:
text = text[:ack_match.start()] + text[ack_match.start() + next_section.start():]
else:
text = text[:ack_match.start()].rstrip()
# 最后:如果指定了上限且超长,从末尾截断(附录在后面,正文在前面,优先保留正文)
if max_chars is not None and len(text) > max_chars:
text = text[:max_chars].rstrip()
return text
def extract_pdf_text(pdf_path: Path, max_chars: int | None = None) -> Path:
"""用 pymupdf 提取 PDF 正文文本,保存为 .txt。
max_chars=None 时不截断,给 search/auto 模式保留完整内容。
"""
import pymupdf
txt_path = pdf_path.with_suffix(".txt")
if txt_path.exists():
# 缓存优先;如果需重新提取(不同 max_chars),先删旧文件
return txt_path
doc = pymupdf.open(str(pdf_path))
raw_text = "\n\n".join(page.get_text() for page in doc)
doc.close()
body = _trim_body(raw_text, max_chars=max_chars)
txt_path.write_text(body, encoding="utf-8")
logger.info(
"Extracted PDF text: %s (%d -> %d chars, -%d%%)",
txt_path,
len(raw_text),
len(body),
(1 - len(body) / len(raw_text)) * 100 if raw_text else 0,
)
return txt_path
# ── Prompt 构建 ─────────────────────────────────────────────────────────
def _build_prompt(
arxiv_id: str,
meta_path: Path,
txt_path: Path,
pdf_mode: str,
fix_errors: list[str] | None = None,
) -> str:
"""根据模式构建 pi prompt。
inject: 全量注入,prompt 末尾包含论文全文内容
search: pi 自主 read 文件,prompt 只包含工作流指令
"""
json_schema = (
"## 必须包含以下字段(不要自创字段名):\n"
'{"arxiv_id": "...", '
'"title_zh": "中文标题", '
'"one_line": "一句话概括(≤50字)", '
'"tags": ["标签1","标签2"], '
'"difficulty": "入门/进阶/前沿", '
'"prerequisites": {"concepts": [{"term":"术语","explanation":"详细解释这个概念是什么、怎么工作的(50-150字)","why_matters":"为什么读懂本文需要它"}]}, '
'"motivation": {"problem": "详细段落:现有方法的具体问题(包含具体场景和数据)", '
'"goal": "详细段落:本文的具体目标", '
'"gap": "详细段落:本文的独特切入角度"}, '
'"method": {"overview": "详细段落:方法整体思路(先直觉再技术路线)", '
'"key_idea": "详细段落:核心创新点(和已有方法的本质区别)", '
'"steps": "详细段落:方法步骤的完整描述(每步的输入输出和具体操作)", '
'"novelty": "详细段落:技术新颖性分析"}, '
'"results": {"main_findings": "详细段落:核心发现(带具体数字和指标,逐一分析每个实验)", '
'"benchmarks": [{"task":"任务","metric":"指标","this_work":"本文结果","baseline":"基线","improvement":"提升"}], '
'"limitations": "详细段落:局限性分析(作者承认的+你自己的观察")}, '
'"improvements": {"weaknesses": "详细段落:独立分析的弱点(具体场景,每个弱点给改进方向)", '
'"future_work": "详细段落:未来研究方向(作者提出的+基于成果可延伸的)", '
'"reproducibility": "详细段落:复现评估(开源情况、数据、算力、难度")}, '
'"figures": [{"id":"Figure 1","caption":"原图标题","description":"文字描述图展示了什么","reason":"为什么这张图对理解论文重要","section":"method"},'
'{"id":"Table 1","caption":"表格标题","description":"文字描述表格包含的数据和结论","reason":"为什么这个表格对理解论文重要","section":"results"}]'
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Tableid 严格使用 \"Figure N\"\"Table N\" 格式。"
"section 必须是 motivation/method/results/limitations 之一,表示该图最适合展示在哪个章节。"
"}"
)
writing_requirements = (
"## 写作要求\n"
"- 每个字符串字段必须写成详细段落(200-500字),不要用列表或数组\n"
"- 必须包含论文中的具体数据、数字、实验指标\n"
"- 像资深同事给同事讲论文一样,专业但易懂\n"
"- 数学公式、符号、变量必须使用 LaTeX 格式:行内公式用 $...$,独立公式用 $$...$$\n"
" 例如:损失函数 $\\mathcal{L} = -\\sum_{i} \\log p(y_i | x_i)$,学习率 $\\eta$\n"
)
if fix_errors:
error_list = "\n".join(f"- {e}" for e in fix_errors)
return (
"你之前生成的 JSON 存在以下问题,请修正后重新用 write_file 保存到 "
f"data/papers/{arxiv_id}/summary.json\n\n"
f"{error_list}\n\n"
"注意:所有字符串字段必须是详细段落(≥50字),不能是数组或列表。"
"修正后请用 bash 运行 python scripts/validate_summary.py 验证。"
)
if pdf_mode == "search":
return (
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。\n\n"
"## 工作流程\n"
f"1. 先用 read 工具读取 {meta_path} 了解论文元信息(标题、作者、摘要)\n"
f"2. 再用 read 工具阅读 {txt_path}(论文正文全文),可以多次读取定位关键段落\n"
f"3. 充分理解后,用 write_file 将结果保存到 data/papers/{arxiv_id}/summary.json\n\n"
+ writing_requirements
+ "\n"
+ json_schema
)
else:
return (
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。\n\n"
"## 工作流程\n"
"论文元信息和正文全文已在上文提供,请仔细阅读。\n"
f"1. 充分理解论文后,用 write_file 将结果保存到 data/papers/{arxiv_id}/summary.json\n"
"2. 用 bash 运行 python scripts/validate_summary.py 验证\n\n"
+ writing_requirements
+ "\n"
+ json_schema
)
# ── pi CLI 调用 ────────────────────────────────────────────────────────
@@ -264,12 +90,10 @@ async def call_pi(
txt_path.write_text(trimmed, encoding="utf-8")
logger.info("Truncated %s for inject: %d%d chars", arxiv_id, txt_size, len(trimmed))
prompt_text = _build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
prompt_text = build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
# 构建 session ID(每篇论文一个独立 session)
if session_id is None:
import uuid
session_id = f"summary-{arxiv_id}-{uuid.uuid4().hex[:8]}"
# 工具列表:search 模式需要 read 工具
@@ -297,6 +121,9 @@ async def call_pi(
arxiv_id, bool(fix_errors), session_id, actual_mode,
)
import time as _time
_t_sub_start = _time.monotonic()
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
@@ -312,69 +139,22 @@ async def call_pi(
await proc.wait()
raise PiTimeoutError(f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s")
_t_sub_end = _time.monotonic()
# 检查 summary.json 是否由 pi 子进程写入
_summary_file = pdf_path.parent / "summary.json"
_file_info = ""
if _summary_file.exists():
_file_mtime = _summary_file.stat().st_mtime
_file_size = _summary_file.stat().st_size
_file_info = f" summary.json={_file_size}B"
logger.info(
"pi subprocess for %s: %.2fs%s",
arxiv_id, _t_sub_end - _t_sub_start, _file_info,
)
if proc.returncode != 0:
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
return stdout.decode("utf-8", errors="replace"), session_id
# ── JSON 提取 ──────────────────────────────────────────────────────────
def extract_json(raw_output: str) -> dict:
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
# 策略 1:整体直接解析
stripped = raw_output.strip()
try:
result = json.loads(stripped)
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
pass
# 策略 2:提取 ```json ... ``` 代码块
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
for match in fence_pattern.finditer(raw_output):
try:
result = json.loads(match.group(1).strip())
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
continue
# 策略 3:匹配包含 title_zh 的最大 {...} 块
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
for match in brace_pattern.finditer(raw_output):
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
continue
# 更宽松:找到最大的 { ... } 平衡块
best = None
best_len = 0
for i, ch in enumerate(raw_output):
if ch != "{":
continue
depth = 0
for j in range(i, len(raw_output)):
if raw_output[j] == "{":
depth += 1
elif raw_output[j] == "}":
depth -= 1
if depth == 0:
candidate = raw_output[i : j + 1]
if len(candidate) > best_len:
try:
parsed = json.loads(candidate)
if isinstance(parsed, dict):
best = parsed
best_len = len(candidate)
except json.JSONDecodeError:
pass
break
if best is not None:
return best
raise JsonNotFoundError("no JSON object found in pi output")