feat: enhance UI, refactor services, improve templates and tests
- Replace image_extractor with pdf_image_extractor service - Enhance pi_client with expanded API capabilities - Improve summarizer service with additional features - Update admin routes with more endpoints - Add login page template - Enhance detail page with comprehensive layout - Improve search and trends pages - Update base template with additional elements - Refactor tests for better coverage - Add validate_summary script - Update project configuration and dependencies
This commit is contained in:
@@ -1,83 +0,0 @@
|
||||
"""LaTeX 图片提取 — 从 arXiv 源码中扫描 \\includegraphics 并提取图片文件。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from app.services.pdf_downloader import download_source_zip, paper_dir, tmp_dir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_INCLUDEGRAPHICS_RE = re.compile(
|
||||
r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
|
||||
)
|
||||
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
|
||||
|
||||
|
||||
async def extract_images_from_source(arxiv_id: str) -> int:
|
||||
"""从 LaTeX 源码中提取图片文件。
|
||||
|
||||
流程:
|
||||
1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
|
||||
2. 扫描 .tex 文件中的 \\includegraphics
|
||||
3. 复制图片到 data/papers/{arxiv_id}/images/
|
||||
4. 清理源码临时文件
|
||||
|
||||
Returns:
|
||||
提取的图片数量
|
||||
"""
|
||||
tmp_source = tmp_dir(arxiv_id) / "source"
|
||||
images_dest = paper_dir(arxiv_id) / "images"
|
||||
|
||||
try:
|
||||
# 下载源码 zip(如果还没下载)
|
||||
if not tmp_source.exists():
|
||||
source_url = f"https://arxiv.org/e-print/{arxiv_id}"
|
||||
await download_source_zip(arxiv_id, source_url, tmp_source)
|
||||
|
||||
if not tmp_source.exists():
|
||||
return 0
|
||||
|
||||
# 扫描 .tex 文件,收集图片路径
|
||||
image_paths: set[str] = set()
|
||||
for tex_file in tmp_source.rglob("*.tex"):
|
||||
try:
|
||||
content = tex_file.read_text(encoding="utf-8", errors="replace")
|
||||
for match in _INCLUDEGRAPHICS_RE.finditer(content):
|
||||
img_path = match.group(1).strip()
|
||||
image_paths.add(img_path)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not image_paths:
|
||||
return 0
|
||||
|
||||
# 查找并复制图片
|
||||
images_dest.mkdir(parents=True, exist_ok=True)
|
||||
copied = 0
|
||||
for img_rel in image_paths:
|
||||
# 尝试在源码目录中找到文件
|
||||
for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
|
||||
candidate = tmp_source / (img_rel + ext)
|
||||
if candidate.is_file():
|
||||
dest_name = candidate.name
|
||||
# 避免文件名冲突
|
||||
dest = images_dest / dest_name
|
||||
if dest.exists():
|
||||
stem = dest.stem
|
||||
suffix = dest.suffix
|
||||
dest = images_dest / f"{stem}_{copied}{suffix}"
|
||||
shutil.copy2(candidate, dest)
|
||||
copied += 1
|
||||
break
|
||||
|
||||
if copied > 0:
|
||||
logger.info("Extracted %d images from source for %s", copied, arxiv_id)
|
||||
return copied
|
||||
|
||||
except Exception:
|
||||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||||
return 0
|
||||
@@ -0,0 +1,261 @@
|
||||
"""PDF 图片与表格提取 — 从 PDF 中提取嵌入图片和表格截图。
|
||||
|
||||
策略:
|
||||
1. 提取 PDF 中嵌入的图片(图表、插图等)
|
||||
2. 检测表格区域,渲染为截图
|
||||
3. 同时搜索页面中的 Figure/Table 标注,记录到 manifest
|
||||
4. 过滤掉过小的图片
|
||||
5. 保存到 data/papers/{arxiv_id}/images/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from app.services.pdf_downloader import paper_dir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 最小面积阈值(像素),小于此值的图片视为图标/装饰
|
||||
_MIN_AREA = 10_000 # ~100x100
|
||||
_MIN_DIM = 80
|
||||
|
||||
# Figure/Table 标注与图片/表格的最大垂直距离(点)
|
||||
_MAX_LABEL_DISTANCE = 120
|
||||
|
||||
# Figure/Table 标注的正则
|
||||
_FIGURE_RE = re.compile(r'\b(?:Fig\.?|Figure)\s*(\d+)\b', re.IGNORECASE)
|
||||
_TABLE_RE = re.compile(r'\bTable\s*(\d+)\b', re.IGNORECASE)
|
||||
|
||||
|
||||
def _find_nearby_labels(
|
||||
rects: list, labels: dict[str, list[tuple[int, float]]], page_num: int
|
||||
) -> list[str]:
|
||||
"""查找与给定矩形区域在位置上接近的 Figure/Table 标注。
|
||||
|
||||
匹配逻辑:标注的垂直位置 (y) 需在图片/表格的上下 _MAX_LABEL_DISTANCE 点范围内。
|
||||
"""
|
||||
matched: list[str] = []
|
||||
for rect in rects:
|
||||
if isinstance(rect, (list, tuple)):
|
||||
y_min, y_max = rect[1], rect[3]
|
||||
else:
|
||||
y_min, y_max = rect.y0, rect.y1
|
||||
|
||||
for label_key, positions in labels.items():
|
||||
for label_page, label_y in positions:
|
||||
if label_page == page_num:
|
||||
# 标注在图片/表格上方或下方的距离
|
||||
distance = min(abs(label_y - y_min), abs(label_y - y_max))
|
||||
if distance <= _MAX_LABEL_DISTANCE:
|
||||
if label_key not in matched:
|
||||
matched.append(label_key)
|
||||
return matched
|
||||
|
||||
|
||||
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
|
||||
"""从 PDF 提取嵌入图片和表格截图,同时生成 manifest。
|
||||
|
||||
Args:
|
||||
arxiv_id: 论文 ID
|
||||
pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf
|
||||
|
||||
Returns:
|
||||
提取的图片+表格数量
|
||||
"""
|
||||
import pymupdf
|
||||
|
||||
if pdf_path is None:
|
||||
pdf_path = Path("data/tmp") / arxiv_id / "paper.pdf"
|
||||
|
||||
if not pdf_path.exists():
|
||||
logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path)
|
||||
return 0
|
||||
|
||||
images_dest = paper_dir(arxiv_id) / "images"
|
||||
images_dest.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
doc = pymupdf.open(str(pdf_path))
|
||||
extracted = 0
|
||||
seen_hashes: set[int] = set()
|
||||
|
||||
# 扫描每页的 Figure/Table 标注位置
|
||||
# figure_labels: {key: [(page_num, y_center)]} — 记录标注在页面中的垂直位置
|
||||
figure_labels: dict[str, list[tuple[int, float]]] = {}
|
||||
table_labels: dict[str, list[tuple[int, float]]] = {}
|
||||
|
||||
for page_num in range(len(doc)):
|
||||
page = doc[page_num]
|
||||
text_dict = page.get_text("dict")
|
||||
for block in text_dict.get("blocks", []):
|
||||
if block.get("type") != 0: # 只看文本块
|
||||
continue
|
||||
block_text = ""
|
||||
for line in block.get("lines", []):
|
||||
for span in line.get("spans", []):
|
||||
block_text += span.get("text", "")
|
||||
for m in _FIGURE_RE.finditer(block_text):
|
||||
key = f"Figure {m.group(1)}"
|
||||
bbox = block.get("bbox", [0, 0, 0, 0])
|
||||
y_center = (bbox[1] + bbox[3]) / 2
|
||||
figure_labels.setdefault(key, []).append((page_num, y_center))
|
||||
for m in _TABLE_RE.finditer(block_text):
|
||||
key = f"Table {m.group(1)}"
|
||||
bbox = block.get("bbox", [0, 0, 0, 0])
|
||||
y_center = (bbox[1] + bbox[3]) / 2
|
||||
table_labels.setdefault(key, []).append((page_num, y_center))
|
||||
|
||||
# 记录每个提取文件的元信息
|
||||
manifest: dict[str, dict] = {}
|
||||
|
||||
for page_num in range(len(doc)):
|
||||
page = doc[page_num]
|
||||
|
||||
# ── 1. 提取嵌入图片 ──
|
||||
image_list = page.get_images(full=True)
|
||||
for img_index, img_info in enumerate(image_list):
|
||||
xref = img_info[0]
|
||||
try:
|
||||
pix = pymupdf.Pixmap(doc, xref)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if pix.width < _MIN_DIM or pix.height < _MIN_DIM:
|
||||
continue
|
||||
if pix.width * pix.height < _MIN_AREA:
|
||||
continue
|
||||
|
||||
img_hash = hash(pix.tobytes()[:1024])
|
||||
if img_hash in seen_hashes:
|
||||
continue
|
||||
seen_hashes.add(img_hash)
|
||||
|
||||
if pix.n >= 5:
|
||||
try:
|
||||
pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
filename = f"page{page_num + 1}_img{img_index + 1}.png"
|
||||
pix.save(str(images_dest / filename))
|
||||
extracted += 1
|
||||
logger.debug("Image: %s (%dx%d)", filename, pix.width, pix.height)
|
||||
|
||||
# 查找该图片位置附近的 Figure 标注
|
||||
img_rects = page.get_image_rects(xref)
|
||||
matched = _find_nearby_labels(img_rects, figure_labels, page_num)
|
||||
manifest[filename] = {"page": page_num + 1, "type": "image", "figures": matched}
|
||||
|
||||
# ── 2. 提取表格截图 ──
|
||||
try:
|
||||
tables = page.find_tables()
|
||||
except Exception:
|
||||
tables = None
|
||||
|
||||
if tables and tables.tables:
|
||||
for table_index, table in enumerate(tables.tables):
|
||||
bbox = table.bbox
|
||||
if not bbox:
|
||||
continue
|
||||
|
||||
margin = 5
|
||||
if isinstance(bbox, (list, tuple)):
|
||||
x0, y0, x1, y1 = bbox
|
||||
else:
|
||||
x0, y0, x1, y1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
|
||||
clip_rect = pymupdf.Rect(x0 - margin, y0 - margin, x1 + margin, y1 + margin)
|
||||
|
||||
zoom = 2
|
||||
mat = pymupdf.Matrix(zoom, zoom)
|
||||
try:
|
||||
pix = page.get_pixmap(matrix=mat, clip=clip_rect)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if pix.width < _MIN_DIM * 2 or pix.height < 30 * 2:
|
||||
continue
|
||||
|
||||
filename = f"page{page_num + 1}_table{table_index + 1}.png"
|
||||
pix.save(str(images_dest / filename))
|
||||
extracted += 1
|
||||
logger.debug("Table: %s (%dx%d)", filename, pix.width, pix.height)
|
||||
|
||||
# 查找该表格位置附近的 Table 标注
|
||||
table_rect = pymupdf.Rect(x0, y0, x1, y1)
|
||||
matched = _find_nearby_labels([table_rect], table_labels, page_num)
|
||||
manifest[filename] = {"page": page_num + 1, "type": "table", "tables": matched}
|
||||
|
||||
doc.close()
|
||||
|
||||
# 保存 manifest
|
||||
manifest_path = images_dest / "manifest.json"
|
||||
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))
|
||||
|
||||
if extracted > 0:
|
||||
logger.info("Extracted %d images+tables from PDF for %s", extracted, arxiv_id)
|
||||
return extracted
|
||||
|
||||
|
||||
def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
|
||||
"""根据 summary 中的 figures 字段过滤提取的图片/表格。
|
||||
|
||||
用 manifest.json 匹配,不需要 PDF 文件。
|
||||
"""
|
||||
if not figures:
|
||||
return 0
|
||||
|
||||
images_dir = paper_dir(arxiv_id) / "images"
|
||||
manifest_path = images_dir / "manifest.json"
|
||||
|
||||
if not images_dir.exists() or not manifest_path.exists():
|
||||
return 0
|
||||
|
||||
all_files = [f for f in images_dir.iterdir() if f.suffix == ".png"]
|
||||
if not all_files:
|
||||
return 0
|
||||
|
||||
manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||||
|
||||
# 收集 summary 中引用的所有 Figure/Table ID(归一化)
|
||||
referenced_ids: set[str] = set()
|
||||
for fig in figures:
|
||||
fig_id = fig.get("id", "")
|
||||
m = re.match(r'(?:Fig\.?|Figure)\s*(\d+)', fig_id, re.IGNORECASE)
|
||||
if m:
|
||||
referenced_ids.add(f"Figure {m.group(1)}")
|
||||
m2 = re.match(r'Table\s*(\d+)', fig_id, re.IGNORECASE)
|
||||
if m2:
|
||||
referenced_ids.add(f"Table {m2.group(1)}")
|
||||
|
||||
if not referenced_ids:
|
||||
logger.warning("No valid figure/table IDs in summary for %s", arxiv_id)
|
||||
return len(all_files)
|
||||
|
||||
# 根据 manifest 判断每个文件是否被引用
|
||||
keep_filenames: set[str] = set()
|
||||
for filename, info in manifest.items():
|
||||
file_refs = info.get("figures", []) + info.get("tables", [])
|
||||
for ref in file_refs:
|
||||
if ref in referenced_ids:
|
||||
keep_filenames.add(filename)
|
||||
break
|
||||
|
||||
if not keep_filenames:
|
||||
logger.warning(
|
||||
"No manifest matches for %s (refs=%s), keeping all",
|
||||
arxiv_id, referenced_ids,
|
||||
)
|
||||
return len(all_files)
|
||||
|
||||
removed = 0
|
||||
for f in all_files:
|
||||
if f.name not in keep_filenames:
|
||||
f.unlink()
|
||||
removed += 1
|
||||
|
||||
kept = len(all_files) - removed
|
||||
logger.info("Filtered images for %s: kept %d, removed %d (refs=%s)", arxiv_id, kept, removed, referenced_ids)
|
||||
return kept
|
||||
+164
-8
@@ -59,23 +59,179 @@ def write_meta_json(paper) -> Path:
|
||||
return meta_path
|
||||
|
||||
|
||||
# ── PDF 文本提取 ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _trim_body(text: str, max_chars: int = 80_000) -> str:
|
||||
"""去除参考文献,保留正文+附录,超长时从末尾截断。
|
||||
|
||||
策略:
|
||||
1. 去掉 References/Bibliography 段落(纯引用列表,对解读无用)
|
||||
2. 正文 + 附录全部保留
|
||||
3. 如果总长超过 max_chars,从末尾截断(附录靠后,优先保留正文)
|
||||
"""
|
||||
import re
|
||||
|
||||
# 找 References 段落的位置(在 Appendix 之后的那个)
|
||||
# 有些论文结构:正文 -> Appendix -> References
|
||||
# 也可能是:正文 -> References -> Appendix
|
||||
# 策略:只删除明确的 References 块
|
||||
ref_pattern = re.compile(
|
||||
r"(?m)^(?:References|Bibliography|参考文献)\s*$\n"
|
||||
r"(?s:.*?)" # References 内容
|
||||
r"(?=\n(?:A\s|Appendix|Supplementary|Acknowledgment|致谢)\s|\Z)",
|
||||
)
|
||||
|
||||
# 简单策略:找到 References 标题,如果后面没有 Appendix 就全删
|
||||
# 如果后面还有 Appendix,只删 References 到 Appendix 之间的内容
|
||||
ref_match = re.search(r"(?m)^(?:References|Bibliography|参考文献)\s*$", text)
|
||||
if ref_match:
|
||||
ref_start = ref_match.start()
|
||||
# 看 References 之后有没有 Appendix
|
||||
after_ref = text[ref_start:]
|
||||
app_match = re.search(
|
||||
r"(?m)^(?:A\s+(?:Appendix|Supplementary)|Appendix|附录)\s*$", after_ref
|
||||
)
|
||||
if app_match:
|
||||
# References 之后有 Appendix:只删 References 段
|
||||
ref_end = ref_start + app_match.start()
|
||||
text = text[:ref_start] + text[ref_end:]
|
||||
else:
|
||||
# References 之后没有 Appendix:删掉从 References 到结尾
|
||||
text = text[:ref_start].rstrip()
|
||||
|
||||
# 去掉 Acknowledgments(对解读无用)
|
||||
ack_match = re.search(r"(?m)^(?:Acknowledgments?\s*|致谢\s*)$", text)
|
||||
if ack_match:
|
||||
# 只删 Acknowledgments 本身,不删后面的内容
|
||||
next_section = re.search(r"(?m)^(?:A\s|Appendix|Supplementary|附录)\s*$", text[ack_match.start():])
|
||||
if next_section:
|
||||
text = text[:ack_match.start()] + text[ack_match.start() + next_section.start():]
|
||||
else:
|
||||
text = text[:ack_match.start()].rstrip()
|
||||
|
||||
# 最后:如果还超长,从末尾截断(附录在后面,正文在前面,优先保留正文)
|
||||
if len(text) > max_chars:
|
||||
text = text[:max_chars].rstrip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def extract_pdf_text(pdf_path: Path) -> Path:
|
||||
"""用 pymupdf 提取 PDF 正文文本(自动截断参考文献和附录),保存为 .txt。"""
|
||||
import pymupdf
|
||||
|
||||
txt_path = pdf_path.with_suffix(".txt")
|
||||
if txt_path.exists():
|
||||
return txt_path
|
||||
|
||||
doc = pymupdf.open(str(pdf_path))
|
||||
raw_text = "\n\n".join(page.get_text() for page in doc)
|
||||
doc.close()
|
||||
|
||||
body = _trim_body(raw_text)
|
||||
txt_path.write_text(body, encoding="utf-8")
|
||||
logger.info(
|
||||
"Extracted PDF text: %s (%d -> %d chars, -%d%%)",
|
||||
txt_path,
|
||||
len(raw_text),
|
||||
len(body),
|
||||
(1 - len(body) / len(raw_text)) * 100 if raw_text else 0,
|
||||
)
|
||||
return txt_path
|
||||
|
||||
|
||||
# ── pi CLI 调用 ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def call_pi(meta_path: Path, pdf_path: Path) -> str:
|
||||
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
|
||||
async def call_pi(
|
||||
meta_path: Path,
|
||||
pdf_path: Path,
|
||||
fix_errors: list[str] | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""调用 pi CLI 非交互模式,返回 (stdout 文本, session_id)。
|
||||
|
||||
fix_errors: 如果非空,表示上一次验证失败的错误列表,pi 需要修正这些问题。
|
||||
session_id: 如果非空,用 --continue 延续该 session;否则创建新 session。
|
||||
"""
|
||||
arxiv_id = meta_path.parent.name
|
||||
|
||||
# 将 PDF 转为文本文件,以 @txt 方式传给 pi
|
||||
txt_path = extract_pdf_text(pdf_path)
|
||||
|
||||
if fix_errors:
|
||||
# 验证失败后的修正提示(同一 session 内,pi 能看到之前写的文件)
|
||||
error_list = "\n".join(f"- {e}" for e in fix_errors)
|
||||
prompt_text = (
|
||||
"你之前生成的 JSON 存在以下问题,请修正后重新用 write_file 保存到 "
|
||||
f"data/papers/{arxiv_id}/summary.json:\n\n"
|
||||
f"{error_list}\n\n"
|
||||
"注意:所有字符串字段必须是详细段落(≥50字),不能是数组或列表。"
|
||||
"修正后请用 bash 运行 python scripts/validate_summary.py 验证。"
|
||||
)
|
||||
else:
|
||||
prompt_text = (
|
||||
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。"
|
||||
"只输出一个 JSON 对象,不要输出其他内容。\n\n"
|
||||
"## 写作要求\n"
|
||||
"- 每个字符串字段必须写成详细段落(200-500字),不要用列表或数组\n"
|
||||
"- 必须包含论文中的具体数据、数字、实验指标\n"
|
||||
"- 像资深同事给同事讲论文一样,专业但易懂\n"
|
||||
"- 数学公式、符号、变量必须使用 LaTeX 格式:行内公式用 $...$,独立公式用 $$...$$\n"
|
||||
" 例如:损失函数 $\\mathcal{L} = -\\sum_{i} \\log p(y_i | x_i)$,学习率 $\\eta$\n\n"
|
||||
"## 必须包含以下字段(不要自创字段名):\n"
|
||||
'{"arxiv_id": "...", '
|
||||
'"title_zh": "中文标题", '
|
||||
'"one_line": "一句话概括(≤50字)", '
|
||||
'"tags": ["标签1","标签2"], '
|
||||
'"difficulty": "入门/进阶/前沿", '
|
||||
'"prerequisites": {"concepts": [{"term":"术语","explanation":"详细解释这个概念是什么、怎么工作的(50-150字)","why_matters":"为什么读懂本文需要它"}]}, '
|
||||
'"motivation": {"problem": "详细段落:现有方法的具体问题(包含具体场景和数据)", '
|
||||
'"goal": "详细段落:本文的具体目标", '
|
||||
'"gap": "详细段落:本文的独特切入角度"}, '
|
||||
'"method": {"overview": "详细段落:方法整体思路(先直觉再技术路线)", '
|
||||
'"key_idea": "详细段落:核心创新点(和已有方法的本质区别)", '
|
||||
'"steps": "详细段落:方法步骤的完整描述(每步的输入输出和具体操作)", '
|
||||
'"novelty": "详细段落:技术新颖性分析"}, '
|
||||
'"results": {"main_findings": "详细段落:核心发现(带具体数字和指标,逐一分析每个实验)", '
|
||||
'"benchmarks": [{"task":"任务","metric":"指标","this_work":"本文结果","baseline":"基线","improvement":"提升"}], '
|
||||
'"limitations": "详细段落:局限性分析(作者承认的+你自己的观察)"}, '
|
||||
'"improvements": {"weaknesses": "详细段落:独立分析的弱点(具体场景,每个弱点给改进方向)", '
|
||||
'"future_work": "详细段落:未来研究方向(作者提出的+基于成果可延伸的)", '
|
||||
'"reproducibility": "详细段落:复现评估(开源情况、数据、算力、难度)"}, '
|
||||
'"figures": [{"id":"Figure 1","caption":"原图标题","description":"文字描述图展示了什么","reason":"为什么这张图对理解论文重要"},'
|
||||
'{"id":"Table 1","caption":"表格标题","description":"文字描述表格包含的数据和结论","reason":"为什么这个表格对理解论文重要"}]'
|
||||
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Table,id 严格使用 \"Figure N\" 或 \"Table N\" 格式。"
|
||||
"}\n\n"
|
||||
"请深度解读以下论文:"
|
||||
)
|
||||
|
||||
# 构建 session ID(每篇论文一个独立 session)
|
||||
if session_id is None:
|
||||
import uuid
|
||||
|
||||
session_id = f"summary-{arxiv_id}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
cmd = [
|
||||
settings.PI_BIN,
|
||||
"-p",
|
||||
"--no-tools",
|
||||
"--tools", "bash,write_file",
|
||||
]
|
||||
if fix_errors:
|
||||
cmd += ["--session", session_id, "--continue"]
|
||||
else:
|
||||
cmd += ["--session-id", session_id]
|
||||
cmd += [
|
||||
"--skill",
|
||||
settings.SUMMARY_SKILL,
|
||||
"请深度解读以下论文,并按指定 JSON schema 输出:",
|
||||
f"@{meta_path}",
|
||||
f"@{pdf_path}",
|
||||
prompt_text,
|
||||
]
|
||||
logger.info("Calling pi for %s", arxiv_id)
|
||||
if not fix_errors:
|
||||
# 首次调用传文件,后续 --continue 不需要(session 内已有)
|
||||
cmd += [f"@{meta_path}", f"@{txt_path}"]
|
||||
|
||||
logger.info("Calling pi for %s (fix=%s, session=%s)", arxiv_id, bool(fix_errors), session_id)
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
@@ -95,7 +251,7 @@ async def call_pi(meta_path: Path, pdf_path: Path) -> str:
|
||||
if proc.returncode != 0:
|
||||
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
|
||||
|
||||
return stdout.decode("utf-8", errors="replace")
|
||||
return stdout.decode("utf-8", errors="replace"), session_id
|
||||
|
||||
|
||||
# ── JSON 提取 ──────────────────────────────────────────────────────────
|
||||
|
||||
+15
-20
@@ -12,8 +12,7 @@ from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
|
||||
|
||||
class PrerequisitesSchema(BaseModel):
|
||||
concepts: list[str] = Field(default_factory=list)
|
||||
level: str = ""
|
||||
concepts: list[dict] = Field(default_factory=list)
|
||||
|
||||
|
||||
class MotivationSchema(BaseModel):
|
||||
@@ -32,7 +31,7 @@ class MotivationSchema(BaseModel):
|
||||
class MethodSchema(BaseModel):
|
||||
overview: str = ""
|
||||
key_idea: str
|
||||
steps: list[str] = Field(default_factory=list)
|
||||
steps: str = ""
|
||||
novelty: str = ""
|
||||
|
||||
@field_validator("key_idea")
|
||||
@@ -44,14 +43,14 @@ class MethodSchema(BaseModel):
|
||||
|
||||
|
||||
class ResultsSchema(BaseModel):
|
||||
main_findings: list[str] = Field(default_factory=list)
|
||||
benchmarks: list[dict] = Field(default_factory=list)
|
||||
limitations: list[str] = Field(default_factory=list)
|
||||
main_findings: str = ""
|
||||
benchmarks: list[str | dict] = Field(default_factory=list)
|
||||
limitations: str = ""
|
||||
|
||||
|
||||
class ImprovementsSchema(BaseModel):
|
||||
weaknesses: list[str] = Field(default_factory=list)
|
||||
future_work: list[str] = Field(default_factory=list)
|
||||
weaknesses: str = ""
|
||||
future_work: str = ""
|
||||
reproducibility: str = ""
|
||||
|
||||
|
||||
@@ -71,6 +70,7 @@ class SummarySchema(BaseModel):
|
||||
method: MethodSchema
|
||||
results: ResultsSchema = Field(default_factory=ResultsSchema)
|
||||
improvements: ImprovementsSchema = Field(default_factory=ImprovementsSchema)
|
||||
figures: list[dict] = Field(default_factory=list)
|
||||
|
||||
@field_validator("title_zh", "one_line")
|
||||
@classmethod
|
||||
@@ -116,7 +116,7 @@ def assess_quality(schema: SummarySchema) -> str:
|
||||
missing_important += 1
|
||||
if not schema.method.overview.strip():
|
||||
missing_important += 1
|
||||
if not schema.results.main_findings:
|
||||
if not schema.results.main_findings.strip():
|
||||
missing_important += 1
|
||||
|
||||
if missing_important == 0:
|
||||
@@ -140,22 +140,17 @@ def flatten_for_db(schema: SummarySchema) -> dict:
|
||||
"motivation_gap": schema.motivation.gap,
|
||||
"method_overview": schema.method.overview,
|
||||
"method_key_idea": schema.method.key_idea,
|
||||
"method_steps_json": json.dumps(schema.method.steps, ensure_ascii=False),
|
||||
"method_steps_json": schema.method.steps,
|
||||
"method_novelty": schema.method.novelty,
|
||||
"results_main_json": json.dumps(
|
||||
schema.results.main_findings, ensure_ascii=False
|
||||
),
|
||||
"results_main_json": schema.results.main_findings,
|
||||
"results_benchmarks_json": json.dumps(
|
||||
schema.results.benchmarks, ensure_ascii=False
|
||||
),
|
||||
"limitations_json": json.dumps(schema.results.limitations, ensure_ascii=False),
|
||||
"weaknesses_json": json.dumps(
|
||||
schema.improvements.weaknesses, ensure_ascii=False
|
||||
),
|
||||
"future_work_json": json.dumps(
|
||||
schema.improvements.future_work, ensure_ascii=False
|
||||
),
|
||||
"limitations_json": schema.results.limitations,
|
||||
"weaknesses_json": schema.improvements.weaknesses,
|
||||
"future_work_json": schema.improvements.future_work,
|
||||
"reproducibility": schema.improvements.reproducibility,
|
||||
"figures_json": json.dumps(schema.figures, ensure_ascii=False),
|
||||
"full_json": schema.model_dump_json(ensure_ascii=False),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
|
||||
+141
-11
@@ -22,7 +22,6 @@ from app.models import (
|
||||
SummaryStatus,
|
||||
TaskLock,
|
||||
)
|
||||
from app.services.image_extractor import extract_images_from_source
|
||||
from app.services.pdf_downloader import (
|
||||
PdfDownloadError,
|
||||
cleanup_tmp,
|
||||
@@ -77,10 +76,9 @@ def _build_fts_summary_text(schema: SummarySchema) -> str:
|
||||
schema.one_line or "",
|
||||
schema.motivation.problem or "",
|
||||
schema.motivation.goal or "",
|
||||
schema.method_overview if hasattr(schema, "method_overview") else "",
|
||||
schema.method.overview or "",
|
||||
schema.method.key_idea or "",
|
||||
" ".join(schema.results.main_findings or []),
|
||||
schema.results.main_findings or "",
|
||||
]
|
||||
return " ".join(p for p in parts if p)
|
||||
|
||||
@@ -141,6 +139,77 @@ def _update_summary_in_db(
|
||||
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
|
||||
|
||||
|
||||
# ── JSON 验证 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
|
||||
"""验证 JSON 数据是否符合要求,返回错误列表(空=通过)。"""
|
||||
errors: list[str] = []
|
||||
|
||||
if not isinstance(json_data, dict):
|
||||
return ["顶层必须是 JSON 对象"]
|
||||
|
||||
# 必填字段
|
||||
for f in ["arxiv_id", "title_zh", "one_line", "tags"]:
|
||||
if f not in json_data or not json_data[f]:
|
||||
errors.append(f"缺少必填字段: {f}")
|
||||
|
||||
# tags 必须是非空数组
|
||||
tags = json_data.get("tags")
|
||||
if not isinstance(tags, list) or len(tags) == 0:
|
||||
errors.append("tags 必须是非空数组")
|
||||
|
||||
# 字符串段落字段(必须是 str 且 ≥50 字)
|
||||
string_fields = [
|
||||
("motivation", "problem"), ("motivation", "goal"), ("motivation", "gap"),
|
||||
("method", "overview"), ("method", "key_idea"), ("method", "steps"),
|
||||
("method", "novelty"),
|
||||
("results", "main_findings"), ("results", "limitations"),
|
||||
("improvements", "weaknesses"), ("improvements", "future_work"),
|
||||
("improvements", "reproducibility"),
|
||||
]
|
||||
for section, field in string_fields:
|
||||
val = json_data.get(section, {}).get(field)
|
||||
if isinstance(val, list):
|
||||
errors.append(f"{section}.{field} 应该是字符串段落,不能是数组")
|
||||
elif not isinstance(val, str) or len(val.strip()) < 50:
|
||||
errors.append(
|
||||
f"{section}.{field} 必须是详细段落(≥50字),"
|
||||
f"当前: {type(val).__name__} ({len(str(val))}字)"
|
||||
)
|
||||
|
||||
# benchmarks 必须是数组
|
||||
benchmarks = json_data.get("results", {}).get("benchmarks")
|
||||
if benchmarks is not None and not isinstance(benchmarks, list):
|
||||
errors.append("results.benchmarks 必须是数组")
|
||||
|
||||
# prerequisites.concepts 必须是对象数组,每个有 term
|
||||
concepts = json_data.get("prerequisites", {}).get("concepts")
|
||||
if concepts is not None:
|
||||
if not isinstance(concepts, list):
|
||||
errors.append("prerequisites.concepts 必须是数组")
|
||||
elif len(concepts) == 0:
|
||||
errors.append("prerequisites.concepts 不能为空")
|
||||
else:
|
||||
for i, c in enumerate(concepts):
|
||||
if isinstance(c, str):
|
||||
errors.append(f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串")
|
||||
elif isinstance(c, dict) and not c.get("term"):
|
||||
errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段")
|
||||
|
||||
# figures 必须是数组,每个元素应有 id
|
||||
figures = json_data.get("figures")
|
||||
if figures is not None:
|
||||
if not isinstance(figures, list):
|
||||
errors.append("figures 必须是数组")
|
||||
else:
|
||||
for i, fig in enumerate(figures):
|
||||
if isinstance(fig, dict) and not fig.get("id"):
|
||||
errors.append(f"figures[{i}] 缺少 id 字段")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# ── 文件操作 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -227,11 +296,64 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
|
||||
# 下载 PDF
|
||||
await download_pdf(arxiv_id, paper.pdf_url)
|
||||
|
||||
# 调用 pi
|
||||
raw_output = await call_pi(meta_path, Path("data/tmp") / arxiv_id / "paper.pdf")
|
||||
# 带验证的生成循环:最多 4 轮,同一 session 内 pi 可看到之前写的文件
|
||||
json_data = None
|
||||
validation_errors = []
|
||||
session_id = None
|
||||
for attempt in range(1, 5):
|
||||
# 清理上一轮 pi 通过 write_file 写的不完整文件
|
||||
stale = paper_dir(arxiv_id) / "summary.json"
|
||||
if stale.exists():
|
||||
stale.unlink()
|
||||
|
||||
# 提取 JSON
|
||||
json_data = extract_json(raw_output)
|
||||
if attempt == 1:
|
||||
raw_output, session_id = await call_pi(
|
||||
meta_path, Path("data/tmp") / arxiv_id / "paper.pdf"
|
||||
)
|
||||
else:
|
||||
# 验证失败,同一 session 内带着错误信息让 pi 修正
|
||||
raw_output, session_id = await call_pi(
|
||||
meta_path,
|
||||
Path("data/tmp") / arxiv_id / "paper.pdf",
|
||||
fix_errors=validation_errors,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# 优先从 pi write_file 写入的 summary.json 读取,否则从 stdout 提取
|
||||
# 如果都失败,当作验证错误,继续下一次尝试
|
||||
json_data = None
|
||||
summary_file = paper_dir(arxiv_id) / "summary.json"
|
||||
try:
|
||||
if summary_file.exists():
|
||||
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
|
||||
logger.info("Read summary.json written by pi for %s", arxiv_id)
|
||||
else:
|
||||
json_data = extract_json(raw_output)
|
||||
except (json.JSONDecodeError, JsonNotFoundError) as exc:
|
||||
logger.warning(
|
||||
"JSON extraction failed for %s (attempt %d): %s",
|
||||
arxiv_id,
|
||||
attempt,
|
||||
str(exc)[:200],
|
||||
)
|
||||
validation_errors = [f"无法提取有效 JSON: {str(exc)[:100]}"]
|
||||
continue
|
||||
|
||||
# 运行验证脚本
|
||||
validation_errors = _validate_summary(json_data, arxiv_id)
|
||||
if not validation_errors:
|
||||
break
|
||||
logger.warning(
|
||||
"Validation failed for %s (attempt %d): %s",
|
||||
arxiv_id,
|
||||
attempt,
|
||||
"; ".join(validation_errors),
|
||||
)
|
||||
|
||||
if validation_errors:
|
||||
raise ValueError(
|
||||
f"Summary validation failed after 4 attempts: {'; '.join(validation_errors)}"
|
||||
)
|
||||
|
||||
# Pydantic 校验
|
||||
schema = SummarySchema.model_validate(json_data)
|
||||
@@ -252,9 +374,17 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
|
||||
status.raw_output_saved = True
|
||||
db.commit()
|
||||
|
||||
# LaTeX 图片提取(可选增强,失败不影响总结)
|
||||
# PDF 图片提取(可选增强,失败不影响总结)
|
||||
try:
|
||||
await extract_images_from_source(arxiv_id)
|
||||
from app.services.pdf_image_extractor import (
|
||||
extract_images_from_pdf,
|
||||
filter_images_by_summary,
|
||||
)
|
||||
pdf_path = Path("data/tmp") / arxiv_id / "paper.pdf"
|
||||
extract_images_from_pdf(arxiv_id, pdf_path)
|
||||
# 根据 summary 中 figures 字段过滤,只保留被引用的图表
|
||||
if schema.figures:
|
||||
filter_images_by_summary(arxiv_id, schema.figures)
|
||||
except Exception:
|
||||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||||
|
||||
@@ -268,8 +398,8 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
|
||||
"title_en": paper.title_en or "",
|
||||
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
|
||||
"one_line": schema.one_line or "",
|
||||
"motivation_problem": schema.motivation_problem or "",
|
||||
"method_key_idea": schema.method_key_idea or "",
|
||||
"motivation_problem": schema.motivation.problem or "",
|
||||
"method_key_idea": schema.method.key_idea or "",
|
||||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
|
||||
}
|
||||
index_paper(arxiv_id, texts_dict)
|
||||
|
||||
Reference in New Issue
Block a user