From 29e6797c12cfff67b79737dac8f7c1aeb660bcc4 Mon Sep 17 00:00:00 2001 From: rain-bus Date: Fri, 5 Jun 2026 22:29:33 +0800 Subject: [PATCH] feat: add admin routes, summarizer service, and CLI summarize command - Add /admin routes for manual trigger and status inspection - Add summarizer service with batch/single summary support - Add summarize CLI command (single arxiv_id or batch pending) - Register admin router in main app - Add tests for summarizer --- app/cli.py | 40 ++ app/main.py | 2 + app/routes/admin.py | 48 +++ app/services/schemas.py | 168 +++++++++ app/services/summarizer.py | 682 ++++++++++++++++++++++++++++++++++ tests/conftest.py | 209 +++++++++++ tests/test_summarizer.py | 725 +++++++++++++++++++++++++++++++++++++ 7 files changed, 1874 insertions(+) create mode 100644 app/routes/admin.py create mode 100644 app/services/schemas.py create mode 100644 app/services/summarizer.py create mode 100644 tests/conftest.py create mode 100644 tests/test_summarizer.py diff --git a/app/cli.py b/app/cli.py index 7a0cf0c..0f25c7e 100644 --- a/app/cli.py +++ b/app/cli.py @@ -49,6 +49,46 @@ def crawl( db.close() +@cli_app.command() +def summarize( + arxiv_id: str = typer.Argument( + None, + help="指定论文 arXiv ID;留空则批量处理所有 pending", + ), +): + """手动触发 AI 总结。""" + from app.config import settings + from app.database import SessionLocal, engine + from app.models import init_db as _init + from app.services.summarizer import summarize_batch, summarize_single + + import os + os.makedirs(settings.db_path.parent, exist_ok=True) + _init(engine) + + db = SessionLocal() + try: + if arxiv_id: + typer.echo(f"🤖 开始总结 {arxiv_id} ...") + result = asyncio.run(summarize_single(db, arxiv_id)) + else: + typer.echo("🤖 开始批量总结 pending 论文 ...") + result = asyncio.run(summarize_batch(db)) + + if result.get("status") in ("success", "done"): + typer.echo(f"✅ 总结完成:{result}") + elif result.get("status") == "conflict": + typer.echo("⚠️ 已有批量总结任务在运行中", err=True) + raise typer.Exit(code=1) + elif result.get("status") == "not_found": + typer.echo(f"❌ 论文未找到:{arxiv_id}", err=True) + raise typer.Exit(code=1) + else: + typer.echo(f"⚠️ 总结结果:{result}", err=True) + finally: + db.close() + + @cli_app.command() def init_db(): """初始化数据库表。""" diff --git a/app/main.py b/app/main.py index 279eafa..ca95709 100644 --- a/app/main.py +++ b/app/main.py @@ -9,6 +9,7 @@ from fastapi.staticfiles import StaticFiles from app.config import settings from app.database import engine from app.models import init_db +from app.routes.admin import router as admin_router from app.routes.pages import router as pages_router logging.basicConfig( @@ -41,6 +42,7 @@ def create_app() -> FastAPI: # 路由 app.include_router(pages_router) + app.include_router(admin_router) return app diff --git a/app/routes/admin.py b/app/routes/admin.py new file mode 100644 index 0000000..9359203 --- /dev/null +++ b/app/routes/admin.py @@ -0,0 +1,48 @@ +"""管理接口 — AI 总结触发,需要 ADMIN_TOKEN 鉴权。""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlalchemy.orm import Session + +from app.config import settings +from app.database import get_db +from app.services.summarizer import summarize_batch, summarize_single + +router = APIRouter(prefix="/admin", tags=["admin"]) +security = HTTPBearer() + + +async def verify_admin( + credentials: HTTPAuthorizationCredentials = Depends(security), +) -> str: + """验证 ADMIN_TOKEN。""" + if credentials.credentials != settings.ADMIN_TOKEN: + raise HTTPException(status_code=401, detail="Invalid admin token") + return credentials.credentials + + +@router.post("/summarize") +async def admin_summarize_batch( + _admin: str = Depends(verify_admin), + db: Session = Depends(get_db), +): + """批量总结所有 pending 论文。""" + result = await summarize_batch(db) + if result.get("status") == "conflict": + raise HTTPException(status_code=409, detail=result.get("error", "batch already running")) + return result + + +@router.post("/summarize/{arxiv_id}") +async def admin_summarize_single( + arxiv_id: str, + _admin: str = Depends(verify_admin), + db: Session = Depends(get_db), +): + """总结或重跑单篇论文。""" + result = await summarize_single(db, arxiv_id, force=True) + if result.get("status") == "not_found": + raise HTTPException(status_code=404, detail=f"Paper not found: {arxiv_id}") + return result diff --git a/app/services/schemas.py b/app/services/schemas.py new file mode 100644 index 0000000..2b2838b --- /dev/null +++ b/app/services/schemas.py @@ -0,0 +1,168 @@ +"""AI 总结 schema — Pydantic 校验模型、质量评估、DB 展平。""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone + +from pydantic import BaseModel, Field, ValidationError, field_validator + + +# ── 子模型 ────────────────────────────────────────────────────────────── + + +class PrerequisitesSchema(BaseModel): + concepts: list[str] = Field(default_factory=list) + level: str = "" + + +class MotivationSchema(BaseModel): + problem: str + goal: str = "" + gap: str = "" + + @field_validator("problem") + @classmethod + def non_empty_problem(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("motivation.problem cannot be empty") + return v.strip() + + +class MethodSchema(BaseModel): + overview: str = "" + key_idea: str + steps: list[str] = Field(default_factory=list) + novelty: str = "" + + @field_validator("key_idea") + @classmethod + def non_empty_key_idea(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("method.key_idea cannot be empty") + return v.strip() + + +class ResultsSchema(BaseModel): + main_findings: list[str] = Field(default_factory=list) + benchmarks: list[dict] = Field(default_factory=list) + limitations: list[str] = Field(default_factory=list) + + +class ImprovementsSchema(BaseModel): + weaknesses: list[str] = Field(default_factory=list) + future_work: list[str] = Field(default_factory=list) + reproducibility: str = "" + + +# ── 顶层 schema ───────────────────────────────────────────────────────── + + +class SummarySchema(BaseModel): + model_config = {"extra": "ignore"} + + title_zh: str + one_line: str + tags: list[str] + difficulty: str = "" + paper_date: str | None = None + prerequisites: PrerequisitesSchema = Field(default_factory=PrerequisitesSchema) + motivation: MotivationSchema + method: MethodSchema + results: ResultsSchema = Field(default_factory=ResultsSchema) + improvements: ImprovementsSchema = Field(default_factory=ImprovementsSchema) + + @field_validator("title_zh", "one_line") + @classmethod + def non_empty_text(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("field cannot be empty") + return v.strip() + + @field_validator("tags") + @classmethod + def non_empty_tags(cls, v: list[str]) -> list[str]: + tags = [tag.strip() for tag in v if tag and tag.strip()] + if not tags: + raise ValueError("tags cannot be empty") + return tags + + +# ── 质量评估 ──────────────────────────────────────────────────────────── + +# 必填字段:title_zh, one_line, tags, motivation.problem, method.key_idea +# — 缺失时 Pydantic 校验就会报错,不会走到 assess_quality +# 重要字段:motivation.goal, motivation.gap, method.overview, results.main_findings +# — 缺失可入库,标记 degraded +_OPTIONAL_BUT_IMPORTANT_FIELDS = [ + "motivation.goal", + "motivation.gap", + "method.overview", + "results.main_findings", +] + + +def assess_quality(schema: SummarySchema) -> str: + """评估总结质量:normal / degraded / low。""" + # low:内容空洞的启发式判断 + if len(schema.one_line.strip()) < 10 or len(schema.method.key_idea.strip()) < 10: + return "low" + + # 检查重要字段是否缺失 + missing_important = 0 + if not schema.motivation.goal.strip(): + missing_important += 1 + if not schema.motivation.gap.strip(): + missing_important += 1 + if not schema.method.overview.strip(): + missing_important += 1 + if not schema.results.main_findings: + missing_important += 1 + + if missing_important == 0: + return "normal" + return "degraded" + + +# ── DB 展平 ───────────────────────────────────────────────────────────── + + +def flatten_for_db(schema: SummarySchema) -> dict: + """将 SummarySchema 展平为 paper_summaries 表的列值 dict。""" + return { + "one_line": schema.one_line, + "difficulty": schema.difficulty, + "prerequisites_json": json.dumps(schema.prerequisites.model_dump(), ensure_ascii=False), + "motivation_problem": schema.motivation.problem, + "motivation_goal": schema.motivation.goal, + "motivation_gap": schema.motivation.gap, + "method_overview": schema.method.overview, + "method_key_idea": schema.method.key_idea, + "method_steps_json": json.dumps(schema.method.steps, ensure_ascii=False), + "method_novelty": schema.method.novelty, + "results_main_json": json.dumps(schema.results.main_findings, ensure_ascii=False), + "results_benchmarks_json": json.dumps(schema.results.benchmarks, ensure_ascii=False), + "limitations_json": json.dumps(schema.results.limitations, ensure_ascii=False), + "weaknesses_json": json.dumps(schema.improvements.weaknesses, ensure_ascii=False), + "future_work_json": json.dumps(schema.improvements.future_work, ensure_ascii=False), + "reproducibility": schema.improvements.reproducibility, + "full_json": schema.model_dump_json(ensure_ascii=False), + "updated_at": datetime.now(timezone.utc), + } + + +# ── 错误分类 ──────────────────────────────────────────────────────────── + +_REQUIRED_FIELDS = {"title_zh", "one_line", "tags", "problem", "key_idea"} + + +def classify_validation_error(exc: ValidationError) -> str: + """区分 field_missing(必填缺失)和 schema_error(类型不合法等)。""" + for err in exc.errors(): + field_name = err["loc"][-1] if err["loc"] else "" + if field_name in _REQUIRED_FIELDS and err["type"] in ( + "missing", + "value_error", + ): + return "field_missing" + return "schema_error" diff --git a/app/services/summarizer.py b/app/services/summarizer.py new file mode 100644 index 0000000..7003f95 --- /dev/null +++ b/app/services/summarizer.py @@ -0,0 +1,682 @@ +"""AI 总结服务 — 调用 pi CLI 生成论文中文结构化总结。""" + +from __future__ import annotations + +import asyncio +import json +import logging +import re +import shutil +from datetime import datetime, timezone +from pathlib import Path + +import httpx +from pydantic import ValidationError +from sqlalchemy import select, text +from sqlalchemy.orm import Session, joinedload + +from app.config import settings +from app.database import SessionLocal +from app.models import ( + CrawlLog, + Paper, + PaperSummary, + PaperTag, + SummaryStatus, + TaskLock, +) +from app.services.schemas import ( + SummarySchema, + assess_quality, + classify_validation_error, + flatten_for_db, +) + +logger = logging.getLogger(__name__) + +# ── 自定义异常 ────────────────────────────────────────────────────────── + + +class PdfDownloadError(Exception): + pass + + +class PiTimeoutError(Exception): + pass + + +class PiProcessError(Exception): + def __init__(self, returncode: int, stderr: str): + self.returncode = returncode + self.stderr = stderr + super().__init__(f"pi exited with code {returncode}: {stderr[:500]}") + + +class JsonNotFoundError(Exception): + pass + + +# ── 路径工具 ──────────────────────────────────────────────────────────── + +_DATA_DIR = Path("data") +_PAPERS_DIR = _DATA_DIR / "papers" +_TMP_DIR = _DATA_DIR / "tmp" + + +def _paper_dir(arxiv_id: str) -> Path: + return _PAPERS_DIR / arxiv_id + + +def _tmp_dir(arxiv_id: str) -> Path: + return _TMP_DIR / arxiv_id + + +# ── PDF 下载 ──────────────────────────────────────────────────────────── + + +async def _download_pdf(arxiv_id: str, pdf_url: str) -> Path: + """下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。""" + if not pdf_url: + raise PdfDownloadError(f"no pdf_url for {arxiv_id}") + + tmp = _tmp_dir(arxiv_id) + tmp.mkdir(parents=True, exist_ok=True) + dest = tmp / "paper.pdf" + + transport = None + if settings.http_proxy: + transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy) + + try: + async with httpx.AsyncClient( + timeout=settings.HTTP_TIMEOUT_SECONDS, + headers={"User-Agent": settings.HTTP_USER_AGENT}, + transport=transport, + follow_redirects=True, + ) as client: + resp = await client.get(pdf_url) + resp.raise_for_status() + dest.write_bytes(resp.content) + except Exception as exc: + raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc + + logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size) + return dest + + +# ── meta.json ─────────────────────────────────────────────────────────── + + +def _write_meta_json(paper: Paper) -> Path: + """写入 data/papers/{arxiv_id}/meta.json,返回路径。""" + d = _paper_dir(paper.arxiv_id) + d.mkdir(parents=True, exist_ok=True) + meta_path = d / "meta.json" + + authors = [a.name for a in paper.authors] + tags = [t.tag for t in paper.tags] + meta = { + "arxiv_id": paper.arxiv_id, + "title_en": paper.title_en, + "abstract": paper.abstract or "", + "published_at": paper.published_at.isoformat() if paper.published_at else None, + "authors": authors, + "tags": tags, + "upvotes": paper.upvotes, + } + meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8") + return meta_path + + +# ── pi CLI 调用 ──────────────────────────────────────────────────────── + + +async def _call_pi(meta_path: Path, pdf_path: Path) -> str: + """调用 pi CLI 非交互模式,返回 stdout 文本。""" + cmd = [ + settings.PI_BIN, + "-p", + "--no-tools", + "--skill", + settings.SUMMARY_SKILL, + "请深度解读以下论文,并按指定 JSON schema 输出:", + f"@{meta_path}", + f"@{pdf_path}", + ] + logger.info("Calling pi: %s %s", paper_id_from_path(meta_path), " ".join(cmd[:4])) + + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + try: + stdout, stderr = await asyncio.wait_for( + proc.communicate(), + timeout=settings.SUMMARY_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + raise PiTimeoutError( + f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s" + ) + + if proc.returncode != 0: + raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace")) + + return stdout.decode("utf-8", errors="replace") + + +def paper_id_from_path(meta_path: Path) -> str: + """从 meta.json 路径反推 arxiv_id。""" + return meta_path.parent.name + + +# ── JSON 提取 ────────────────────────────────────────────────────────── + + +def _extract_json(raw_output: str) -> dict: + """从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。""" + # 策略 1:整体直接解析 + stripped = raw_output.strip() + try: + result = json.loads(stripped) + if isinstance(result, dict) and "title_zh" in result: + return result + except json.JSONDecodeError: + pass + + # 策略 2:提取 ```json ... ``` 代码块 + fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL) + for match in fence_pattern.finditer(raw_output): + try: + result = json.loads(match.group(1).strip()) + if isinstance(result, dict) and "title_zh" in result: + return result + except json.JSONDecodeError: + continue + + # 策略 3:匹配包含 title_zh 的最大 {...} 块 + brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL) + # 先尝试一层嵌套;如果没命中再用更宽松的策略 + for match in brace_pattern.finditer(raw_output): + try: + return json.loads(match.group(0)) + except json.JSONDecodeError: + continue + + # 更宽松:找到最大的 { ... } 平衡块 + best = None + best_len = 0 + for i, ch in enumerate(raw_output): + if ch != "{": + continue + depth = 0 + for j in range(i, len(raw_output)): + if raw_output[j] == "{": + depth += 1 + elif raw_output[j] == "}": + depth -= 1 + if depth == 0: + candidate = raw_output[i : j + 1] + if len(candidate) > best_len: + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + best = parsed + best_len = len(candidate) + except json.JSONDecodeError: + pass + break + + if best is not None: + return best + + raise JsonNotFoundError("no JSON object found in pi output") + + +# ── 错误分类 ──────────────────────────────────────────────────────────── + + +def _classify_error(exc: Exception) -> str: + """将异常映射到 error_type 枚举值。""" + if isinstance(exc, PdfDownloadError): + return "pdf_download_failed" + if isinstance(exc, PiTimeoutError): + return "timeout" + if isinstance(exc, PiProcessError): + return "process_error" + if isinstance(exc, JsonNotFoundError): + return "json_not_found" + if isinstance(exc, json.JSONDecodeError): + return "json_invalid" + if isinstance(exc, ValidationError): + return classify_validation_error(exc) + return "unknown" + + +# ── FTS5 文本构建 ─────────────────────────────────────────────────────── + + +def _build_fts_summary_text(schema: SummarySchema) -> str: + """拼接用于 FTS5 索引的总结文本。""" + parts = [ + schema.one_line or "", + schema.motivation.problem or "", + schema.motivation.goal or "", + schema.method_overview if hasattr(schema, "method_overview") else "", + schema.method.overview or "", + schema.method.key_idea or "", + " ".join(schema.results.main_findings or []), + ] + return " ".join(p for p in parts if p) + + +# ── DB 更新 ───────────────────────────────────────────────────────────── + + +def _update_summary_in_db( + db: Session, + paper: Paper, + schema: SummarySchema, + quality: str, + raw_output: str, +) -> None: + """将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。""" + now = datetime.now(timezone.utc) + + # 1. paper_summaries:upsert + existing = db.get(PaperSummary, paper.id) + flat = flatten_for_db(schema) + if existing: + for k, v in flat.items(): + setattr(existing, k, v) + else: + db.add(PaperSummary(paper_id=paper.id, **flat)) + + # 2. papers 表 + paper.title_zh = schema.title_zh + paper.summary_quality = quality + paper_dir = _paper_dir(paper.arxiv_id) + paper.summary_path = str(paper_dir / "summary.json") + paper.raw_output_path = str(paper_dir / "raw_output.txt") + + # 3. AI 标签 + existing_tag_names = {t.tag for t in paper.tags} + for tag_name in schema.tags: + if tag_name not in existing_tag_names: + db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai")) + existing_tag_names.add(tag_name) + + # 4. FTS5 更新 + summary_text = _build_fts_summary_text(schema) + db.execute( + text( + "UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text " + "WHERE rowid=:paper_id" + ), + { + "title_zh": schema.title_zh, + "summary_text": summary_text, + "paper_id": paper.id, + }, + ) + + db.commit() + logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality) + + +# ── 文件操作 ──────────────────────────────────────────────────────────── + + +def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None: + """保存 summary.json 和 raw_output.txt。""" + d = _paper_dir(arxiv_id) + d.mkdir(parents=True, exist_ok=True) + (d / "summary.json").write_text( + schema.model_dump_json(ensure_ascii=False, indent=2), + encoding="utf-8", + ) + (d / "raw_output.txt").write_text(raw_output, encoding="utf-8") + + +def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None: + """仅保存 raw_output.txt(失败时)。""" + d = _paper_dir(arxiv_id) + d.mkdir(parents=True, exist_ok=True) + (d / "raw_output.txt").write_text(raw_output, encoding="utf-8") + + +def _cleanup_tmp(arxiv_id: str) -> None: + """清理 data/tmp/{arxiv_id}/ 目录。""" + tmp = _tmp_dir(arxiv_id) + if tmp.exists(): + try: + shutil.rmtree(tmp) + logger.debug("Cleaned tmp: %s", arxiv_id) + except Exception: + logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True) + + +# ── 单篇总结 ──────────────────────────────────────────────────────────── + + +async def summarize_one( + db: Session, + paper: Paper, + semaphore: asyncio.Semaphore | None = None, + *, + force: bool = False, +) -> dict: + """总结单篇论文的完整流程。""" + arxiv_id = paper.arxiv_id + + # 获取或创建 summary_status + if not paper.summary_status: + db.add(SummaryStatus(paper_id=paper.id, status="pending")) + db.commit() + db.refresh(paper) + + status = paper.summary_status + + # 跳过已完成的(除非 force) + if status.status == "done" and not force: + return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"} + + # 跳过 permanent_failure(除非 force) + if status.status == "permanent_failure" and not force: + return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "permanent_failure"} + + if semaphore: + await semaphore.acquire() + try: + return await _do_summarize_one(db, paper) + finally: + if semaphore: + semaphore.release() + + +async def _do_summarize_one(db: Session, paper: Paper) -> dict: + """实际的单篇总结执行(在 semaphore 保护下)。""" + arxiv_id = paper.arxiv_id + status = paper.summary_status + now = datetime.now(timezone.utc) + + # 状态 → processing + status.status = "processing" + status.started_at = now + db.commit() + + raw_output = "" + try: + # 写 meta.json + meta_path = _write_meta_json(paper) + + # 下载 PDF + await _download_pdf(arxiv_id, paper.pdf_url) + + # 调用 pi + raw_output = await _call_pi(meta_path, _tmp_dir(arxiv_id) / "paper.pdf") + + # 提取 JSON + json_data = _extract_json(raw_output) + + # Pydantic 校验 + schema = SummarySchema.model_validate(json_data) + + # 质量评估 + quality = assess_quality(schema) + + # 保存文件 + _save_files(arxiv_id, schema, raw_output) + + # 更新 DB + _update_summary_in_db(db, paper, schema, quality, raw_output) + + # 状态 → done + status.status = "done" + status.quality = quality + status.completed_at = datetime.now(timezone.utc) + status.raw_output_saved = True + db.commit() + + logger.info("Summarize done: %s quality=%s", arxiv_id, quality) + return {"arxiv_id": arxiv_id, "status": "done", "quality": quality} + + except Exception as exc: + error_type = _classify_error(exc) + logger.error( + "Summarize failed: %s error_type=%s %s", + arxiv_id, + error_type, + str(exc)[:200], + ) + + # 保存 raw_output(如果有) + if raw_output: + _save_raw_output_only(arxiv_id, raw_output) + status.raw_output_saved = True + + # 重试逻辑 + status.retry_count = (status.retry_count or 0) + 1 + status.error_type = error_type + status.error = str(exc)[:2000] + + if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1: + status.status = "permanent_failure" + else: + status.status = "pending" + + status.completed_at = datetime.now(timezone.utc) + db.commit() + + return { + "arxiv_id": arxiv_id, + "status": "failed", + "error_type": error_type, + "error": str(exc)[:200], + "retry_count": status.retry_count, + } + + finally: + _cleanup_tmp(arxiv_id) + + +# ── 单篇入口 ──────────────────────────────────────────────────────────── + + +async def summarize_single( + db: Session, + arxiv_id: str, + *, + force: bool = True, + _session_factory=None, +) -> dict: + """单篇总结入口(供 admin 路由和 CLI 调用)。 + + _session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。 + """ + paper = ( + db.query(Paper) + .filter(Paper.arxiv_id == arxiv_id) + .options( + joinedload(Paper.authors), + joinedload(Paper.tags), + joinedload(Paper.summary_status), + ) + .first() + ) + if not paper: + return {"status": "not_found", "arxiv_id": arxiv_id} + + make_session = _session_factory or SessionLocal + + # 每篇用独立 session 避免并发问题 + paper_db = make_session() + try: + paper_in_new_session = ( + paper_db.query(Paper) + .filter(Paper.arxiv_id == arxiv_id) + .options( + joinedload(Paper.authors), + joinedload(Paper.tags), + joinedload(Paper.summary_status), + ) + .first() + ) + result = await summarize_one(paper_db, paper_in_new_session, force=force) + finally: + paper_db.close() + + return result + + +# ── 批量总结 ──────────────────────────────────────────────────────────── + + +async def summarize_batch( + db: Session, + arxiv_ids: list[str] | None = None, + *, + _session_factory=None, +) -> dict: + """批量总结入口。arxiv_ids=None 时处理所有 pending 论文。 + + _session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。 + """ + now = datetime.now(timezone.utc) + + # TaskLock 防重入 + lock = TaskLock( + task="summarize", + lock_key="batch", + status="running", + owner="summarize_batch", + acquired_at=now, + ) + try: + db.add(lock) + db.commit() + except Exception: + db.rollback() + logger.warning("Summarize batch already running (lock conflict)") + return {"status": "conflict", "error": "summarize batch already running"} + + # CrawlLog + log_entry = CrawlLog( + task="summarize", + status="running", + started_at=now, + ) + db.add(log_entry) + db.commit() + + try: + # 查询待总结论文 + query = db.query(Paper).options( + joinedload(Paper.authors), + joinedload(Paper.tags), + joinedload(Paper.summary_status), + ) + if arxiv_ids: + query = query.filter(Paper.arxiv_id.in_(arxiv_ids)) + else: + # 只处理 pending 或 failed(可重试的) + query = query.join(SummaryStatus).filter( + SummaryStatus.status.in_(["pending", "failed"]) + ) + + papers = query.all() + total = len(papers) + logger.info("Summarize batch: %d papers to process", total) + + if total == 0: + log_entry.status = "success" + log_entry.papers_found = 0 + log_entry.papers_new = 0 + log_entry.completed_at = datetime.now(timezone.utc) + _release_lock(db, lock) + return {"status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0} + + # 并发控制 + semaphore = asyncio.Semaphore(settings.SUMMARY_CONCURRENCY) + make_session = _session_factory or SessionLocal + + async def _process_paper(paper: Paper) -> dict: + paper_db = make_session() + try: + p = ( + paper_db.query(Paper) + .filter(Paper.id == paper.id) + .options( + joinedload(Paper.authors), + joinedload(Paper.tags), + joinedload(Paper.summary_status), + ) + .first() + ) + return await summarize_one(paper_db, p, semaphore) + finally: + paper_db.close() + + results = await asyncio.gather( + *[_process_paper(p) for p in papers], + return_exceptions=True, + ) + + # 统计结果 + done = 0 + failed = 0 + skipped = 0 + for r in results: + if isinstance(r, Exception): + logger.error("Unexpected error in batch: %s", r) + failed += 1 + elif isinstance(r, dict): + if r.get("status") == "done": + done += 1 + elif r.get("status") == "skipped": + skipped += 1 + else: + failed += 1 + + log_entry.status = "success" if failed == 0 else "failed" + log_entry.papers_found = total + log_entry.papers_new = done + log_entry.completed_at = datetime.now(timezone.utc) + db.commit() + + logger.info( + "Summarize batch done: total=%d done=%d failed=%d skipped=%d", + total, done, failed, skipped, + ) + return { + "status": "success" if failed == 0 else "partial", + "total": total, + "done": done, + "failed": failed, + "skipped": skipped, + } + + except Exception as exc: + logger.exception("Summarize batch failed") + log_entry.status = "failed" + log_entry.error = str(exc)[:2000] + log_entry.completed_at = datetime.now(timezone.utc) + db.commit() + return {"status": "failed", "error": str(exc)} + + finally: + _release_lock(db, lock) + + +def _release_lock(db: Session, lock: TaskLock) -> None: + """释放 TaskLock。""" + try: + lock.status = "finished" + lock.released_at = datetime.now(timezone.utc) + db.commit() + except Exception: + db.rollback() + logger.warning("Failed to release summarize lock", exc_info=True) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..7c636d3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,209 @@ +"""测试 fixtures — 内存 SQLite、TestClient、样例数据。""" + +from __future__ import annotations + +import json +from datetime import date, datetime, timezone +from pathlib import Path +from unittest.mock import AsyncMock + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import create_engine, event +from sqlalchemy.orm import DeclarativeBase, sessionmaker + +from app.database import get_db +from app.main import create_app +from app.models import ( + Paper, + PaperAuthor, + PaperSummary, + PaperTag, + SummaryStatus, + init_db, +) + + +# ── 内存数据库 ────────────────────────────────────────────────────────── + + +class _TestBase(DeclarativeBase): + pass + + +# 复用 app.models 的 Base metadata +from app.database import Base as _AppBase # noqa: E402 + +_TestBase.metadata = _AppBase.metadata + + +@pytest.fixture +def db_engine(): + """创建内存 SQLite 引擎 + FTS5。""" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + ) + + @event.listens_for(engine, "connect") + def _pragma(dbapi_connection, _record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + init_db(engine) + return engine + + +@pytest.fixture +def db_session(db_engine): + """提供事务隔离的数据库 session。""" + Session = sessionmaker(bind=db_engine, autoflush=False, autocommit=False) + session = Session() + try: + yield session + finally: + session.close() + + +@pytest.fixture +def client(db_engine, db_session): + """FastAPI TestClient,override get_db。""" + app = create_app() + + def _override_get_db(): + yield db_session + + app.dependency_overrides[get_db] = _override_get_db + + with TestClient(app, raise_server_exceptions=False) as c: + yield c + + app.dependency_overrides.clear() + + +# ── 样例数据 ──────────────────────────────────────────────────────────── + +SAMPLE_ARXIV_ID = "2401.12345" +ADMIN_TOKEN = "test-admin-token-12345" + + +@pytest.fixture +def sample_paper(db_session): + """插入一篇测试论文 + 作者 + 标签 + summary_status(pending)。""" + now = datetime.now(timezone.utc) + paper = Paper( + arxiv_id=SAMPLE_ARXIV_ID, + title_en="Test Paper Title", + abstract="This is a test abstract for the paper.", + published_at=date(2024, 1, 15), + paper_date=date(2024, 1, 15), + crawled_at=now, + upvotes=42, + hf_url=f"https://huggingface.co/papers/{SAMPLE_ARXIV_ID}", + arxiv_url=f"https://arxiv.org/abs/{SAMPLE_ARXIV_ID}", + pdf_url=f"https://arxiv.org/pdf/{SAMPLE_ARXIV_ID}.pdf", + ) + db_session.add(paper) + db_session.flush() + + db_session.add(PaperAuthor(paper_id=paper.id, name="Alice Smith", position=0)) + db_session.add(PaperAuthor(paper_id=paper.id, name="Bob Jones", position=1)) + db_session.add(PaperTag(paper_id=paper.id, tag="NLP", source="hf")) + db_session.add(PaperTag(paper_id=paper.id, tag="LLM", source="hf")) + + db_session.add(SummaryStatus(paper_id=paper.id, status="pending")) + + # FTS5 初始行(与 crawler 一致) + db_session.execute( + __import__("sqlalchemy").text( + "INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) " + "VALUES (:id, :title, :abstract, :authors, :tags)" + ), + { + "id": paper.id, + "title": paper.title_en, + "abstract": paper.abstract or "", + "authors": "Alice Smith, Bob Jones", + "tags": "NLP, LLM", + }, + ) + db_session.commit() + return paper + + +@pytest.fixture +def sample_summary_dict() -> dict: + """完整合法的 summary dict。""" + return { + "title_zh": "测试论文中文标题", + "one_line": "这是一篇关于自然语言处理的测试论文的一句话总结。", + "tags": ["自然语言处理", "大语言模型", "Transformer"], + "difficulty": "中级", + "prerequisites": { + "concepts": ["Transformer", "注意力机制"], + "level": "中级", + }, + "motivation": { + "problem": "现有模型在长文本理解上存在不足。", + "goal": "提出一种新的注意力机制来提升长文本建模能力。", + "gap": "当前方法计算复杂度过高。", + }, + "method": { + "overview": "提出了一种高效的稀疏注意力机制。", + "key_idea": "使用局部-全局混合的注意力模式来降低计算复杂度。", + "steps": [ + "分析现有注意力机制的瓶颈", + "设计稀疏注意力模式", + "在多个基准上验证效果", + ], + "novelty": "首次将局部-全局注意力模式结合应用于长文本建模。", + }, + "results": { + "main_findings": [ + "在长文本基准上取得了 SOTA 结果", + "推理速度提升了 2 倍", + ], + "benchmarks": [ + {"dataset": "LongBench", "score": 85.3}, + ], + "limitations": [ + "在超长文本(>100k tokens)上效果有所下降", + ], + }, + "improvements": { + "weaknesses": ["仅验证了英文数据"], + "future_work": ["扩展到多语言场景"], + "reproducibility": "代码已开源,模型权重可下载。", + }, + } + + +@pytest.fixture +def sample_summary_json(sample_summary_dict) -> str: + """合法 summary 的 JSON 字符串。""" + return json.dumps(sample_summary_dict, ensure_ascii=False, indent=2) + + +@pytest.fixture +def mock_pi_output(sample_summary_json) -> str: + """模拟 pi CLI 的完整输出(包含 JSON)。""" + return f"""以下是论文的深度解读: + +```json +{sample_summary_json} +``` + +希望这个总结对你有帮助!""" + + +@pytest.fixture +def admin_token(): + """返回测试用的 ADMIN_TOKEN(需要配合 monkeypatch 使用)。""" + return ADMIN_TOKEN + + +@pytest.fixture +def admin_headers(admin_token): + """带 Bearer token 的请求头。""" + return {"Authorization": f"Bearer {admin_token}"} diff --git a/tests/test_summarizer.py b/tests/test_summarizer.py new file mode 100644 index 0000000..ebe8b8f --- /dev/null +++ b/tests/test_summarizer.py @@ -0,0 +1,725 @@ +"""AI 总结服务测试 — Mock 全链路,不调用真实 pi。""" + +from __future__ import annotations + +import asyncio +import json +from datetime import date, datetime, timezone +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy import text + +from app.models import ( + CrawlLog, + Paper, + PaperSummary, + PaperTag, + SummaryStatus, + TaskLock, +) +from app.services.schemas import ( + SummarySchema, + assess_quality, + classify_validation_error, + flatten_for_db, +) +from app.services.summarizer import ( + JsonNotFoundError, + PdfDownloadError, + PiProcessError, + PiTimeoutError, + _call_pi, + _classify_error, + _cleanup_tmp, + _extract_json, + _save_files, + _save_raw_output_only, + _update_summary_in_db, + summarize_batch, + summarize_one, + summarize_single, +) + + +# ═══════════════════════════════════════════════════════════════════════ +# Schema 校验测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestSummarySchema: + """Pydantic schema 校验。""" + + def test_valid_summary(self, sample_summary_dict): + schema = SummarySchema.model_validate(sample_summary_dict) + assert schema.title_zh == "测试论文中文标题" + assert len(schema.tags) == 3 + assert schema.motivation.problem + + def test_missing_title_zh(self, sample_summary_dict): + del sample_summary_dict["title_zh"] + with pytest.raises(ValidationError) as exc_info: + SummarySchema.model_validate(sample_summary_dict) + assert classify_validation_error(exc_info.value) == "field_missing" + + def test_empty_one_line(self, sample_summary_dict): + sample_summary_dict["one_line"] = "" + with pytest.raises(ValidationError): + SummarySchema.model_validate(sample_summary_dict) + + def test_empty_tags(self, sample_summary_dict): + sample_summary_dict["tags"] = [] + with pytest.raises(ValidationError): + SummarySchema.model_validate(sample_summary_dict) + + def test_empty_motivation_problem(self, sample_summary_dict): + sample_summary_dict["motivation"]["problem"] = "" + with pytest.raises(ValidationError): + SummarySchema.model_validate(sample_summary_dict) + + def test_empty_method_key_idea(self, sample_summary_dict): + sample_summary_dict["method"]["key_idea"] = "" + with pytest.raises(ValidationError): + SummarySchema.model_validate(sample_summary_dict) + + def test_extra_fields_ignored(self, sample_summary_dict): + sample_summary_dict["figures"] = ["fig1.png"] + sample_summary_dict["takeaway"] = "important paper" + schema = SummarySchema.model_validate(sample_summary_dict) + assert not hasattr(schema, "figures") + assert schema.title_zh # 正常解析 + + def test_flatten_for_db(self, sample_summary_dict): + schema = SummarySchema.model_validate(sample_summary_dict) + flat = flatten_for_db(schema) + assert flat["one_line"] == schema.one_line + assert flat["motivation_problem"] == schema.motivation.problem + assert flat["method_key_idea"] == schema.method.key_idea + assert "full_json" in flat + assert "updated_at" in flat + # JSON 字段可解析 + assert isinstance(json.loads(flat["prerequisites_json"]), dict) + assert isinstance(json.loads(flat["method_steps_json"]), list) + + +class TestQualityAssessment: + """质量分级测试。""" + + def test_quality_normal(self, sample_summary_dict): + schema = SummarySchema.model_validate(sample_summary_dict) + assert assess_quality(schema) == "normal" + + def test_quality_degraded_missing_goal(self, sample_summary_dict): + sample_summary_dict["motivation"]["goal"] = "" + sample_summary_dict["motivation"]["gap"] = "" + sample_summary_dict["method"]["overview"] = "" + sample_summary_dict["results"]["main_findings"] = [] + schema = SummarySchema.model_validate(sample_summary_dict) + assert assess_quality(schema) == "degraded" + + def test_quality_low_short_one_line(self, sample_summary_dict): + sample_summary_dict["one_line"] = "短" + schema = SummarySchema.model_validate(sample_summary_dict) + assert assess_quality(schema) == "low" + + def test_quality_low_short_key_idea(self, sample_summary_dict): + sample_summary_dict["method"]["key_idea"] = "短" + schema = SummarySchema.model_validate(sample_summary_dict) + assert assess_quality(schema) == "low" + + +# ═══════════════════════════════════════════════════════════════════════ +# JSON 提取测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestJsonExtraction: + """pi 输出的 JSON 提取。""" + + def test_direct_json(self, sample_summary_json): + result = _extract_json(sample_summary_json) + assert result["title_zh"] == "测试论文中文标题" + + def test_fenced_code_block(self, sample_summary_json): + raw = f"一些文字\n```json\n{sample_summary_json}\n```\n更多文字" + result = _extract_json(raw) + assert result["title_zh"] == "测试论文中文标题" + + def test_fenced_without_lang(self, sample_summary_json): + raw = f"文字\n```\n{sample_summary_json}\n```" + result = _extract_json(raw) + assert result["title_zh"] == "测试论文中文标题" + + def test_embedded_braces(self, sample_summary_dict): + json_str = json.dumps(sample_summary_dict, ensure_ascii=False) + raw = f"Here is the summary:\n{json_str}\nEnd." + result = _extract_json(raw) + assert result["title_zh"] == "测试论文中文标题" + + def test_no_json_raises(self): + with pytest.raises(JsonNotFoundError): + _extract_json("No JSON here at all.") + + def test_json_without_title_zh_falls_through(self): + """不含 title_zh 的 JSON 不是我们要的。""" + raw = json.dumps({"other": "data"}) + # 如果有其他合法 JSON 块也能返回,但没有就直接找最大块 + # 此场景 raw 本身就是一个 JSON dict,但没有 title_zh + # 策略 1 会跳过(无 title_zh),策略 2 无代码块,策略 3 找到最大块 + result = _extract_json(raw) + assert result == {"other": "data"} # 最大块兜底 + + +# ═══════════════════════════════════════════════════════════════════════ +# 错误分类测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestErrorClassification: + """异常 → error_type 映射。""" + + def test_pdf_download_error(self): + assert _classify_error(PdfDownloadError("fail")) == "pdf_download_failed" + + def test_timeout_error(self): + assert _classify_error(PiTimeoutError("timeout")) == "timeout" + + def test_process_error(self): + assert _classify_error(PiProcessError(1, "stderr")) == "process_error" + + def test_json_not_found(self): + assert _classify_error(JsonNotFoundError("not found")) == "json_not_found" + + def test_json_invalid(self): + assert _classify_error(json.JSONDecodeError("bad", "", 0)) == "json_invalid" + + def test_field_missing(self): + try: + SummarySchema.model_validate({"title_zh": ""}) # type: ignore + except ValidationError as exc: + assert _classify_error(exc) == "field_missing" + + def test_unknown_error(self): + assert _classify_error(RuntimeError("boom")) == "unknown" + + +# ═══════════════════════════════════════════════════════════════════════ +# DB 更新测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestDbUpdate: + """_update_summary_in_db 验证。""" + + def test_summary_written(self, db_session, sample_paper, sample_summary_dict): + schema = SummarySchema.model_validate(sample_summary_dict) + _update_summary_in_db(db_session, sample_paper, schema, "normal", "raw") + + summary = db_session.get(PaperSummary, sample_paper.id) + assert summary is not None + assert summary.one_line == schema.one_line + assert summary.motivation_problem == schema.motivation.problem + assert json.loads(summary.full_json)["title_zh"] == schema.title_zh + + def test_paper_title_zh_updated(self, db_session, sample_paper, sample_summary_dict): + schema = SummarySchema.model_validate(sample_summary_dict) + _update_summary_in_db(db_session, sample_paper, schema, "normal", "raw") + + db_session.refresh(sample_paper) + assert sample_paper.title_zh == "测试论文中文标题" + assert sample_paper.summary_quality == "normal" + + def test_fts_updated(self, db_session, sample_paper, sample_summary_dict): + schema = SummarySchema.model_validate(sample_summary_dict) + _update_summary_in_db(db_session, sample_paper, schema, "normal", "raw") + + row = db_session.execute( + text("SELECT title_zh, summary_text FROM papers_fts WHERE rowid = :id"), + {"id": sample_paper.id}, + ).fetchone() + assert row is not None + assert row[0] == "测试论文中文标题" + assert schema.one_line in row[1] + + def test_ai_tags_added(self, db_session, sample_paper, sample_summary_dict): + schema = SummarySchema.model_validate(sample_summary_dict) + _update_summary_in_db(db_session, sample_paper, schema, "normal", "raw") + + tags = ( + db_session.query(PaperTag) + .filter(PaperTag.paper_id == sample_paper.id, PaperTag.source == "ai") + .all() + ) + tag_names = {t.tag for t in tags} + # AI tags 来自 schema.tags + assert "自然语言处理" in tag_names + assert "大语言模型" in tag_names + + def test_existing_tags_not_duplicated(self, db_session, sample_paper, sample_summary_dict): + """已存在的标签名(同 name)不会被 AI source 重复插入。""" + # sample_paper 已有 NLP (hf)、LLM (hf) + # 让 AI 输出包含 NLP(与 HF 重复)和 "新标签"(新的) + sample_summary_dict["tags"] = ["NLP", "新标签"] + schema = SummarySchema.model_validate(sample_summary_dict) + _update_summary_in_db(db_session, sample_paper, schema, "normal", "raw") + + all_tags = ( + db_session.query(PaperTag) + .filter(PaperTag.paper_id == sample_paper.id) + .all() + ) + tag_names = [t.tag for t in all_tags] + # NLP 只出现一次(HF 原有的),AI 不会重复加 + assert tag_names.count("NLP") == 1 + # "新标签" 是 AI 新加的 + assert "新标签" in tag_names + + +# ═══════════════════════════════════════════════════════════════════════ +# 文件操作测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestFileOperations: + """文件保存和清理。""" + + def test_save_files(self, tmp_path, sample_summary_dict): + schema = SummarySchema.model_validate(sample_summary_dict) + with patch("app.services.summarizer._PAPERS_DIR", tmp_path): + _save_files("2401.12345", schema, "raw output text") + + paper_dir = tmp_path / "2401.12345" + assert (paper_dir / "summary.json").exists() + assert (paper_dir / "raw_output.txt").exists() + saved = json.loads((paper_dir / "summary.json").read_text()) + assert saved["title_zh"] == "测试论文中文标题" + + def test_save_raw_output_only(self, tmp_path): + with patch("app.services.summarizer._PAPERS_DIR", tmp_path): + _save_raw_output_only("2401.12345", "raw output") + paper_dir = tmp_path / "2401.12345" + assert (paper_dir / "raw_output.txt").exists() + assert not (paper_dir / "summary.json").exists() + + def test_cleanup_tmp(self, tmp_path): + tmp_paper = tmp_path / "2401.12345" + tmp_paper.mkdir() + (tmp_paper / "paper.pdf").write_bytes(b"%PDF-fake") + with patch("app.services.summarizer._TMP_DIR", tmp_path): + _cleanup_tmp("2401.12345") + assert not tmp_paper.exists() + + def test_cleanup_tmp_nonexistent(self, tmp_path): + """清理不存在的目录不报错。""" + with patch("app.services.summarizer._TMP_DIR", tmp_path): + _cleanup_tmp("nonexistent") # 不抛异常 + + +# ═══════════════════════════════════════════════════════════════════════ +# 全流程状态流转测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestSummarizeOneFlow: + """summarize_one 的状态流转(mock pi 和 PDF)。""" + + @pytest.fixture + def _patch_paths(self, tmp_path): + """将 data 目录重定向到 tmp_path。""" + with ( + patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"), + patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"), + patch("app.services.summarizer._DATA_DIR", tmp_path), + ): + yield + + @pytest.mark.asyncio + async def test_full_success_path( + self, db_session, sample_paper, mock_pi_output, _patch_paths + ): + """pending → processing → done 全流程。""" + with ( + patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), + patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output), + ): + result = await summarize_one(db_session, sample_paper) + + assert result["status"] == "done" + assert result["quality"] == "normal" + + # 验证 DB 状态 + db_session.refresh(sample_paper) + assert sample_paper.summary_status.status == "done" + assert sample_paper.summary_status.quality == "normal" + assert sample_paper.title_zh == "测试论文中文标题" + + # 验证 summary 已写入 + summary = db_session.get(PaperSummary, sample_paper.id) + assert summary is not None + assert summary.one_line + + # 验证 FTS 已更新 + fts_row = db_session.execute( + text("SELECT title_zh FROM papers_fts WHERE rowid = :id"), + {"id": sample_paper.id}, + ).fetchone() + assert fts_row[0] == "测试论文中文标题" + + @pytest.mark.asyncio + async def test_pdf_download_failure( + self, db_session, sample_paper, _patch_paths + ): + """PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。""" + with ( + patch( + "app.services.summarizer._download_pdf", + new_callable=AsyncMock, + side_effect=PdfDownloadError("network error"), + ), + ): + result = await summarize_one(db_session, sample_paper) + + assert result["status"] == "failed" + assert result["error_type"] == "pdf_download_failed" + + db_session.refresh(sample_paper) + status = sample_paper.summary_status + assert status.error_type == "pdf_download_failed" + + @pytest.mark.asyncio + async def test_pi_timeout(self, db_session, sample_paper, _patch_paths): + """pi 超时 → timeout 错误,retry_count 递增。""" + with ( + patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), + patch( + "app.services.summarizer._call_pi", + new_callable=AsyncMock, + side_effect=PiTimeoutError("timeout after 300s"), + ), + ): + result = await summarize_one(db_session, sample_paper) + + assert result["status"] == "failed" + assert result["error_type"] == "timeout" + assert result["retry_count"] == 1 + + @pytest.mark.asyncio + async def test_json_not_found(self, db_session, sample_paper, _patch_paths): + """pi 输出无 JSON → json_not_found。""" + with ( + patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), + patch( + "app.services.summarizer._call_pi", + new_callable=AsyncMock, + return_value="No JSON in this output at all.", + ), + ): + result = await summarize_one(db_session, sample_paper) + + assert result["status"] == "failed" + assert result["error_type"] == "json_not_found" + + @pytest.mark.asyncio + async def test_field_missing_and_retry( + self, db_session, sample_paper, _patch_paths + ): + """必填字段缺失 → field_missing → retry → permanent_failure。""" + bad_json = json.dumps({ + "title_zh": "", # 空的必填字段 + "one_line": "valid line", + "tags": ["tag1"], + "motivation": {"problem": "valid problem"}, + "method": {"key_idea": "valid idea"}, + }, ensure_ascii=False) + bad_output = f"```json\n{bad_json}\n```" + + with ( + patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), + patch( + "app.services.summarizer._call_pi", + new_callable=AsyncMock, + return_value=bad_output, + ), + ): + # 第一次失败 → pending (retry) + result1 = await summarize_one(db_session, sample_paper) + assert result1["status"] == "failed" + assert result1["error_type"] == "field_missing" + assert result1["retry_count"] == 1 + + # 第二次失败 → permanent_failure (SUMMARY_MAX_RETRIES=1, 所以 2 次 > 1+1) + db_session.refresh(sample_paper) + result2 = await summarize_one(db_session, sample_paper) + assert result2["status"] == "failed" + assert result2["retry_count"] == 2 + + db_session.refresh(sample_paper) + assert sample_paper.summary_status.status == "permanent_failure" + + @pytest.mark.asyncio + async def test_raw_output_saved_on_failure( + self, db_session, sample_paper, tmp_path, _patch_paths + ): + """失败时仍保存 raw_output.txt。""" + with ( + patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), + patch( + "app.services.summarizer._call_pi", + new_callable=AsyncMock, + return_value="Some output without JSON", + ), + ): + await summarize_one(db_session, sample_paper) + + raw_file = tmp_path / "papers" / sample_paper.arxiv_id / "raw_output.txt" + assert raw_file.exists() + assert "Some output without JSON" in raw_file.read_text() + + @pytest.mark.asyncio + async def test_tmp_cleaned_on_success( + self, db_session, sample_paper, mock_pi_output, tmp_path, _patch_paths + ): + """成功后清理 tmp 目录。""" + with ( + patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), + patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output), + ): + await summarize_one(db_session, sample_paper) + + tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id + assert not tmp_paper.exists() + + @pytest.mark.asyncio + async def test_tmp_cleaned_on_failure( + self, db_session, sample_paper, tmp_path, _patch_paths + ): + """失败后也清理 tmp 目录。""" + with ( + patch( + "app.services.summarizer._download_pdf", + new_callable=AsyncMock, + side_effect=PdfDownloadError("fail"), + ), + ): + await summarize_one(db_session, sample_paper) + + tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id + assert not tmp_paper.exists() + + @pytest.mark.asyncio + async def test_skips_done_paper(self, db_session, sample_paper, _patch_paths): + """已完成的论文跳过。""" + sample_paper.summary_status.status = "done" + db_session.commit() + + result = await summarize_one(db_session, sample_paper) + assert result["status"] == "skipped" + + +# ═══════════════════════════════════════════════════════════════════════ +# 批量操作测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestBatchSummarize: + """批量总结测试。""" + + @pytest.fixture + def _patch_paths(self, tmp_path): + with ( + patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"), + patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"), + patch("app.services.summarizer._DATA_DIR", tmp_path), + ): + yield + + @pytest.mark.asyncio + async def test_batch_multiple_papers( + self, db_session, db_engine, mock_pi_output, _patch_paths + ): + """批量处理多篇论文。""" + now = datetime.now(timezone.utc) + for i in range(3): + p = Paper( + arxiv_id=f"2401.1234{i}", + title_en=f"Test Paper {i}", + abstract=f"Abstract {i}", + paper_date=date(2024, 1, 15), + crawled_at=now, + pdf_url=f"https://arxiv.org/pdf/2401.1234{i}.pdf", + ) + db_session.add(p) + db_session.flush() + db_session.add(SummaryStatus(paper_id=p.id, status="pending")) + + db_session.commit() + + # 每个 worker 用独立 session(同一个内存引擎) + from sqlalchemy.orm import sessionmaker as _sm + _TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False) + + with ( + patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), + patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output), + ): + result = await summarize_batch( + db_session, _session_factory=_TestSession + ) + + assert result["status"] == "success" + assert result["done"] == 3 + assert result["failed"] == 0 + + # 验证 CrawlLog + log = db_session.query(CrawlLog).filter(CrawlLog.task == "summarize").first() + assert log is not None + assert log.status == "success" + assert log.papers_found == 3 + + @pytest.mark.asyncio + async def test_single_failure_no_block( + self, db_session, db_engine, mock_pi_output, _patch_paths + ): + """一篇失败不阻塞其他。""" + now = datetime.now(timezone.utc) + for i in range(2): + p = Paper( + arxiv_id=f"2401.5678{i}", + title_en=f"Paper {i}", + abstract=f"Abstract {i}", + paper_date=date(2024, 1, 15), + crawled_at=now, + pdf_url=f"https://arxiv.org/pdf/2401.5678{i}.pdf", + ) + db_session.add(p) + db_session.flush() + db_session.add(SummaryStatus(paper_id=p.id, status="pending")) + + db_session.commit() + + from sqlalchemy.orm import sessionmaker as _sm + _TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False) + + call_count = 0 + + async def _mock_call_pi(meta_path, pdf_path): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise PiTimeoutError("timeout") + return mock_pi_output + + with ( + patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), + patch("app.services.summarizer._call_pi", side_effect=_mock_call_pi), + ): + result = await summarize_batch( + db_session, _session_factory=_TestSession + ) + + assert result["done"] == 1 + assert result["failed"] == 1 + + @pytest.mark.asyncio + async def test_task_lock_conflict(self, db_session, _patch_paths): + """TaskLock 防止并发 batch。""" + # 先插入一个 running 锁 + db_session.add( + TaskLock( + task="summarize", + lock_key="batch", + status="running", + acquired_at=datetime.now(timezone.utc), + ) + ) + db_session.commit() + + result = await summarize_batch(db_session) + assert result["status"] == "conflict" + + @pytest.mark.asyncio + async def test_task_lock_released(self, db_session, db_engine, mock_pi_output, _patch_paths): + """完成后释放 TaskLock。""" + from sqlalchemy.orm import sessionmaker as _sm + _TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False) + + with ( + patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), + patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output), + ): + await summarize_batch( + db_session, _session_factory=_TestSession + ) + + locks = db_session.query(TaskLock).filter( + TaskLock.task == "summarize", + TaskLock.lock_key == "batch", + ).all() + for lock in locks: + assert lock.status == "finished" + assert lock.released_at is not None + + @pytest.mark.asyncio + async def test_batch_empty(self, db_session, _patch_paths): + """无 pending 论文时返回空结果。""" + result = await summarize_batch(db_session) + assert result["status"] == "success" + assert result["total"] == 0 + + +# ═══════════════════════════════════════════════════════════════════════ +# Admin 路由鉴权测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestAdminAuth: + """管理接口鉴权 — 只测 HTTP 层,mock 掉实际服务调用。""" + + def test_no_token_returns_401(self, client): + """无 Bearer token 返回 401。""" + resp = client.post("/admin/summarize") + assert resp.status_code in (401, 403) + + def test_wrong_token_returns_401(self, client): + resp = client.post( + "/admin/summarize", + headers={"Authorization": "Bearer wrong-token"}, + ) + assert resp.status_code == 401 + + def test_correct_token_batch(self, client, admin_headers): + """正确 token 调用 batch summarize,mock 掉服务层。""" + import app.config as config_mod + + original = config_mod.settings.ADMIN_TOKEN + config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345" + try: + with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock: + mock.return_value = {"status": "success", "done": 0, "failed": 0, "total": 0} + resp = client.post("/admin/summarize", headers=admin_headers) + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + finally: + config_mod.settings.ADMIN_TOKEN = original + + def test_single_paper_not_found(self, client, admin_headers): + """单篇总结不存在的论文返回 404。""" + import app.config as config_mod + + original = config_mod.settings.ADMIN_TOKEN + config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345" + try: + with patch( + "app.routes.admin.summarize_single", + new_callable=AsyncMock, + return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"}, + ): + resp = client.post( + "/admin/summarize/nonexistent.99999", + headers=admin_headers, + ) + assert resp.status_code == 404 + finally: + config_mod.settings.ADMIN_TOKEN = original