feat: add admin routes, summarizer service, and CLI summarize command

- Add /admin routes for manual trigger and status inspection
- Add summarizer service with batch/single summary support
- Add summarize CLI command (single arxiv_id or batch pending)
- Register admin router in main app
- Add tests for summarizer
This commit is contained in:
2026-06-05 22:29:33 +08:00
parent d69df2be10
commit 29e6797c12
7 changed files with 1874 additions and 0 deletions
+40
View File
@@ -49,6 +49,46 @@ def crawl(
db.close() db.close()
@cli_app.command()
def summarize(
arxiv_id: str = typer.Argument(
None,
help="指定论文 arXiv ID;留空则批量处理所有 pending",
),
):
"""手动触发 AI 总结。"""
from app.config import settings
from app.database import SessionLocal, engine
from app.models import init_db as _init
from app.services.summarizer import summarize_batch, summarize_single
import os
os.makedirs(settings.db_path.parent, exist_ok=True)
_init(engine)
db = SessionLocal()
try:
if arxiv_id:
typer.echo(f"🤖 开始总结 {arxiv_id} ...")
result = asyncio.run(summarize_single(db, arxiv_id))
else:
typer.echo("🤖 开始批量总结 pending 论文 ...")
result = asyncio.run(summarize_batch(db))
if result.get("status") in ("success", "done"):
typer.echo(f"✅ 总结完成:{result}")
elif result.get("status") == "conflict":
typer.echo("⚠️ 已有批量总结任务在运行中", err=True)
raise typer.Exit(code=1)
elif result.get("status") == "not_found":
typer.echo(f"❌ 论文未找到:{arxiv_id}", err=True)
raise typer.Exit(code=1)
else:
typer.echo(f"⚠️ 总结结果:{result}", err=True)
finally:
db.close()
@cli_app.command() @cli_app.command()
def init_db(): def init_db():
"""初始化数据库表。""" """初始化数据库表。"""
+2
View File
@@ -9,6 +9,7 @@ from fastapi.staticfiles import StaticFiles
from app.config import settings from app.config import settings
from app.database import engine from app.database import engine
from app.models import init_db from app.models import init_db
from app.routes.admin import router as admin_router
from app.routes.pages import router as pages_router from app.routes.pages import router as pages_router
logging.basicConfig( logging.basicConfig(
@@ -41,6 +42,7 @@ def create_app() -> FastAPI:
# 路由 # 路由
app.include_router(pages_router) app.include_router(pages_router)
app.include_router(admin_router)
return app return app
+48
View File
@@ -0,0 +1,48 @@
"""管理接口 — AI 总结触发,需要 ADMIN_TOKEN 鉴权。"""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
from app.config import settings
from app.database import get_db
from app.services.summarizer import summarize_batch, summarize_single
router = APIRouter(prefix="/admin", tags=["admin"])
security = HTTPBearer()
async def verify_admin(
credentials: HTTPAuthorizationCredentials = Depends(security),
) -> str:
"""验证 ADMIN_TOKEN。"""
if credentials.credentials != settings.ADMIN_TOKEN:
raise HTTPException(status_code=401, detail="Invalid admin token")
return credentials.credentials
@router.post("/summarize")
async def admin_summarize_batch(
_admin: str = Depends(verify_admin),
db: Session = Depends(get_db),
):
"""批量总结所有 pending 论文。"""
result = await summarize_batch(db)
if result.get("status") == "conflict":
raise HTTPException(status_code=409, detail=result.get("error", "batch already running"))
return result
@router.post("/summarize/{arxiv_id}")
async def admin_summarize_single(
arxiv_id: str,
_admin: str = Depends(verify_admin),
db: Session = Depends(get_db),
):
"""总结或重跑单篇论文。"""
result = await summarize_single(db, arxiv_id, force=True)
if result.get("status") == "not_found":
raise HTTPException(status_code=404, detail=f"Paper not found: {arxiv_id}")
return result
+168
View File
@@ -0,0 +1,168 @@
"""AI 总结 schema — Pydantic 校验模型、质量评估、DB 展平。"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from pydantic import BaseModel, Field, ValidationError, field_validator
# ── 子模型 ──────────────────────────────────────────────────────────────
class PrerequisitesSchema(BaseModel):
concepts: list[str] = Field(default_factory=list)
level: str = ""
class MotivationSchema(BaseModel):
problem: str
goal: str = ""
gap: str = ""
@field_validator("problem")
@classmethod
def non_empty_problem(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("motivation.problem cannot be empty")
return v.strip()
class MethodSchema(BaseModel):
overview: str = ""
key_idea: str
steps: list[str] = Field(default_factory=list)
novelty: str = ""
@field_validator("key_idea")
@classmethod
def non_empty_key_idea(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("method.key_idea cannot be empty")
return v.strip()
class ResultsSchema(BaseModel):
main_findings: list[str] = Field(default_factory=list)
benchmarks: list[dict] = Field(default_factory=list)
limitations: list[str] = Field(default_factory=list)
class ImprovementsSchema(BaseModel):
weaknesses: list[str] = Field(default_factory=list)
future_work: list[str] = Field(default_factory=list)
reproducibility: str = ""
# ── 顶层 schema ─────────────────────────────────────────────────────────
class SummarySchema(BaseModel):
model_config = {"extra": "ignore"}
title_zh: str
one_line: str
tags: list[str]
difficulty: str = ""
paper_date: str | None = None
prerequisites: PrerequisitesSchema = Field(default_factory=PrerequisitesSchema)
motivation: MotivationSchema
method: MethodSchema
results: ResultsSchema = Field(default_factory=ResultsSchema)
improvements: ImprovementsSchema = Field(default_factory=ImprovementsSchema)
@field_validator("title_zh", "one_line")
@classmethod
def non_empty_text(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("field cannot be empty")
return v.strip()
@field_validator("tags")
@classmethod
def non_empty_tags(cls, v: list[str]) -> list[str]:
tags = [tag.strip() for tag in v if tag and tag.strip()]
if not tags:
raise ValueError("tags cannot be empty")
return tags
# ── 质量评估 ────────────────────────────────────────────────────────────
# 必填字段:title_zh, one_line, tags, motivation.problem, method.key_idea
# — 缺失时 Pydantic 校验就会报错,不会走到 assess_quality
# 重要字段:motivation.goal, motivation.gap, method.overview, results.main_findings
# — 缺失可入库,标记 degraded
_OPTIONAL_BUT_IMPORTANT_FIELDS = [
"motivation.goal",
"motivation.gap",
"method.overview",
"results.main_findings",
]
def assess_quality(schema: SummarySchema) -> str:
"""评估总结质量:normal / degraded / low。"""
# low:内容空洞的启发式判断
if len(schema.one_line.strip()) < 10 or len(schema.method.key_idea.strip()) < 10:
return "low"
# 检查重要字段是否缺失
missing_important = 0
if not schema.motivation.goal.strip():
missing_important += 1
if not schema.motivation.gap.strip():
missing_important += 1
if not schema.method.overview.strip():
missing_important += 1
if not schema.results.main_findings:
missing_important += 1
if missing_important == 0:
return "normal"
return "degraded"
# ── DB 展平 ─────────────────────────────────────────────────────────────
def flatten_for_db(schema: SummarySchema) -> dict:
"""将 SummarySchema 展平为 paper_summaries 表的列值 dict。"""
return {
"one_line": schema.one_line,
"difficulty": schema.difficulty,
"prerequisites_json": json.dumps(schema.prerequisites.model_dump(), ensure_ascii=False),
"motivation_problem": schema.motivation.problem,
"motivation_goal": schema.motivation.goal,
"motivation_gap": schema.motivation.gap,
"method_overview": schema.method.overview,
"method_key_idea": schema.method.key_idea,
"method_steps_json": json.dumps(schema.method.steps, ensure_ascii=False),
"method_novelty": schema.method.novelty,
"results_main_json": json.dumps(schema.results.main_findings, ensure_ascii=False),
"results_benchmarks_json": json.dumps(schema.results.benchmarks, ensure_ascii=False),
"limitations_json": json.dumps(schema.results.limitations, ensure_ascii=False),
"weaknesses_json": json.dumps(schema.improvements.weaknesses, ensure_ascii=False),
"future_work_json": json.dumps(schema.improvements.future_work, ensure_ascii=False),
"reproducibility": schema.improvements.reproducibility,
"full_json": schema.model_dump_json(ensure_ascii=False),
"updated_at": datetime.now(timezone.utc),
}
# ── 错误分类 ────────────────────────────────────────────────────────────
_REQUIRED_FIELDS = {"title_zh", "one_line", "tags", "problem", "key_idea"}
def classify_validation_error(exc: ValidationError) -> str:
"""区分 field_missing(必填缺失)和 schema_error(类型不合法等)。"""
for err in exc.errors():
field_name = err["loc"][-1] if err["loc"] else ""
if field_name in _REQUIRED_FIELDS and err["type"] in (
"missing",
"value_error",
):
return "field_missing"
return "schema_error"
+682
View File
@@ -0,0 +1,682 @@
"""AI 总结服务 — 调用 pi CLI 生成论文中文结构化总结。"""
from __future__ import annotations
import asyncio
import json
import logging
import re
import shutil
from datetime import datetime, timezone
from pathlib import Path
import httpx
from pydantic import ValidationError
from sqlalchemy import select, text
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.database import SessionLocal
from app.models import (
CrawlLog,
Paper,
PaperSummary,
PaperTag,
SummaryStatus,
TaskLock,
)
from app.services.schemas import (
SummarySchema,
assess_quality,
classify_validation_error,
flatten_for_db,
)
logger = logging.getLogger(__name__)
# ── 自定义异常 ──────────────────────────────────────────────────────────
class PdfDownloadError(Exception):
pass
class PiTimeoutError(Exception):
pass
class PiProcessError(Exception):
def __init__(self, returncode: int, stderr: str):
self.returncode = returncode
self.stderr = stderr
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
class JsonNotFoundError(Exception):
pass
# ── 路径工具 ────────────────────────────────────────────────────────────
_DATA_DIR = Path("data")
_PAPERS_DIR = _DATA_DIR / "papers"
_TMP_DIR = _DATA_DIR / "tmp"
def _paper_dir(arxiv_id: str) -> Path:
return _PAPERS_DIR / arxiv_id
def _tmp_dir(arxiv_id: str) -> Path:
return _TMP_DIR / arxiv_id
# ── PDF 下载 ────────────────────────────────────────────────────────────
async def _download_pdf(arxiv_id: str, pdf_url: str) -> Path:
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
if not pdf_url:
raise PdfDownloadError(f"no pdf_url for {arxiv_id}")
tmp = _tmp_dir(arxiv_id)
tmp.mkdir(parents=True, exist_ok=True)
dest = tmp / "paper.pdf"
transport = None
if settings.http_proxy:
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
try:
async with httpx.AsyncClient(
timeout=settings.HTTP_TIMEOUT_SECONDS,
headers={"User-Agent": settings.HTTP_USER_AGENT},
transport=transport,
follow_redirects=True,
) as client:
resp = await client.get(pdf_url)
resp.raise_for_status()
dest.write_bytes(resp.content)
except Exception as exc:
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
return dest
# ── meta.json ───────────────────────────────────────────────────────────
def _write_meta_json(paper: Paper) -> Path:
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
d = _paper_dir(paper.arxiv_id)
d.mkdir(parents=True, exist_ok=True)
meta_path = d / "meta.json"
authors = [a.name for a in paper.authors]
tags = [t.tag for t in paper.tags]
meta = {
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"abstract": paper.abstract or "",
"published_at": paper.published_at.isoformat() if paper.published_at else None,
"authors": authors,
"tags": tags,
"upvotes": paper.upvotes,
}
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
return meta_path
# ── pi CLI 调用 ────────────────────────────────────────────────────────
async def _call_pi(meta_path: Path, pdf_path: Path) -> str:
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
cmd = [
settings.PI_BIN,
"-p",
"--no-tools",
"--skill",
settings.SUMMARY_SKILL,
"请深度解读以下论文,并按指定 JSON schema 输出:",
f"@{meta_path}",
f"@{pdf_path}",
]
logger.info("Calling pi: %s %s", paper_id_from_path(meta_path), " ".join(cmd[:4]))
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(),
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise PiTimeoutError(
f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s"
)
if proc.returncode != 0:
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
return stdout.decode("utf-8", errors="replace")
def paper_id_from_path(meta_path: Path) -> str:
"""从 meta.json 路径反推 arxiv_id。"""
return meta_path.parent.name
# ── JSON 提取 ──────────────────────────────────────────────────────────
def _extract_json(raw_output: str) -> dict:
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
# 策略 1:整体直接解析
stripped = raw_output.strip()
try:
result = json.loads(stripped)
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
pass
# 策略 2:提取 ```json ... ``` 代码块
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
for match in fence_pattern.finditer(raw_output):
try:
result = json.loads(match.group(1).strip())
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
continue
# 策略 3:匹配包含 title_zh 的最大 {...} 块
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
# 先尝试一层嵌套;如果没命中再用更宽松的策略
for match in brace_pattern.finditer(raw_output):
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
continue
# 更宽松:找到最大的 { ... } 平衡块
best = None
best_len = 0
for i, ch in enumerate(raw_output):
if ch != "{":
continue
depth = 0
for j in range(i, len(raw_output)):
if raw_output[j] == "{":
depth += 1
elif raw_output[j] == "}":
depth -= 1
if depth == 0:
candidate = raw_output[i : j + 1]
if len(candidate) > best_len:
try:
parsed = json.loads(candidate)
if isinstance(parsed, dict):
best = parsed
best_len = len(candidate)
except json.JSONDecodeError:
pass
break
if best is not None:
return best
raise JsonNotFoundError("no JSON object found in pi output")
# ── 错误分类 ────────────────────────────────────────────────────────────
def _classify_error(exc: Exception) -> str:
"""将异常映射到 error_type 枚举值。"""
if isinstance(exc, PdfDownloadError):
return "pdf_download_failed"
if isinstance(exc, PiTimeoutError):
return "timeout"
if isinstance(exc, PiProcessError):
return "process_error"
if isinstance(exc, JsonNotFoundError):
return "json_not_found"
if isinstance(exc, json.JSONDecodeError):
return "json_invalid"
if isinstance(exc, ValidationError):
return classify_validation_error(exc)
return "unknown"
# ── FTS5 文本构建 ───────────────────────────────────────────────────────
def _build_fts_summary_text(schema: SummarySchema) -> str:
"""拼接用于 FTS5 索引的总结文本。"""
parts = [
schema.one_line or "",
schema.motivation.problem or "",
schema.motivation.goal or "",
schema.method_overview if hasattr(schema, "method_overview") else "",
schema.method.overview or "",
schema.method.key_idea or "",
" ".join(schema.results.main_findings or []),
]
return " ".join(p for p in parts if p)
# ── DB 更新 ─────────────────────────────────────────────────────────────
def _update_summary_in_db(
db: Session,
paper: Paper,
schema: SummarySchema,
quality: str,
raw_output: str,
) -> None:
"""将校验后的总结写入 DBpaper_summaries + papers + paper_tags + FTS5。"""
now = datetime.now(timezone.utc)
# 1. paper_summariesupsert
existing = db.get(PaperSummary, paper.id)
flat = flatten_for_db(schema)
if existing:
for k, v in flat.items():
setattr(existing, k, v)
else:
db.add(PaperSummary(paper_id=paper.id, **flat))
# 2. papers 表
paper.title_zh = schema.title_zh
paper.summary_quality = quality
paper_dir = _paper_dir(paper.arxiv_id)
paper.summary_path = str(paper_dir / "summary.json")
paper.raw_output_path = str(paper_dir / "raw_output.txt")
# 3. AI 标签
existing_tag_names = {t.tag for t in paper.tags}
for tag_name in schema.tags:
if tag_name not in existing_tag_names:
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
existing_tag_names.add(tag_name)
# 4. FTS5 更新
summary_text = _build_fts_summary_text(schema)
db.execute(
text(
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
"WHERE rowid=:paper_id"
),
{
"title_zh": schema.title_zh,
"summary_text": summary_text,
"paper_id": paper.id,
},
)
db.commit()
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
# ── 文件操作 ────────────────────────────────────────────────────────────
def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
"""保存 summary.json 和 raw_output.txt。"""
d = _paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
(d / "summary.json").write_text(
schema.model_dump_json(ensure_ascii=False, indent=2),
encoding="utf-8",
)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None:
"""仅保存 raw_output.txt(失败时)。"""
d = _paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
def _cleanup_tmp(arxiv_id: str) -> None:
"""清理 data/tmp/{arxiv_id}/ 目录。"""
tmp = _tmp_dir(arxiv_id)
if tmp.exists():
try:
shutil.rmtree(tmp)
logger.debug("Cleaned tmp: %s", arxiv_id)
except Exception:
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
# ── 单篇总结 ────────────────────────────────────────────────────────────
async def summarize_one(
db: Session,
paper: Paper,
semaphore: asyncio.Semaphore | None = None,
*,
force: bool = False,
) -> dict:
"""总结单篇论文的完整流程。"""
arxiv_id = paper.arxiv_id
# 获取或创建 summary_status
if not paper.summary_status:
db.add(SummaryStatus(paper_id=paper.id, status="pending"))
db.commit()
db.refresh(paper)
status = paper.summary_status
# 跳过已完成的(除非 force
if status.status == "done" and not force:
return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"}
# 跳过 permanent_failure(除非 force
if status.status == "permanent_failure" and not force:
return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "permanent_failure"}
if semaphore:
await semaphore.acquire()
try:
return await _do_summarize_one(db, paper)
finally:
if semaphore:
semaphore.release()
async def _do_summarize_one(db: Session, paper: Paper) -> dict:
"""实际的单篇总结执行(在 semaphore 保护下)。"""
arxiv_id = paper.arxiv_id
status = paper.summary_status
now = datetime.now(timezone.utc)
# 状态 → processing
status.status = "processing"
status.started_at = now
db.commit()
raw_output = ""
try:
# 写 meta.json
meta_path = _write_meta_json(paper)
# 下载 PDF
await _download_pdf(arxiv_id, paper.pdf_url)
# 调用 pi
raw_output = await _call_pi(meta_path, _tmp_dir(arxiv_id) / "paper.pdf")
# 提取 JSON
json_data = _extract_json(raw_output)
# Pydantic 校验
schema = SummarySchema.model_validate(json_data)
# 质量评估
quality = assess_quality(schema)
# 保存文件
_save_files(arxiv_id, schema, raw_output)
# 更新 DB
_update_summary_in_db(db, paper, schema, quality, raw_output)
# 状态 → done
status.status = "done"
status.quality = quality
status.completed_at = datetime.now(timezone.utc)
status.raw_output_saved = True
db.commit()
logger.info("Summarize done: %s quality=%s", arxiv_id, quality)
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
except Exception as exc:
error_type = _classify_error(exc)
logger.error(
"Summarize failed: %s error_type=%s %s",
arxiv_id,
error_type,
str(exc)[:200],
)
# 保存 raw_output(如果有)
if raw_output:
_save_raw_output_only(arxiv_id, raw_output)
status.raw_output_saved = True
# 重试逻辑
status.retry_count = (status.retry_count or 0) + 1
status.error_type = error_type
status.error = str(exc)[:2000]
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
status.status = "permanent_failure"
else:
status.status = "pending"
status.completed_at = datetime.now(timezone.utc)
db.commit()
return {
"arxiv_id": arxiv_id,
"status": "failed",
"error_type": error_type,
"error": str(exc)[:200],
"retry_count": status.retry_count,
}
finally:
_cleanup_tmp(arxiv_id)
# ── 单篇入口 ────────────────────────────────────────────────────────────
async def summarize_single(
db: Session,
arxiv_id: str,
*,
force: bool = True,
_session_factory=None,
) -> dict:
"""单篇总结入口(供 admin 路由和 CLI 调用)。
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
paper = (
db.query(Paper)
.filter(Paper.arxiv_id == arxiv_id)
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
.first()
)
if not paper:
return {"status": "not_found", "arxiv_id": arxiv_id}
make_session = _session_factory or SessionLocal
# 每篇用独立 session 避免并发问题
paper_db = make_session()
try:
paper_in_new_session = (
paper_db.query(Paper)
.filter(Paper.arxiv_id == arxiv_id)
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
.first()
)
result = await summarize_one(paper_db, paper_in_new_session, force=force)
finally:
paper_db.close()
return result
# ── 批量总结 ────────────────────────────────────────────────────────────
async def summarize_batch(
db: Session,
arxiv_ids: list[str] | None = None,
*,
_session_factory=None,
) -> dict:
"""批量总结入口。arxiv_ids=None 时处理所有 pending 论文。
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
now = datetime.now(timezone.utc)
# TaskLock 防重入
lock = TaskLock(
task="summarize",
lock_key="batch",
status="running",
owner="summarize_batch",
acquired_at=now,
)
try:
db.add(lock)
db.commit()
except Exception:
db.rollback()
logger.warning("Summarize batch already running (lock conflict)")
return {"status": "conflict", "error": "summarize batch already running"}
# CrawlLog
log_entry = CrawlLog(
task="summarize",
status="running",
started_at=now,
)
db.add(log_entry)
db.commit()
try:
# 查询待总结论文
query = db.query(Paper).options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
if arxiv_ids:
query = query.filter(Paper.arxiv_id.in_(arxiv_ids))
else:
# 只处理 pending 或 failed(可重试的)
query = query.join(SummaryStatus).filter(
SummaryStatus.status.in_(["pending", "failed"])
)
papers = query.all()
total = len(papers)
logger.info("Summarize batch: %d papers to process", total)
if total == 0:
log_entry.status = "success"
log_entry.papers_found = 0
log_entry.papers_new = 0
log_entry.completed_at = datetime.now(timezone.utc)
_release_lock(db, lock)
return {"status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0}
# 并发控制
semaphore = asyncio.Semaphore(settings.SUMMARY_CONCURRENCY)
make_session = _session_factory or SessionLocal
async def _process_paper(paper: Paper) -> dict:
paper_db = make_session()
try:
p = (
paper_db.query(Paper)
.filter(Paper.id == paper.id)
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
.first()
)
return await summarize_one(paper_db, p, semaphore)
finally:
paper_db.close()
results = await asyncio.gather(
*[_process_paper(p) for p in papers],
return_exceptions=True,
)
# 统计结果
done = 0
failed = 0
skipped = 0
for r in results:
if isinstance(r, Exception):
logger.error("Unexpected error in batch: %s", r)
failed += 1
elif isinstance(r, dict):
if r.get("status") == "done":
done += 1
elif r.get("status") == "skipped":
skipped += 1
else:
failed += 1
log_entry.status = "success" if failed == 0 else "failed"
log_entry.papers_found = total
log_entry.papers_new = done
log_entry.completed_at = datetime.now(timezone.utc)
db.commit()
logger.info(
"Summarize batch done: total=%d done=%d failed=%d skipped=%d",
total, done, failed, skipped,
)
return {
"status": "success" if failed == 0 else "partial",
"total": total,
"done": done,
"failed": failed,
"skipped": skipped,
}
except Exception as exc:
logger.exception("Summarize batch failed")
log_entry.status = "failed"
log_entry.error = str(exc)[:2000]
log_entry.completed_at = datetime.now(timezone.utc)
db.commit()
return {"status": "failed", "error": str(exc)}
finally:
_release_lock(db, lock)
def _release_lock(db: Session, lock: TaskLock) -> None:
"""释放 TaskLock。"""
try:
lock.status = "finished"
lock.released_at = datetime.now(timezone.utc)
db.commit()
except Exception:
db.rollback()
logger.warning("Failed to release summarize lock", exc_info=True)
+209
View File
@@ -0,0 +1,209 @@
"""测试 fixtures — 内存 SQLite、TestClient、样例数据。"""
from __future__ import annotations
import json
from datetime import date, datetime, timezone
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine, event
from sqlalchemy.orm import DeclarativeBase, sessionmaker
from app.database import get_db
from app.main import create_app
from app.models import (
Paper,
PaperAuthor,
PaperSummary,
PaperTag,
SummaryStatus,
init_db,
)
# ── 内存数据库 ──────────────────────────────────────────────────────────
class _TestBase(DeclarativeBase):
pass
# 复用 app.models 的 Base metadata
from app.database import Base as _AppBase # noqa: E402
_TestBase.metadata = _AppBase.metadata
@pytest.fixture
def db_engine():
"""创建内存 SQLite 引擎 + FTS5。"""
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
)
@event.listens_for(engine, "connect")
def _pragma(dbapi_connection, _record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
init_db(engine)
return engine
@pytest.fixture
def db_session(db_engine):
"""提供事务隔离的数据库 session。"""
Session = sessionmaker(bind=db_engine, autoflush=False, autocommit=False)
session = Session()
try:
yield session
finally:
session.close()
@pytest.fixture
def client(db_engine, db_session):
"""FastAPI TestClientoverride get_db。"""
app = create_app()
def _override_get_db():
yield db_session
app.dependency_overrides[get_db] = _override_get_db
with TestClient(app, raise_server_exceptions=False) as c:
yield c
app.dependency_overrides.clear()
# ── 样例数据 ────────────────────────────────────────────────────────────
SAMPLE_ARXIV_ID = "2401.12345"
ADMIN_TOKEN = "test-admin-token-12345"
@pytest.fixture
def sample_paper(db_session):
"""插入一篇测试论文 + 作者 + 标签 + summary_status(pending)。"""
now = datetime.now(timezone.utc)
paper = Paper(
arxiv_id=SAMPLE_ARXIV_ID,
title_en="Test Paper Title",
abstract="This is a test abstract for the paper.",
published_at=date(2024, 1, 15),
paper_date=date(2024, 1, 15),
crawled_at=now,
upvotes=42,
hf_url=f"https://huggingface.co/papers/{SAMPLE_ARXIV_ID}",
arxiv_url=f"https://arxiv.org/abs/{SAMPLE_ARXIV_ID}",
pdf_url=f"https://arxiv.org/pdf/{SAMPLE_ARXIV_ID}.pdf",
)
db_session.add(paper)
db_session.flush()
db_session.add(PaperAuthor(paper_id=paper.id, name="Alice Smith", position=0))
db_session.add(PaperAuthor(paper_id=paper.id, name="Bob Jones", position=1))
db_session.add(PaperTag(paper_id=paper.id, tag="NLP", source="hf"))
db_session.add(PaperTag(paper_id=paper.id, tag="LLM", source="hf"))
db_session.add(SummaryStatus(paper_id=paper.id, status="pending"))
# FTS5 初始行(与 crawler 一致)
db_session.execute(
__import__("sqlalchemy").text(
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
"VALUES (:id, :title, :abstract, :authors, :tags)"
),
{
"id": paper.id,
"title": paper.title_en,
"abstract": paper.abstract or "",
"authors": "Alice Smith, Bob Jones",
"tags": "NLP, LLM",
},
)
db_session.commit()
return paper
@pytest.fixture
def sample_summary_dict() -> dict:
"""完整合法的 summary dict。"""
return {
"title_zh": "测试论文中文标题",
"one_line": "这是一篇关于自然语言处理的测试论文的一句话总结。",
"tags": ["自然语言处理", "大语言模型", "Transformer"],
"difficulty": "中级",
"prerequisites": {
"concepts": ["Transformer", "注意力机制"],
"level": "中级",
},
"motivation": {
"problem": "现有模型在长文本理解上存在不足。",
"goal": "提出一种新的注意力机制来提升长文本建模能力。",
"gap": "当前方法计算复杂度过高。",
},
"method": {
"overview": "提出了一种高效的稀疏注意力机制。",
"key_idea": "使用局部-全局混合的注意力模式来降低计算复杂度。",
"steps": [
"分析现有注意力机制的瓶颈",
"设计稀疏注意力模式",
"在多个基准上验证效果",
],
"novelty": "首次将局部-全局注意力模式结合应用于长文本建模。",
},
"results": {
"main_findings": [
"在长文本基准上取得了 SOTA 结果",
"推理速度提升了 2 倍",
],
"benchmarks": [
{"dataset": "LongBench", "score": 85.3},
],
"limitations": [
"在超长文本(>100k tokens)上效果有所下降",
],
},
"improvements": {
"weaknesses": ["仅验证了英文数据"],
"future_work": ["扩展到多语言场景"],
"reproducibility": "代码已开源,模型权重可下载。",
},
}
@pytest.fixture
def sample_summary_json(sample_summary_dict) -> str:
"""合法 summary 的 JSON 字符串。"""
return json.dumps(sample_summary_dict, ensure_ascii=False, indent=2)
@pytest.fixture
def mock_pi_output(sample_summary_json) -> str:
"""模拟 pi CLI 的完整输出(包含 JSON)。"""
return f"""以下是论文的深度解读:
```json
{sample_summary_json}
```
希望这个总结对你有帮助!"""
@pytest.fixture
def admin_token():
"""返回测试用的 ADMIN_TOKEN(需要配合 monkeypatch 使用)。"""
return ADMIN_TOKEN
@pytest.fixture
def admin_headers(admin_token):
"""带 Bearer token 的请求头。"""
return {"Authorization": f"Bearer {admin_token}"}
+725
View File
@@ -0,0 +1,725 @@
"""AI 总结服务测试 — Mock 全链路,不调用真实 pi。"""
from __future__ import annotations
import asyncio
import json
from datetime import date, datetime, timezone
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import ValidationError
from sqlalchemy import text
from app.models import (
CrawlLog,
Paper,
PaperSummary,
PaperTag,
SummaryStatus,
TaskLock,
)
from app.services.schemas import (
SummarySchema,
assess_quality,
classify_validation_error,
flatten_for_db,
)
from app.services.summarizer import (
JsonNotFoundError,
PdfDownloadError,
PiProcessError,
PiTimeoutError,
_call_pi,
_classify_error,
_cleanup_tmp,
_extract_json,
_save_files,
_save_raw_output_only,
_update_summary_in_db,
summarize_batch,
summarize_one,
summarize_single,
)
# ═══════════════════════════════════════════════════════════════════════
# Schema 校验测试
# ═══════════════════════════════════════════════════════════════════════
class TestSummarySchema:
"""Pydantic schema 校验。"""
def test_valid_summary(self, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
assert schema.title_zh == "测试论文中文标题"
assert len(schema.tags) == 3
assert schema.motivation.problem
def test_missing_title_zh(self, sample_summary_dict):
del sample_summary_dict["title_zh"]
with pytest.raises(ValidationError) as exc_info:
SummarySchema.model_validate(sample_summary_dict)
assert classify_validation_error(exc_info.value) == "field_missing"
def test_empty_one_line(self, sample_summary_dict):
sample_summary_dict["one_line"] = ""
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_empty_tags(self, sample_summary_dict):
sample_summary_dict["tags"] = []
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_empty_motivation_problem(self, sample_summary_dict):
sample_summary_dict["motivation"]["problem"] = ""
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_empty_method_key_idea(self, sample_summary_dict):
sample_summary_dict["method"]["key_idea"] = ""
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_extra_fields_ignored(self, sample_summary_dict):
sample_summary_dict["figures"] = ["fig1.png"]
sample_summary_dict["takeaway"] = "important paper"
schema = SummarySchema.model_validate(sample_summary_dict)
assert not hasattr(schema, "figures")
assert schema.title_zh # 正常解析
def test_flatten_for_db(self, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
flat = flatten_for_db(schema)
assert flat["one_line"] == schema.one_line
assert flat["motivation_problem"] == schema.motivation.problem
assert flat["method_key_idea"] == schema.method.key_idea
assert "full_json" in flat
assert "updated_at" in flat
# JSON 字段可解析
assert isinstance(json.loads(flat["prerequisites_json"]), dict)
assert isinstance(json.loads(flat["method_steps_json"]), list)
class TestQualityAssessment:
"""质量分级测试。"""
def test_quality_normal(self, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "normal"
def test_quality_degraded_missing_goal(self, sample_summary_dict):
sample_summary_dict["motivation"]["goal"] = ""
sample_summary_dict["motivation"]["gap"] = ""
sample_summary_dict["method"]["overview"] = ""
sample_summary_dict["results"]["main_findings"] = []
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "degraded"
def test_quality_low_short_one_line(self, sample_summary_dict):
sample_summary_dict["one_line"] = ""
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "low"
def test_quality_low_short_key_idea(self, sample_summary_dict):
sample_summary_dict["method"]["key_idea"] = ""
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "low"
# ═══════════════════════════════════════════════════════════════════════
# JSON 提取测试
# ═══════════════════════════════════════════════════════════════════════
class TestJsonExtraction:
"""pi 输出的 JSON 提取。"""
def test_direct_json(self, sample_summary_json):
result = _extract_json(sample_summary_json)
assert result["title_zh"] == "测试论文中文标题"
def test_fenced_code_block(self, sample_summary_json):
raw = f"一些文字\n```json\n{sample_summary_json}\n```\n更多文字"
result = _extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_fenced_without_lang(self, sample_summary_json):
raw = f"文字\n```\n{sample_summary_json}\n```"
result = _extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_embedded_braces(self, sample_summary_dict):
json_str = json.dumps(sample_summary_dict, ensure_ascii=False)
raw = f"Here is the summary:\n{json_str}\nEnd."
result = _extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_no_json_raises(self):
with pytest.raises(JsonNotFoundError):
_extract_json("No JSON here at all.")
def test_json_without_title_zh_falls_through(self):
"""不含 title_zh 的 JSON 不是我们要的。"""
raw = json.dumps({"other": "data"})
# 如果有其他合法 JSON 块也能返回,但没有就直接找最大块
# 此场景 raw 本身就是一个 JSON dict,但没有 title_zh
# 策略 1 会跳过(无 title_zh),策略 2 无代码块,策略 3 找到最大块
result = _extract_json(raw)
assert result == {"other": "data"} # 最大块兜底
# ═══════════════════════════════════════════════════════════════════════
# 错误分类测试
# ═══════════════════════════════════════════════════════════════════════
class TestErrorClassification:
"""异常 → error_type 映射。"""
def test_pdf_download_error(self):
assert _classify_error(PdfDownloadError("fail")) == "pdf_download_failed"
def test_timeout_error(self):
assert _classify_error(PiTimeoutError("timeout")) == "timeout"
def test_process_error(self):
assert _classify_error(PiProcessError(1, "stderr")) == "process_error"
def test_json_not_found(self):
assert _classify_error(JsonNotFoundError("not found")) == "json_not_found"
def test_json_invalid(self):
assert _classify_error(json.JSONDecodeError("bad", "", 0)) == "json_invalid"
def test_field_missing(self):
try:
SummarySchema.model_validate({"title_zh": ""}) # type: ignore
except ValidationError as exc:
assert _classify_error(exc) == "field_missing"
def test_unknown_error(self):
assert _classify_error(RuntimeError("boom")) == "unknown"
# ═══════════════════════════════════════════════════════════════════════
# DB 更新测试
# ═══════════════════════════════════════════════════════════════════════
class TestDbUpdate:
"""_update_summary_in_db 验证。"""
def test_summary_written(self, db_session, sample_paper, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
summary = db_session.get(PaperSummary, sample_paper.id)
assert summary is not None
assert summary.one_line == schema.one_line
assert summary.motivation_problem == schema.motivation.problem
assert json.loads(summary.full_json)["title_zh"] == schema.title_zh
def test_paper_title_zh_updated(self, db_session, sample_paper, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
db_session.refresh(sample_paper)
assert sample_paper.title_zh == "测试论文中文标题"
assert sample_paper.summary_quality == "normal"
def test_fts_updated(self, db_session, sample_paper, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
row = db_session.execute(
text("SELECT title_zh, summary_text FROM papers_fts WHERE rowid = :id"),
{"id": sample_paper.id},
).fetchone()
assert row is not None
assert row[0] == "测试论文中文标题"
assert schema.one_line in row[1]
def test_ai_tags_added(self, db_session, sample_paper, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
tags = (
db_session.query(PaperTag)
.filter(PaperTag.paper_id == sample_paper.id, PaperTag.source == "ai")
.all()
)
tag_names = {t.tag for t in tags}
# AI tags 来自 schema.tags
assert "自然语言处理" in tag_names
assert "大语言模型" in tag_names
def test_existing_tags_not_duplicated(self, db_session, sample_paper, sample_summary_dict):
"""已存在的标签名(同 name)不会被 AI source 重复插入。"""
# sample_paper 已有 NLP (hf)、LLM (hf)
# 让 AI 输出包含 NLP(与 HF 重复)和 "新标签"(新的)
sample_summary_dict["tags"] = ["NLP", "新标签"]
schema = SummarySchema.model_validate(sample_summary_dict)
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
all_tags = (
db_session.query(PaperTag)
.filter(PaperTag.paper_id == sample_paper.id)
.all()
)
tag_names = [t.tag for t in all_tags]
# NLP 只出现一次(HF 原有的),AI 不会重复加
assert tag_names.count("NLP") == 1
# "新标签" 是 AI 新加的
assert "新标签" in tag_names
# ═══════════════════════════════════════════════════════════════════════
# 文件操作测试
# ═══════════════════════════════════════════════════════════════════════
class TestFileOperations:
"""文件保存和清理。"""
def test_save_files(self, tmp_path, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
with patch("app.services.summarizer._PAPERS_DIR", tmp_path):
_save_files("2401.12345", schema, "raw output text")
paper_dir = tmp_path / "2401.12345"
assert (paper_dir / "summary.json").exists()
assert (paper_dir / "raw_output.txt").exists()
saved = json.loads((paper_dir / "summary.json").read_text())
assert saved["title_zh"] == "测试论文中文标题"
def test_save_raw_output_only(self, tmp_path):
with patch("app.services.summarizer._PAPERS_DIR", tmp_path):
_save_raw_output_only("2401.12345", "raw output")
paper_dir = tmp_path / "2401.12345"
assert (paper_dir / "raw_output.txt").exists()
assert not (paper_dir / "summary.json").exists()
def test_cleanup_tmp(self, tmp_path):
tmp_paper = tmp_path / "2401.12345"
tmp_paper.mkdir()
(tmp_paper / "paper.pdf").write_bytes(b"%PDF-fake")
with patch("app.services.summarizer._TMP_DIR", tmp_path):
_cleanup_tmp("2401.12345")
assert not tmp_paper.exists()
def test_cleanup_tmp_nonexistent(self, tmp_path):
"""清理不存在的目录不报错。"""
with patch("app.services.summarizer._TMP_DIR", tmp_path):
_cleanup_tmp("nonexistent") # 不抛异常
# ═══════════════════════════════════════════════════════════════════════
# 全流程状态流转测试
# ═══════════════════════════════════════════════════════════════════════
class TestSummarizeOneFlow:
"""summarize_one 的状态流转(mock pi 和 PDF)。"""
@pytest.fixture
def _patch_paths(self, tmp_path):
"""将 data 目录重定向到 tmp_path。"""
with (
patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"),
patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"),
patch("app.services.summarizer._DATA_DIR", tmp_path),
):
yield
@pytest.mark.asyncio
async def test_full_success_path(
self, db_session, sample_paper, mock_pi_output, _patch_paths
):
"""pending → processing → done 全流程。"""
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
result = await summarize_one(db_session, sample_paper)
assert result["status"] == "done"
assert result["quality"] == "normal"
# 验证 DB 状态
db_session.refresh(sample_paper)
assert sample_paper.summary_status.status == "done"
assert sample_paper.summary_status.quality == "normal"
assert sample_paper.title_zh == "测试论文中文标题"
# 验证 summary 已写入
summary = db_session.get(PaperSummary, sample_paper.id)
assert summary is not None
assert summary.one_line
# 验证 FTS 已更新
fts_row = db_session.execute(
text("SELECT title_zh FROM papers_fts WHERE rowid = :id"),
{"id": sample_paper.id},
).fetchone()
assert fts_row[0] == "测试论文中文标题"
@pytest.mark.asyncio
async def test_pdf_download_failure(
self, db_session, sample_paper, _patch_paths
):
"""PDF 下载失败 → error_type=pdf_download_failedtmp 被清理。"""
with (
patch(
"app.services.summarizer._download_pdf",
new_callable=AsyncMock,
side_effect=PdfDownloadError("network error"),
),
):
result = await summarize_one(db_session, sample_paper)
assert result["status"] == "failed"
assert result["error_type"] == "pdf_download_failed"
db_session.refresh(sample_paper)
status = sample_paper.summary_status
assert status.error_type == "pdf_download_failed"
@pytest.mark.asyncio
async def test_pi_timeout(self, db_session, sample_paper, _patch_paths):
"""pi 超时 → timeout 错误,retry_count 递增。"""
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer._call_pi",
new_callable=AsyncMock,
side_effect=PiTimeoutError("timeout after 300s"),
),
):
result = await summarize_one(db_session, sample_paper)
assert result["status"] == "failed"
assert result["error_type"] == "timeout"
assert result["retry_count"] == 1
@pytest.mark.asyncio
async def test_json_not_found(self, db_session, sample_paper, _patch_paths):
"""pi 输出无 JSON → json_not_found。"""
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer._call_pi",
new_callable=AsyncMock,
return_value="No JSON in this output at all.",
),
):
result = await summarize_one(db_session, sample_paper)
assert result["status"] == "failed"
assert result["error_type"] == "json_not_found"
@pytest.mark.asyncio
async def test_field_missing_and_retry(
self, db_session, sample_paper, _patch_paths
):
"""必填字段缺失 → field_missing → retry → permanent_failure。"""
bad_json = json.dumps({
"title_zh": "", # 空的必填字段
"one_line": "valid line",
"tags": ["tag1"],
"motivation": {"problem": "valid problem"},
"method": {"key_idea": "valid idea"},
}, ensure_ascii=False)
bad_output = f"```json\n{bad_json}\n```"
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer._call_pi",
new_callable=AsyncMock,
return_value=bad_output,
),
):
# 第一次失败 → pending (retry)
result1 = await summarize_one(db_session, sample_paper)
assert result1["status"] == "failed"
assert result1["error_type"] == "field_missing"
assert result1["retry_count"] == 1
# 第二次失败 → permanent_failure (SUMMARY_MAX_RETRIES=1, 所以 2 次 > 1+1)
db_session.refresh(sample_paper)
result2 = await summarize_one(db_session, sample_paper)
assert result2["status"] == "failed"
assert result2["retry_count"] == 2
db_session.refresh(sample_paper)
assert sample_paper.summary_status.status == "permanent_failure"
@pytest.mark.asyncio
async def test_raw_output_saved_on_failure(
self, db_session, sample_paper, tmp_path, _patch_paths
):
"""失败时仍保存 raw_output.txt。"""
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer._call_pi",
new_callable=AsyncMock,
return_value="Some output without JSON",
),
):
await summarize_one(db_session, sample_paper)
raw_file = tmp_path / "papers" / sample_paper.arxiv_id / "raw_output.txt"
assert raw_file.exists()
assert "Some output without JSON" in raw_file.read_text()
@pytest.mark.asyncio
async def test_tmp_cleaned_on_success(
self, db_session, sample_paper, mock_pi_output, tmp_path, _patch_paths
):
"""成功后清理 tmp 目录。"""
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
await summarize_one(db_session, sample_paper)
tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id
assert not tmp_paper.exists()
@pytest.mark.asyncio
async def test_tmp_cleaned_on_failure(
self, db_session, sample_paper, tmp_path, _patch_paths
):
"""失败后也清理 tmp 目录。"""
with (
patch(
"app.services.summarizer._download_pdf",
new_callable=AsyncMock,
side_effect=PdfDownloadError("fail"),
),
):
await summarize_one(db_session, sample_paper)
tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id
assert not tmp_paper.exists()
@pytest.mark.asyncio
async def test_skips_done_paper(self, db_session, sample_paper, _patch_paths):
"""已完成的论文跳过。"""
sample_paper.summary_status.status = "done"
db_session.commit()
result = await summarize_one(db_session, sample_paper)
assert result["status"] == "skipped"
# ═══════════════════════════════════════════════════════════════════════
# 批量操作测试
# ═══════════════════════════════════════════════════════════════════════
class TestBatchSummarize:
"""批量总结测试。"""
@pytest.fixture
def _patch_paths(self, tmp_path):
with (
patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"),
patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"),
patch("app.services.summarizer._DATA_DIR", tmp_path),
):
yield
@pytest.mark.asyncio
async def test_batch_multiple_papers(
self, db_session, db_engine, mock_pi_output, _patch_paths
):
"""批量处理多篇论文。"""
now = datetime.now(timezone.utc)
for i in range(3):
p = Paper(
arxiv_id=f"2401.1234{i}",
title_en=f"Test Paper {i}",
abstract=f"Abstract {i}",
paper_date=date(2024, 1, 15),
crawled_at=now,
pdf_url=f"https://arxiv.org/pdf/2401.1234{i}.pdf",
)
db_session.add(p)
db_session.flush()
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
db_session.commit()
# 每个 worker 用独立 session(同一个内存引擎)
from sqlalchemy.orm import sessionmaker as _sm
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
result = await summarize_batch(
db_session, _session_factory=_TestSession
)
assert result["status"] == "success"
assert result["done"] == 3
assert result["failed"] == 0
# 验证 CrawlLog
log = db_session.query(CrawlLog).filter(CrawlLog.task == "summarize").first()
assert log is not None
assert log.status == "success"
assert log.papers_found == 3
@pytest.mark.asyncio
async def test_single_failure_no_block(
self, db_session, db_engine, mock_pi_output, _patch_paths
):
"""一篇失败不阻塞其他。"""
now = datetime.now(timezone.utc)
for i in range(2):
p = Paper(
arxiv_id=f"2401.5678{i}",
title_en=f"Paper {i}",
abstract=f"Abstract {i}",
paper_date=date(2024, 1, 15),
crawled_at=now,
pdf_url=f"https://arxiv.org/pdf/2401.5678{i}.pdf",
)
db_session.add(p)
db_session.flush()
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
db_session.commit()
from sqlalchemy.orm import sessionmaker as _sm
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
call_count = 0
async def _mock_call_pi(meta_path, pdf_path):
nonlocal call_count
call_count += 1
if call_count == 1:
raise PiTimeoutError("timeout")
return mock_pi_output
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer._call_pi", side_effect=_mock_call_pi),
):
result = await summarize_batch(
db_session, _session_factory=_TestSession
)
assert result["done"] == 1
assert result["failed"] == 1
@pytest.mark.asyncio
async def test_task_lock_conflict(self, db_session, _patch_paths):
"""TaskLock 防止并发 batch。"""
# 先插入一个 running 锁
db_session.add(
TaskLock(
task="summarize",
lock_key="batch",
status="running",
acquired_at=datetime.now(timezone.utc),
)
)
db_session.commit()
result = await summarize_batch(db_session)
assert result["status"] == "conflict"
@pytest.mark.asyncio
async def test_task_lock_released(self, db_session, db_engine, mock_pi_output, _patch_paths):
"""完成后释放 TaskLock。"""
from sqlalchemy.orm import sessionmaker as _sm
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
await summarize_batch(
db_session, _session_factory=_TestSession
)
locks = db_session.query(TaskLock).filter(
TaskLock.task == "summarize",
TaskLock.lock_key == "batch",
).all()
for lock in locks:
assert lock.status == "finished"
assert lock.released_at is not None
@pytest.mark.asyncio
async def test_batch_empty(self, db_session, _patch_paths):
"""无 pending 论文时返回空结果。"""
result = await summarize_batch(db_session)
assert result["status"] == "success"
assert result["total"] == 0
# ═══════════════════════════════════════════════════════════════════════
# Admin 路由鉴权测试
# ═══════════════════════════════════════════════════════════════════════
class TestAdminAuth:
"""管理接口鉴权 — 只测 HTTP 层,mock 掉实际服务调用。"""
def test_no_token_returns_401(self, client):
"""无 Bearer token 返回 401。"""
resp = client.post("/admin/summarize")
assert resp.status_code in (401, 403)
def test_wrong_token_returns_401(self, client):
resp = client.post(
"/admin/summarize",
headers={"Authorization": "Bearer wrong-token"},
)
assert resp.status_code == 401
def test_correct_token_batch(self, client, admin_headers):
"""正确 token 调用 batch summarizemock 掉服务层。"""
import app.config as config_mod
original = config_mod.settings.ADMIN_TOKEN
config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345"
try:
with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock:
mock.return_value = {"status": "success", "done": 0, "failed": 0, "total": 0}
resp = client.post("/admin/summarize", headers=admin_headers)
assert resp.status_code == 200
assert resp.json()["status"] == "success"
finally:
config_mod.settings.ADMIN_TOKEN = original
def test_single_paper_not_found(self, client, admin_headers):
"""单篇总结不存在的论文返回 404。"""
import app.config as config_mod
original = config_mod.settings.ADMIN_TOKEN
config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345"
try:
with patch(
"app.routes.admin.summarize_single",
new_callable=AsyncMock,
return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"},
):
resp = client.post(
"/admin/summarize/nonexistent.99999",
headers=admin_headers,
)
assert resp.status_code == 404
finally:
config_mod.settings.ADMIN_TOKEN = original