21f16e6756
- Split summarizer into summary_generator and summary_persister modules - Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection - Add layout_detector service for PicoDet-S_layout_3cls integration - Add exceptions module with ConflictError and NotFoundError - Improve admin dashboard with better statistics and task management - Add design review document with system optimization suggestions - Add new tests for crawler, pdf_downloader, pipeline, and summary_utils - Update dependencies and configuration - Clean up dead code and improve error handling
276 lines
12 KiB
Python
276 lines
12 KiB
Python
"""总结工具函数 — PDF 文本提取、prompt 构建、JSON 提取、meta.json 写入。
|
||
|
||
与后端无关的通用逻辑,pi 和 claude 后端共享。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
from pathlib import Path
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ── 自定义异常 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
class JsonNotFoundError(Exception):
|
||
pass
|
||
|
||
|
||
# ── meta.json ───────────────────────────────────────────────────────────
|
||
|
||
|
||
def write_meta_json(paper) -> Path:
|
||
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
|
||
from app.services.pdf_downloader import paper_dir
|
||
|
||
d = paper_dir(paper.arxiv_id)
|
||
d.mkdir(parents=True, exist_ok=True)
|
||
meta_path = d / "meta.json"
|
||
|
||
authors = [a.name for a in paper.authors]
|
||
tags = [t.tag for t in paper.tags]
|
||
meta = {
|
||
"arxiv_id": paper.arxiv_id,
|
||
"title_en": paper.title_en,
|
||
"abstract": paper.abstract or "",
|
||
"published_at": paper.published_at.isoformat() if paper.published_at else None,
|
||
"authors": authors,
|
||
"tags": tags,
|
||
"upvotes": paper.upvotes,
|
||
}
|
||
meta_path.write_text(
|
||
json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8"
|
||
)
|
||
return meta_path
|
||
|
||
|
||
# ── PDF 文本提取 ────────────────────────────────────────────────────────
|
||
|
||
|
||
def _trim_body(text: str, max_chars: int | None = None) -> str:
|
||
"""去除参考文献,保留正文+附录,超长时从末尾截断。
|
||
|
||
策略:
|
||
1. 去掉 References/Bibliography 段落(纯引用列表,对解读无用)
|
||
2. 正文 + 附录全部保留
|
||
3. 如果指定了 max_chars 且总长超过,从末尾截断(附录靠后,优先保留正文)
|
||
"""
|
||
# 找 References 段落的位置(在 Appendix 之后的那个)
|
||
ref_match = re.search(r"(?m)^(?:References|Bibliography|参考文献)\s*$", text)
|
||
if ref_match:
|
||
ref_start = ref_match.start()
|
||
# 看 References 之后有没有 Appendix
|
||
after_ref = text[ref_start:]
|
||
app_match = re.search(
|
||
r"(?m)^(?:A\s+(?:Appendix|Supplementary)|Appendix|附录)\s*$", after_ref
|
||
)
|
||
if app_match:
|
||
# References 之后有 Appendix:只删 References 段
|
||
ref_end = ref_start + app_match.start()
|
||
text = text[:ref_start] + text[ref_end:]
|
||
else:
|
||
# References 之后没有 Appendix:删掉从 References 到结尾
|
||
text = text[:ref_start].rstrip()
|
||
|
||
# 去掉 Acknowledgments(对解读无用)
|
||
ack_match = re.search(r"(?m)^(?:Acknowledgments?\s*|致谢\s*)$", text)
|
||
if ack_match:
|
||
# 只删 Acknowledgments 本身,不删后面的内容
|
||
next_section = re.search(
|
||
r"(?m)^(?:A\s|Appendix|Supplementary|附录)\s*$", text[ack_match.start() :]
|
||
)
|
||
if next_section:
|
||
text = (
|
||
text[: ack_match.start()]
|
||
+ text[ack_match.start() + next_section.start() :]
|
||
)
|
||
else:
|
||
text = text[: ack_match.start()].rstrip()
|
||
|
||
# 最后:如果指定了上限且超长,从末尾截断(附录在后面,正文在前面,优先保留正文)
|
||
if max_chars is not None and len(text) > max_chars:
|
||
text = text[:max_chars].rstrip()
|
||
|
||
return text
|
||
|
||
|
||
def extract_pdf_text(pdf_path: Path, max_chars: int | None = None) -> Path:
|
||
"""用 pymupdf 提取 PDF 正文文本,保存为 .txt。
|
||
|
||
max_chars=None 时不截断,给 search/auto 模式保留完整内容。
|
||
"""
|
||
import pymupdf
|
||
|
||
txt_path = pdf_path.with_suffix(".txt")
|
||
if txt_path.exists():
|
||
# 缓存优先;如果需重新提取(不同 max_chars),先删旧文件
|
||
return txt_path
|
||
|
||
with pymupdf.open(str(pdf_path)) as doc:
|
||
# sort=True 启用阅读顺序检测,避免双栏论文中跨栏错位
|
||
raw_text = "\n\n".join(page.get_text(sort=True) for page in doc)
|
||
|
||
body = _trim_body(raw_text, max_chars=max_chars)
|
||
txt_path.write_text(body, encoding="utf-8")
|
||
logger.info(
|
||
"Extracted PDF text: %s (%d -> %d chars, -%d%%)",
|
||
txt_path,
|
||
len(raw_text),
|
||
len(body),
|
||
(1 - len(body) / len(raw_text)) * 100 if raw_text else 0,
|
||
)
|
||
return txt_path
|
||
|
||
|
||
# ── Prompt 构建 ─────────────────────────────────────────────────────────
|
||
|
||
|
||
def build_prompt(
|
||
arxiv_id: str,
|
||
meta_path: Path,
|
||
txt_path: Path,
|
||
pdf_mode: str,
|
||
fix_errors: list[str] | None = None,
|
||
) -> str:
|
||
"""根据模式构建 prompt。
|
||
|
||
inject: 全量注入,prompt 末尾包含论文全文内容
|
||
search: pi 自主 read 文件,prompt 只包含工作流指令
|
||
"""
|
||
json_schema = (
|
||
"## 必须包含以下字段(不要自创字段名):\n"
|
||
'{"arxiv_id": "...", '
|
||
'"title_zh": "中文标题", '
|
||
'"one_line": "一句话概括(≤50字)", '
|
||
'"tags": ["标签1","标签2"], '
|
||
'"difficulty": "入门/进阶/前沿", '
|
||
'"prerequisites": {"concepts": [{"term":"术语","explanation":"详细解释这个概念是什么、怎么工作的(50-150字)","why_matters":"为什么读懂本文需要它"}]}, '
|
||
'"motivation": {"problem": "详细段落:现有方法的具体问题(包含具体场景和数据)", '
|
||
'"goal": "详细段落:本文的具体目标", '
|
||
'"gap": "详细段落:本文的独特切入角度"}, '
|
||
'"method": {"overview": "详细段落:方法整体思路(先直觉再技术路线)", '
|
||
'"key_idea": "详细段落:核心创新点(和已有方法的本质区别)", '
|
||
'"steps": "详细段落:方法步骤的完整描述(每步的输入输出和具体操作)", '
|
||
'"novelty": "详细段落:技术新颖性分析"}, '
|
||
'"results": {"main_findings": "详细段落:核心发现(带具体数字和指标,逐一分析每个实验)", '
|
||
'"benchmarks": [{"task":"任务","metric":"指标","this_work":"本文结果","baseline":"基线","improvement":"提升"}], '
|
||
'"limitations": "详细段落:局限性分析(作者承认的+你自己的观察")}, '
|
||
'"improvements": {"weaknesses": "详细段落:独立分析的弱点(具体场景,每个弱点给改进方向)", '
|
||
'"future_work": "详细段落:未来研究方向(作者提出的+基于成果可延伸的)", '
|
||
'"reproducibility": "详细段落:复现评估(开源情况、数据、算力、难度")}, '
|
||
'"figures": [{"id":"Figure 1","caption":"原图标题","description":"文字描述图展示了什么","reason":"为什么这张图对理解论文重要","section":"method"},'
|
||
'{"id":"Table 1","caption":"表格标题","description":"文字描述表格包含的数据和结论","reason":"为什么这个表格对理解论文重要","section":"results"}]'
|
||
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Table。"
|
||
'id 必须严格复用论文原文的写法(原文用 "Fig. 1" 就写 "Fig. 1",用 "Figure A1" 就写 "Figure A1",用 "Table 1" 就写 "Table 1")。'
|
||
"section 必须是 motivation/method/results/limitations 之一,表示该图最适合展示在哪个章节。"
|
||
"}"
|
||
)
|
||
|
||
writing_requirements = (
|
||
"## 写作要求\n"
|
||
"- 每个字符串字段必须写成详细段落(200-500字),不要用列表或数组\n"
|
||
"- 必须包含论文中的具体数据、数字、实验指标\n"
|
||
"- 像资深同事给同事讲论文一样,专业但易懂\n"
|
||
"- 数学公式、符号、变量必须使用 LaTeX 格式:行内公式用 $...$,独立公式用 $$...$$\n"
|
||
" 例如:损失函数 $\\mathcal{L} = -\\sum_{i} \\log p(y_i | x_i)$,学习率 $\\eta$\n"
|
||
)
|
||
|
||
if fix_errors:
|
||
error_list = "\n".join(f"- {e}" for e in fix_errors)
|
||
return (
|
||
"你之前生成的 JSON 存在以下问题,请修正后重新用 write_file 保存到 "
|
||
f"data/papers/{arxiv_id}/summary.json:\n\n"
|
||
f"{error_list}\n\n"
|
||
"注意:所有字符串字段必须是详细段落(≥50字),不能是数组或列表。"
|
||
"修正后请用 bash 运行 python scripts/validate_summary.py 验证。"
|
||
)
|
||
|
||
if pdf_mode == "search":
|
||
return (
|
||
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。\n\n"
|
||
"## 工作流程\n"
|
||
f"1. 先用 read 工具读取 {meta_path} 了解论文元信息(标题、作者、摘要)\n"
|
||
f"2. 再用 read 工具阅读 {txt_path}(论文正文全文),可以多次读取定位关键段落\n"
|
||
f"3. 充分理解后,用 write_file 将结果保存到 data/papers/{arxiv_id}/summary.json\n\n"
|
||
+ writing_requirements
|
||
+ "\n"
|
||
+ json_schema
|
||
)
|
||
else:
|
||
return (
|
||
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。\n\n"
|
||
"## 工作流程\n"
|
||
"论文元信息和正文全文已在上文提供,请仔细阅读。\n"
|
||
f"1. 充分理解论文后,用 write_file 将结果保存到 data/papers/{arxiv_id}/summary.json\n"
|
||
"2. 用 bash 运行 python scripts/validate_summary.py 验证\n\n"
|
||
+ writing_requirements
|
||
+ "\n"
|
||
+ json_schema
|
||
)
|
||
|
||
|
||
# ── JSON 提取 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
def extract_json(raw_output: str) -> dict:
|
||
"""从输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
|
||
# 策略 1:整体直接解析
|
||
stripped = raw_output.strip()
|
||
try:
|
||
result = json.loads(stripped)
|
||
if isinstance(result, dict) and "title_zh" in result:
|
||
return result
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# 策略 2:提取 ```json ... ``` 代码块
|
||
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
|
||
for match in fence_pattern.finditer(raw_output):
|
||
try:
|
||
result = json.loads(match.group(1).strip())
|
||
if isinstance(result, dict) and "title_zh" in result:
|
||
return result
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
# 策略 3:匹配包含 title_zh 的最大 {...} 块
|
||
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
|
||
for match in brace_pattern.finditer(raw_output):
|
||
try:
|
||
return json.loads(match.group(0))
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
# 更宽松:找到最大的 { ... } 平衡块
|
||
best = None
|
||
best_len = 0
|
||
for i, ch in enumerate(raw_output):
|
||
if ch != "{":
|
||
continue
|
||
depth = 0
|
||
for j in range(i, len(raw_output)):
|
||
if raw_output[j] == "{":
|
||
depth += 1
|
||
elif raw_output[j] == "}":
|
||
depth -= 1
|
||
if depth == 0:
|
||
candidate = raw_output[i : j + 1]
|
||
if len(candidate) > best_len:
|
||
try:
|
||
parsed = json.loads(candidate)
|
||
if isinstance(parsed, dict):
|
||
best = parsed
|
||
best_len = len(candidate)
|
||
except json.JSONDecodeError:
|
||
pass
|
||
break
|
||
|
||
if best is not None:
|
||
return best
|
||
|
||
raise JsonNotFoundError("no JSON object found in output")
|