feat: enhance UI, refactor services, improve templates and tests

- Replace image_extractor with pdf_image_extractor service
- Enhance pi_client with expanded API capabilities
- Improve summarizer service with additional features
- Update admin routes with more endpoints
- Add login page template
- Enhance detail page with comprehensive layout
- Improve search and trends pages
- Update base template with additional elements
- Refactor tests for better coverage
- Add validate_summary script
- Update project configuration and dependencies
This commit is contained in:
2026-06-07 19:38:58 +08:00
parent 4a72c35452
commit 0d293422ac
32 changed files with 2003 additions and 586 deletions
-83
View File
@@ -1,83 +0,0 @@
"""LaTeX 图片提取 — 从 arXiv 源码中扫描 \\includegraphics 并提取图片文件。"""
from __future__ import annotations
import logging
import re
import shutil
from pathlib import Path
from app.services.pdf_downloader import download_source_zip, paper_dir, tmp_dir
logger = logging.getLogger(__name__)
_INCLUDEGRAPHICS_RE = re.compile(
r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
)
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
async def extract_images_from_source(arxiv_id: str) -> int:
"""从 LaTeX 源码中提取图片文件。
流程:
1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
2. 扫描 .tex 文件中的 \\includegraphics
3. 复制图片到 data/papers/{arxiv_id}/images/
4. 清理源码临时文件
Returns:
提取的图片数量
"""
tmp_source = tmp_dir(arxiv_id) / "source"
images_dest = paper_dir(arxiv_id) / "images"
try:
# 下载源码 zip(如果还没下载)
if not tmp_source.exists():
source_url = f"https://arxiv.org/e-print/{arxiv_id}"
await download_source_zip(arxiv_id, source_url, tmp_source)
if not tmp_source.exists():
return 0
# 扫描 .tex 文件,收集图片路径
image_paths: set[str] = set()
for tex_file in tmp_source.rglob("*.tex"):
try:
content = tex_file.read_text(encoding="utf-8", errors="replace")
for match in _INCLUDEGRAPHICS_RE.finditer(content):
img_path = match.group(1).strip()
image_paths.add(img_path)
except Exception:
continue
if not image_paths:
return 0
# 查找并复制图片
images_dest.mkdir(parents=True, exist_ok=True)
copied = 0
for img_rel in image_paths:
# 尝试在源码目录中找到文件
for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
candidate = tmp_source / (img_rel + ext)
if candidate.is_file():
dest_name = candidate.name
# 避免文件名冲突
dest = images_dest / dest_name
if dest.exists():
stem = dest.stem
suffix = dest.suffix
dest = images_dest / f"{stem}_{copied}{suffix}"
shutil.copy2(candidate, dest)
copied += 1
break
if copied > 0:
logger.info("Extracted %d images from source for %s", copied, arxiv_id)
return copied
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
return 0
+261
View File
@@ -0,0 +1,261 @@
"""PDF 图片与表格提取 — 从 PDF 中提取嵌入图片和表格截图。
策略:
1. 提取 PDF 中嵌入的图片(图表、插图等)
2. 检测表格区域,渲染为截图
3. 同时搜索页面中的 Figure/Table 标注,记录到 manifest
4. 过滤掉过小的图片
5. 保存到 data/papers/{arxiv_id}/images/
"""
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
from app.services.pdf_downloader import paper_dir
logger = logging.getLogger(__name__)
# 最小面积阈值(像素),小于此值的图片视为图标/装饰
_MIN_AREA = 10_000 # ~100x100
_MIN_DIM = 80
# Figure/Table 标注与图片/表格的最大垂直距离(点)
_MAX_LABEL_DISTANCE = 120
# Figure/Table 标注的正则
_FIGURE_RE = re.compile(r'\b(?:Fig\.?|Figure)\s*(\d+)\b', re.IGNORECASE)
_TABLE_RE = re.compile(r'\bTable\s*(\d+)\b', re.IGNORECASE)
def _find_nearby_labels(
rects: list, labels: dict[str, list[tuple[int, float]]], page_num: int
) -> list[str]:
"""查找与给定矩形区域在位置上接近的 Figure/Table 标注。
匹配逻辑:标注的垂直位置 (y) 需在图片/表格的上下 _MAX_LABEL_DISTANCE 点范围内。
"""
matched: list[str] = []
for rect in rects:
if isinstance(rect, (list, tuple)):
y_min, y_max = rect[1], rect[3]
else:
y_min, y_max = rect.y0, rect.y1
for label_key, positions in labels.items():
for label_page, label_y in positions:
if label_page == page_num:
# 标注在图片/表格上方或下方的距离
distance = min(abs(label_y - y_min), abs(label_y - y_max))
if distance <= _MAX_LABEL_DISTANCE:
if label_key not in matched:
matched.append(label_key)
return matched
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
"""从 PDF 提取嵌入图片和表格截图,同时生成 manifest。
Args:
arxiv_id: 论文 ID
pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf
Returns:
提取的图片+表格数量
"""
import pymupdf
if pdf_path is None:
pdf_path = Path("data/tmp") / arxiv_id / "paper.pdf"
if not pdf_path.exists():
logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path)
return 0
images_dest = paper_dir(arxiv_id) / "images"
images_dest.mkdir(parents=True, exist_ok=True)
doc = pymupdf.open(str(pdf_path))
extracted = 0
seen_hashes: set[int] = set()
# 扫描每页的 Figure/Table 标注位置
# figure_labels: {key: [(page_num, y_center)]} — 记录标注在页面中的垂直位置
figure_labels: dict[str, list[tuple[int, float]]] = {}
table_labels: dict[str, list[tuple[int, float]]] = {}
for page_num in range(len(doc)):
page = doc[page_num]
text_dict = page.get_text("dict")
for block in text_dict.get("blocks", []):
if block.get("type") != 0: # 只看文本块
continue
block_text = ""
for line in block.get("lines", []):
for span in line.get("spans", []):
block_text += span.get("text", "")
for m in _FIGURE_RE.finditer(block_text):
key = f"Figure {m.group(1)}"
bbox = block.get("bbox", [0, 0, 0, 0])
y_center = (bbox[1] + bbox[3]) / 2
figure_labels.setdefault(key, []).append((page_num, y_center))
for m in _TABLE_RE.finditer(block_text):
key = f"Table {m.group(1)}"
bbox = block.get("bbox", [0, 0, 0, 0])
y_center = (bbox[1] + bbox[3]) / 2
table_labels.setdefault(key, []).append((page_num, y_center))
# 记录每个提取文件的元信息
manifest: dict[str, dict] = {}
for page_num in range(len(doc)):
page = doc[page_num]
# ── 1. 提取嵌入图片 ──
image_list = page.get_images(full=True)
for img_index, img_info in enumerate(image_list):
xref = img_info[0]
try:
pix = pymupdf.Pixmap(doc, xref)
except Exception:
continue
if pix.width < _MIN_DIM or pix.height < _MIN_DIM:
continue
if pix.width * pix.height < _MIN_AREA:
continue
img_hash = hash(pix.tobytes()[:1024])
if img_hash in seen_hashes:
continue
seen_hashes.add(img_hash)
if pix.n >= 5:
try:
pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
except Exception:
continue
filename = f"page{page_num + 1}_img{img_index + 1}.png"
pix.save(str(images_dest / filename))
extracted += 1
logger.debug("Image: %s (%dx%d)", filename, pix.width, pix.height)
# 查找该图片位置附近的 Figure 标注
img_rects = page.get_image_rects(xref)
matched = _find_nearby_labels(img_rects, figure_labels, page_num)
manifest[filename] = {"page": page_num + 1, "type": "image", "figures": matched}
# ── 2. 提取表格截图 ──
try:
tables = page.find_tables()
except Exception:
tables = None
if tables and tables.tables:
for table_index, table in enumerate(tables.tables):
bbox = table.bbox
if not bbox:
continue
margin = 5
if isinstance(bbox, (list, tuple)):
x0, y0, x1, y1 = bbox
else:
x0, y0, x1, y1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
clip_rect = pymupdf.Rect(x0 - margin, y0 - margin, x1 + margin, y1 + margin)
zoom = 2
mat = pymupdf.Matrix(zoom, zoom)
try:
pix = page.get_pixmap(matrix=mat, clip=clip_rect)
except Exception:
continue
if pix.width < _MIN_DIM * 2 or pix.height < 30 * 2:
continue
filename = f"page{page_num + 1}_table{table_index + 1}.png"
pix.save(str(images_dest / filename))
extracted += 1
logger.debug("Table: %s (%dx%d)", filename, pix.width, pix.height)
# 查找该表格位置附近的 Table 标注
table_rect = pymupdf.Rect(x0, y0, x1, y1)
matched = _find_nearby_labels([table_rect], table_labels, page_num)
manifest[filename] = {"page": page_num + 1, "type": "table", "tables": matched}
doc.close()
# 保存 manifest
manifest_path = images_dest / "manifest.json"
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))
if extracted > 0:
logger.info("Extracted %d images+tables from PDF for %s", extracted, arxiv_id)
return extracted
def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
"""根据 summary 中的 figures 字段过滤提取的图片/表格。
用 manifest.json 匹配,不需要 PDF 文件。
"""
if not figures:
return 0
images_dir = paper_dir(arxiv_id) / "images"
manifest_path = images_dir / "manifest.json"
if not images_dir.exists() or not manifest_path.exists():
return 0
all_files = [f for f in images_dir.iterdir() if f.suffix == ".png"]
if not all_files:
return 0
manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8"))
# 收集 summary 中引用的所有 Figure/Table ID(归一化)
referenced_ids: set[str] = set()
for fig in figures:
fig_id = fig.get("id", "")
m = re.match(r'(?:Fig\.?|Figure)\s*(\d+)', fig_id, re.IGNORECASE)
if m:
referenced_ids.add(f"Figure {m.group(1)}")
m2 = re.match(r'Table\s*(\d+)', fig_id, re.IGNORECASE)
if m2:
referenced_ids.add(f"Table {m2.group(1)}")
if not referenced_ids:
logger.warning("No valid figure/table IDs in summary for %s", arxiv_id)
return len(all_files)
# 根据 manifest 判断每个文件是否被引用
keep_filenames: set[str] = set()
for filename, info in manifest.items():
file_refs = info.get("figures", []) + info.get("tables", [])
for ref in file_refs:
if ref in referenced_ids:
keep_filenames.add(filename)
break
if not keep_filenames:
logger.warning(
"No manifest matches for %s (refs=%s), keeping all",
arxiv_id, referenced_ids,
)
return len(all_files)
removed = 0
for f in all_files:
if f.name not in keep_filenames:
f.unlink()
removed += 1
kept = len(all_files) - removed
logger.info("Filtered images for %s: kept %d, removed %d (refs=%s)", arxiv_id, kept, removed, referenced_ids)
return kept
+164 -8
View File
@@ -59,23 +59,179 @@ def write_meta_json(paper) -> Path:
return meta_path
# ── PDF 文本提取 ────────────────────────────────────────────────────────
def _trim_body(text: str, max_chars: int = 80_000) -> str:
"""去除参考文献,保留正文+附录,超长时从末尾截断。
策略:
1. 去掉 References/Bibliography 段落(纯引用列表,对解读无用)
2. 正文 + 附录全部保留
3. 如果总长超过 max_chars,从末尾截断(附录靠后,优先保留正文)
"""
import re
# 找 References 段落的位置(在 Appendix 之后的那个)
# 有些论文结构:正文 -> Appendix -> References
# 也可能是:正文 -> References -> Appendix
# 策略:只删除明确的 References 块
ref_pattern = re.compile(
r"(?m)^(?:References|Bibliography|参考文献)\s*$\n"
r"(?s:.*?)" # References 内容
r"(?=\n(?:A\s|Appendix|Supplementary|Acknowledgment|致谢)\s|\Z)",
)
# 简单策略:找到 References 标题,如果后面没有 Appendix 就全删
# 如果后面还有 Appendix,只删 References 到 Appendix 之间的内容
ref_match = re.search(r"(?m)^(?:References|Bibliography|参考文献)\s*$", text)
if ref_match:
ref_start = ref_match.start()
# 看 References 之后有没有 Appendix
after_ref = text[ref_start:]
app_match = re.search(
r"(?m)^(?:A\s+(?:Appendix|Supplementary)|Appendix|附录)\s*$", after_ref
)
if app_match:
# References 之后有 Appendix:只删 References 段
ref_end = ref_start + app_match.start()
text = text[:ref_start] + text[ref_end:]
else:
# References 之后没有 Appendix:删掉从 References 到结尾
text = text[:ref_start].rstrip()
# 去掉 Acknowledgments(对解读无用)
ack_match = re.search(r"(?m)^(?:Acknowledgments?\s*|致谢\s*)$", text)
if ack_match:
# 只删 Acknowledgments 本身,不删后面的内容
next_section = re.search(r"(?m)^(?:A\s|Appendix|Supplementary|附录)\s*$", text[ack_match.start():])
if next_section:
text = text[:ack_match.start()] + text[ack_match.start() + next_section.start():]
else:
text = text[:ack_match.start()].rstrip()
# 最后:如果还超长,从末尾截断(附录在后面,正文在前面,优先保留正文)
if len(text) > max_chars:
text = text[:max_chars].rstrip()
return text
def extract_pdf_text(pdf_path: Path) -> Path:
"""用 pymupdf 提取 PDF 正文文本(自动截断参考文献和附录),保存为 .txt。"""
import pymupdf
txt_path = pdf_path.with_suffix(".txt")
if txt_path.exists():
return txt_path
doc = pymupdf.open(str(pdf_path))
raw_text = "\n\n".join(page.get_text() for page in doc)
doc.close()
body = _trim_body(raw_text)
txt_path.write_text(body, encoding="utf-8")
logger.info(
"Extracted PDF text: %s (%d -> %d chars, -%d%%)",
txt_path,
len(raw_text),
len(body),
(1 - len(body) / len(raw_text)) * 100 if raw_text else 0,
)
return txt_path
# ── pi CLI 调用 ────────────────────────────────────────────────────────
async def call_pi(meta_path: Path, pdf_path: Path) -> str:
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
async def call_pi(
meta_path: Path,
pdf_path: Path,
fix_errors: list[str] | None = None,
session_id: str | None = None,
) -> tuple[str, str]:
"""调用 pi CLI 非交互模式,返回 (stdout 文本, session_id)。
fix_errors: 如果非空,表示上一次验证失败的错误列表,pi 需要修正这些问题。
session_id: 如果非空,用 --continue 延续该 session;否则创建新 session。
"""
arxiv_id = meta_path.parent.name
# 将 PDF 转为文本文件,以 @txt 方式传给 pi
txt_path = extract_pdf_text(pdf_path)
if fix_errors:
# 验证失败后的修正提示(同一 session 内,pi 能看到之前写的文件)
error_list = "\n".join(f"- {e}" for e in fix_errors)
prompt_text = (
"你之前生成的 JSON 存在以下问题,请修正后重新用 write_file 保存到 "
f"data/papers/{arxiv_id}/summary.json\n\n"
f"{error_list}\n\n"
"注意:所有字符串字段必须是详细段落(≥50字),不能是数组或列表。"
"修正后请用 bash 运行 python scripts/validate_summary.py 验证。"
)
else:
prompt_text = (
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。"
"只输出一个 JSON 对象,不要输出其他内容。\n\n"
"## 写作要求\n"
"- 每个字符串字段必须写成详细段落(200-500字),不要用列表或数组\n"
"- 必须包含论文中的具体数据、数字、实验指标\n"
"- 像资深同事给同事讲论文一样,专业但易懂\n"
"- 数学公式、符号、变量必须使用 LaTeX 格式:行内公式用 $...$,独立公式用 $$...$$\n"
" 例如:损失函数 $\\mathcal{L} = -\\sum_{i} \\log p(y_i | x_i)$,学习率 $\\eta$\n\n"
"## 必须包含以下字段(不要自创字段名):\n"
'{"arxiv_id": "...", '
'"title_zh": "中文标题", '
'"one_line": "一句话概括(≤50字)", '
'"tags": ["标签1","标签2"], '
'"difficulty": "入门/进阶/前沿", '
'"prerequisites": {"concepts": [{"term":"术语","explanation":"详细解释这个概念是什么、怎么工作的(50-150字)","why_matters":"为什么读懂本文需要它"}]}, '
'"motivation": {"problem": "详细段落:现有方法的具体问题(包含具体场景和数据)", '
'"goal": "详细段落:本文的具体目标", '
'"gap": "详细段落:本文的独特切入角度"}, '
'"method": {"overview": "详细段落:方法整体思路(先直觉再技术路线)", '
'"key_idea": "详细段落:核心创新点(和已有方法的本质区别)", '
'"steps": "详细段落:方法步骤的完整描述(每步的输入输出和具体操作)", '
'"novelty": "详细段落:技术新颖性分析"}, '
'"results": {"main_findings": "详细段落:核心发现(带具体数字和指标,逐一分析每个实验)", '
'"benchmarks": [{"task":"任务","metric":"指标","this_work":"本文结果","baseline":"基线","improvement":"提升"}], '
'"limitations": "详细段落:局限性分析(作者承认的+你自己的观察)"}, '
'"improvements": {"weaknesses": "详细段落:独立分析的弱点(具体场景,每个弱点给改进方向)", '
'"future_work": "详细段落:未来研究方向(作者提出的+基于成果可延伸的)", '
'"reproducibility": "详细段落:复现评估(开源情况、数据、算力、难度)"}, '
'"figures": [{"id":"Figure 1","caption":"原图标题","description":"文字描述图展示了什么","reason":"为什么这张图对理解论文重要"},'
'{"id":"Table 1","caption":"表格标题","description":"文字描述表格包含的数据和结论","reason":"为什么这个表格对理解论文重要"}]'
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Tableid 严格使用 \"Figure N\"\"Table N\" 格式。"
"}\n\n"
"请深度解读以下论文:"
)
# 构建 session ID(每篇论文一个独立 session)
if session_id is None:
import uuid
session_id = f"summary-{arxiv_id}-{uuid.uuid4().hex[:8]}"
cmd = [
settings.PI_BIN,
"-p",
"--no-tools",
"--tools", "bash,write_file",
]
if fix_errors:
cmd += ["--session", session_id, "--continue"]
else:
cmd += ["--session-id", session_id]
cmd += [
"--skill",
settings.SUMMARY_SKILL,
"请深度解读以下论文,并按指定 JSON schema 输出:",
f"@{meta_path}",
f"@{pdf_path}",
prompt_text,
]
logger.info("Calling pi for %s", arxiv_id)
if not fix_errors:
# 首次调用传文件,后续 --continue 不需要(session 内已有)
cmd += [f"@{meta_path}", f"@{txt_path}"]
logger.info("Calling pi for %s (fix=%s, session=%s)", arxiv_id, bool(fix_errors), session_id)
proc = await asyncio.create_subprocess_exec(
*cmd,
@@ -95,7 +251,7 @@ async def call_pi(meta_path: Path, pdf_path: Path) -> str:
if proc.returncode != 0:
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
return stdout.decode("utf-8", errors="replace")
return stdout.decode("utf-8", errors="replace"), session_id
# ── JSON 提取 ──────────────────────────────────────────────────────────
+15 -20
View File
@@ -12,8 +12,7 @@ from pydantic import BaseModel, Field, ValidationError, field_validator
class PrerequisitesSchema(BaseModel):
concepts: list[str] = Field(default_factory=list)
level: str = ""
concepts: list[dict] = Field(default_factory=list)
class MotivationSchema(BaseModel):
@@ -32,7 +31,7 @@ class MotivationSchema(BaseModel):
class MethodSchema(BaseModel):
overview: str = ""
key_idea: str
steps: list[str] = Field(default_factory=list)
steps: str = ""
novelty: str = ""
@field_validator("key_idea")
@@ -44,14 +43,14 @@ class MethodSchema(BaseModel):
class ResultsSchema(BaseModel):
main_findings: list[str] = Field(default_factory=list)
benchmarks: list[dict] = Field(default_factory=list)
limitations: list[str] = Field(default_factory=list)
main_findings: str = ""
benchmarks: list[str | dict] = Field(default_factory=list)
limitations: str = ""
class ImprovementsSchema(BaseModel):
weaknesses: list[str] = Field(default_factory=list)
future_work: list[str] = Field(default_factory=list)
weaknesses: str = ""
future_work: str = ""
reproducibility: str = ""
@@ -71,6 +70,7 @@ class SummarySchema(BaseModel):
method: MethodSchema
results: ResultsSchema = Field(default_factory=ResultsSchema)
improvements: ImprovementsSchema = Field(default_factory=ImprovementsSchema)
figures: list[dict] = Field(default_factory=list)
@field_validator("title_zh", "one_line")
@classmethod
@@ -116,7 +116,7 @@ def assess_quality(schema: SummarySchema) -> str:
missing_important += 1
if not schema.method.overview.strip():
missing_important += 1
if not schema.results.main_findings:
if not schema.results.main_findings.strip():
missing_important += 1
if missing_important == 0:
@@ -140,22 +140,17 @@ def flatten_for_db(schema: SummarySchema) -> dict:
"motivation_gap": schema.motivation.gap,
"method_overview": schema.method.overview,
"method_key_idea": schema.method.key_idea,
"method_steps_json": json.dumps(schema.method.steps, ensure_ascii=False),
"method_steps_json": schema.method.steps,
"method_novelty": schema.method.novelty,
"results_main_json": json.dumps(
schema.results.main_findings, ensure_ascii=False
),
"results_main_json": schema.results.main_findings,
"results_benchmarks_json": json.dumps(
schema.results.benchmarks, ensure_ascii=False
),
"limitations_json": json.dumps(schema.results.limitations, ensure_ascii=False),
"weaknesses_json": json.dumps(
schema.improvements.weaknesses, ensure_ascii=False
),
"future_work_json": json.dumps(
schema.improvements.future_work, ensure_ascii=False
),
"limitations_json": schema.results.limitations,
"weaknesses_json": schema.improvements.weaknesses,
"future_work_json": schema.improvements.future_work,
"reproducibility": schema.improvements.reproducibility,
"figures_json": json.dumps(schema.figures, ensure_ascii=False),
"full_json": schema.model_dump_json(ensure_ascii=False),
"updated_at": datetime.now(timezone.utc),
}
+141 -11
View File
@@ -22,7 +22,6 @@ from app.models import (
SummaryStatus,
TaskLock,
)
from app.services.image_extractor import extract_images_from_source
from app.services.pdf_downloader import (
PdfDownloadError,
cleanup_tmp,
@@ -77,10 +76,9 @@ def _build_fts_summary_text(schema: SummarySchema) -> str:
schema.one_line or "",
schema.motivation.problem or "",
schema.motivation.goal or "",
schema.method_overview if hasattr(schema, "method_overview") else "",
schema.method.overview or "",
schema.method.key_idea or "",
" ".join(schema.results.main_findings or []),
schema.results.main_findings or "",
]
return " ".join(p for p in parts if p)
@@ -141,6 +139,77 @@ def _update_summary_in_db(
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
# ── JSON 验证 ──────────────────────────────────────────────────────────
def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
"""验证 JSON 数据是否符合要求,返回错误列表(空=通过)。"""
errors: list[str] = []
if not isinstance(json_data, dict):
return ["顶层必须是 JSON 对象"]
# 必填字段
for f in ["arxiv_id", "title_zh", "one_line", "tags"]:
if f not in json_data or not json_data[f]:
errors.append(f"缺少必填字段: {f}")
# tags 必须是非空数组
tags = json_data.get("tags")
if not isinstance(tags, list) or len(tags) == 0:
errors.append("tags 必须是非空数组")
# 字符串段落字段(必须是 str 且 ≥50 字)
string_fields = [
("motivation", "problem"), ("motivation", "goal"), ("motivation", "gap"),
("method", "overview"), ("method", "key_idea"), ("method", "steps"),
("method", "novelty"),
("results", "main_findings"), ("results", "limitations"),
("improvements", "weaknesses"), ("improvements", "future_work"),
("improvements", "reproducibility"),
]
for section, field in string_fields:
val = json_data.get(section, {}).get(field)
if isinstance(val, list):
errors.append(f"{section}.{field} 应该是字符串段落,不能是数组")
elif not isinstance(val, str) or len(val.strip()) < 50:
errors.append(
f"{section}.{field} 必须是详细段落(≥50字),"
f"当前: {type(val).__name__} ({len(str(val))}字)"
)
# benchmarks 必须是数组
benchmarks = json_data.get("results", {}).get("benchmarks")
if benchmarks is not None and not isinstance(benchmarks, list):
errors.append("results.benchmarks 必须是数组")
# prerequisites.concepts 必须是对象数组,每个有 term
concepts = json_data.get("prerequisites", {}).get("concepts")
if concepts is not None:
if not isinstance(concepts, list):
errors.append("prerequisites.concepts 必须是数组")
elif len(concepts) == 0:
errors.append("prerequisites.concepts 不能为空")
else:
for i, c in enumerate(concepts):
if isinstance(c, str):
errors.append(f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串")
elif isinstance(c, dict) and not c.get("term"):
errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段")
# figures 必须是数组,每个元素应有 id
figures = json_data.get("figures")
if figures is not None:
if not isinstance(figures, list):
errors.append("figures 必须是数组")
else:
for i, fig in enumerate(figures):
if isinstance(fig, dict) and not fig.get("id"):
errors.append(f"figures[{i}] 缺少 id 字段")
return errors
# ── 文件操作 ────────────────────────────────────────────────────────────
@@ -227,11 +296,64 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
# 下载 PDF
await download_pdf(arxiv_id, paper.pdf_url)
# 调用 pi
raw_output = await call_pi(meta_path, Path("data/tmp") / arxiv_id / "paper.pdf")
# 带验证的生成循环:最多 4 轮,同一 session 内 pi 可看到之前写的文件
json_data = None
validation_errors = []
session_id = None
for attempt in range(1, 5):
# 清理上一轮 pi 通过 write_file 写的不完整文件
stale = paper_dir(arxiv_id) / "summary.json"
if stale.exists():
stale.unlink()
# 提取 JSON
json_data = extract_json(raw_output)
if attempt == 1:
raw_output, session_id = await call_pi(
meta_path, Path("data/tmp") / arxiv_id / "paper.pdf"
)
else:
# 验证失败,同一 session 内带着错误信息让 pi 修正
raw_output, session_id = await call_pi(
meta_path,
Path("data/tmp") / arxiv_id / "paper.pdf",
fix_errors=validation_errors,
session_id=session_id,
)
# 优先从 pi write_file 写入的 summary.json 读取,否则从 stdout 提取
# 如果都失败,当作验证错误,继续下一次尝试
json_data = None
summary_file = paper_dir(arxiv_id) / "summary.json"
try:
if summary_file.exists():
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
logger.info("Read summary.json written by pi for %s", arxiv_id)
else:
json_data = extract_json(raw_output)
except (json.JSONDecodeError, JsonNotFoundError) as exc:
logger.warning(
"JSON extraction failed for %s (attempt %d): %s",
arxiv_id,
attempt,
str(exc)[:200],
)
validation_errors = [f"无法提取有效 JSON: {str(exc)[:100]}"]
continue
# 运行验证脚本
validation_errors = _validate_summary(json_data, arxiv_id)
if not validation_errors:
break
logger.warning(
"Validation failed for %s (attempt %d): %s",
arxiv_id,
attempt,
"; ".join(validation_errors),
)
if validation_errors:
raise ValueError(
f"Summary validation failed after 4 attempts: {'; '.join(validation_errors)}"
)
# Pydantic 校验
schema = SummarySchema.model_validate(json_data)
@@ -252,9 +374,17 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
status.raw_output_saved = True
db.commit()
# LaTeX 图片提取(可选增强,失败不影响总结)
# PDF 图片提取(可选增强,失败不影响总结)
try:
await extract_images_from_source(arxiv_id)
from app.services.pdf_image_extractor import (
extract_images_from_pdf,
filter_images_by_summary,
)
pdf_path = Path("data/tmp") / arxiv_id / "paper.pdf"
extract_images_from_pdf(arxiv_id, pdf_path)
# 根据 summary 中 figures 字段过滤,只保留被引用的图表
if schema.figures:
filter_images_by_summary(arxiv_id, schema.figures)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
@@ -268,8 +398,8 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
"title_en": paper.title_en or "",
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
"one_line": schema.one_line or "",
"motivation_problem": schema.motivation_problem or "",
"method_key_idea": schema.method_key_idea or "",
"motivation_problem": schema.motivation.problem or "",
"method_key_idea": schema.method.key_idea or "",
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
}
index_paper(arxiv_id, texts_dict)