feat: add admin dashboard, pipeline service, lightbox, and update dependencies

This commit is contained in:
2026-06-09 09:32:10 +08:00
parent 0d293422ac
commit 32978b3fc5
50 changed files with 4054 additions and 1618 deletions
+109
View File
@@ -0,0 +1,109 @@
"""管理后台服务 — 统计聚合、系统状态。"""
from __future__ import annotations
from datetime import date
from pathlib import Path
from sqlalchemy import func, select, text
from sqlalchemy.orm import Session
from app.config import settings
from app.models import CrawlLog, Paper, SummaryState, TaskLock
from app.services.scheduler import get_scheduler
from app.utils import PAPERS_DIR, TMP_DIR
def _dir_size(path: Path) -> int:
"""递归计算目录总字节数。"""
if not path.exists():
return 0
return sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
def _fmt_size(nbytes: int) -> str:
"""字节数 → 人类可读字符串。"""
for unit in ("B", "KB", "MB", "GB"):
if nbytes < 1024:
return f"{nbytes:.1f} {unit}"
nbytes /= 1024
return f"{nbytes:.1f} TB"
def get_admin_stats(db: Session) -> dict:
"""管理仪表盘统计数据。"""
today = date.today()
# ── 论文统计 ──────────────────────────────────────────────────────
total_papers = db.scalar(select(func.count(Paper.id)))
today_papers = db.scalar(
select(func.count(Paper.id)).where(Paper.paper_date == today)
)
# ── 总结状态分布 ──────────────────────────────────────────────────
summary_rows = db.execute(
text("""
SELECT COALESCE(ss.status, 'none') AS status, COUNT(*) AS cnt
FROM papers p
LEFT JOIN summary_status ss ON ss.paper_id = p.id
GROUP BY status
""")
).fetchall()
status_counts = {row[0]: row[1] for row in summary_rows}
# ── 存储概况 ──────────────────────────────────────────────────────
db_size = _fmt_size(settings.db_path.stat().st_size) if settings.db_path.exists() else "0 B"
papers_size = _fmt_size(_dir_size(PAPERS_DIR))
tmp_size = _fmt_size(_dir_size(TMP_DIR))
# ── 调度器状态 ────────────────────────────────────────────────────
scheduler = get_scheduler()
scheduler_enabled = scheduler is not None
next_run = None
if scheduler_enabled:
for job in scheduler.get_jobs():
if job.id == "daily_pipeline":
next_run = job.next_run_time
break
# ── 最近日志(5 条) ──────────────────────────────────────────────
recent_logs = (
db.execute(
select(CrawlLog)
.order_by(CrawlLog.started_at.desc())
.limit(5)
)
.scalars()
.all()
)
# ── 活跃锁 ────────────────────────────────────────────────────────
active_locks = (
db.execute(
select(TaskLock).where(TaskLock.status == "running")
)
.scalars()
.all()
)
return {
"total_papers": total_papers or 0,
"today_papers": today_papers or 0,
"pending_count": status_counts.get(SummaryState.PENDING, 0),
"failed_count": status_counts.get(SummaryState.FAILED, 0)
+ status_counts.get(SummaryState.PERMANENT_FAILURE, 0),
"done_count": status_counts.get(SummaryState.DONE, 0),
"running_count": status_counts.get("running", 0)
+ status_counts.get(SummaryState.PROCESSING, 0),
"none_count": status_counts.get("none", 0),
"status_counts": status_counts,
"db_size": db_size,
"papers_size": papers_size,
"tmp_size": tmp_size,
"scheduler_enabled": scheduler_enabled,
"schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}",
"timezone": settings.APP_TIMEZONE,
"next_run": next_run.isoformat() if next_run else None,
"recent_logs": recent_logs,
"active_locks": active_locks,
}
+13 -9
View File
@@ -2,21 +2,20 @@
from __future__ import annotations
import json
import logging
import shutil
from datetime import date, datetime, timezone
from pathlib import Path
from datetime import date
from sqlalchemy import delete, select, text
from sqlalchemy import select, text
from sqlalchemy.orm import Session
from app.models import (
CrawlLog,
DataDeleteJob,
Paper,
TaskLock,
)
from app.utils import PAPERS_DIR, TMP_DIR
from app.utils import PAPERS_DIR, TMP_DIR, utc_now
logger = logging.getLogger(__name__)
@@ -39,7 +38,7 @@ def cleanup_tmp(max_age_hours: int = _MAX_TMP_AGE_HOURS) -> dict:
if not TMP_DIR.exists():
return {"scanned": 0, "removed": 0, "errors": []}
now = datetime.now(timezone.utc)
now = utc_now()
cutoff = now.timestamp() - (max_age_hours * 3600)
scanned = 0
removed = 0
@@ -96,7 +95,7 @@ async def delete_papers_by_date_range(
Returns:
删除结果统计
"""
now = datetime.now(timezone.utc)
now = utc_now()
# 查询目标论文
papers = (
@@ -195,7 +194,7 @@ async def delete_papers_by_date_range(
job.status = job_status
job.paper_count = deleted
job.completed_at = datetime.now(timezone.utc)
job.completed_at = utc_now()
if job_error:
job.error = job_error[:4000]
db.commit()
@@ -205,9 +204,14 @@ async def delete_papers_by_date_range(
task="delete",
status=job_status,
started_at=now,
completed_at=datetime.now(timezone.utc),
completed_at=utc_now(),
papers_found=total,
papers_new=deleted,
details_json=json.dumps({
"total_before": total,
"deleted": deleted,
"failed": len(failed_items),
}, ensure_ascii=False),
error=job_error,
)
db.add(log_entry)
+10 -8
View File
@@ -1,8 +1,7 @@
"""爬虫服务 — 从 HuggingFace Daily Papers API 抓取论文元数据。"""
import logging
from datetime import date as date_type
from datetime import datetime, timezone
from datetime import date as date_type, datetime, timezone
import httpx
from sqlalchemy import select, text
@@ -14,9 +13,10 @@ from app.models import (
Paper,
PaperAuthor,
PaperTag,
SummaryState,
SummaryStatus,
)
from app.utils import make_http_client
from app.utils import make_http_client, utc_now
logger = logging.getLogger(__name__)
@@ -131,15 +131,17 @@ def upsert_papers(db: Session, papers_raw: list[dict], paper_date: str) -> list[
db.add(paper)
db.flush()
seen_authors: set[str] = set()
for idx, name in enumerate(meta["authors"]):
if name:
if name and name not in seen_authors:
seen_authors.add(name)
db.add(PaperAuthor(paper_id=paper.id, name=name, position=idx))
for tag_name in meta["tags"]:
if tag_name:
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="hf"))
db.add(SummaryStatus(paper_id=paper.id, status="pending"))
db.add(SummaryStatus(paper_id=paper.id, status=SummaryState.PENDING))
authors_text = ", ".join(meta["authors"])
tags_text = ", ".join(meta["tags"])
@@ -172,7 +174,7 @@ def upsert_papers(db: Session, papers_raw: list[dict], paper_date: str) -> list[
async def crawl_daily(db: Session, target_date: str, top_n: int | None = None) -> dict:
"""完整的抓取流程:获取 + 入库 + 写日志。"""
now = datetime.now(timezone.utc)
now = utc_now()
log_entry = CrawlLog(
task="crawl",
status="running",
@@ -188,7 +190,7 @@ async def crawl_daily(db: Session, target_date: str, top_n: int | None = None) -
log_entry.status = "success"
log_entry.papers_found = len(raw_papers)
log_entry.papers_new = len(new_papers)
log_entry.completed_at = datetime.now(timezone.utc)
log_entry.completed_at = utc_now()
db.commit()
return {
"found": len(raw_papers),
@@ -200,6 +202,6 @@ async def crawl_daily(db: Session, target_date: str, top_n: int | None = None) -
logger.exception("Crawl failed for %s", target_date)
log_entry.status = "failed"
log_entry.error = str(exc)
log_entry.completed_at = datetime.now(timezone.utc)
log_entry.completed_at = utc_now()
db.commit()
return {"found": 0, "new": 0, "status": "failed", "error": str(exc)}
+6 -36
View File
@@ -5,7 +5,8 @@ from __future__ import annotations
import logging
from pathlib import Path
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from app.config import settings
from app.models import Paper
@@ -188,12 +189,11 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
db = SessionLocal()
try:
paper = (
db.query(Paper)
.filter(Paper.arxiv_id == paper_id)
paper = db.execute(
select(Paper)
.where(Paper.arxiv_id == paper_id)
.options(joinedload(Paper.tags), joinedload(Paper.summary))
.first()
)
).unique().scalar_one_or_none()
if not paper:
logger.warning("Paper %s not found for indexing", paper_id)
return False
@@ -242,36 +242,6 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
return False
# ── 批量索引 ────────────────────────────────────────────────────────────
def index_batch(paper_ids: list[str]) -> dict:
"""批量索引论文,单篇失败不影响其他。
Returns:
{"total": int, "success": int, "failed": int}
"""
if not paper_ids:
return {"total": 0, "success": 0, "failed": 0}
col = get_collection()
if col is None:
return {"total": len(paper_ids), "success": 0, "failed": len(paper_ids)}
success = 0
failed = 0
for pid in paper_ids:
if index_paper(pid):
success += 1
else:
failed += 1
logger.info(
"Batch index: total=%d success=%d failed=%d", len(paper_ids), success, failed
)
return {"total": len(paper_ids), "success": success, "failed": failed}
# ── 删除 ────────────────────────────────────────────────────────────────
+1 -40
View File
@@ -1,10 +1,9 @@
"""PDF 下载与源码下载 — 从 arXiv 下载论文 PDF 和 LaTeX 源码包"""
"""PDF 下载 — 从 arXiv 下载论文 PDF。"""
from __future__ import annotations
import logging
import shutil
import zipfile
from pathlib import Path
from app.utils import PAPERS_DIR, TMP_DIR, make_http_client
@@ -54,44 +53,6 @@ async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
return dest
# ── 源码下载 ────────────────────────────────────────────────────────────
async def download_source_zip(arxiv_id: str, source_url: str, dest_dir: Path) -> None:
"""下载 arXiv 源码并解压。"""
dest_dir.mkdir(parents=True, exist_ok=True)
zip_path = tmp_dir(arxiv_id) / "source.zip"
try:
async with make_http_client(follow_redirects=True) as client:
resp = await client.get(source_url)
resp.raise_for_status()
zip_path.write_bytes(resp.content)
except Exception as exc:
logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
return
try:
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(dest_dir)
logger.debug("Extracted source for %s", arxiv_id)
except zipfile.BadZipFile:
# 可能是 tar.gz
import tarfile
try:
with tarfile.open(zip_path, "r:*") as tf:
tf.extractall(dest_dir, filter="data")
logger.debug("Extracted source (tar) for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
finally:
if zip_path.exists():
zip_path.unlink()
# ── 临时文件清理 ────────────────────────────────────────────────────────
+4 -9
View File
@@ -16,6 +16,7 @@ import re
from pathlib import Path
from app.services.pdf_downloader import paper_dir
from app.utils import TMP_DIR
logger = logging.getLogger(__name__)
@@ -40,10 +41,7 @@ def _find_nearby_labels(
"""
matched: list[str] = []
for rect in rects:
if isinstance(rect, (list, tuple)):
y_min, y_max = rect[1], rect[3]
else:
y_min, y_max = rect.y0, rect.y1
y_min, y_max = rect.y0, rect.y1
for label_key, positions in labels.items():
for label_page, label_y in positions:
@@ -69,7 +67,7 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
import pymupdf
if pdf_path is None:
pdf_path = Path("data/tmp") / arxiv_id / "paper.pdf"
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
if not pdf_path.exists():
logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path)
@@ -162,10 +160,7 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
continue
margin = 5
if isinstance(bbox, (list, tuple)):
x0, y0, x1, y1 = bbox
else:
x0, y0, x1, y1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
x0, y0, x1, y1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
clip_rect = pymupdf.Rect(x0 - margin, y0 - margin, x1 + margin, y1 + margin)
zoom = 2
+131 -68
View File
@@ -62,26 +62,17 @@ def write_meta_json(paper) -> Path:
# ── PDF 文本提取 ────────────────────────────────────────────────────────
def _trim_body(text: str, max_chars: int = 80_000) -> str:
def _trim_body(text: str, max_chars: int | None = None) -> str:
"""去除参考文献,保留正文+附录,超长时从末尾截断。
策略:
1. 去掉 References/Bibliography 段落(纯引用列表,对解读无用)
2. 正文 + 附录全部保留
3. 如果总长超过 max_chars,从末尾截断(附录靠后,优先保留正文)
3. 如果指定了 max_chars 且总长超过,从末尾截断(附录靠后,优先保留正文)
"""
import re
# 找 References 段落的位置(在 Appendix 之后的那个)
# 有些论文结构:正文 -> Appendix -> References
# 也可能是:正文 -> References -> Appendix
# 策略:只删除明确的 References 块
ref_pattern = re.compile(
r"(?m)^(?:References|Bibliography|参考文献)\s*$\n"
r"(?s:.*?)" # References 内容
r"(?=\n(?:A\s|Appendix|Supplementary|Acknowledgment|致谢)\s|\Z)",
)
# 简单策略:找到 References 标题,如果后面没有 Appendix 就全删
# 如果后面还有 Appendix,只删 References 到 Appendix 之间的内容
ref_match = re.search(r"(?m)^(?:References|Bibliography|参考文献)\s*$", text)
@@ -110,26 +101,30 @@ def _trim_body(text: str, max_chars: int = 80_000) -> str:
else:
text = text[:ack_match.start()].rstrip()
# 最后:如果超长,从末尾截断(附录在后面,正文在前面,优先保留正文)
if len(text) > max_chars:
# 最后:如果指定了上限且超长,从末尾截断(附录在后面,正文在前面,优先保留正文)
if max_chars is not None and len(text) > max_chars:
text = text[:max_chars].rstrip()
return text
def extract_pdf_text(pdf_path: Path) -> Path:
"""用 pymupdf 提取 PDF 正文文本(自动截断参考文献和附录),保存为 .txt。"""
def extract_pdf_text(pdf_path: Path, max_chars: int | None = None) -> Path:
"""用 pymupdf 提取 PDF 正文文本,保存为 .txt。
max_chars=None 时不截断,给 search/auto 模式保留完整内容。
"""
import pymupdf
txt_path = pdf_path.with_suffix(".txt")
if txt_path.exists():
# 缓存优先;如果需重新提取(不同 max_chars),先删旧文件
return txt_path
doc = pymupdf.open(str(pdf_path))
raw_text = "\n\n".join(page.get_text() for page in doc)
doc.close()
body = _trim_body(raw_text)
body = _trim_body(raw_text, max_chars=max_chars)
txt_path.write_text(body, encoding="utf-8")
logger.info(
"Extracted PDF text: %s (%d -> %d chars, -%d%%)",
@@ -141,6 +136,91 @@ def extract_pdf_text(pdf_path: Path) -> Path:
return txt_path
# ── Prompt 构建 ─────────────────────────────────────────────────────────
def _build_prompt(
arxiv_id: str,
meta_path: Path,
txt_path: Path,
pdf_mode: str,
fix_errors: list[str] | None = None,
) -> str:
"""根据模式构建 pi prompt。
inject: 全量注入,prompt 末尾包含论文全文内容
search: pi 自主 read 文件,prompt 只包含工作流指令
"""
json_schema = (
"## 必须包含以下字段(不要自创字段名):\n"
'{"arxiv_id": "...", '
'"title_zh": "中文标题", '
'"one_line": "一句话概括(≤50字)", '
'"tags": ["标签1","标签2"], '
'"difficulty": "入门/进阶/前沿", '
'"prerequisites": {"concepts": [{"term":"术语","explanation":"详细解释这个概念是什么、怎么工作的(50-150字)","why_matters":"为什么读懂本文需要它"}]}, '
'"motivation": {"problem": "详细段落:现有方法的具体问题(包含具体场景和数据)", '
'"goal": "详细段落:本文的具体目标", '
'"gap": "详细段落:本文的独特切入角度"}, '
'"method": {"overview": "详细段落:方法整体思路(先直觉再技术路线)", '
'"key_idea": "详细段落:核心创新点(和已有方法的本质区别)", '
'"steps": "详细段落:方法步骤的完整描述(每步的输入输出和具体操作)", '
'"novelty": "详细段落:技术新颖性分析"}, '
'"results": {"main_findings": "详细段落:核心发现(带具体数字和指标,逐一分析每个实验)", '
'"benchmarks": [{"task":"任务","metric":"指标","this_work":"本文结果","baseline":"基线","improvement":"提升"}], '
'"limitations": "详细段落:局限性分析(作者承认的+你自己的观察")}, '
'"improvements": {"weaknesses": "详细段落:独立分析的弱点(具体场景,每个弱点给改进方向)", '
'"future_work": "详细段落:未来研究方向(作者提出的+基于成果可延伸的)", '
'"reproducibility": "详细段落:复现评估(开源情况、数据、算力、难度")}, '
'"figures": [{"id":"Figure 1","caption":"原图标题","description":"文字描述图展示了什么","reason":"为什么这张图对理解论文重要"},'
'{"id":"Table 1","caption":"表格标题","description":"文字描述表格包含的数据和结论","reason":"为什么这个表格对理解论文重要"}]'
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Tableid 严格使用 \"Figure N\"\"Table N\" 格式。"
"}"
)
writing_requirements = (
"## 写作要求\n"
"- 每个字符串字段必须写成详细段落(200-500字),不要用列表或数组\n"
"- 必须包含论文中的具体数据、数字、实验指标\n"
"- 像资深同事给同事讲论文一样,专业但易懂\n"
"- 数学公式、符号、变量必须使用 LaTeX 格式:行内公式用 $...$,独立公式用 $$...$$\n"
" 例如:损失函数 $\\mathcal{L} = -\\sum_{i} \\log p(y_i | x_i)$,学习率 $\\eta$\n"
)
if fix_errors:
error_list = "\n".join(f"- {e}" for e in fix_errors)
return (
"你之前生成的 JSON 存在以下问题,请修正后重新用 write_file 保存到 "
f"data/papers/{arxiv_id}/summary.json\n\n"
f"{error_list}\n\n"
"注意:所有字符串字段必须是详细段落(≥50字),不能是数组或列表。"
"修正后请用 bash 运行 python scripts/validate_summary.py 验证。"
)
if pdf_mode == "search":
return (
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。\n\n"
"## 工作流程\n"
f"1. 先用 read 工具读取 {meta_path} 了解论文元信息(标题、作者、摘要)\n"
f"2. 再用 read 工具阅读 {txt_path}(论文正文全文),可以多次读取定位关键段落\n"
f"3. 充分理解后,用 write_file 将结果保存到 data/papers/{arxiv_id}/summary.json\n\n"
+ writing_requirements
+ "\n"
+ json_schema
)
else:
return (
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。\n\n"
"## 工作流程\n"
"论文元信息和正文全文已在上文提供,请仔细阅读。\n"
f"1. 充分理解论文后,用 write_file 将结果保存到 data/papers/{arxiv_id}/summary.json\n"
"2. 用 bash 运行 python scripts/validate_summary.py 验证\n\n"
+ writing_requirements
+ "\n"
+ json_schema
)
# ── pi CLI 调用 ────────────────────────────────────────────────────────
@@ -149,63 +229,41 @@ async def call_pi(
pdf_path: Path,
fix_errors: list[str] | None = None,
session_id: str | None = None,
pdf_mode: str = "inject",
) -> tuple[str, str]:
"""调用 pi CLI 非交互模式,返回 (stdout 文本, session_id)。
fix_errors: 如果非空,表示上一次验证失败的错误列表,pi 需要修正这些问题。
session_id: 如果非空,用 --continue 延续该 session;否则创建新 session。
pdf_mode: "inject" = 全量注入 prompt@file),"search" = pi 自主 read 文件。
"""
arxiv_id = meta_path.parent.name
# PDF 转为文本文件,以 @txt 方式传给 pi
txt_path = extract_pdf_text(pdf_path)
# 提取 PDF 全文(不截断),根据实际大小自动选择模式
txt_path = extract_pdf_text(pdf_path, max_chars=None)
txt_size = len(txt_path.read_text(encoding="utf-8"))
if fix_errors:
# 验证失败后的修正提示(同一 session 内,pi 能看到之前写的文件)
error_list = "\n".join(f"- {e}" for e in fix_errors)
prompt_text = (
"你之前生成的 JSON 存在以下问题,请修正后重新用 write_file 保存到 "
f"data/papers/{arxiv_id}/summary.json\n\n"
f"{error_list}\n\n"
"注意:所有字符串字段必须是详细段落(≥50字),不能是数组或列表。"
"修正后请用 bash 运行 python scripts/validate_summary.py 验证。"
)
else:
prompt_text = (
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。"
"只输出一个 JSON 对象,不要输出其他内容。\n\n"
"## 写作要求\n"
"- 每个字符串字段必须写成详细段落(200-500字),不要用列表或数组\n"
"- 必须包含论文中的具体数据、数字、实验指标\n"
"- 像资深同事给同事讲论文一样,专业但易懂\n"
"- 数学公式、符号、变量必须使用 LaTeX 格式:行内公式用 $...$,独立公式用 $$...$$\n"
" 例如:损失函数 $\\mathcal{L} = -\\sum_{i} \\log p(y_i | x_i)$,学习率 $\\eta$\n\n"
"## 必须包含以下字段(不要自创字段名):\n"
'{"arxiv_id": "...", '
'"title_zh": "中文标题", '
'"one_line": "一句话概括(≤50字)", '
'"tags": ["标签1","标签2"], '
'"difficulty": "入门/进阶/前沿", '
'"prerequisites": {"concepts": [{"term":"术语","explanation":"详细解释这个概念是什么、怎么工作的(50-150字)","why_matters":"为什么读懂本文需要它"}]}, '
'"motivation": {"problem": "详细段落:现有方法的具体问题(包含具体场景和数据)", '
'"goal": "详细段落:本文的具体目标", '
'"gap": "详细段落:本文的独特切入角度"}, '
'"method": {"overview": "详细段落:方法整体思路(先直觉再技术路线)", '
'"key_idea": "详细段落:核心创新点(和已有方法的本质区别)", '
'"steps": "详细段落:方法步骤的完整描述(每步的输入输出和具体操作)", '
'"novelty": "详细段落:技术新颖性分析"}, '
'"results": {"main_findings": "详细段落:核心发现(带具体数字和指标,逐一分析每个实验)", '
'"benchmarks": [{"task":"任务","metric":"指标","this_work":"本文结果","baseline":"基线","improvement":"提升"}], '
'"limitations": "详细段落:局限性分析(作者承认的+你自己的观察)"}, '
'"improvements": {"weaknesses": "详细段落:独立分析的弱点(具体场景,每个弱点给改进方向)", '
'"future_work": "详细段落:未来研究方向(作者提出的+基于成果可延伸的)", '
'"reproducibility": "详细段落:复现评估(开源情况、数据、算力、难度)"}, '
'"figures": [{"id":"Figure 1","caption":"原图标题","description":"文字描述图展示了什么","reason":"为什么这张图对理解论文重要"},'
'{"id":"Table 1","caption":"表格标题","description":"文字描述表格包含的数据和结论","reason":"为什么这个表格对理解论文重要"}]'
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Tableid 严格使用 \"Figure N\"\"Table N\" 格式。"
"}\n\n"
"请深度解读以下论文:"
)
actual_mode = pdf_mode
if pdf_mode == "auto":
if txt_size > 80_000:
actual_mode = "search"
logger.info(
"Auto mode: %s text=%d chars > 80k → search", arxiv_id, txt_size
)
else:
actual_mode = "inject"
logger.info(
"Auto mode: %s text=%d chars ≤ 80k → inject", arxiv_id, txt_size
)
# inject 模式需要截断过长的文本(避免撑爆 context)
if actual_mode == "inject" and txt_size > 80_000:
body = txt_path.read_text(encoding="utf-8")
trimmed = body[:80_000].rstrip()
txt_path.write_text(trimmed, encoding="utf-8")
logger.info("Truncated %s for inject: %d%d chars", arxiv_id, txt_size, len(trimmed))
prompt_text = _build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
# 构建 session ID(每篇论文一个独立 session)
if session_id is None:
@@ -213,10 +271,12 @@ async def call_pi(
session_id = f"summary-{arxiv_id}-{uuid.uuid4().hex[:8]}"
# 工具列表:search 模式需要 read 工具
tools = "bash,write_file" if actual_mode != "search" else "bash,write_file,read"
cmd = [
settings.PI_BIN,
"-p",
"--tools", "bash,write_file",
"--tools", tools,
]
if fix_errors:
cmd += ["--session", session_id, "--continue"]
@@ -227,11 +287,14 @@ async def call_pi(
settings.SUMMARY_SKILL,
prompt_text,
]
if not fix_errors:
# 首次调用传文件,后续 --continue 不需要(session 内已有)
if not fix_errors and actual_mode != "search":
# inject 模式:首次调用传 @filesearch 模式 pi 自己 read,不注入
cmd += [f"@{meta_path}", f"@{txt_path}"]
logger.info("Calling pi for %s (fix=%s, session=%s)", arxiv_id, bool(fix_errors), session_id)
logger.info(
"Calling pi for %s (fix=%s, session=%s, mode=%s)",
arxiv_id, bool(fix_errors), session_id, actual_mode,
)
proc = await asyncio.create_subprocess_exec(
*cmd,
+108
View File
@@ -0,0 +1,108 @@
"""流水线服务 — crawl → summarize → cleanup 的共享编排逻辑。
供 admin 手动触发和 scheduler 定时调度共用。
"""
from __future__ import annotations
import logging
from datetime import date as date_type
from sqlalchemy.orm import Session
from app.config import settings
from app.models import CrawlLog, TaskLock
from app.services.cleaner import cleanup_tmp
from app.services.crawler import crawl_daily
from app.services.summarizer import summarize_batch
from app.utils import utc_now, yesterday_str
logger = logging.getLogger(__name__)
async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
"""执行完整流水线:crawl → summarize → cleanup。
使用 task_locks 防重入,写入 CrawlLog 记录。
Args:
db: 数据库 session
target_date: 目标日期 YYYY-MM-DD
owner: 调用者标识(如 "admin_trigger" / "daily_pipeline"
Returns:
{"status": "success"|"failed", "error": str|None, ...}
"""
now = utc_now()
lock_key = f"pipeline-{target_date}"
# ── 获取锁 ──────────────────────────────────────────────────────────
lock = TaskLock(
task="scheduler",
lock_key=lock_key,
status="running",
owner=owner,
acquired_at=now,
)
try:
db.add(lock)
db.commit()
except Exception:
db.rollback()
raise RuntimeError(f"Pipeline already running for {target_date}")
# ── 写调度日志 ──────────────────────────────────────────────────────
log_entry = CrawlLog(
task="scheduler",
status="running",
date=date_type.fromisoformat(target_date),
started_at=now,
)
db.add(log_entry)
db.commit()
error_msg = None
crawl_result: dict = {}
try:
# Step 1: 抓取(先试今天,无数据则回退昨天)
crawl_result = await crawl_daily(db, target_date)
logger.info("Pipeline [%s]: crawl %s, found=%d new=%d",
owner, target_date,
crawl_result.get("found", 0), crawl_result.get("new", 0))
if crawl_result.get("status") == "success" and crawl_result.get("found") == 0:
yesterday = yesterday_str()
logger.info("Pipeline [%s]: falling back to %s", owner, yesterday)
crawl_result = await crawl_daily(db, yesterday)
# Step 2: 总结
summarize_result = await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE)
logger.info("Pipeline [%s]: summarize done, result=%s", owner, summarize_result)
# Step 3: 清理
cleanup_result = cleanup_tmp()
logger.info("Pipeline [%s]: cleanup done, removed=%d",
owner, cleanup_result.get("removed", 0))
log_entry.status = "success"
log_entry.papers_found = crawl_result.get("found", 0)
log_entry.papers_new = crawl_result.get("new", 0)
except Exception as exc:
logger.exception("Pipeline [%s] failed", owner)
log_entry.status = "failed"
error_msg = str(exc)[:2000]
finally:
log_entry.completed_at = utc_now()
if error_msg:
log_entry.error = error_msg
db.commit()
lock.status = "finished"
lock.released_at = utc_now()
db.commit()
if error_msg:
return {"status": "failed", "error": error_msg}
return {"status": "success", "message": "Pipeline completed"}
+7 -80
View File
@@ -3,7 +3,6 @@
from __future__ import annotations
import logging
from datetime import datetime, timezone
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
@@ -12,10 +11,8 @@ from zoneinfo import ZoneInfo
from app.config import settings
from app.database import SessionLocal
from app.models import CrawlLog, TaskLock
from app.services.cleaner import cleanup_tmp
from app.services.crawler import crawl_daily
from app.services.summarizer import summarize_batch
from app.services.pipeline import run_pipeline
from app.utils import today_str
logger = logging.getLogger(__name__)
@@ -92,85 +89,15 @@ def stop_scheduler() -> None:
async def _daily_pipeline() -> None:
"""每日流水线:抓取 → 总结 → 清理。
使用 task_locks 表防止重入:同一天的 pipeline 任务只有一个能运行
委托给 pipeline.run_pipeline 执行,使用 task_locks 防重入
"""
tz = ZoneInfo(settings.APP_TIMEZONE)
today = datetime.now(tz).strftime("%Y-%m-%d")
now = datetime.now(timezone.utc)
lock_key = f"pipeline-{today}"
today = today_str()
db: Session = SessionLocal()
try:
# 尝试获取锁
lock = TaskLock(
task="scheduler",
lock_key=lock_key,
status="running",
owner="daily_pipeline",
acquired_at=now,
)
try:
db.add(lock)
db.commit()
except Exception:
db.rollback()
logger.warning("Daily pipeline already running for %s, skipping", today)
return
# 写调度日志
log_entry = CrawlLog(
task="scheduler",
status="running",
date=datetime.now(tz).date(),
started_at=now,
)
db.add(log_entry)
db.commit()
error_msg = None
try:
# Step 1: 抓取
logger.info("Scheduler pipeline: crawl %s", today)
crawl_result = await crawl_daily(db, today)
logger.info(
"Scheduler pipeline: crawl done, found=%d new=%d",
crawl_result.get("found", 0),
crawl_result.get("new", 0),
)
# Step 2: 总结 pending 论文
logger.info("Scheduler pipeline: summarize batch")
summarize_result = await summarize_batch(db)
logger.info(
"Scheduler pipeline: summarize done, result=%s", summarize_result
)
# Step 3: 清理临时文件
logger.info("Scheduler pipeline: cleanup tmp")
cleanup_result = cleanup_tmp()
logger.info(
"Scheduler pipeline: cleanup done, removed=%d",
cleanup_result.get("removed", 0),
)
log_entry.status = "success"
except Exception as exc:
logger.exception("Scheduler pipeline failed for %s", today)
log_entry.status = "failed"
error_msg = str(exc)[:2000]
finally:
log_entry.completed_at = datetime.now(timezone.utc)
if error_msg:
log_entry.error = error_msg
db.commit()
# 释放锁
lock.status = "finished"
lock.released_at = datetime.now(timezone.utc)
db.commit()
await run_pipeline(db, today, owner="daily_pipeline")
except RuntimeError:
logger.warning("Daily pipeline already running for %s, skipping", today)
except Exception:
logger.exception("Unexpected error in daily pipeline")
finally:
+29 -32
View File
@@ -3,10 +3,10 @@
from __future__ import annotations
import json
from datetime import datetime, timezone
from pydantic import BaseModel, Field, ValidationError, field_validator
from app.utils import sanitize_html, utc_now
# ── 子模型 ──────────────────────────────────────────────────────────────
@@ -90,18 +90,6 @@ class SummarySchema(BaseModel):
# ── 质量评估 ────────────────────────────────────────────────────────────
# 必填字段:title_zh, one_line, tags, motivation.problem, method.key_idea
# — 缺失时 Pydantic 校验就会报错,不会走到 assess_quality
# 重要字段:motivation.goal, motivation.gap, method.overview, results.main_findings
# — 缺失可入库,标记 degraded
_OPTIONAL_BUT_IMPORTANT_FIELDS = [
"motivation.goal",
"motivation.gap",
"method.overview",
"results.main_findings",
]
def assess_quality(schema: SummarySchema) -> str:
"""评估总结质量:normal / degraded / low。"""
# low:内容空洞的启发式判断
@@ -128,31 +116,40 @@ def assess_quality(schema: SummarySchema) -> str:
def flatten_for_db(schema: SummarySchema) -> dict:
"""将 SummarySchema 展平为 paper_summaries 表的列值 dict。"""
"""将 SummarySchema 展平为 paper_summaries 表的列值 dict。
所有供前端用 |safe 渲染的文本字段均经过 HTML 清洗。
"""
# 清洗 prerequisites 嵌套文本
prereqs = schema.prerequisites.model_dump()
for c in prereqs.get("concepts", []):
if isinstance(c, dict):
for key in ("explanation", "why_matters"):
if key in c and c[key]:
c[key] = sanitize_html(c[key])
return {
"one_line": schema.one_line,
"one_line": sanitize_html(schema.one_line),
"difficulty": schema.difficulty,
"prerequisites_json": json.dumps(
schema.prerequisites.model_dump(), ensure_ascii=False
),
"motivation_problem": schema.motivation.problem,
"motivation_goal": schema.motivation.goal,
"motivation_gap": schema.motivation.gap,
"method_overview": schema.method.overview,
"method_key_idea": schema.method.key_idea,
"method_steps_json": schema.method.steps,
"method_novelty": schema.method.novelty,
"results_main_json": schema.results.main_findings,
"prerequisites_json": json.dumps(prereqs, ensure_ascii=False),
"motivation_problem": sanitize_html(schema.motivation.problem),
"motivation_goal": sanitize_html(schema.motivation.goal),
"motivation_gap": sanitize_html(schema.motivation.gap),
"method_overview": sanitize_html(schema.method.overview),
"method_key_idea": sanitize_html(schema.method.key_idea),
"method_steps_json": sanitize_html(schema.method.steps),
"method_novelty": sanitize_html(schema.method.novelty),
"results_main_json": sanitize_html(schema.results.main_findings),
"results_benchmarks_json": json.dumps(
schema.results.benchmarks, ensure_ascii=False
),
"limitations_json": schema.results.limitations,
"weaknesses_json": schema.improvements.weaknesses,
"future_work_json": schema.improvements.future_work,
"reproducibility": schema.improvements.reproducibility,
"limitations_json": sanitize_html(schema.results.limitations),
"weaknesses_json": sanitize_html(schema.improvements.weaknesses),
"future_work_json": sanitize_html(schema.improvements.future_work),
"reproducibility": sanitize_html(schema.improvements.reproducibility),
"figures_json": json.dumps(schema.figures, ensure_ascii=False),
"full_json": schema.model_dump_json(ensure_ascii=False),
"updated_at": datetime.now(timezone.utc),
"updated_at": utc_now(),
}
+16 -28
View File
@@ -6,11 +6,11 @@ import logging
import math
import re
from sqlalchemy import text
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import select, text
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Paper
from app.models import PAPER_FULL_LOAD, Paper
logger = logging.getLogger(__name__)
@@ -213,21 +213,15 @@ def _search_semantic(
arxiv_ids = [c["arxiv_id"] for c in candidates]
distance_map = {c["arxiv_id"]: c["distance"] for c in candidates}
papers_query = (
db.query(Paper)
.filter(Paper.arxiv_id.in_(arxiv_ids))
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
joinedload(Paper.bookmark),
joinedload(Paper.reading_status),
)
stmt = (
select(Paper)
.where(Paper.arxiv_id.in_(arxiv_ids))
.options(*PAPER_FULL_LOAD)
)
if tag:
papers_query = papers_query.filter(Paper.tags.any(tag=tag))
stmt = stmt.where(Paper.tags.any(tag=tag))
papers = papers_query.all()
papers = db.execute(stmt).unique().scalars().all()
# 按语义距离排序
id_order = {aid: idx for idx, aid in enumerate(arxiv_ids)}
@@ -257,11 +251,7 @@ def _search_tag_only(
offset: int,
) -> dict:
"""只有标签筛选,无关键词。"""
order = (
"p.paper_date DESC, p.upvotes DESC"
if sort == "date"
else "p.paper_date DESC, p.upvotes DESC"
)
order = "p.paper_date DESC, p.upvotes DESC"
rows_sql = text(f"""
SELECT p.id
@@ -307,15 +297,13 @@ def _load_papers_by_ids(
return []
papers = (
db.query(Paper)
.filter(Paper.id.in_(paper_ids))
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
joinedload(Paper.bookmark),
joinedload(Paper.reading_status),
db.execute(
select(Paper)
.where(Paper.id.in_(paper_ids))
.options(*PAPER_FULL_LOAD)
)
.unique()
.scalars()
.all()
)
+217 -225
View File
@@ -2,23 +2,24 @@
from __future__ import annotations
import asyncio
import json
import logging
import shutil
from datetime import datetime, timezone
from pathlib import Path
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.orm import Session
from app.config import settings
from app.database import SessionLocal
from app.models import (
PAPER_DEFAULT_LOAD,
CrawlLog,
Paper,
PaperSummary,
PaperTag,
SummaryState,
SummaryStatus,
TaskLock,
)
@@ -42,7 +43,7 @@ from app.services.schemas import (
classify_validation_error,
flatten_for_db,
)
from app.utils import PAPERS_DIR, release_lock
from app.utils import TMP_DIR, release_lock, utc_now
logger = logging.getLogger(__name__)
@@ -96,8 +97,6 @@ def _update_summary_in_db(
"""将校验后的总结写入 DBpaper_summaries + papers + paper_tags + FTS5。"""
from sqlalchemy import text
now = datetime.now(timezone.utc)
# 1. paper_summariesupsert
existing = db.get(PaperSummary, paper.id)
flat = flatten_for_db(schema)
@@ -213,21 +212,14 @@ def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
# ── 文件操作 ────────────────────────────────────────────────────────────
def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
"""保存 summary.json 和 raw_output.txt。"""
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
(d / "summary.json").write_text(
schema.model_dump_json(ensure_ascii=False, indent=2),
encoding="utf-8",
)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None:
"""仅保存 raw_output.txt(失败时)。"""
def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) -> None:
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
if schema:
(d / "summary.json").write_text(
schema.model_dump_json(ensure_ascii=False, indent=2),
encoding="utf-8",
)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
@@ -240,26 +232,25 @@ async def summarize_one(
semaphore: asyncio.Semaphore | None = None,
*,
force: bool = False,
pdf_mode: str = "auto",
) -> dict:
"""总结单篇论文的完整流程。"""
import asyncio
arxiv_id = paper.arxiv_id
# 获取或创建 summary_status
if not paper.summary_status:
db.add(SummaryStatus(paper_id=paper.id, status="pending"))
db.add(SummaryStatus(paper_id=paper.id, status=SummaryState.PENDING))
db.commit()
db.refresh(paper)
status = paper.summary_status
# 跳过已完成的(除非 force
if status.status == "done" and not force:
if status.status == SummaryState.DONE and not force:
return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"}
# 跳过 permanent_failure(除非 force
if status.status == "permanent_failure" and not force:
if status.status == SummaryState.PERMANENT_FAILURE and not force:
return {
"arxiv_id": arxiv_id,
"status": "skipped",
@@ -269,182 +260,202 @@ async def summarize_one(
if semaphore:
await semaphore.acquire()
try:
return await _do_summarize_one(db, paper)
return await _do_summarize_one(db, paper, pdf_mode=pdf_mode)
finally:
if semaphore:
semaphore.release()
async def _do_summarize_one(db: Session, paper: Paper) -> dict:
"""实际的单篇总结执行(在 semaphore 保护下)。"""
import asyncio
async def _generate_with_retry(
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
) -> tuple[dict, str]:
"""调用 pi CLI 生成总结,最多 4 轮验证循环。
Returns:
(json_data, raw_output)
Raises:
ValueError: 4 轮验证仍未通过
"""
validation_errors: list[str] = []
json_data: dict | None = None
raw_output = ""
session_id = None
for attempt in range(1, 5):
# 清理上一轮 pi 写的不完整文件
stale = paper_dir(arxiv_id) / "summary.json"
if stale.exists():
stale.unlink()
if attempt == 1:
raw_output, session_id = await call_pi(meta_path, pdf_path, pdf_mode=pdf_mode)
else:
raw_output, session_id = await call_pi(
meta_path, pdf_path,
fix_errors=validation_errors,
session_id=session_id,
pdf_mode=pdf_mode,
)
# 优先读取 pi 写入的 summary.json,否则从 stdout 提取
summary_file = paper_dir(arxiv_id) / "summary.json"
try:
if summary_file.exists():
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
logger.info("Read summary.json written by pi for %s", arxiv_id)
else:
json_data = extract_json(raw_output)
except (json.JSONDecodeError, JsonNotFoundError) as exc:
logger.warning(
"JSON extraction failed for %s (attempt %d): %s",
arxiv_id, attempt, str(exc)[:200],
)
validation_errors = [f"无法提取有效 JSON: {str(exc)[:100]}"]
continue
validation_errors = _validate_summary(json_data, arxiv_id)
if not validation_errors:
break
logger.warning(
"Validation failed for %s (attempt %d): %s",
arxiv_id, attempt, "; ".join(validation_errors),
)
if validation_errors:
exc = ValueError(
f"Summary validation failed after 4 attempts: {'; '.join(validation_errors)}"
)
exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用
raise exc
return json_data, raw_output
def _persist_summary(
db: Session, paper: Paper, json_data: dict, raw_output: str
) -> str:
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
schema = SummarySchema.model_validate(json_data)
quality = assess_quality(schema)
_save_files(paper.arxiv_id, schema, raw_output)
_update_summary_in_db(db, paper, schema, quality, raw_output)
# 状态 → done
paper.summary_status.status = SummaryState.DONE
paper.summary_status.quality = quality
paper.summary_status.completed_at = utc_now()
paper.summary_status.raw_output_saved = True
db.commit()
# 触发性增强(失败不影响总结)
_maybe_extract_images(paper.arxiv_id, schema)
_maybe_index_chroma(paper.arxiv_id, paper, schema)
return quality
def _handle_summary_failure(
db: Session, paper: Paper, exc: Exception, raw_output: str,
) -> dict:
"""记录失败:保存 raw_output、重试计数、错误分类。"""
error_type = _classify_error(exc)
logger.error(
"Summarize failed: %s error_type=%s %s",
paper.arxiv_id, error_type, str(exc)[:200],
)
arxiv_id = paper.arxiv_id
status = paper.summary_status
now = datetime.now(timezone.utc)
if raw_output:
_save_files(paper.arxiv_id, None, raw_output)
status.raw_output_saved = True
status.retry_count = (status.retry_count or 0) + 1
status.error_type = error_type
status.error = str(exc)[:2000]
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
status.status = SummaryState.PERMANENT_FAILURE
else:
status.status = SummaryState.PENDING
status.completed_at = utc_now()
db.commit()
return {
"arxiv_id": paper.arxiv_id,
"status": "failed",
"error_type": error_type,
"error": str(exc)[:200],
"retry_count": status.retry_count,
}
def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
"""从 PDF 提取图片和表格(失败不影响总结)。"""
try:
from app.services.pdf_image_extractor import (
extract_images_from_pdf,
filter_images_by_summary,
)
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
extract_images_from_pdf(arxiv_id, pdf_path)
if schema.figures:
filter_images_by_summary(arxiv_id, schema.figures)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None:
"""写入 ChromaDB 语义索引(失败不影响总结)。"""
try:
from app.services.embedder import index_paper
texts_dict = {
"arxiv_id": arxiv_id,
"title_zh": schema.title_zh or "",
"title_en": paper.title_en or "",
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
"one_line": schema.one_line or "",
"motivation_problem": schema.motivation.problem or "",
"method_key_idea": schema.method.key_idea or "",
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
}
index_paper(arxiv_id, texts_dict)
except Exception:
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
async def _do_summarize_one(
db: Session, paper: Paper, pdf_mode: str = "auto"
) -> dict:
"""实际的单篇总结执行(在 semaphore 保护下)。"""
arxiv_id = paper.arxiv_id
# 状态 → processing
status.status = "processing"
status.started_at = now
paper.summary_status.status = SummaryState.PROCESSING
paper.summary_status.started_at = utc_now()
db.commit()
raw_output = ""
try:
# 写 meta.json
meta_path = write_meta_json(paper)
# 下载 PDF
await download_pdf(arxiv_id, paper.pdf_url)
# 带验证的生成循环:最多 4 轮,同一 session 内 pi 可看到之前写的文件
json_data = None
validation_errors = []
session_id = None
for attempt in range(1, 5):
# 清理上一轮 pi 通过 write_file 写的不完整文件
stale = paper_dir(arxiv_id) / "summary.json"
if stale.exists():
stale.unlink()
json_data, raw_output = await _generate_with_retry(
arxiv_id, meta_path, TMP_DIR / arxiv_id / "paper.pdf",
pdf_mode=pdf_mode,
)
if attempt == 1:
raw_output, session_id = await call_pi(
meta_path, Path("data/tmp") / arxiv_id / "paper.pdf"
)
else:
# 验证失败,同一 session 内带着错误信息让 pi 修正
raw_output, session_id = await call_pi(
meta_path,
Path("data/tmp") / arxiv_id / "paper.pdf",
fix_errors=validation_errors,
session_id=session_id,
)
# 优先从 pi write_file 写入的 summary.json 读取,否则从 stdout 提取
# 如果都失败,当作验证错误,继续下一次尝试
json_data = None
summary_file = paper_dir(arxiv_id) / "summary.json"
try:
if summary_file.exists():
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
logger.info("Read summary.json written by pi for %s", arxiv_id)
else:
json_data = extract_json(raw_output)
except (json.JSONDecodeError, JsonNotFoundError) as exc:
logger.warning(
"JSON extraction failed for %s (attempt %d): %s",
arxiv_id,
attempt,
str(exc)[:200],
)
validation_errors = [f"无法提取有效 JSON: {str(exc)[:100]}"]
continue
# 运行验证脚本
validation_errors = _validate_summary(json_data, arxiv_id)
if not validation_errors:
break
logger.warning(
"Validation failed for %s (attempt %d): %s",
arxiv_id,
attempt,
"; ".join(validation_errors),
)
if validation_errors:
raise ValueError(
f"Summary validation failed after 4 attempts: {'; '.join(validation_errors)}"
)
# Pydantic 校验
schema = SummarySchema.model_validate(json_data)
# 质量评估
quality = assess_quality(schema)
# 保存文件
_save_files(arxiv_id, schema, raw_output)
# 更新 DB
_update_summary_in_db(db, paper, schema, quality, raw_output)
# 状态 → done
status.status = "done"
status.quality = quality
status.completed_at = datetime.now(timezone.utc)
status.raw_output_saved = True
db.commit()
# PDF 图片提取(可选增强,失败不影响总结)
try:
from app.services.pdf_image_extractor import (
extract_images_from_pdf,
filter_images_by_summary,
)
pdf_path = Path("data/tmp") / arxiv_id / "paper.pdf"
extract_images_from_pdf(arxiv_id, pdf_path)
# 根据 summary 中 figures 字段过滤,只保留被引用的图表
if schema.figures:
filter_images_by_summary(arxiv_id, schema.figures)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
# 同步写入语义索引(失败仅 log
try:
from app.services.embedder import index_paper
texts_dict = {
"arxiv_id": arxiv_id,
"title_zh": schema.title_zh or "",
"title_en": paper.title_en or "",
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
"one_line": schema.one_line or "",
"motivation_problem": schema.motivation.problem or "",
"method_key_idea": schema.method.key_idea or "",
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
}
index_paper(arxiv_id, texts_dict)
except Exception:
logger.warning(
"Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True
)
quality = _persist_summary(db, paper, json_data, raw_output)
logger.info("Summarize done: %s quality=%s", arxiv_id, quality)
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
except Exception as exc:
error_type = _classify_error(exc)
logger.error(
"Summarize failed: %s error_type=%s %s",
arxiv_id,
error_type,
str(exc)[:200],
)
# 保存 raw_output(如果有)
if raw_output:
_save_raw_output_only(arxiv_id, raw_output)
status.raw_output_saved = True
# 重试逻辑
status.retry_count = (status.retry_count or 0) + 1
status.error_type = error_type
status.error = str(exc)[:2000]
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
status.status = "permanent_failure"
else:
status.status = "pending"
status.completed_at = datetime.now(timezone.utc)
db.commit()
return {
"arxiv_id": arxiv_id,
"status": "failed",
"error_type": error_type,
"error": str(exc)[:200],
"retry_count": status.retry_count,
}
# 从异常对象获取 raw_output_generate_with_retry 失败时仍有输出)
fail_output = getattr(exc, "raw_output", raw_output)
return _handle_summary_failure(db, paper, exc, fail_output)
finally:
cleanup_tmp(arxiv_id)
@@ -458,22 +469,18 @@ async def summarize_single(
arxiv_id: str,
*,
force: bool = True,
pdf_mode: str = "auto",
_session_factory=None,
) -> dict:
"""单篇总结入口(供 admin 路由和 CLI 调用)。
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
paper = (
db.query(Paper)
.filter(Paper.arxiv_id == arxiv_id)
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
.first()
)
paper = db.execute(
select(Paper)
.where(Paper.arxiv_id == arxiv_id)
.options(*PAPER_DEFAULT_LOAD)
).unique().scalar_one_or_none()
if not paper:
return {"status": "not_found", "arxiv_id": arxiv_id}
@@ -482,17 +489,12 @@ async def summarize_single(
# 每篇用独立 session 避免并发问题
paper_db = make_session()
try:
paper_in_new_session = (
paper_db.query(Paper)
.filter(Paper.arxiv_id == arxiv_id)
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
.first()
)
result = await summarize_one(paper_db, paper_in_new_session, force=force)
paper_in_new_session = paper_db.execute(
select(Paper)
.where(Paper.arxiv_id == arxiv_id)
.options(*PAPER_DEFAULT_LOAD)
).unique().scalar_one_or_none()
result = await summarize_one(paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode)
finally:
paper_db.close()
@@ -506,15 +508,14 @@ async def summarize_batch(
db: Session,
arxiv_ids: list[str] | None = None,
*,
pdf_mode: str = "auto",
_session_factory=None,
) -> dict:
"""批量总结入口。arxiv_ids=None 时处理所有 pending 论文。
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
import asyncio
now = datetime.now(timezone.utc)
now = utc_now()
# TaskLock 防重入
lock = TaskLock(
@@ -543,20 +544,16 @@ async def summarize_batch(
try:
# 查询待总结论文
query = db.query(Paper).options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
stmt = select(Paper).options(*PAPER_DEFAULT_LOAD)
if arxiv_ids:
query = query.filter(Paper.arxiv_id.in_(arxiv_ids))
stmt = stmt.where(Paper.arxiv_id.in_(arxiv_ids))
else:
# 只处理 pending 或 failed(可重试的)
query = query.join(SummaryStatus).filter(
SummaryStatus.status.in_(["pending", "failed"])
stmt = stmt.join(SummaryStatus).where(
SummaryStatus.status.in_([SummaryState.PENDING, SummaryState.FAILED])
)
papers = query.all()
papers = db.execute(stmt).unique().scalars().all()
total = len(papers)
logger.info("Summarize batch: %d papers to process", total)
@@ -564,7 +561,7 @@ async def summarize_batch(
log_entry.status = "success"
log_entry.papers_found = 0
log_entry.papers_new = 0
log_entry.completed_at = datetime.now(timezone.utc)
log_entry.completed_at = utc_now()
release_lock(db, lock)
return {
"status": "success",
@@ -581,17 +578,12 @@ async def summarize_batch(
async def _process_paper(paper: Paper) -> dict:
paper_db = make_session()
try:
p = (
paper_db.query(Paper)
.filter(Paper.id == paper.id)
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
.first()
)
return await summarize_one(paper_db, p, semaphore)
p = paper_db.execute(
select(Paper)
.where(Paper.id == paper.id)
.options(*PAPER_DEFAULT_LOAD)
).unique().scalar_one_or_none()
return await summarize_one(paper_db, p, semaphore, pdf_mode=pdf_mode)
finally:
paper_db.close()
@@ -619,7 +611,7 @@ async def summarize_batch(
log_entry.status = "success" if failed == 0 else "failed"
log_entry.papers_found = total
log_entry.papers_new = done
log_entry.completed_at = datetime.now(timezone.utc)
log_entry.completed_at = utc_now()
db.commit()
logger.info(
@@ -641,7 +633,7 @@ async def summarize_batch(
logger.exception("Summarize batch failed")
log_entry.status = "failed"
log_entry.error = str(exc)[:2000]
log_entry.completed_at = datetime.now(timezone.utc)
log_entry.completed_at = utc_now()
db.commit()
return {"status": "failed", "error": str(exc)}
+34 -31
View File
@@ -2,23 +2,24 @@
from __future__ import annotations
from datetime import datetime, timezone
from sqlalchemy import or_
from sqlalchemy import or_, select
from sqlalchemy.orm import Session, joinedload
from app.models import Paper, PaperTag, UserBookmark, UserNote, UserReadingStatus
from app.models import PAPER_FULL_LOAD, Paper, PaperTag, UserBookmark, UserNote, UserReadingStatus
from app.utils import utc_now
# ── 收藏 ──────────────────────────────────────────────────────────────
def toggle_bookmark(db: Session, arxiv_id: str) -> dict:
"""切换收藏状态。返回 {"bookmarked": bool, "arxiv_id": str}。"""
paper = db.query(Paper).filter(Paper.arxiv_id == arxiv_id).first()
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
if not paper:
return {"error": "not_found"}
existing = db.query(UserBookmark).filter(UserBookmark.paper_id == paper.id).first()
existing = db.execute(
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
).scalar_one_or_none()
if existing:
db.delete(existing)
db.commit()
@@ -26,7 +27,7 @@ def toggle_bookmark(db: Session, arxiv_id: str) -> dict:
else:
bookmark = UserBookmark(
paper_id=paper.id,
created_at=datetime.now(timezone.utc),
created_at=utc_now(),
)
db.add(bookmark)
db.commit()
@@ -43,16 +44,14 @@ def set_reading_status(db: Session, arxiv_id: str, status: str) -> dict:
if status not in VALID_STATUSES:
return {"error": "invalid_status", "valid": sorted(VALID_STATUSES)}
paper = db.query(Paper).filter(Paper.arxiv_id == arxiv_id).first()
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
if not paper:
return {"error": "not_found"}
now = datetime.now(timezone.utc)
existing = (
db.query(UserReadingStatus)
.filter(UserReadingStatus.paper_id == paper.id)
.first()
)
now = utc_now()
existing = db.execute(
select(UserReadingStatus).where(UserReadingStatus.paper_id == paper.id)
).scalar_one_or_none()
if existing:
existing.status = status
existing.updated_at = now
@@ -73,11 +72,13 @@ def set_reading_status(db: Session, arxiv_id: str, status: str) -> dict:
def get_note(db: Session, arxiv_id: str) -> dict | None:
"""获取笔记。返回 {"arxiv_id", "content", "updated_at"} 或 None(论文不存在时)。"""
paper = db.query(Paper).filter(Paper.arxiv_id == arxiv_id).first()
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
if not paper:
return None
note = db.query(UserNote).filter(UserNote.paper_id == paper.id).first()
note = db.execute(
select(UserNote).where(UserNote.paper_id == paper.id)
).scalar_one_or_none()
if not note:
return {"arxiv_id": arxiv_id, "content": "", "updated_at": None}
@@ -90,12 +91,14 @@ def get_note(db: Session, arxiv_id: str) -> dict | None:
def save_note(db: Session, arxiv_id: str, content: str) -> dict:
"""创建或更新笔记。返回 {"arxiv_id", "content", "updated_at"}。"""
paper = db.query(Paper).filter(Paper.arxiv_id == arxiv_id).first()
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
if not paper:
return {"error": "not_found"}
now = datetime.now(timezone.utc)
existing = db.query(UserNote).filter(UserNote.paper_id == paper.id).first()
now = utc_now()
existing = db.execute(
select(UserNote).where(UserNote.paper_id == paper.id)
).scalar_one_or_none()
if existing:
existing.content = content
existing.updated_at = now
@@ -126,7 +129,7 @@ def query_reading_list(
) -> list[Paper]:
"""根据筛选条件查询阅读列表。"""
# 基础:有任意用户数据的论文
base = db.query(Paper).filter(
stmt = select(Paper).where(
or_(
Paper.bookmark.has(),
Paper.reading_status.has(),
@@ -136,25 +139,25 @@ def query_reading_list(
# 应用筛选
if filter_type == "has_note":
base = base.filter(Paper.note.has())
stmt = stmt.where(Paper.note.has())
elif filter_type in ("unread", "skimmed", "read_summary", "read_full"):
base = base.filter(
stmt = stmt.where(
Paper.reading_status.has(UserReadingStatus.status == filter_type)
)
# 应用标签
if tag:
base = base.filter(Paper.tags.any(PaperTag.tag == tag))
stmt = stmt.where(Paper.tags.any(PaperTag.tag == tag))
return (
base.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
joinedload(Paper.bookmark),
joinedload(Paper.reading_status),
joinedload(Paper.note),
db.execute(
stmt.options(
joinedload(Paper.note),
*PAPER_FULL_LOAD,
)
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
)
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
.unique()
.scalars()
.all()
)