feat: add admin dashboard, pipeline service, lightbox, and update dependencies
This commit is contained in:
@@ -0,0 +1,109 @@
|
||||
"""管理后台服务 — 统计聚合、系统状态。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import CrawlLog, Paper, SummaryState, TaskLock
|
||||
from app.services.scheduler import get_scheduler
|
||||
from app.utils import PAPERS_DIR, TMP_DIR
|
||||
|
||||
|
||||
def _dir_size(path: Path) -> int:
|
||||
"""递归计算目录总字节数。"""
|
||||
if not path.exists():
|
||||
return 0
|
||||
return sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
|
||||
|
||||
|
||||
def _fmt_size(nbytes: int) -> str:
|
||||
"""字节数 → 人类可读字符串。"""
|
||||
for unit in ("B", "KB", "MB", "GB"):
|
||||
if nbytes < 1024:
|
||||
return f"{nbytes:.1f} {unit}"
|
||||
nbytes /= 1024
|
||||
return f"{nbytes:.1f} TB"
|
||||
|
||||
|
||||
def get_admin_stats(db: Session) -> dict:
|
||||
"""管理仪表盘统计数据。"""
|
||||
today = date.today()
|
||||
|
||||
# ── 论文统计 ──────────────────────────────────────────────────────
|
||||
total_papers = db.scalar(select(func.count(Paper.id)))
|
||||
today_papers = db.scalar(
|
||||
select(func.count(Paper.id)).where(Paper.paper_date == today)
|
||||
)
|
||||
|
||||
# ── 总结状态分布 ──────────────────────────────────────────────────
|
||||
summary_rows = db.execute(
|
||||
text("""
|
||||
SELECT COALESCE(ss.status, 'none') AS status, COUNT(*) AS cnt
|
||||
FROM papers p
|
||||
LEFT JOIN summary_status ss ON ss.paper_id = p.id
|
||||
GROUP BY status
|
||||
""")
|
||||
).fetchall()
|
||||
status_counts = {row[0]: row[1] for row in summary_rows}
|
||||
|
||||
# ── 存储概况 ──────────────────────────────────────────────────────
|
||||
db_size = _fmt_size(settings.db_path.stat().st_size) if settings.db_path.exists() else "0 B"
|
||||
papers_size = _fmt_size(_dir_size(PAPERS_DIR))
|
||||
tmp_size = _fmt_size(_dir_size(TMP_DIR))
|
||||
|
||||
# ── 调度器状态 ────────────────────────────────────────────────────
|
||||
scheduler = get_scheduler()
|
||||
scheduler_enabled = scheduler is not None
|
||||
next_run = None
|
||||
if scheduler_enabled:
|
||||
for job in scheduler.get_jobs():
|
||||
if job.id == "daily_pipeline":
|
||||
next_run = job.next_run_time
|
||||
break
|
||||
|
||||
# ── 最近日志(5 条) ──────────────────────────────────────────────
|
||||
recent_logs = (
|
||||
db.execute(
|
||||
select(CrawlLog)
|
||||
.order_by(CrawlLog.started_at.desc())
|
||||
.limit(5)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
# ── 活跃锁 ────────────────────────────────────────────────────────
|
||||
active_locks = (
|
||||
db.execute(
|
||||
select(TaskLock).where(TaskLock.status == "running")
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"total_papers": total_papers or 0,
|
||||
"today_papers": today_papers or 0,
|
||||
"pending_count": status_counts.get(SummaryState.PENDING, 0),
|
||||
"failed_count": status_counts.get(SummaryState.FAILED, 0)
|
||||
+ status_counts.get(SummaryState.PERMANENT_FAILURE, 0),
|
||||
"done_count": status_counts.get(SummaryState.DONE, 0),
|
||||
"running_count": status_counts.get("running", 0)
|
||||
+ status_counts.get(SummaryState.PROCESSING, 0),
|
||||
"none_count": status_counts.get("none", 0),
|
||||
"status_counts": status_counts,
|
||||
"db_size": db_size,
|
||||
"papers_size": papers_size,
|
||||
"tmp_size": tmp_size,
|
||||
"scheduler_enabled": scheduler_enabled,
|
||||
"schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}",
|
||||
"timezone": settings.APP_TIMEZONE,
|
||||
"next_run": next_run.isoformat() if next_run else None,
|
||||
"recent_logs": recent_logs,
|
||||
"active_locks": active_locks,
|
||||
}
|
||||
+13
-9
@@ -2,21 +2,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
from datetime import date, datetime, timezone
|
||||
from pathlib import Path
|
||||
from datetime import date
|
||||
|
||||
from sqlalchemy import delete, select, text
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import (
|
||||
CrawlLog,
|
||||
DataDeleteJob,
|
||||
Paper,
|
||||
TaskLock,
|
||||
)
|
||||
from app.utils import PAPERS_DIR, TMP_DIR
|
||||
from app.utils import PAPERS_DIR, TMP_DIR, utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,7 +38,7 @@ def cleanup_tmp(max_age_hours: int = _MAX_TMP_AGE_HOURS) -> dict:
|
||||
if not TMP_DIR.exists():
|
||||
return {"scanned": 0, "removed": 0, "errors": []}
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
cutoff = now.timestamp() - (max_age_hours * 3600)
|
||||
scanned = 0
|
||||
removed = 0
|
||||
@@ -96,7 +95,7 @@ async def delete_papers_by_date_range(
|
||||
Returns:
|
||||
删除结果统计
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
|
||||
# 查询目标论文
|
||||
papers = (
|
||||
@@ -195,7 +194,7 @@ async def delete_papers_by_date_range(
|
||||
|
||||
job.status = job_status
|
||||
job.paper_count = deleted
|
||||
job.completed_at = datetime.now(timezone.utc)
|
||||
job.completed_at = utc_now()
|
||||
if job_error:
|
||||
job.error = job_error[:4000]
|
||||
db.commit()
|
||||
@@ -205,9 +204,14 @@ async def delete_papers_by_date_range(
|
||||
task="delete",
|
||||
status=job_status,
|
||||
started_at=now,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
completed_at=utc_now(),
|
||||
papers_found=total,
|
||||
papers_new=deleted,
|
||||
details_json=json.dumps({
|
||||
"total_before": total,
|
||||
"deleted": deleted,
|
||||
"failed": len(failed_items),
|
||||
}, ensure_ascii=False),
|
||||
error=job_error,
|
||||
)
|
||||
db.add(log_entry)
|
||||
|
||||
+10
-8
@@ -1,8 +1,7 @@
|
||||
"""爬虫服务 — 从 HuggingFace Daily Papers API 抓取论文元数据。"""
|
||||
|
||||
import logging
|
||||
from datetime import date as date_type
|
||||
from datetime import datetime, timezone
|
||||
from datetime import date as date_type, datetime, timezone
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select, text
|
||||
@@ -14,9 +13,10 @@ from app.models import (
|
||||
Paper,
|
||||
PaperAuthor,
|
||||
PaperTag,
|
||||
SummaryState,
|
||||
SummaryStatus,
|
||||
)
|
||||
from app.utils import make_http_client
|
||||
from app.utils import make_http_client, utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -131,15 +131,17 @@ def upsert_papers(db: Session, papers_raw: list[dict], paper_date: str) -> list[
|
||||
db.add(paper)
|
||||
db.flush()
|
||||
|
||||
seen_authors: set[str] = set()
|
||||
for idx, name in enumerate(meta["authors"]):
|
||||
if name:
|
||||
if name and name not in seen_authors:
|
||||
seen_authors.add(name)
|
||||
db.add(PaperAuthor(paper_id=paper.id, name=name, position=idx))
|
||||
|
||||
for tag_name in meta["tags"]:
|
||||
if tag_name:
|
||||
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="hf"))
|
||||
|
||||
db.add(SummaryStatus(paper_id=paper.id, status="pending"))
|
||||
db.add(SummaryStatus(paper_id=paper.id, status=SummaryState.PENDING))
|
||||
|
||||
authors_text = ", ".join(meta["authors"])
|
||||
tags_text = ", ".join(meta["tags"])
|
||||
@@ -172,7 +174,7 @@ def upsert_papers(db: Session, papers_raw: list[dict], paper_date: str) -> list[
|
||||
|
||||
async def crawl_daily(db: Session, target_date: str, top_n: int | None = None) -> dict:
|
||||
"""完整的抓取流程:获取 + 入库 + 写日志。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
log_entry = CrawlLog(
|
||||
task="crawl",
|
||||
status="running",
|
||||
@@ -188,7 +190,7 @@ async def crawl_daily(db: Session, target_date: str, top_n: int | None = None) -
|
||||
log_entry.status = "success"
|
||||
log_entry.papers_found = len(raw_papers)
|
||||
log_entry.papers_new = len(new_papers)
|
||||
log_entry.completed_at = datetime.now(timezone.utc)
|
||||
log_entry.completed_at = utc_now()
|
||||
db.commit()
|
||||
return {
|
||||
"found": len(raw_papers),
|
||||
@@ -200,6 +202,6 @@ async def crawl_daily(db: Session, target_date: str, top_n: int | None = None) -
|
||||
logger.exception("Crawl failed for %s", target_date)
|
||||
log_entry.status = "failed"
|
||||
log_entry.error = str(exc)
|
||||
log_entry.completed_at = datetime.now(timezone.utc)
|
||||
log_entry.completed_at = utc_now()
|
||||
db.commit()
|
||||
return {"found": 0, "new": 0, "status": "failed", "error": str(exc)}
|
||||
|
||||
@@ -5,7 +5,8 @@ from __future__ import annotations
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Paper
|
||||
@@ -188,12 +189,11 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
paper = (
|
||||
db.query(Paper)
|
||||
.filter(Paper.arxiv_id == paper_id)
|
||||
paper = db.execute(
|
||||
select(Paper)
|
||||
.where(Paper.arxiv_id == paper_id)
|
||||
.options(joinedload(Paper.tags), joinedload(Paper.summary))
|
||||
.first()
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not paper:
|
||||
logger.warning("Paper %s not found for indexing", paper_id)
|
||||
return False
|
||||
@@ -242,36 +242,6 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# ── 批量索引 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def index_batch(paper_ids: list[str]) -> dict:
|
||||
"""批量索引论文,单篇失败不影响其他。
|
||||
|
||||
Returns:
|
||||
{"total": int, "success": int, "failed": int}
|
||||
"""
|
||||
if not paper_ids:
|
||||
return {"total": 0, "success": 0, "failed": 0}
|
||||
|
||||
col = get_collection()
|
||||
if col is None:
|
||||
return {"total": len(paper_ids), "success": 0, "failed": len(paper_ids)}
|
||||
|
||||
success = 0
|
||||
failed = 0
|
||||
for pid in paper_ids:
|
||||
if index_paper(pid):
|
||||
success += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
logger.info(
|
||||
"Batch index: total=%d success=%d failed=%d", len(paper_ids), success, failed
|
||||
)
|
||||
return {"total": len(paper_ids), "success": success, "failed": failed}
|
||||
|
||||
|
||||
# ── 删除 ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""PDF 下载与源码下载 — 从 arXiv 下载论文 PDF 和 LaTeX 源码包。"""
|
||||
"""PDF 下载 — 从 arXiv 下载论文 PDF。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
from app.utils import PAPERS_DIR, TMP_DIR, make_http_client
|
||||
@@ -54,44 +53,6 @@ async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
|
||||
return dest
|
||||
|
||||
|
||||
# ── 源码下载 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def download_source_zip(arxiv_id: str, source_url: str, dest_dir: Path) -> None:
|
||||
"""下载 arXiv 源码并解压。"""
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
zip_path = tmp_dir(arxiv_id) / "source.zip"
|
||||
|
||||
try:
|
||||
async with make_http_client(follow_redirects=True) as client:
|
||||
resp = await client.get(source_url)
|
||||
resp.raise_for_status()
|
||||
zip_path.write_bytes(resp.content)
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
|
||||
return
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
zf.extractall(dest_dir)
|
||||
logger.debug("Extracted source for %s", arxiv_id)
|
||||
except zipfile.BadZipFile:
|
||||
# 可能是 tar.gz
|
||||
import tarfile
|
||||
|
||||
try:
|
||||
with tarfile.open(zip_path, "r:*") as tf:
|
||||
tf.extractall(dest_dir, filter="data")
|
||||
logger.debug("Extracted source (tar) for %s", arxiv_id)
|
||||
except Exception:
|
||||
logger.warning("Cannot extract source for %s", arxiv_id)
|
||||
except Exception:
|
||||
logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
|
||||
finally:
|
||||
if zip_path.exists():
|
||||
zip_path.unlink()
|
||||
|
||||
|
||||
# ── 临时文件清理 ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import re
|
||||
from pathlib import Path
|
||||
|
||||
from app.services.pdf_downloader import paper_dir
|
||||
from app.utils import TMP_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -40,10 +41,7 @@ def _find_nearby_labels(
|
||||
"""
|
||||
matched: list[str] = []
|
||||
for rect in rects:
|
||||
if isinstance(rect, (list, tuple)):
|
||||
y_min, y_max = rect[1], rect[3]
|
||||
else:
|
||||
y_min, y_max = rect.y0, rect.y1
|
||||
y_min, y_max = rect.y0, rect.y1
|
||||
|
||||
for label_key, positions in labels.items():
|
||||
for label_page, label_y in positions:
|
||||
@@ -69,7 +67,7 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
|
||||
import pymupdf
|
||||
|
||||
if pdf_path is None:
|
||||
pdf_path = Path("data/tmp") / arxiv_id / "paper.pdf"
|
||||
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
|
||||
|
||||
if not pdf_path.exists():
|
||||
logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path)
|
||||
@@ -162,10 +160,7 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
|
||||
continue
|
||||
|
||||
margin = 5
|
||||
if isinstance(bbox, (list, tuple)):
|
||||
x0, y0, x1, y1 = bbox
|
||||
else:
|
||||
x0, y0, x1, y1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
|
||||
x0, y0, x1, y1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
|
||||
clip_rect = pymupdf.Rect(x0 - margin, y0 - margin, x1 + margin, y1 + margin)
|
||||
|
||||
zoom = 2
|
||||
|
||||
+131
-68
@@ -62,26 +62,17 @@ def write_meta_json(paper) -> Path:
|
||||
# ── PDF 文本提取 ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _trim_body(text: str, max_chars: int = 80_000) -> str:
|
||||
def _trim_body(text: str, max_chars: int | None = None) -> str:
|
||||
"""去除参考文献,保留正文+附录,超长时从末尾截断。
|
||||
|
||||
策略:
|
||||
1. 去掉 References/Bibliography 段落(纯引用列表,对解读无用)
|
||||
2. 正文 + 附录全部保留
|
||||
3. 如果总长超过 max_chars,从末尾截断(附录靠后,优先保留正文)
|
||||
3. 如果指定了 max_chars 且总长超过,从末尾截断(附录靠后,优先保留正文)
|
||||
"""
|
||||
import re
|
||||
|
||||
# 找 References 段落的位置(在 Appendix 之后的那个)
|
||||
# 有些论文结构:正文 -> Appendix -> References
|
||||
# 也可能是:正文 -> References -> Appendix
|
||||
# 策略:只删除明确的 References 块
|
||||
ref_pattern = re.compile(
|
||||
r"(?m)^(?:References|Bibliography|参考文献)\s*$\n"
|
||||
r"(?s:.*?)" # References 内容
|
||||
r"(?=\n(?:A\s|Appendix|Supplementary|Acknowledgment|致谢)\s|\Z)",
|
||||
)
|
||||
|
||||
# 简单策略:找到 References 标题,如果后面没有 Appendix 就全删
|
||||
# 如果后面还有 Appendix,只删 References 到 Appendix 之间的内容
|
||||
ref_match = re.search(r"(?m)^(?:References|Bibliography|参考文献)\s*$", text)
|
||||
@@ -110,26 +101,30 @@ def _trim_body(text: str, max_chars: int = 80_000) -> str:
|
||||
else:
|
||||
text = text[:ack_match.start()].rstrip()
|
||||
|
||||
# 最后:如果还超长,从末尾截断(附录在后面,正文在前面,优先保留正文)
|
||||
if len(text) > max_chars:
|
||||
# 最后:如果指定了上限且超长,从末尾截断(附录在后面,正文在前面,优先保留正文)
|
||||
if max_chars is not None and len(text) > max_chars:
|
||||
text = text[:max_chars].rstrip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def extract_pdf_text(pdf_path: Path) -> Path:
|
||||
"""用 pymupdf 提取 PDF 正文文本(自动截断参考文献和附录),保存为 .txt。"""
|
||||
def extract_pdf_text(pdf_path: Path, max_chars: int | None = None) -> Path:
|
||||
"""用 pymupdf 提取 PDF 正文文本,保存为 .txt。
|
||||
|
||||
max_chars=None 时不截断,给 search/auto 模式保留完整内容。
|
||||
"""
|
||||
import pymupdf
|
||||
|
||||
txt_path = pdf_path.with_suffix(".txt")
|
||||
if txt_path.exists():
|
||||
# 缓存优先;如果需重新提取(不同 max_chars),先删旧文件
|
||||
return txt_path
|
||||
|
||||
doc = pymupdf.open(str(pdf_path))
|
||||
raw_text = "\n\n".join(page.get_text() for page in doc)
|
||||
doc.close()
|
||||
|
||||
body = _trim_body(raw_text)
|
||||
body = _trim_body(raw_text, max_chars=max_chars)
|
||||
txt_path.write_text(body, encoding="utf-8")
|
||||
logger.info(
|
||||
"Extracted PDF text: %s (%d -> %d chars, -%d%%)",
|
||||
@@ -141,6 +136,91 @@ def extract_pdf_text(pdf_path: Path) -> Path:
|
||||
return txt_path
|
||||
|
||||
|
||||
# ── Prompt 构建 ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _build_prompt(
|
||||
arxiv_id: str,
|
||||
meta_path: Path,
|
||||
txt_path: Path,
|
||||
pdf_mode: str,
|
||||
fix_errors: list[str] | None = None,
|
||||
) -> str:
|
||||
"""根据模式构建 pi prompt。
|
||||
|
||||
inject: 全量注入,prompt 末尾包含论文全文内容
|
||||
search: pi 自主 read 文件,prompt 只包含工作流指令
|
||||
"""
|
||||
json_schema = (
|
||||
"## 必须包含以下字段(不要自创字段名):\n"
|
||||
'{"arxiv_id": "...", '
|
||||
'"title_zh": "中文标题", '
|
||||
'"one_line": "一句话概括(≤50字)", '
|
||||
'"tags": ["标签1","标签2"], '
|
||||
'"difficulty": "入门/进阶/前沿", '
|
||||
'"prerequisites": {"concepts": [{"term":"术语","explanation":"详细解释这个概念是什么、怎么工作的(50-150字)","why_matters":"为什么读懂本文需要它"}]}, '
|
||||
'"motivation": {"problem": "详细段落:现有方法的具体问题(包含具体场景和数据)", '
|
||||
'"goal": "详细段落:本文的具体目标", '
|
||||
'"gap": "详细段落:本文的独特切入角度"}, '
|
||||
'"method": {"overview": "详细段落:方法整体思路(先直觉再技术路线)", '
|
||||
'"key_idea": "详细段落:核心创新点(和已有方法的本质区别)", '
|
||||
'"steps": "详细段落:方法步骤的完整描述(每步的输入输出和具体操作)", '
|
||||
'"novelty": "详细段落:技术新颖性分析"}, '
|
||||
'"results": {"main_findings": "详细段落:核心发现(带具体数字和指标,逐一分析每个实验)", '
|
||||
'"benchmarks": [{"task":"任务","metric":"指标","this_work":"本文结果","baseline":"基线","improvement":"提升"}], '
|
||||
'"limitations": "详细段落:局限性分析(作者承认的+你自己的观察")}, '
|
||||
'"improvements": {"weaknesses": "详细段落:独立分析的弱点(具体场景,每个弱点给改进方向)", '
|
||||
'"future_work": "详细段落:未来研究方向(作者提出的+基于成果可延伸的)", '
|
||||
'"reproducibility": "详细段落:复现评估(开源情况、数据、算力、难度")}, '
|
||||
'"figures": [{"id":"Figure 1","caption":"原图标题","description":"文字描述图展示了什么","reason":"为什么这张图对理解论文重要"},'
|
||||
'{"id":"Table 1","caption":"表格标题","description":"文字描述表格包含的数据和结论","reason":"为什么这个表格对理解论文重要"}]'
|
||||
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Table,id 严格使用 \"Figure N\" 或 \"Table N\" 格式。"
|
||||
"}"
|
||||
)
|
||||
|
||||
writing_requirements = (
|
||||
"## 写作要求\n"
|
||||
"- 每个字符串字段必须写成详细段落(200-500字),不要用列表或数组\n"
|
||||
"- 必须包含论文中的具体数据、数字、实验指标\n"
|
||||
"- 像资深同事给同事讲论文一样,专业但易懂\n"
|
||||
"- 数学公式、符号、变量必须使用 LaTeX 格式:行内公式用 $...$,独立公式用 $$...$$\n"
|
||||
" 例如:损失函数 $\\mathcal{L} = -\\sum_{i} \\log p(y_i | x_i)$,学习率 $\\eta$\n"
|
||||
)
|
||||
|
||||
if fix_errors:
|
||||
error_list = "\n".join(f"- {e}" for e in fix_errors)
|
||||
return (
|
||||
"你之前生成的 JSON 存在以下问题,请修正后重新用 write_file 保存到 "
|
||||
f"data/papers/{arxiv_id}/summary.json:\n\n"
|
||||
f"{error_list}\n\n"
|
||||
"注意:所有字符串字段必须是详细段落(≥50字),不能是数组或列表。"
|
||||
"修正后请用 bash 运行 python scripts/validate_summary.py 验证。"
|
||||
)
|
||||
|
||||
if pdf_mode == "search":
|
||||
return (
|
||||
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。\n\n"
|
||||
"## 工作流程\n"
|
||||
f"1. 先用 read 工具读取 {meta_path} 了解论文元信息(标题、作者、摘要)\n"
|
||||
f"2. 再用 read 工具阅读 {txt_path}(论文正文全文),可以多次读取定位关键段落\n"
|
||||
f"3. 充分理解后,用 write_file 将结果保存到 data/papers/{arxiv_id}/summary.json\n\n"
|
||||
+ writing_requirements
|
||||
+ "\n"
|
||||
+ json_schema
|
||||
)
|
||||
else:
|
||||
return (
|
||||
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。\n\n"
|
||||
"## 工作流程\n"
|
||||
"论文元信息和正文全文已在上文提供,请仔细阅读。\n"
|
||||
f"1. 充分理解论文后,用 write_file 将结果保存到 data/papers/{arxiv_id}/summary.json\n"
|
||||
"2. 用 bash 运行 python scripts/validate_summary.py 验证\n\n"
|
||||
+ writing_requirements
|
||||
+ "\n"
|
||||
+ json_schema
|
||||
)
|
||||
|
||||
|
||||
# ── pi CLI 调用 ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -149,63 +229,41 @@ async def call_pi(
|
||||
pdf_path: Path,
|
||||
fix_errors: list[str] | None = None,
|
||||
session_id: str | None = None,
|
||||
pdf_mode: str = "inject",
|
||||
) -> tuple[str, str]:
|
||||
"""调用 pi CLI 非交互模式,返回 (stdout 文本, session_id)。
|
||||
|
||||
fix_errors: 如果非空,表示上一次验证失败的错误列表,pi 需要修正这些问题。
|
||||
session_id: 如果非空,用 --continue 延续该 session;否则创建新 session。
|
||||
pdf_mode: "inject" = 全量注入 prompt(@file),"search" = pi 自主 read 文件。
|
||||
"""
|
||||
arxiv_id = meta_path.parent.name
|
||||
|
||||
# 将 PDF 转为文本文件,以 @txt 方式传给 pi
|
||||
txt_path = extract_pdf_text(pdf_path)
|
||||
# 提取 PDF 全文(不截断),根据实际大小自动选择模式
|
||||
txt_path = extract_pdf_text(pdf_path, max_chars=None)
|
||||
txt_size = len(txt_path.read_text(encoding="utf-8"))
|
||||
|
||||
if fix_errors:
|
||||
# 验证失败后的修正提示(同一 session 内,pi 能看到之前写的文件)
|
||||
error_list = "\n".join(f"- {e}" for e in fix_errors)
|
||||
prompt_text = (
|
||||
"你之前生成的 JSON 存在以下问题,请修正后重新用 write_file 保存到 "
|
||||
f"data/papers/{arxiv_id}/summary.json:\n\n"
|
||||
f"{error_list}\n\n"
|
||||
"注意:所有字符串字段必须是详细段落(≥50字),不能是数组或列表。"
|
||||
"修正后请用 bash 运行 python scripts/validate_summary.py 验证。"
|
||||
)
|
||||
else:
|
||||
prompt_text = (
|
||||
"请深度解读以下论文,严格按下面的 JSON schema 输出结果。"
|
||||
"只输出一个 JSON 对象,不要输出其他内容。\n\n"
|
||||
"## 写作要求\n"
|
||||
"- 每个字符串字段必须写成详细段落(200-500字),不要用列表或数组\n"
|
||||
"- 必须包含论文中的具体数据、数字、实验指标\n"
|
||||
"- 像资深同事给同事讲论文一样,专业但易懂\n"
|
||||
"- 数学公式、符号、变量必须使用 LaTeX 格式:行内公式用 $...$,独立公式用 $$...$$\n"
|
||||
" 例如:损失函数 $\\mathcal{L} = -\\sum_{i} \\log p(y_i | x_i)$,学习率 $\\eta$\n\n"
|
||||
"## 必须包含以下字段(不要自创字段名):\n"
|
||||
'{"arxiv_id": "...", '
|
||||
'"title_zh": "中文标题", '
|
||||
'"one_line": "一句话概括(≤50字)", '
|
||||
'"tags": ["标签1","标签2"], '
|
||||
'"difficulty": "入门/进阶/前沿", '
|
||||
'"prerequisites": {"concepts": [{"term":"术语","explanation":"详细解释这个概念是什么、怎么工作的(50-150字)","why_matters":"为什么读懂本文需要它"}]}, '
|
||||
'"motivation": {"problem": "详细段落:现有方法的具体问题(包含具体场景和数据)", '
|
||||
'"goal": "详细段落:本文的具体目标", '
|
||||
'"gap": "详细段落:本文的独特切入角度"}, '
|
||||
'"method": {"overview": "详细段落:方法整体思路(先直觉再技术路线)", '
|
||||
'"key_idea": "详细段落:核心创新点(和已有方法的本质区别)", '
|
||||
'"steps": "详细段落:方法步骤的完整描述(每步的输入输出和具体操作)", '
|
||||
'"novelty": "详细段落:技术新颖性分析"}, '
|
||||
'"results": {"main_findings": "详细段落:核心发现(带具体数字和指标,逐一分析每个实验)", '
|
||||
'"benchmarks": [{"task":"任务","metric":"指标","this_work":"本文结果","baseline":"基线","improvement":"提升"}], '
|
||||
'"limitations": "详细段落:局限性分析(作者承认的+你自己的观察)"}, '
|
||||
'"improvements": {"weaknesses": "详细段落:独立分析的弱点(具体场景,每个弱点给改进方向)", '
|
||||
'"future_work": "详细段落:未来研究方向(作者提出的+基于成果可延伸的)", '
|
||||
'"reproducibility": "详细段落:复现评估(开源情况、数据、算力、难度)"}, '
|
||||
'"figures": [{"id":"Figure 1","caption":"原图标题","description":"文字描述图展示了什么","reason":"为什么这张图对理解论文重要"},'
|
||||
'{"id":"Table 1","caption":"表格标题","description":"文字描述表格包含的数据和结论","reason":"为什么这个表格对理解论文重要"}]'
|
||||
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Table,id 严格使用 \"Figure N\" 或 \"Table N\" 格式。"
|
||||
"}\n\n"
|
||||
"请深度解读以下论文:"
|
||||
)
|
||||
actual_mode = pdf_mode
|
||||
if pdf_mode == "auto":
|
||||
if txt_size > 80_000:
|
||||
actual_mode = "search"
|
||||
logger.info(
|
||||
"Auto mode: %s text=%d chars > 80k → search", arxiv_id, txt_size
|
||||
)
|
||||
else:
|
||||
actual_mode = "inject"
|
||||
logger.info(
|
||||
"Auto mode: %s text=%d chars ≤ 80k → inject", arxiv_id, txt_size
|
||||
)
|
||||
|
||||
# inject 模式需要截断过长的文本(避免撑爆 context)
|
||||
if actual_mode == "inject" and txt_size > 80_000:
|
||||
body = txt_path.read_text(encoding="utf-8")
|
||||
trimmed = body[:80_000].rstrip()
|
||||
txt_path.write_text(trimmed, encoding="utf-8")
|
||||
logger.info("Truncated %s for inject: %d → %d chars", arxiv_id, txt_size, len(trimmed))
|
||||
|
||||
prompt_text = _build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
|
||||
|
||||
# 构建 session ID(每篇论文一个独立 session)
|
||||
if session_id is None:
|
||||
@@ -213,10 +271,12 @@ async def call_pi(
|
||||
|
||||
session_id = f"summary-{arxiv_id}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 工具列表:search 模式需要 read 工具
|
||||
tools = "bash,write_file" if actual_mode != "search" else "bash,write_file,read"
|
||||
cmd = [
|
||||
settings.PI_BIN,
|
||||
"-p",
|
||||
"--tools", "bash,write_file",
|
||||
"--tools", tools,
|
||||
]
|
||||
if fix_errors:
|
||||
cmd += ["--session", session_id, "--continue"]
|
||||
@@ -227,11 +287,14 @@ async def call_pi(
|
||||
settings.SUMMARY_SKILL,
|
||||
prompt_text,
|
||||
]
|
||||
if not fix_errors:
|
||||
# 首次调用传文件,后续 --continue 不需要(session 内已有)
|
||||
if not fix_errors and actual_mode != "search":
|
||||
# inject 模式:首次调用传 @file;search 模式 pi 自己 read,不注入
|
||||
cmd += [f"@{meta_path}", f"@{txt_path}"]
|
||||
|
||||
logger.info("Calling pi for %s (fix=%s, session=%s)", arxiv_id, bool(fix_errors), session_id)
|
||||
logger.info(
|
||||
"Calling pi for %s (fix=%s, session=%s, mode=%s)",
|
||||
arxiv_id, bool(fix_errors), session_id, actual_mode,
|
||||
)
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
"""流水线服务 — crawl → summarize → cleanup 的共享编排逻辑。
|
||||
|
||||
供 admin 手动触发和 scheduler 定时调度共用。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import date as date_type
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import CrawlLog, TaskLock
|
||||
from app.services.cleaner import cleanup_tmp
|
||||
from app.services.crawler import crawl_daily
|
||||
from app.services.summarizer import summarize_batch
|
||||
from app.utils import utc_now, yesterday_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
|
||||
"""执行完整流水线:crawl → summarize → cleanup。
|
||||
|
||||
使用 task_locks 防重入,写入 CrawlLog 记录。
|
||||
|
||||
Args:
|
||||
db: 数据库 session
|
||||
target_date: 目标日期 YYYY-MM-DD
|
||||
owner: 调用者标识(如 "admin_trigger" / "daily_pipeline")
|
||||
|
||||
Returns:
|
||||
{"status": "success"|"failed", "error": str|None, ...}
|
||||
"""
|
||||
now = utc_now()
|
||||
lock_key = f"pipeline-{target_date}"
|
||||
|
||||
# ── 获取锁 ──────────────────────────────────────────────────────────
|
||||
lock = TaskLock(
|
||||
task="scheduler",
|
||||
lock_key=lock_key,
|
||||
status="running",
|
||||
owner=owner,
|
||||
acquired_at=now,
|
||||
)
|
||||
try:
|
||||
db.add(lock)
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise RuntimeError(f"Pipeline already running for {target_date}")
|
||||
|
||||
# ── 写调度日志 ──────────────────────────────────────────────────────
|
||||
log_entry = CrawlLog(
|
||||
task="scheduler",
|
||||
status="running",
|
||||
date=date_type.fromisoformat(target_date),
|
||||
started_at=now,
|
||||
)
|
||||
db.add(log_entry)
|
||||
db.commit()
|
||||
|
||||
error_msg = None
|
||||
crawl_result: dict = {}
|
||||
try:
|
||||
# Step 1: 抓取(先试今天,无数据则回退昨天)
|
||||
crawl_result = await crawl_daily(db, target_date)
|
||||
logger.info("Pipeline [%s]: crawl %s, found=%d new=%d",
|
||||
owner, target_date,
|
||||
crawl_result.get("found", 0), crawl_result.get("new", 0))
|
||||
|
||||
if crawl_result.get("status") == "success" and crawl_result.get("found") == 0:
|
||||
yesterday = yesterday_str()
|
||||
logger.info("Pipeline [%s]: falling back to %s", owner, yesterday)
|
||||
crawl_result = await crawl_daily(db, yesterday)
|
||||
|
||||
# Step 2: 总结
|
||||
summarize_result = await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE)
|
||||
logger.info("Pipeline [%s]: summarize done, result=%s", owner, summarize_result)
|
||||
|
||||
# Step 3: 清理
|
||||
cleanup_result = cleanup_tmp()
|
||||
logger.info("Pipeline [%s]: cleanup done, removed=%d",
|
||||
owner, cleanup_result.get("removed", 0))
|
||||
|
||||
log_entry.status = "success"
|
||||
log_entry.papers_found = crawl_result.get("found", 0)
|
||||
log_entry.papers_new = crawl_result.get("new", 0)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("Pipeline [%s] failed", owner)
|
||||
log_entry.status = "failed"
|
||||
error_msg = str(exc)[:2000]
|
||||
|
||||
finally:
|
||||
log_entry.completed_at = utc_now()
|
||||
if error_msg:
|
||||
log_entry.error = error_msg
|
||||
db.commit()
|
||||
|
||||
lock.status = "finished"
|
||||
lock.released_at = utc_now()
|
||||
db.commit()
|
||||
|
||||
if error_msg:
|
||||
return {"status": "failed", "error": error_msg}
|
||||
return {"status": "success", "message": "Pipeline completed"}
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
@@ -12,10 +11,8 @@ from zoneinfo import ZoneInfo
|
||||
|
||||
from app.config import settings
|
||||
from app.database import SessionLocal
|
||||
from app.models import CrawlLog, TaskLock
|
||||
from app.services.cleaner import cleanup_tmp
|
||||
from app.services.crawler import crawl_daily
|
||||
from app.services.summarizer import summarize_batch
|
||||
from app.services.pipeline import run_pipeline
|
||||
from app.utils import today_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -92,85 +89,15 @@ def stop_scheduler() -> None:
|
||||
async def _daily_pipeline() -> None:
|
||||
"""每日流水线:抓取 → 总结 → 清理。
|
||||
|
||||
使用 task_locks 表防止重入:同一天的 pipeline 任务只有一个能运行。
|
||||
委托给 pipeline.run_pipeline 执行,使用 task_locks 防重入。
|
||||
"""
|
||||
tz = ZoneInfo(settings.APP_TIMEZONE)
|
||||
today = datetime.now(tz).strftime("%Y-%m-%d")
|
||||
now = datetime.now(timezone.utc)
|
||||
lock_key = f"pipeline-{today}"
|
||||
today = today_str()
|
||||
|
||||
db: Session = SessionLocal()
|
||||
try:
|
||||
# 尝试获取锁
|
||||
lock = TaskLock(
|
||||
task="scheduler",
|
||||
lock_key=lock_key,
|
||||
status="running",
|
||||
owner="daily_pipeline",
|
||||
acquired_at=now,
|
||||
)
|
||||
try:
|
||||
db.add(lock)
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
logger.warning("Daily pipeline already running for %s, skipping", today)
|
||||
return
|
||||
|
||||
# 写调度日志
|
||||
log_entry = CrawlLog(
|
||||
task="scheduler",
|
||||
status="running",
|
||||
date=datetime.now(tz).date(),
|
||||
started_at=now,
|
||||
)
|
||||
db.add(log_entry)
|
||||
db.commit()
|
||||
|
||||
error_msg = None
|
||||
try:
|
||||
# Step 1: 抓取
|
||||
logger.info("Scheduler pipeline: crawl %s", today)
|
||||
crawl_result = await crawl_daily(db, today)
|
||||
logger.info(
|
||||
"Scheduler pipeline: crawl done, found=%d new=%d",
|
||||
crawl_result.get("found", 0),
|
||||
crawl_result.get("new", 0),
|
||||
)
|
||||
|
||||
# Step 2: 总结 pending 论文
|
||||
logger.info("Scheduler pipeline: summarize batch")
|
||||
summarize_result = await summarize_batch(db)
|
||||
logger.info(
|
||||
"Scheduler pipeline: summarize done, result=%s", summarize_result
|
||||
)
|
||||
|
||||
# Step 3: 清理临时文件
|
||||
logger.info("Scheduler pipeline: cleanup tmp")
|
||||
cleanup_result = cleanup_tmp()
|
||||
logger.info(
|
||||
"Scheduler pipeline: cleanup done, removed=%d",
|
||||
cleanup_result.get("removed", 0),
|
||||
)
|
||||
|
||||
log_entry.status = "success"
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("Scheduler pipeline failed for %s", today)
|
||||
log_entry.status = "failed"
|
||||
error_msg = str(exc)[:2000]
|
||||
|
||||
finally:
|
||||
log_entry.completed_at = datetime.now(timezone.utc)
|
||||
if error_msg:
|
||||
log_entry.error = error_msg
|
||||
db.commit()
|
||||
|
||||
# 释放锁
|
||||
lock.status = "finished"
|
||||
lock.released_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
await run_pipeline(db, today, owner="daily_pipeline")
|
||||
except RuntimeError:
|
||||
logger.warning("Daily pipeline already running for %s, skipping", today)
|
||||
except Exception:
|
||||
logger.exception("Unexpected error in daily pipeline")
|
||||
finally:
|
||||
|
||||
+29
-32
@@ -3,10 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
|
||||
from app.utils import sanitize_html, utc_now
|
||||
|
||||
|
||||
# ── 子模型 ──────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -90,18 +90,6 @@ class SummarySchema(BaseModel):
|
||||
|
||||
# ── 质量评估 ────────────────────────────────────────────────────────────
|
||||
|
||||
# 必填字段:title_zh, one_line, tags, motivation.problem, method.key_idea
|
||||
# — 缺失时 Pydantic 校验就会报错,不会走到 assess_quality
|
||||
# 重要字段:motivation.goal, motivation.gap, method.overview, results.main_findings
|
||||
# — 缺失可入库,标记 degraded
|
||||
_OPTIONAL_BUT_IMPORTANT_FIELDS = [
|
||||
"motivation.goal",
|
||||
"motivation.gap",
|
||||
"method.overview",
|
||||
"results.main_findings",
|
||||
]
|
||||
|
||||
|
||||
def assess_quality(schema: SummarySchema) -> str:
|
||||
"""评估总结质量:normal / degraded / low。"""
|
||||
# low:内容空洞的启发式判断
|
||||
@@ -128,31 +116,40 @@ def assess_quality(schema: SummarySchema) -> str:
|
||||
|
||||
|
||||
def flatten_for_db(schema: SummarySchema) -> dict:
|
||||
"""将 SummarySchema 展平为 paper_summaries 表的列值 dict。"""
|
||||
"""将 SummarySchema 展平为 paper_summaries 表的列值 dict。
|
||||
|
||||
所有供前端用 |safe 渲染的文本字段均经过 HTML 清洗。
|
||||
"""
|
||||
# 清洗 prerequisites 嵌套文本
|
||||
prereqs = schema.prerequisites.model_dump()
|
||||
for c in prereqs.get("concepts", []):
|
||||
if isinstance(c, dict):
|
||||
for key in ("explanation", "why_matters"):
|
||||
if key in c and c[key]:
|
||||
c[key] = sanitize_html(c[key])
|
||||
|
||||
return {
|
||||
"one_line": schema.one_line,
|
||||
"one_line": sanitize_html(schema.one_line),
|
||||
"difficulty": schema.difficulty,
|
||||
"prerequisites_json": json.dumps(
|
||||
schema.prerequisites.model_dump(), ensure_ascii=False
|
||||
),
|
||||
"motivation_problem": schema.motivation.problem,
|
||||
"motivation_goal": schema.motivation.goal,
|
||||
"motivation_gap": schema.motivation.gap,
|
||||
"method_overview": schema.method.overview,
|
||||
"method_key_idea": schema.method.key_idea,
|
||||
"method_steps_json": schema.method.steps,
|
||||
"method_novelty": schema.method.novelty,
|
||||
"results_main_json": schema.results.main_findings,
|
||||
"prerequisites_json": json.dumps(prereqs, ensure_ascii=False),
|
||||
"motivation_problem": sanitize_html(schema.motivation.problem),
|
||||
"motivation_goal": sanitize_html(schema.motivation.goal),
|
||||
"motivation_gap": sanitize_html(schema.motivation.gap),
|
||||
"method_overview": sanitize_html(schema.method.overview),
|
||||
"method_key_idea": sanitize_html(schema.method.key_idea),
|
||||
"method_steps_json": sanitize_html(schema.method.steps),
|
||||
"method_novelty": sanitize_html(schema.method.novelty),
|
||||
"results_main_json": sanitize_html(schema.results.main_findings),
|
||||
"results_benchmarks_json": json.dumps(
|
||||
schema.results.benchmarks, ensure_ascii=False
|
||||
),
|
||||
"limitations_json": schema.results.limitations,
|
||||
"weaknesses_json": schema.improvements.weaknesses,
|
||||
"future_work_json": schema.improvements.future_work,
|
||||
"reproducibility": schema.improvements.reproducibility,
|
||||
"limitations_json": sanitize_html(schema.results.limitations),
|
||||
"weaknesses_json": sanitize_html(schema.improvements.weaknesses),
|
||||
"future_work_json": sanitize_html(schema.improvements.future_work),
|
||||
"reproducibility": sanitize_html(schema.improvements.reproducibility),
|
||||
"figures_json": json.dumps(schema.figures, ensure_ascii=False),
|
||||
"full_json": schema.model_dump_json(ensure_ascii=False),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
"updated_at": utc_now(),
|
||||
}
|
||||
|
||||
|
||||
|
||||
+16
-28
@@ -6,11 +6,11 @@ import logging
|
||||
import math
|
||||
import re
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Paper
|
||||
from app.models import PAPER_FULL_LOAD, Paper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -213,21 +213,15 @@ def _search_semantic(
|
||||
arxiv_ids = [c["arxiv_id"] for c in candidates]
|
||||
distance_map = {c["arxiv_id"]: c["distance"] for c in candidates}
|
||||
|
||||
papers_query = (
|
||||
db.query(Paper)
|
||||
.filter(Paper.arxiv_id.in_(arxiv_ids))
|
||||
.options(
|
||||
joinedload(Paper.authors),
|
||||
joinedload(Paper.tags),
|
||||
joinedload(Paper.summary_status),
|
||||
joinedload(Paper.bookmark),
|
||||
joinedload(Paper.reading_status),
|
||||
)
|
||||
stmt = (
|
||||
select(Paper)
|
||||
.where(Paper.arxiv_id.in_(arxiv_ids))
|
||||
.options(*PAPER_FULL_LOAD)
|
||||
)
|
||||
if tag:
|
||||
papers_query = papers_query.filter(Paper.tags.any(tag=tag))
|
||||
stmt = stmt.where(Paper.tags.any(tag=tag))
|
||||
|
||||
papers = papers_query.all()
|
||||
papers = db.execute(stmt).unique().scalars().all()
|
||||
|
||||
# 按语义距离排序
|
||||
id_order = {aid: idx for idx, aid in enumerate(arxiv_ids)}
|
||||
@@ -257,11 +251,7 @@ def _search_tag_only(
|
||||
offset: int,
|
||||
) -> dict:
|
||||
"""只有标签筛选,无关键词。"""
|
||||
order = (
|
||||
"p.paper_date DESC, p.upvotes DESC"
|
||||
if sort == "date"
|
||||
else "p.paper_date DESC, p.upvotes DESC"
|
||||
)
|
||||
order = "p.paper_date DESC, p.upvotes DESC"
|
||||
|
||||
rows_sql = text(f"""
|
||||
SELECT p.id
|
||||
@@ -307,15 +297,13 @@ def _load_papers_by_ids(
|
||||
return []
|
||||
|
||||
papers = (
|
||||
db.query(Paper)
|
||||
.filter(Paper.id.in_(paper_ids))
|
||||
.options(
|
||||
joinedload(Paper.authors),
|
||||
joinedload(Paper.tags),
|
||||
joinedload(Paper.summary_status),
|
||||
joinedload(Paper.bookmark),
|
||||
joinedload(Paper.reading_status),
|
||||
db.execute(
|
||||
select(Paper)
|
||||
.where(Paper.id.in_(paper_ids))
|
||||
.options(*PAPER_FULL_LOAD)
|
||||
)
|
||||
.unique()
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
+217
-225
@@ -2,23 +2,24 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.database import SessionLocal
|
||||
from app.models import (
|
||||
PAPER_DEFAULT_LOAD,
|
||||
CrawlLog,
|
||||
Paper,
|
||||
PaperSummary,
|
||||
PaperTag,
|
||||
SummaryState,
|
||||
SummaryStatus,
|
||||
TaskLock,
|
||||
)
|
||||
@@ -42,7 +43,7 @@ from app.services.schemas import (
|
||||
classify_validation_error,
|
||||
flatten_for_db,
|
||||
)
|
||||
from app.utils import PAPERS_DIR, release_lock
|
||||
from app.utils import TMP_DIR, release_lock, utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -96,8 +97,6 @@ def _update_summary_in_db(
|
||||
"""将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。"""
|
||||
from sqlalchemy import text
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# 1. paper_summaries:upsert
|
||||
existing = db.get(PaperSummary, paper.id)
|
||||
flat = flatten_for_db(schema)
|
||||
@@ -213,21 +212,14 @@ def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
|
||||
# ── 文件操作 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
|
||||
"""保存 summary.json 和 raw_output.txt。"""
|
||||
d = paper_dir(arxiv_id)
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
(d / "summary.json").write_text(
|
||||
schema.model_dump_json(ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
|
||||
|
||||
|
||||
def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None:
|
||||
"""仅保存 raw_output.txt(失败时)。"""
|
||||
def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) -> None:
|
||||
d = paper_dir(arxiv_id)
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
if schema:
|
||||
(d / "summary.json").write_text(
|
||||
schema.model_dump_json(ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
|
||||
|
||||
|
||||
@@ -240,26 +232,25 @@ async def summarize_one(
|
||||
semaphore: asyncio.Semaphore | None = None,
|
||||
*,
|
||||
force: bool = False,
|
||||
pdf_mode: str = "auto",
|
||||
) -> dict:
|
||||
"""总结单篇论文的完整流程。"""
|
||||
import asyncio
|
||||
|
||||
arxiv_id = paper.arxiv_id
|
||||
|
||||
# 获取或创建 summary_status
|
||||
if not paper.summary_status:
|
||||
db.add(SummaryStatus(paper_id=paper.id, status="pending"))
|
||||
db.add(SummaryStatus(paper_id=paper.id, status=SummaryState.PENDING))
|
||||
db.commit()
|
||||
db.refresh(paper)
|
||||
|
||||
status = paper.summary_status
|
||||
|
||||
# 跳过已完成的(除非 force)
|
||||
if status.status == "done" and not force:
|
||||
if status.status == SummaryState.DONE and not force:
|
||||
return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"}
|
||||
|
||||
# 跳过 permanent_failure(除非 force)
|
||||
if status.status == "permanent_failure" and not force:
|
||||
if status.status == SummaryState.PERMANENT_FAILURE and not force:
|
||||
return {
|
||||
"arxiv_id": arxiv_id,
|
||||
"status": "skipped",
|
||||
@@ -269,182 +260,202 @@ async def summarize_one(
|
||||
if semaphore:
|
||||
await semaphore.acquire()
|
||||
try:
|
||||
return await _do_summarize_one(db, paper)
|
||||
return await _do_summarize_one(db, paper, pdf_mode=pdf_mode)
|
||||
finally:
|
||||
if semaphore:
|
||||
semaphore.release()
|
||||
|
||||
|
||||
async def _do_summarize_one(db: Session, paper: Paper) -> dict:
|
||||
"""实际的单篇总结执行(在 semaphore 保护下)。"""
|
||||
import asyncio
|
||||
async def _generate_with_retry(
|
||||
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
|
||||
) -> tuple[dict, str]:
|
||||
"""调用 pi CLI 生成总结,最多 4 轮验证循环。
|
||||
|
||||
Returns:
|
||||
(json_data, raw_output)
|
||||
Raises:
|
||||
ValueError: 4 轮验证仍未通过
|
||||
"""
|
||||
validation_errors: list[str] = []
|
||||
json_data: dict | None = None
|
||||
raw_output = ""
|
||||
session_id = None
|
||||
|
||||
for attempt in range(1, 5):
|
||||
# 清理上一轮 pi 写的不完整文件
|
||||
stale = paper_dir(arxiv_id) / "summary.json"
|
||||
if stale.exists():
|
||||
stale.unlink()
|
||||
|
||||
if attempt == 1:
|
||||
raw_output, session_id = await call_pi(meta_path, pdf_path, pdf_mode=pdf_mode)
|
||||
else:
|
||||
raw_output, session_id = await call_pi(
|
||||
meta_path, pdf_path,
|
||||
fix_errors=validation_errors,
|
||||
session_id=session_id,
|
||||
pdf_mode=pdf_mode,
|
||||
)
|
||||
|
||||
# 优先读取 pi 写入的 summary.json,否则从 stdout 提取
|
||||
summary_file = paper_dir(arxiv_id) / "summary.json"
|
||||
try:
|
||||
if summary_file.exists():
|
||||
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
|
||||
logger.info("Read summary.json written by pi for %s", arxiv_id)
|
||||
else:
|
||||
json_data = extract_json(raw_output)
|
||||
except (json.JSONDecodeError, JsonNotFoundError) as exc:
|
||||
logger.warning(
|
||||
"JSON extraction failed for %s (attempt %d): %s",
|
||||
arxiv_id, attempt, str(exc)[:200],
|
||||
)
|
||||
validation_errors = [f"无法提取有效 JSON: {str(exc)[:100]}"]
|
||||
continue
|
||||
|
||||
validation_errors = _validate_summary(json_data, arxiv_id)
|
||||
if not validation_errors:
|
||||
break
|
||||
logger.warning(
|
||||
"Validation failed for %s (attempt %d): %s",
|
||||
arxiv_id, attempt, "; ".join(validation_errors),
|
||||
)
|
||||
|
||||
if validation_errors:
|
||||
exc = ValueError(
|
||||
f"Summary validation failed after 4 attempts: {'; '.join(validation_errors)}"
|
||||
)
|
||||
exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用
|
||||
raise exc
|
||||
|
||||
return json_data, raw_output
|
||||
|
||||
|
||||
def _persist_summary(
|
||||
db: Session, paper: Paper, json_data: dict, raw_output: str
|
||||
) -> str:
|
||||
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
|
||||
schema = SummarySchema.model_validate(json_data)
|
||||
quality = assess_quality(schema)
|
||||
|
||||
_save_files(paper.arxiv_id, schema, raw_output)
|
||||
_update_summary_in_db(db, paper, schema, quality, raw_output)
|
||||
|
||||
# 状态 → done
|
||||
paper.summary_status.status = SummaryState.DONE
|
||||
paper.summary_status.quality = quality
|
||||
paper.summary_status.completed_at = utc_now()
|
||||
paper.summary_status.raw_output_saved = True
|
||||
db.commit()
|
||||
|
||||
# 触发性增强(失败不影响总结)
|
||||
_maybe_extract_images(paper.arxiv_id, schema)
|
||||
_maybe_index_chroma(paper.arxiv_id, paper, schema)
|
||||
|
||||
return quality
|
||||
|
||||
|
||||
def _handle_summary_failure(
|
||||
db: Session, paper: Paper, exc: Exception, raw_output: str,
|
||||
) -> dict:
|
||||
"""记录失败:保存 raw_output、重试计数、错误分类。"""
|
||||
error_type = _classify_error(exc)
|
||||
logger.error(
|
||||
"Summarize failed: %s error_type=%s %s",
|
||||
paper.arxiv_id, error_type, str(exc)[:200],
|
||||
)
|
||||
|
||||
arxiv_id = paper.arxiv_id
|
||||
status = paper.summary_status
|
||||
now = datetime.now(timezone.utc)
|
||||
if raw_output:
|
||||
_save_files(paper.arxiv_id, None, raw_output)
|
||||
status.raw_output_saved = True
|
||||
|
||||
status.retry_count = (status.retry_count or 0) + 1
|
||||
status.error_type = error_type
|
||||
status.error = str(exc)[:2000]
|
||||
|
||||
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
|
||||
status.status = SummaryState.PERMANENT_FAILURE
|
||||
else:
|
||||
status.status = SummaryState.PENDING
|
||||
|
||||
status.completed_at = utc_now()
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"arxiv_id": paper.arxiv_id,
|
||||
"status": "failed",
|
||||
"error_type": error_type,
|
||||
"error": str(exc)[:200],
|
||||
"retry_count": status.retry_count,
|
||||
}
|
||||
|
||||
|
||||
def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
|
||||
"""从 PDF 提取图片和表格(失败不影响总结)。"""
|
||||
try:
|
||||
from app.services.pdf_image_extractor import (
|
||||
extract_images_from_pdf,
|
||||
filter_images_by_summary,
|
||||
)
|
||||
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
|
||||
extract_images_from_pdf(arxiv_id, pdf_path)
|
||||
if schema.figures:
|
||||
filter_images_by_summary(arxiv_id, schema.figures)
|
||||
except Exception:
|
||||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||||
|
||||
|
||||
def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None:
|
||||
"""写入 ChromaDB 语义索引(失败不影响总结)。"""
|
||||
try:
|
||||
from app.services.embedder import index_paper
|
||||
|
||||
texts_dict = {
|
||||
"arxiv_id": arxiv_id,
|
||||
"title_zh": schema.title_zh or "",
|
||||
"title_en": paper.title_en or "",
|
||||
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
|
||||
"one_line": schema.one_line or "",
|
||||
"motivation_problem": schema.motivation.problem or "",
|
||||
"method_key_idea": schema.method.key_idea or "",
|
||||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
|
||||
}
|
||||
index_paper(arxiv_id, texts_dict)
|
||||
except Exception:
|
||||
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
|
||||
|
||||
|
||||
async def _do_summarize_one(
|
||||
db: Session, paper: Paper, pdf_mode: str = "auto"
|
||||
) -> dict:
|
||||
"""实际的单篇总结执行(在 semaphore 保护下)。"""
|
||||
arxiv_id = paper.arxiv_id
|
||||
|
||||
# 状态 → processing
|
||||
status.status = "processing"
|
||||
status.started_at = now
|
||||
paper.summary_status.status = SummaryState.PROCESSING
|
||||
paper.summary_status.started_at = utc_now()
|
||||
db.commit()
|
||||
|
||||
raw_output = ""
|
||||
try:
|
||||
# 写 meta.json
|
||||
meta_path = write_meta_json(paper)
|
||||
|
||||
# 下载 PDF
|
||||
await download_pdf(arxiv_id, paper.pdf_url)
|
||||
|
||||
# 带验证的生成循环:最多 4 轮,同一 session 内 pi 可看到之前写的文件
|
||||
json_data = None
|
||||
validation_errors = []
|
||||
session_id = None
|
||||
for attempt in range(1, 5):
|
||||
# 清理上一轮 pi 通过 write_file 写的不完整文件
|
||||
stale = paper_dir(arxiv_id) / "summary.json"
|
||||
if stale.exists():
|
||||
stale.unlink()
|
||||
json_data, raw_output = await _generate_with_retry(
|
||||
arxiv_id, meta_path, TMP_DIR / arxiv_id / "paper.pdf",
|
||||
pdf_mode=pdf_mode,
|
||||
)
|
||||
|
||||
if attempt == 1:
|
||||
raw_output, session_id = await call_pi(
|
||||
meta_path, Path("data/tmp") / arxiv_id / "paper.pdf"
|
||||
)
|
||||
else:
|
||||
# 验证失败,同一 session 内带着错误信息让 pi 修正
|
||||
raw_output, session_id = await call_pi(
|
||||
meta_path,
|
||||
Path("data/tmp") / arxiv_id / "paper.pdf",
|
||||
fix_errors=validation_errors,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# 优先从 pi write_file 写入的 summary.json 读取,否则从 stdout 提取
|
||||
# 如果都失败,当作验证错误,继续下一次尝试
|
||||
json_data = None
|
||||
summary_file = paper_dir(arxiv_id) / "summary.json"
|
||||
try:
|
||||
if summary_file.exists():
|
||||
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
|
||||
logger.info("Read summary.json written by pi for %s", arxiv_id)
|
||||
else:
|
||||
json_data = extract_json(raw_output)
|
||||
except (json.JSONDecodeError, JsonNotFoundError) as exc:
|
||||
logger.warning(
|
||||
"JSON extraction failed for %s (attempt %d): %s",
|
||||
arxiv_id,
|
||||
attempt,
|
||||
str(exc)[:200],
|
||||
)
|
||||
validation_errors = [f"无法提取有效 JSON: {str(exc)[:100]}"]
|
||||
continue
|
||||
|
||||
# 运行验证脚本
|
||||
validation_errors = _validate_summary(json_data, arxiv_id)
|
||||
if not validation_errors:
|
||||
break
|
||||
logger.warning(
|
||||
"Validation failed for %s (attempt %d): %s",
|
||||
arxiv_id,
|
||||
attempt,
|
||||
"; ".join(validation_errors),
|
||||
)
|
||||
|
||||
if validation_errors:
|
||||
raise ValueError(
|
||||
f"Summary validation failed after 4 attempts: {'; '.join(validation_errors)}"
|
||||
)
|
||||
|
||||
# Pydantic 校验
|
||||
schema = SummarySchema.model_validate(json_data)
|
||||
|
||||
# 质量评估
|
||||
quality = assess_quality(schema)
|
||||
|
||||
# 保存文件
|
||||
_save_files(arxiv_id, schema, raw_output)
|
||||
|
||||
# 更新 DB
|
||||
_update_summary_in_db(db, paper, schema, quality, raw_output)
|
||||
|
||||
# 状态 → done
|
||||
status.status = "done"
|
||||
status.quality = quality
|
||||
status.completed_at = datetime.now(timezone.utc)
|
||||
status.raw_output_saved = True
|
||||
db.commit()
|
||||
|
||||
# PDF 图片提取(可选增强,失败不影响总结)
|
||||
try:
|
||||
from app.services.pdf_image_extractor import (
|
||||
extract_images_from_pdf,
|
||||
filter_images_by_summary,
|
||||
)
|
||||
pdf_path = Path("data/tmp") / arxiv_id / "paper.pdf"
|
||||
extract_images_from_pdf(arxiv_id, pdf_path)
|
||||
# 根据 summary 中 figures 字段过滤,只保留被引用的图表
|
||||
if schema.figures:
|
||||
filter_images_by_summary(arxiv_id, schema.figures)
|
||||
except Exception:
|
||||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||||
|
||||
# 同步写入语义索引(失败仅 log)
|
||||
try:
|
||||
from app.services.embedder import index_paper
|
||||
|
||||
texts_dict = {
|
||||
"arxiv_id": arxiv_id,
|
||||
"title_zh": schema.title_zh or "",
|
||||
"title_en": paper.title_en or "",
|
||||
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
|
||||
"one_line": schema.one_line or "",
|
||||
"motivation_problem": schema.motivation.problem or "",
|
||||
"method_key_idea": schema.method.key_idea or "",
|
||||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
|
||||
}
|
||||
index_paper(arxiv_id, texts_dict)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True
|
||||
)
|
||||
quality = _persist_summary(db, paper, json_data, raw_output)
|
||||
|
||||
logger.info("Summarize done: %s quality=%s", arxiv_id, quality)
|
||||
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
|
||||
|
||||
except Exception as exc:
|
||||
error_type = _classify_error(exc)
|
||||
logger.error(
|
||||
"Summarize failed: %s error_type=%s %s",
|
||||
arxiv_id,
|
||||
error_type,
|
||||
str(exc)[:200],
|
||||
)
|
||||
|
||||
# 保存 raw_output(如果有)
|
||||
if raw_output:
|
||||
_save_raw_output_only(arxiv_id, raw_output)
|
||||
status.raw_output_saved = True
|
||||
|
||||
# 重试逻辑
|
||||
status.retry_count = (status.retry_count or 0) + 1
|
||||
status.error_type = error_type
|
||||
status.error = str(exc)[:2000]
|
||||
|
||||
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
|
||||
status.status = "permanent_failure"
|
||||
else:
|
||||
status.status = "pending"
|
||||
|
||||
status.completed_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"arxiv_id": arxiv_id,
|
||||
"status": "failed",
|
||||
"error_type": error_type,
|
||||
"error": str(exc)[:200],
|
||||
"retry_count": status.retry_count,
|
||||
}
|
||||
# 从异常对象获取 raw_output(_generate_with_retry 失败时仍有输出)
|
||||
fail_output = getattr(exc, "raw_output", raw_output)
|
||||
return _handle_summary_failure(db, paper, exc, fail_output)
|
||||
|
||||
finally:
|
||||
cleanup_tmp(arxiv_id)
|
||||
@@ -458,22 +469,18 @@ async def summarize_single(
|
||||
arxiv_id: str,
|
||||
*,
|
||||
force: bool = True,
|
||||
pdf_mode: str = "auto",
|
||||
_session_factory=None,
|
||||
) -> dict:
|
||||
"""单篇总结入口(供 admin 路由和 CLI 调用)。
|
||||
|
||||
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
|
||||
"""
|
||||
paper = (
|
||||
db.query(Paper)
|
||||
.filter(Paper.arxiv_id == arxiv_id)
|
||||
.options(
|
||||
joinedload(Paper.authors),
|
||||
joinedload(Paper.tags),
|
||||
joinedload(Paper.summary_status),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
paper = db.execute(
|
||||
select(Paper)
|
||||
.where(Paper.arxiv_id == arxiv_id)
|
||||
.options(*PAPER_DEFAULT_LOAD)
|
||||
).unique().scalar_one_or_none()
|
||||
if not paper:
|
||||
return {"status": "not_found", "arxiv_id": arxiv_id}
|
||||
|
||||
@@ -482,17 +489,12 @@ async def summarize_single(
|
||||
# 每篇用独立 session 避免并发问题
|
||||
paper_db = make_session()
|
||||
try:
|
||||
paper_in_new_session = (
|
||||
paper_db.query(Paper)
|
||||
.filter(Paper.arxiv_id == arxiv_id)
|
||||
.options(
|
||||
joinedload(Paper.authors),
|
||||
joinedload(Paper.tags),
|
||||
joinedload(Paper.summary_status),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
result = await summarize_one(paper_db, paper_in_new_session, force=force)
|
||||
paper_in_new_session = paper_db.execute(
|
||||
select(Paper)
|
||||
.where(Paper.arxiv_id == arxiv_id)
|
||||
.options(*PAPER_DEFAULT_LOAD)
|
||||
).unique().scalar_one_or_none()
|
||||
result = await summarize_one(paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode)
|
||||
finally:
|
||||
paper_db.close()
|
||||
|
||||
@@ -506,15 +508,14 @@ async def summarize_batch(
|
||||
db: Session,
|
||||
arxiv_ids: list[str] | None = None,
|
||||
*,
|
||||
pdf_mode: str = "auto",
|
||||
_session_factory=None,
|
||||
) -> dict:
|
||||
"""批量总结入口。arxiv_ids=None 时处理所有 pending 论文。
|
||||
|
||||
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
|
||||
# TaskLock 防重入
|
||||
lock = TaskLock(
|
||||
@@ -543,20 +544,16 @@ async def summarize_batch(
|
||||
|
||||
try:
|
||||
# 查询待总结论文
|
||||
query = db.query(Paper).options(
|
||||
joinedload(Paper.authors),
|
||||
joinedload(Paper.tags),
|
||||
joinedload(Paper.summary_status),
|
||||
)
|
||||
stmt = select(Paper).options(*PAPER_DEFAULT_LOAD)
|
||||
if arxiv_ids:
|
||||
query = query.filter(Paper.arxiv_id.in_(arxiv_ids))
|
||||
stmt = stmt.where(Paper.arxiv_id.in_(arxiv_ids))
|
||||
else:
|
||||
# 只处理 pending 或 failed(可重试的)
|
||||
query = query.join(SummaryStatus).filter(
|
||||
SummaryStatus.status.in_(["pending", "failed"])
|
||||
stmt = stmt.join(SummaryStatus).where(
|
||||
SummaryStatus.status.in_([SummaryState.PENDING, SummaryState.FAILED])
|
||||
)
|
||||
|
||||
papers = query.all()
|
||||
papers = db.execute(stmt).unique().scalars().all()
|
||||
total = len(papers)
|
||||
logger.info("Summarize batch: %d papers to process", total)
|
||||
|
||||
@@ -564,7 +561,7 @@ async def summarize_batch(
|
||||
log_entry.status = "success"
|
||||
log_entry.papers_found = 0
|
||||
log_entry.papers_new = 0
|
||||
log_entry.completed_at = datetime.now(timezone.utc)
|
||||
log_entry.completed_at = utc_now()
|
||||
release_lock(db, lock)
|
||||
return {
|
||||
"status": "success",
|
||||
@@ -581,17 +578,12 @@ async def summarize_batch(
|
||||
async def _process_paper(paper: Paper) -> dict:
|
||||
paper_db = make_session()
|
||||
try:
|
||||
p = (
|
||||
paper_db.query(Paper)
|
||||
.filter(Paper.id == paper.id)
|
||||
.options(
|
||||
joinedload(Paper.authors),
|
||||
joinedload(Paper.tags),
|
||||
joinedload(Paper.summary_status),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return await summarize_one(paper_db, p, semaphore)
|
||||
p = paper_db.execute(
|
||||
select(Paper)
|
||||
.where(Paper.id == paper.id)
|
||||
.options(*PAPER_DEFAULT_LOAD)
|
||||
).unique().scalar_one_or_none()
|
||||
return await summarize_one(paper_db, p, semaphore, pdf_mode=pdf_mode)
|
||||
finally:
|
||||
paper_db.close()
|
||||
|
||||
@@ -619,7 +611,7 @@ async def summarize_batch(
|
||||
log_entry.status = "success" if failed == 0 else "failed"
|
||||
log_entry.papers_found = total
|
||||
log_entry.papers_new = done
|
||||
log_entry.completed_at = datetime.now(timezone.utc)
|
||||
log_entry.completed_at = utc_now()
|
||||
db.commit()
|
||||
|
||||
logger.info(
|
||||
@@ -641,7 +633,7 @@ async def summarize_batch(
|
||||
logger.exception("Summarize batch failed")
|
||||
log_entry.status = "failed"
|
||||
log_entry.error = str(exc)[:2000]
|
||||
log_entry.completed_at = datetime.now(timezone.utc)
|
||||
log_entry.completed_at = utc_now()
|
||||
db.commit()
|
||||
return {"status": "failed", "error": str(exc)}
|
||||
|
||||
|
||||
+34
-31
@@ -2,23 +2,24 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.models import Paper, PaperTag, UserBookmark, UserNote, UserReadingStatus
|
||||
from app.models import PAPER_FULL_LOAD, Paper, PaperTag, UserBookmark, UserNote, UserReadingStatus
|
||||
from app.utils import utc_now
|
||||
|
||||
# ── 收藏 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def toggle_bookmark(db: Session, arxiv_id: str) -> dict:
|
||||
"""切换收藏状态。返回 {"bookmarked": bool, "arxiv_id": str}。"""
|
||||
paper = db.query(Paper).filter(Paper.arxiv_id == arxiv_id).first()
|
||||
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
|
||||
if not paper:
|
||||
return {"error": "not_found"}
|
||||
|
||||
existing = db.query(UserBookmark).filter(UserBookmark.paper_id == paper.id).first()
|
||||
existing = db.execute(
|
||||
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
|
||||
).scalar_one_or_none()
|
||||
if existing:
|
||||
db.delete(existing)
|
||||
db.commit()
|
||||
@@ -26,7 +27,7 @@ def toggle_bookmark(db: Session, arxiv_id: str) -> dict:
|
||||
else:
|
||||
bookmark = UserBookmark(
|
||||
paper_id=paper.id,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
created_at=utc_now(),
|
||||
)
|
||||
db.add(bookmark)
|
||||
db.commit()
|
||||
@@ -43,16 +44,14 @@ def set_reading_status(db: Session, arxiv_id: str, status: str) -> dict:
|
||||
if status not in VALID_STATUSES:
|
||||
return {"error": "invalid_status", "valid": sorted(VALID_STATUSES)}
|
||||
|
||||
paper = db.query(Paper).filter(Paper.arxiv_id == arxiv_id).first()
|
||||
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
|
||||
if not paper:
|
||||
return {"error": "not_found"}
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
existing = (
|
||||
db.query(UserReadingStatus)
|
||||
.filter(UserReadingStatus.paper_id == paper.id)
|
||||
.first()
|
||||
)
|
||||
now = utc_now()
|
||||
existing = db.execute(
|
||||
select(UserReadingStatus).where(UserReadingStatus.paper_id == paper.id)
|
||||
).scalar_one_or_none()
|
||||
if existing:
|
||||
existing.status = status
|
||||
existing.updated_at = now
|
||||
@@ -73,11 +72,13 @@ def set_reading_status(db: Session, arxiv_id: str, status: str) -> dict:
|
||||
|
||||
def get_note(db: Session, arxiv_id: str) -> dict | None:
|
||||
"""获取笔记。返回 {"arxiv_id", "content", "updated_at"} 或 None(论文不存在时)。"""
|
||||
paper = db.query(Paper).filter(Paper.arxiv_id == arxiv_id).first()
|
||||
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
|
||||
if not paper:
|
||||
return None
|
||||
|
||||
note = db.query(UserNote).filter(UserNote.paper_id == paper.id).first()
|
||||
note = db.execute(
|
||||
select(UserNote).where(UserNote.paper_id == paper.id)
|
||||
).scalar_one_or_none()
|
||||
if not note:
|
||||
return {"arxiv_id": arxiv_id, "content": "", "updated_at": None}
|
||||
|
||||
@@ -90,12 +91,14 @@ def get_note(db: Session, arxiv_id: str) -> dict | None:
|
||||
|
||||
def save_note(db: Session, arxiv_id: str, content: str) -> dict:
|
||||
"""创建或更新笔记。返回 {"arxiv_id", "content", "updated_at"}。"""
|
||||
paper = db.query(Paper).filter(Paper.arxiv_id == arxiv_id).first()
|
||||
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
|
||||
if not paper:
|
||||
return {"error": "not_found"}
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
existing = db.query(UserNote).filter(UserNote.paper_id == paper.id).first()
|
||||
now = utc_now()
|
||||
existing = db.execute(
|
||||
select(UserNote).where(UserNote.paper_id == paper.id)
|
||||
).scalar_one_or_none()
|
||||
if existing:
|
||||
existing.content = content
|
||||
existing.updated_at = now
|
||||
@@ -126,7 +129,7 @@ def query_reading_list(
|
||||
) -> list[Paper]:
|
||||
"""根据筛选条件查询阅读列表。"""
|
||||
# 基础:有任意用户数据的论文
|
||||
base = db.query(Paper).filter(
|
||||
stmt = select(Paper).where(
|
||||
or_(
|
||||
Paper.bookmark.has(),
|
||||
Paper.reading_status.has(),
|
||||
@@ -136,25 +139,25 @@ def query_reading_list(
|
||||
|
||||
# 应用筛选
|
||||
if filter_type == "has_note":
|
||||
base = base.filter(Paper.note.has())
|
||||
stmt = stmt.where(Paper.note.has())
|
||||
elif filter_type in ("unread", "skimmed", "read_summary", "read_full"):
|
||||
base = base.filter(
|
||||
stmt = stmt.where(
|
||||
Paper.reading_status.has(UserReadingStatus.status == filter_type)
|
||||
)
|
||||
|
||||
# 应用标签
|
||||
if tag:
|
||||
base = base.filter(Paper.tags.any(PaperTag.tag == tag))
|
||||
stmt = stmt.where(Paper.tags.any(PaperTag.tag == tag))
|
||||
|
||||
return (
|
||||
base.options(
|
||||
joinedload(Paper.authors),
|
||||
joinedload(Paper.tags),
|
||||
joinedload(Paper.summary_status),
|
||||
joinedload(Paper.bookmark),
|
||||
joinedload(Paper.reading_status),
|
||||
joinedload(Paper.note),
|
||||
db.execute(
|
||||
stmt.options(
|
||||
joinedload(Paper.note),
|
||||
*PAPER_FULL_LOAD,
|
||||
)
|
||||
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
|
||||
)
|
||||
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
|
||||
.unique()
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user