feat: add claude backend, refactor summary utilities, improve batch worker pattern, add pymupdf4llm
This commit is contained in:
+43
-263
@@ -1,17 +1,38 @@
|
||||
"""pi CLI 调用与 JSON 提取 — 调用 pi 生成总结,从输出中提取结构化 JSON。"""
|
||||
"""pi CLI 后端 — 调用 pi 子进程生成总结。
|
||||
|
||||
通用工具函数(prompt 构建、PDF 提取、JSON 提取、meta.json)已移至 summary_utils.py。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from app.config import settings
|
||||
from app.services.summary_utils import (
|
||||
JsonNotFoundError,
|
||||
build_prompt,
|
||||
extract_json,
|
||||
extract_pdf_text,
|
||||
write_meta_json,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 重新导出,保持向后兼容
|
||||
__all__ = [
|
||||
"PiTimeoutError",
|
||||
"PiProcessError",
|
||||
"JsonNotFoundError",
|
||||
"call_pi",
|
||||
"write_meta_json",
|
||||
"extract_pdf_text",
|
||||
"build_prompt",
|
||||
"extract_json",
|
||||
]
|
||||
|
||||
|
||||
# ── 自定义异常 ──────────────────────────────────────────────────────────
|
||||
|
||||
@@ -27,201 +48,6 @@ class PiProcessError(Exception):
|
||||
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
|
||||
|
||||
|
||||
class JsonNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# ── meta.json ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def write_meta_json(paper) -> Path:
|
||||
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
|
||||
from app.services.pdf_downloader import paper_dir
|
||||
|
||||
d = paper_dir(paper.arxiv_id)
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
meta_path = d / "meta.json"
|
||||
|
||||
authors = [a.name for a in paper.authors]
|
||||
tags = [t.tag for t in paper.tags]
|
||||
meta = {
|
||||
"arxiv_id": paper.arxiv_id,
|
||||
"title_en": paper.title_en,
|
||||
"abstract": paper.abstract or "",
|
||||
"published_at": paper.published_at.isoformat() if paper.published_at else None,
|
||||
"authors": authors,
|
||||
"tags": tags,
|
||||
"upvotes": paper.upvotes,
|
||||
}
|
||||
meta_path.write_text(
|
||||
json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8"
|
||||
)
|
||||
return meta_path
|
||||
|
||||
|
||||
# ── PDF 文本提取 ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _trim_body(text: str, max_chars: int | None = None) -> str:
|
||||
"""去除参考文献,保留正文+附录,超长时从末尾截断。
|
||||
|
||||
策略:
|
||||
1. 去掉 References/Bibliography 段落(纯引用列表,对解读无用)
|
||||
2. 正文 + 附录全部保留
|
||||
3. 如果指定了 max_chars 且总长超过,从末尾截断(附录靠后,优先保留正文)
|
||||
"""
|
||||
import re
|
||||
|
||||
# 找 References 段落的位置(在 Appendix 之后的那个)
|
||||
# 简单策略:找到 References 标题,如果后面没有 Appendix 就全删
|
||||
# 如果后面还有 Appendix,只删 References 到 Appendix 之间的内容
|
||||
ref_match = re.search(r"(?m)^(?:References|Bibliography|参考文献)\s*$", text)
|
||||
if ref_match:
|
||||
ref_start = ref_match.start()
|
||||
# 看 References 之后有没有 Appendix
|
||||
after_ref = text[ref_start:]
|
||||
app_match = re.search(
|
||||
r"(?m)^(?:A\s+(?:Appendix|Supplementary)|Appendix|附录)\s*$", after_ref
|
||||
)
|
||||
if app_match:
|
||||
# References 之后有 Appendix:只删 References 段
|
||||
ref_end = ref_start + app_match.start()
|
||||
text = text[:ref_start] + text[ref_end:]
|
||||
else:
|
||||
# References 之后没有 Appendix:删掉从 References 到结尾
|
||||
text = text[:ref_start].rstrip()
|
||||
|
||||
# 去掉 Acknowledgments(对解读无用)
|
||||
ack_match = re.search(r"(?m)^(?:Acknowledgments?\s*|致谢\s*)$", text)
|
||||
if ack_match:
|
||||
# 只删 Acknowledgments 本身,不删后面的内容
|
||||
next_section = re.search(r"(?m)^(?:A\s|Appendix|Supplementary|附录)\s*$", text[ack_match.start():])
|
||||
if next_section:
|
||||
text = text[:ack_match.start()] + text[ack_match.start() + next_section.start():]
|
||||
else:
|
||||
text = text[:ack_match.start()].rstrip()
|
||||
|
||||
# 最后:如果指定了上限且超长,从末尾截断(附录在后面,正文在前面,优先保留正文)
|
||||
if max_chars is not None and len(text) > max_chars:
|
||||
text = text[:max_chars].rstrip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def extract_pdf_text(pdf_path: Path, max_chars: int | None = None) -> Path:
|
||||
"""用 pymupdf 提取 PDF 正文文本,保存为 .txt。
|
||||
|
||||
max_chars=None 时不截断,给 search/auto 模式保留完整内容。
|
||||
"""
|
||||
import pymupdf
|
||||
|
||||
txt_path = pdf_path.with_suffix(".txt")
|
||||
if txt_path.exists():
|
||||
# 缓存优先;如果需重新提取(不同 max_chars),先删旧文件
|
||||
return txt_path
|
||||
|
||||
doc = pymupdf.open(str(pdf_path))
|
||||
raw_text = "\n\n".join(page.get_text() for page in doc)
|
||||
doc.close()
|
||||
|
||||
body = _trim_body(raw_text, max_chars=max_chars)
|
||||
txt_path.write_text(body, encoding="utf-8")
|
||||
logger.info(
|
||||
"Extracted PDF text: %s (%d -> %d chars, -%d%%)",
|
||||
txt_path,
|
||||
len(raw_text),
|
||||
len(body),
|
||||
(1 - len(body) / len(raw_text)) * 100 if raw_text else 0,
|
||||
)
|
||||
return txt_path
|
||||
|
||||
|
||||
# ── Prompt 构建 ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _build_prompt(
|
||||
arxiv_id: str,
|
||||
meta_path: Path,
|
||||
txt_path: Path,
|
||||
pdf_mode: str,
|
||||
fix_errors: list[str] | None = None,
|
||||
) -> str:
|
||||
"""根据模式构建 pi prompt。
|
||||
|
||||
inject: 全量注入,prompt 末尾包含论文全文内容
|
||||
search: pi 自主 read 文件,prompt 只包含工作流指令
|
||||
"""
|
||||
json_schema = (
|
||||
"## 必须包含以下字段(不要自创字段名):\n"
|
||||
'{"arxiv_id": "...", '
|
||||
'"title_zh": "中文标题", '
|
||||
'"one_line": "一句话概括(≤50字)", '
|
||||
'"tags": ["标签1","标签2"], '
|
||||
'"difficulty": "入门/进阶/前沿", '
|
||||
'"prerequisites": {"concepts": [{"term":"术语","explanation":"详细解释这个概念是什么、怎么工作的(50-150字)","why_matters":"为什么读懂本文需要它"}]}, '
|
||||
'"motivation": {"problem": "详细段落:现有方法的具体问题(包含具体场景和数据)", '
|
||||
'"goal": "详细段落:本文的具体目标", '
|
||||
'"gap": "详细段落:本文的独特切入角度"}, '
|
||||
'"method": {"overview": "详细段落:方法整体思路(先直觉再技术路线)", '
|
||||
'"key_idea": "详细段落:核心创新点(和已有方法的本质区别)", '
|
||||
'"steps": "详细段落:方法步骤的完整描述(每步的输入输出和具体操作)", '
|
||||
'"novelty": "详细段落:技术新颖性分析"}, '
|
||||
'"results": {"main_findings": "详细段落:核心发现(带具体数字和指标,逐一分析每个实验)", '
|
||||
'"benchmarks": [{"task":"任务","metric":"指标","this_work":"本文结果","baseline":"基线","improvement":"提升"}], '
|
||||
'"limitations": "详细段落:局限性分析(作者承认的+你自己的观察")}, '
|
||||
'"improvements": {"weaknesses": "详细段落:独立分析的弱点(具体场景,每个弱点给改进方向)", '
|
||||
'"future_work": "详细段落:未来研究方向(作者提出的+基于成果可延伸的)", '
|
||||
'"reproducibility": "详细段落:复现评估(开源情况、数据、算力、难度")}, '
|
||||
'"figures": [{"id":"Figure 1","caption":"原图标题","description":"文字描述图展示了什么","reason":"为什么这张图对理解论文重要","section":"method"},'
|
||||
'{"id":"Table 1","caption":"表格标题","description":"文字描述表格包含的数据和结论","reason":"为什么这个表格对理解论文重要","section":"results"}]'
|
||||
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Table,id 严格使用 \"Figure N\" 或 \"Table N\" 格式。"
|
||||
"section 必须是 motivation/method/results/limitations 之一,表示该图最适合展示在哪个章节。"
|
||||
"}"
|
||||
)
|
||||
|
||||
writing_requirements = (
|
||||
"## 写作要求\n"
|
||||
"- 每个字符串字段必须写成详细段落(200-500字),不要用列表或数组\n"
|
||||
"- 必须包含论文中的具体数据、数字、实验指标\n"
|
||||
"- 像资深同事给同事讲论文一样,专业但易懂\n"
|
||||
"- 数学公式、符号、变量必须使用 LaTeX 格式:行内公式用 $...$,独立公式用 $$...$$\n"
|
||||
" 例如:损失函数 $\\mathcal{L} = -\\sum_{i} \\log p(y_i | x_i)$,学习率 $\\eta$\n"
|
||||
)
|
||||
|
||||
if fix_errors:
|
||||
error_list = "\n".join(f"- {e}" for e in fix_errors)
|
||||
return (
|
||||
"你之前生成的 JSON 存在以下问题,请修正后重新用 write_file 保存到 "
|
||||
f"data/papers/{arxiv_id}/summary.json:\n\n"
|
||||
f"{error_list}\n\n"
|
||||
"注意:所有字符串字段必须是详细段落(≥50字),不能是数组或列表。"
|
||||
"修正后请用 bash 运行 python scripts/validate_summary.py 验证。"
|
||||
)
|
||||
|
||||
if pdf_mode == "search":
|
||||
return (
|
||||
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。\n\n"
|
||||
"## 工作流程\n"
|
||||
f"1. 先用 read 工具读取 {meta_path} 了解论文元信息(标题、作者、摘要)\n"
|
||||
f"2. 再用 read 工具阅读 {txt_path}(论文正文全文),可以多次读取定位关键段落\n"
|
||||
f"3. 充分理解后,用 write_file 将结果保存到 data/papers/{arxiv_id}/summary.json\n\n"
|
||||
+ writing_requirements
|
||||
+ "\n"
|
||||
+ json_schema
|
||||
)
|
||||
else:
|
||||
return (
|
||||
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。\n\n"
|
||||
"## 工作流程\n"
|
||||
"论文元信息和正文全文已在上文提供,请仔细阅读。\n"
|
||||
f"1. 充分理解论文后,用 write_file 将结果保存到 data/papers/{arxiv_id}/summary.json\n"
|
||||
"2. 用 bash 运行 python scripts/validate_summary.py 验证\n\n"
|
||||
+ writing_requirements
|
||||
+ "\n"
|
||||
+ json_schema
|
||||
)
|
||||
|
||||
|
||||
# ── pi CLI 调用 ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -264,12 +90,10 @@ async def call_pi(
|
||||
txt_path.write_text(trimmed, encoding="utf-8")
|
||||
logger.info("Truncated %s for inject: %d → %d chars", arxiv_id, txt_size, len(trimmed))
|
||||
|
||||
prompt_text = _build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
|
||||
prompt_text = build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
|
||||
|
||||
# 构建 session ID(每篇论文一个独立 session)
|
||||
if session_id is None:
|
||||
import uuid
|
||||
|
||||
session_id = f"summary-{arxiv_id}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 工具列表:search 模式需要 read 工具
|
||||
@@ -297,6 +121,9 @@ async def call_pi(
|
||||
arxiv_id, bool(fix_errors), session_id, actual_mode,
|
||||
)
|
||||
|
||||
import time as _time
|
||||
_t_sub_start = _time.monotonic()
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
@@ -312,69 +139,22 @@ async def call_pi(
|
||||
await proc.wait()
|
||||
raise PiTimeoutError(f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s")
|
||||
|
||||
_t_sub_end = _time.monotonic()
|
||||
|
||||
# 检查 summary.json 是否由 pi 子进程写入
|
||||
_summary_file = pdf_path.parent / "summary.json"
|
||||
_file_info = ""
|
||||
if _summary_file.exists():
|
||||
_file_mtime = _summary_file.stat().st_mtime
|
||||
_file_size = _summary_file.stat().st_size
|
||||
_file_info = f" summary.json={_file_size}B"
|
||||
|
||||
logger.info(
|
||||
"pi subprocess for %s: %.2fs%s",
|
||||
arxiv_id, _t_sub_end - _t_sub_start, _file_info,
|
||||
)
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
|
||||
|
||||
return stdout.decode("utf-8", errors="replace"), session_id
|
||||
|
||||
|
||||
# ── JSON 提取 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def extract_json(raw_output: str) -> dict:
|
||||
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
|
||||
# 策略 1:整体直接解析
|
||||
stripped = raw_output.strip()
|
||||
try:
|
||||
result = json.loads(stripped)
|
||||
if isinstance(result, dict) and "title_zh" in result:
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 策略 2:提取 ```json ... ``` 代码块
|
||||
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
|
||||
for match in fence_pattern.finditer(raw_output):
|
||||
try:
|
||||
result = json.loads(match.group(1).strip())
|
||||
if isinstance(result, dict) and "title_zh" in result:
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# 策略 3:匹配包含 title_zh 的最大 {...} 块
|
||||
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
|
||||
for match in brace_pattern.finditer(raw_output):
|
||||
try:
|
||||
return json.loads(match.group(0))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# 更宽松:找到最大的 { ... } 平衡块
|
||||
best = None
|
||||
best_len = 0
|
||||
for i, ch in enumerate(raw_output):
|
||||
if ch != "{":
|
||||
continue
|
||||
depth = 0
|
||||
for j in range(i, len(raw_output)):
|
||||
if raw_output[j] == "{":
|
||||
depth += 1
|
||||
elif raw_output[j] == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
candidate = raw_output[i : j + 1]
|
||||
if len(candidate) > best_len:
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
if isinstance(parsed, dict):
|
||||
best = parsed
|
||||
best_len = len(candidate)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
break
|
||||
|
||||
if best is not None:
|
||||
return best
|
||||
|
||||
raise JsonNotFoundError("no JSON object found in pi output")
|
||||
|
||||
Reference in New Issue
Block a user