feat: add admin routes, summarizer service, and CLI summarize command
- Add /admin routes for manual trigger and status inspection - Add summarizer service with batch/single summary support - Add summarize CLI command (single arxiv_id or batch pending) - Register admin router in main app - Add tests for summarizer
This commit is contained in:
+40
@@ -49,6 +49,46 @@ def crawl(
|
|||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
@cli_app.command()
|
||||||
|
def summarize(
|
||||||
|
arxiv_id: str = typer.Argument(
|
||||||
|
None,
|
||||||
|
help="指定论文 arXiv ID;留空则批量处理所有 pending",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""手动触发 AI 总结。"""
|
||||||
|
from app.config import settings
|
||||||
|
from app.database import SessionLocal, engine
|
||||||
|
from app.models import init_db as _init
|
||||||
|
from app.services.summarizer import summarize_batch, summarize_single
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.makedirs(settings.db_path.parent, exist_ok=True)
|
||||||
|
_init(engine)
|
||||||
|
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
if arxiv_id:
|
||||||
|
typer.echo(f"🤖 开始总结 {arxiv_id} ...")
|
||||||
|
result = asyncio.run(summarize_single(db, arxiv_id))
|
||||||
|
else:
|
||||||
|
typer.echo("🤖 开始批量总结 pending 论文 ...")
|
||||||
|
result = asyncio.run(summarize_batch(db))
|
||||||
|
|
||||||
|
if result.get("status") in ("success", "done"):
|
||||||
|
typer.echo(f"✅ 总结完成:{result}")
|
||||||
|
elif result.get("status") == "conflict":
|
||||||
|
typer.echo("⚠️ 已有批量总结任务在运行中", err=True)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
elif result.get("status") == "not_found":
|
||||||
|
typer.echo(f"❌ 论文未找到:{arxiv_id}", err=True)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
else:
|
||||||
|
typer.echo(f"⚠️ 总结结果:{result}", err=True)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
@cli_app.command()
|
@cli_app.command()
|
||||||
def init_db():
|
def init_db():
|
||||||
"""初始化数据库表。"""
|
"""初始化数据库表。"""
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from fastapi.staticfiles import StaticFiles
|
|||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database import engine
|
from app.database import engine
|
||||||
from app.models import init_db
|
from app.models import init_db
|
||||||
|
from app.routes.admin import router as admin_router
|
||||||
from app.routes.pages import router as pages_router
|
from app.routes.pages import router as pages_router
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -41,6 +42,7 @@ def create_app() -> FastAPI:
|
|||||||
|
|
||||||
# 路由
|
# 路由
|
||||||
app.include_router(pages_router)
|
app.include_router(pages_router)
|
||||||
|
app.include_router(admin_router)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,48 @@
|
|||||||
|
"""管理接口 — AI 总结触发,需要 ADMIN_TOKEN 鉴权。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.database import get_db
|
||||||
|
from app.services.summarizer import summarize_batch, summarize_single
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||||
|
security = HTTPBearer()
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_admin(
|
||||||
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||||
|
) -> str:
|
||||||
|
"""验证 ADMIN_TOKEN。"""
|
||||||
|
if credentials.credentials != settings.ADMIN_TOKEN:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid admin token")
|
||||||
|
return credentials.credentials
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/summarize")
|
||||||
|
async def admin_summarize_batch(
|
||||||
|
_admin: str = Depends(verify_admin),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""批量总结所有 pending 论文。"""
|
||||||
|
result = await summarize_batch(db)
|
||||||
|
if result.get("status") == "conflict":
|
||||||
|
raise HTTPException(status_code=409, detail=result.get("error", "batch already running"))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/summarize/{arxiv_id}")
|
||||||
|
async def admin_summarize_single(
|
||||||
|
arxiv_id: str,
|
||||||
|
_admin: str = Depends(verify_admin),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""总结或重跑单篇论文。"""
|
||||||
|
result = await summarize_single(db, arxiv_id, force=True)
|
||||||
|
if result.get("status") == "not_found":
|
||||||
|
raise HTTPException(status_code=404, detail=f"Paper not found: {arxiv_id}")
|
||||||
|
return result
|
||||||
@@ -0,0 +1,168 @@
|
|||||||
|
"""AI 总结 schema — Pydantic 校验模型、质量评估、DB 展平。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
# ── 子模型 ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class PrerequisitesSchema(BaseModel):
|
||||||
|
concepts: list[str] = Field(default_factory=list)
|
||||||
|
level: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class MotivationSchema(BaseModel):
|
||||||
|
problem: str
|
||||||
|
goal: str = ""
|
||||||
|
gap: str = ""
|
||||||
|
|
||||||
|
@field_validator("problem")
|
||||||
|
@classmethod
|
||||||
|
def non_empty_problem(cls, v: str) -> str:
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("motivation.problem cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class MethodSchema(BaseModel):
|
||||||
|
overview: str = ""
|
||||||
|
key_idea: str
|
||||||
|
steps: list[str] = Field(default_factory=list)
|
||||||
|
novelty: str = ""
|
||||||
|
|
||||||
|
@field_validator("key_idea")
|
||||||
|
@classmethod
|
||||||
|
def non_empty_key_idea(cls, v: str) -> str:
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("method.key_idea cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class ResultsSchema(BaseModel):
|
||||||
|
main_findings: list[str] = Field(default_factory=list)
|
||||||
|
benchmarks: list[dict] = Field(default_factory=list)
|
||||||
|
limitations: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ImprovementsSchema(BaseModel):
|
||||||
|
weaknesses: list[str] = Field(default_factory=list)
|
||||||
|
future_work: list[str] = Field(default_factory=list)
|
||||||
|
reproducibility: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
# ── 顶层 schema ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class SummarySchema(BaseModel):
|
||||||
|
model_config = {"extra": "ignore"}
|
||||||
|
|
||||||
|
title_zh: str
|
||||||
|
one_line: str
|
||||||
|
tags: list[str]
|
||||||
|
difficulty: str = ""
|
||||||
|
paper_date: str | None = None
|
||||||
|
prerequisites: PrerequisitesSchema = Field(default_factory=PrerequisitesSchema)
|
||||||
|
motivation: MotivationSchema
|
||||||
|
method: MethodSchema
|
||||||
|
results: ResultsSchema = Field(default_factory=ResultsSchema)
|
||||||
|
improvements: ImprovementsSchema = Field(default_factory=ImprovementsSchema)
|
||||||
|
|
||||||
|
@field_validator("title_zh", "one_line")
|
||||||
|
@classmethod
|
||||||
|
def non_empty_text(cls, v: str) -> str:
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("field cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
@field_validator("tags")
|
||||||
|
@classmethod
|
||||||
|
def non_empty_tags(cls, v: list[str]) -> list[str]:
|
||||||
|
tags = [tag.strip() for tag in v if tag and tag.strip()]
|
||||||
|
if not tags:
|
||||||
|
raise ValueError("tags cannot be empty")
|
||||||
|
return tags
|
||||||
|
|
||||||
|
|
||||||
|
# ── 质量评估 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# 必填字段:title_zh, one_line, tags, motivation.problem, method.key_idea
|
||||||
|
# — 缺失时 Pydantic 校验就会报错,不会走到 assess_quality
|
||||||
|
# 重要字段:motivation.goal, motivation.gap, method.overview, results.main_findings
|
||||||
|
# — 缺失可入库,标记 degraded
|
||||||
|
_OPTIONAL_BUT_IMPORTANT_FIELDS = [
|
||||||
|
"motivation.goal",
|
||||||
|
"motivation.gap",
|
||||||
|
"method.overview",
|
||||||
|
"results.main_findings",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def assess_quality(schema: SummarySchema) -> str:
|
||||||
|
"""评估总结质量:normal / degraded / low。"""
|
||||||
|
# low:内容空洞的启发式判断
|
||||||
|
if len(schema.one_line.strip()) < 10 or len(schema.method.key_idea.strip()) < 10:
|
||||||
|
return "low"
|
||||||
|
|
||||||
|
# 检查重要字段是否缺失
|
||||||
|
missing_important = 0
|
||||||
|
if not schema.motivation.goal.strip():
|
||||||
|
missing_important += 1
|
||||||
|
if not schema.motivation.gap.strip():
|
||||||
|
missing_important += 1
|
||||||
|
if not schema.method.overview.strip():
|
||||||
|
missing_important += 1
|
||||||
|
if not schema.results.main_findings:
|
||||||
|
missing_important += 1
|
||||||
|
|
||||||
|
if missing_important == 0:
|
||||||
|
return "normal"
|
||||||
|
return "degraded"
|
||||||
|
|
||||||
|
|
||||||
|
# ── DB 展平 ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_for_db(schema: SummarySchema) -> dict:
|
||||||
|
"""将 SummarySchema 展平为 paper_summaries 表的列值 dict。"""
|
||||||
|
return {
|
||||||
|
"one_line": schema.one_line,
|
||||||
|
"difficulty": schema.difficulty,
|
||||||
|
"prerequisites_json": json.dumps(schema.prerequisites.model_dump(), ensure_ascii=False),
|
||||||
|
"motivation_problem": schema.motivation.problem,
|
||||||
|
"motivation_goal": schema.motivation.goal,
|
||||||
|
"motivation_gap": schema.motivation.gap,
|
||||||
|
"method_overview": schema.method.overview,
|
||||||
|
"method_key_idea": schema.method.key_idea,
|
||||||
|
"method_steps_json": json.dumps(schema.method.steps, ensure_ascii=False),
|
||||||
|
"method_novelty": schema.method.novelty,
|
||||||
|
"results_main_json": json.dumps(schema.results.main_findings, ensure_ascii=False),
|
||||||
|
"results_benchmarks_json": json.dumps(schema.results.benchmarks, ensure_ascii=False),
|
||||||
|
"limitations_json": json.dumps(schema.results.limitations, ensure_ascii=False),
|
||||||
|
"weaknesses_json": json.dumps(schema.improvements.weaknesses, ensure_ascii=False),
|
||||||
|
"future_work_json": json.dumps(schema.improvements.future_work, ensure_ascii=False),
|
||||||
|
"reproducibility": schema.improvements.reproducibility,
|
||||||
|
"full_json": schema.model_dump_json(ensure_ascii=False),
|
||||||
|
"updated_at": datetime.now(timezone.utc),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 错误分类 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_REQUIRED_FIELDS = {"title_zh", "one_line", "tags", "problem", "key_idea"}
|
||||||
|
|
||||||
|
|
||||||
|
def classify_validation_error(exc: ValidationError) -> str:
|
||||||
|
"""区分 field_missing(必填缺失)和 schema_error(类型不合法等)。"""
|
||||||
|
for err in exc.errors():
|
||||||
|
field_name = err["loc"][-1] if err["loc"] else ""
|
||||||
|
if field_name in _REQUIRED_FIELDS and err["type"] in (
|
||||||
|
"missing",
|
||||||
|
"value_error",
|
||||||
|
):
|
||||||
|
return "field_missing"
|
||||||
|
return "schema_error"
|
||||||
@@ -0,0 +1,682 @@
|
|||||||
|
"""AI 总结服务 — 调用 pi CLI 生成论文中文结构化总结。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from sqlalchemy import select, text
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.database import SessionLocal
|
||||||
|
from app.models import (
|
||||||
|
CrawlLog,
|
||||||
|
Paper,
|
||||||
|
PaperSummary,
|
||||||
|
PaperTag,
|
||||||
|
SummaryStatus,
|
||||||
|
TaskLock,
|
||||||
|
)
|
||||||
|
from app.services.schemas import (
|
||||||
|
SummarySchema,
|
||||||
|
assess_quality,
|
||||||
|
classify_validation_error,
|
||||||
|
flatten_for_db,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── 自定义异常 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class PdfDownloadError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PiTimeoutError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PiProcessError(Exception):
|
||||||
|
def __init__(self, returncode: int, stderr: str):
|
||||||
|
self.returncode = returncode
|
||||||
|
self.stderr = stderr
|
||||||
|
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
|
||||||
|
|
||||||
|
|
||||||
|
class JsonNotFoundError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ── 路径工具 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_DATA_DIR = Path("data")
|
||||||
|
_PAPERS_DIR = _DATA_DIR / "papers"
|
||||||
|
_TMP_DIR = _DATA_DIR / "tmp"
|
||||||
|
|
||||||
|
|
||||||
|
def _paper_dir(arxiv_id: str) -> Path:
|
||||||
|
return _PAPERS_DIR / arxiv_id
|
||||||
|
|
||||||
|
|
||||||
|
def _tmp_dir(arxiv_id: str) -> Path:
|
||||||
|
return _TMP_DIR / arxiv_id
|
||||||
|
|
||||||
|
|
||||||
|
# ── PDF 下载 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _download_pdf(arxiv_id: str, pdf_url: str) -> Path:
|
||||||
|
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
|
||||||
|
if not pdf_url:
|
||||||
|
raise PdfDownloadError(f"no pdf_url for {arxiv_id}")
|
||||||
|
|
||||||
|
tmp = _tmp_dir(arxiv_id)
|
||||||
|
tmp.mkdir(parents=True, exist_ok=True)
|
||||||
|
dest = tmp / "paper.pdf"
|
||||||
|
|
||||||
|
transport = None
|
||||||
|
if settings.http_proxy:
|
||||||
|
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
timeout=settings.HTTP_TIMEOUT_SECONDS,
|
||||||
|
headers={"User-Agent": settings.HTTP_USER_AGENT},
|
||||||
|
transport=transport,
|
||||||
|
follow_redirects=True,
|
||||||
|
) as client:
|
||||||
|
resp = await client.get(pdf_url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
dest.write_bytes(resp.content)
|
||||||
|
except Exception as exc:
|
||||||
|
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
|
||||||
|
|
||||||
|
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
|
||||||
|
return dest
|
||||||
|
|
||||||
|
|
||||||
|
# ── meta.json ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _write_meta_json(paper: Paper) -> Path:
|
||||||
|
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
|
||||||
|
d = _paper_dir(paper.arxiv_id)
|
||||||
|
d.mkdir(parents=True, exist_ok=True)
|
||||||
|
meta_path = d / "meta.json"
|
||||||
|
|
||||||
|
authors = [a.name for a in paper.authors]
|
||||||
|
tags = [t.tag for t in paper.tags]
|
||||||
|
meta = {
|
||||||
|
"arxiv_id": paper.arxiv_id,
|
||||||
|
"title_en": paper.title_en,
|
||||||
|
"abstract": paper.abstract or "",
|
||||||
|
"published_at": paper.published_at.isoformat() if paper.published_at else None,
|
||||||
|
"authors": authors,
|
||||||
|
"tags": tags,
|
||||||
|
"upvotes": paper.upvotes,
|
||||||
|
}
|
||||||
|
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
return meta_path
|
||||||
|
|
||||||
|
|
||||||
|
# ── pi CLI 调用 ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_pi(meta_path: Path, pdf_path: Path) -> str:
|
||||||
|
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
|
||||||
|
cmd = [
|
||||||
|
settings.PI_BIN,
|
||||||
|
"-p",
|
||||||
|
"--no-tools",
|
||||||
|
"--skill",
|
||||||
|
settings.SUMMARY_SKILL,
|
||||||
|
"请深度解读以下论文,并按指定 JSON schema 输出:",
|
||||||
|
f"@{meta_path}",
|
||||||
|
f"@{pdf_path}",
|
||||||
|
]
|
||||||
|
logger.info("Calling pi: %s %s", paper_id_from_path(meta_path), " ".join(cmd[:4]))
|
||||||
|
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*cmd,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
stdout, stderr = await asyncio.wait_for(
|
||||||
|
proc.communicate(),
|
||||||
|
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
proc.kill()
|
||||||
|
await proc.wait()
|
||||||
|
raise PiTimeoutError(
|
||||||
|
f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
if proc.returncode != 0:
|
||||||
|
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
|
||||||
|
|
||||||
|
return stdout.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
|
||||||
|
def paper_id_from_path(meta_path: Path) -> str:
|
||||||
|
"""从 meta.json 路径反推 arxiv_id。"""
|
||||||
|
return meta_path.parent.name
|
||||||
|
|
||||||
|
|
||||||
|
# ── JSON 提取 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json(raw_output: str) -> dict:
|
||||||
|
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
|
||||||
|
# 策略 1:整体直接解析
|
||||||
|
stripped = raw_output.strip()
|
||||||
|
try:
|
||||||
|
result = json.loads(stripped)
|
||||||
|
if isinstance(result, dict) and "title_zh" in result:
|
||||||
|
return result
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 策略 2:提取 ```json ... ``` 代码块
|
||||||
|
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
|
||||||
|
for match in fence_pattern.finditer(raw_output):
|
||||||
|
try:
|
||||||
|
result = json.loads(match.group(1).strip())
|
||||||
|
if isinstance(result, dict) and "title_zh" in result:
|
||||||
|
return result
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 策略 3:匹配包含 title_zh 的最大 {...} 块
|
||||||
|
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
|
||||||
|
# 先尝试一层嵌套;如果没命中再用更宽松的策略
|
||||||
|
for match in brace_pattern.finditer(raw_output):
|
||||||
|
try:
|
||||||
|
return json.loads(match.group(0))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 更宽松:找到最大的 { ... } 平衡块
|
||||||
|
best = None
|
||||||
|
best_len = 0
|
||||||
|
for i, ch in enumerate(raw_output):
|
||||||
|
if ch != "{":
|
||||||
|
continue
|
||||||
|
depth = 0
|
||||||
|
for j in range(i, len(raw_output)):
|
||||||
|
if raw_output[j] == "{":
|
||||||
|
depth += 1
|
||||||
|
elif raw_output[j] == "}":
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
candidate = raw_output[i : j + 1]
|
||||||
|
if len(candidate) > best_len:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(candidate)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
best = parsed
|
||||||
|
best_len = len(candidate)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
break
|
||||||
|
|
||||||
|
if best is not None:
|
||||||
|
return best
|
||||||
|
|
||||||
|
raise JsonNotFoundError("no JSON object found in pi output")
|
||||||
|
|
||||||
|
|
||||||
|
# ── 错误分类 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_error(exc: Exception) -> str:
|
||||||
|
"""将异常映射到 error_type 枚举值。"""
|
||||||
|
if isinstance(exc, PdfDownloadError):
|
||||||
|
return "pdf_download_failed"
|
||||||
|
if isinstance(exc, PiTimeoutError):
|
||||||
|
return "timeout"
|
||||||
|
if isinstance(exc, PiProcessError):
|
||||||
|
return "process_error"
|
||||||
|
if isinstance(exc, JsonNotFoundError):
|
||||||
|
return "json_not_found"
|
||||||
|
if isinstance(exc, json.JSONDecodeError):
|
||||||
|
return "json_invalid"
|
||||||
|
if isinstance(exc, ValidationError):
|
||||||
|
return classify_validation_error(exc)
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
# ── FTS5 文本构建 ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _build_fts_summary_text(schema: SummarySchema) -> str:
|
||||||
|
"""拼接用于 FTS5 索引的总结文本。"""
|
||||||
|
parts = [
|
||||||
|
schema.one_line or "",
|
||||||
|
schema.motivation.problem or "",
|
||||||
|
schema.motivation.goal or "",
|
||||||
|
schema.method_overview if hasattr(schema, "method_overview") else "",
|
||||||
|
schema.method.overview or "",
|
||||||
|
schema.method.key_idea or "",
|
||||||
|
" ".join(schema.results.main_findings or []),
|
||||||
|
]
|
||||||
|
return " ".join(p for p in parts if p)
|
||||||
|
|
||||||
|
|
||||||
|
# ── DB 更新 ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _update_summary_in_db(
|
||||||
|
db: Session,
|
||||||
|
paper: Paper,
|
||||||
|
schema: SummarySchema,
|
||||||
|
quality: str,
|
||||||
|
raw_output: str,
|
||||||
|
) -> None:
|
||||||
|
"""将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# 1. paper_summaries:upsert
|
||||||
|
existing = db.get(PaperSummary, paper.id)
|
||||||
|
flat = flatten_for_db(schema)
|
||||||
|
if existing:
|
||||||
|
for k, v in flat.items():
|
||||||
|
setattr(existing, k, v)
|
||||||
|
else:
|
||||||
|
db.add(PaperSummary(paper_id=paper.id, **flat))
|
||||||
|
|
||||||
|
# 2. papers 表
|
||||||
|
paper.title_zh = schema.title_zh
|
||||||
|
paper.summary_quality = quality
|
||||||
|
paper_dir = _paper_dir(paper.arxiv_id)
|
||||||
|
paper.summary_path = str(paper_dir / "summary.json")
|
||||||
|
paper.raw_output_path = str(paper_dir / "raw_output.txt")
|
||||||
|
|
||||||
|
# 3. AI 标签
|
||||||
|
existing_tag_names = {t.tag for t in paper.tags}
|
||||||
|
for tag_name in schema.tags:
|
||||||
|
if tag_name not in existing_tag_names:
|
||||||
|
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
|
||||||
|
existing_tag_names.add(tag_name)
|
||||||
|
|
||||||
|
# 4. FTS5 更新
|
||||||
|
summary_text = _build_fts_summary_text(schema)
|
||||||
|
db.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
|
||||||
|
"WHERE rowid=:paper_id"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"title_zh": schema.title_zh,
|
||||||
|
"summary_text": summary_text,
|
||||||
|
"paper_id": paper.id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 文件操作 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
|
||||||
|
"""保存 summary.json 和 raw_output.txt。"""
|
||||||
|
d = _paper_dir(arxiv_id)
|
||||||
|
d.mkdir(parents=True, exist_ok=True)
|
||||||
|
(d / "summary.json").write_text(
|
||||||
|
schema.model_dump_json(ensure_ascii=False, indent=2),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None:
|
||||||
|
"""仅保存 raw_output.txt(失败时)。"""
|
||||||
|
d = _paper_dir(arxiv_id)
|
||||||
|
d.mkdir(parents=True, exist_ok=True)
|
||||||
|
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_tmp(arxiv_id: str) -> None:
|
||||||
|
"""清理 data/tmp/{arxiv_id}/ 目录。"""
|
||||||
|
tmp = _tmp_dir(arxiv_id)
|
||||||
|
if tmp.exists():
|
||||||
|
try:
|
||||||
|
shutil.rmtree(tmp)
|
||||||
|
logger.debug("Cleaned tmp: %s", arxiv_id)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 单篇总结 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def summarize_one(
|
||||||
|
db: Session,
|
||||||
|
paper: Paper,
|
||||||
|
semaphore: asyncio.Semaphore | None = None,
|
||||||
|
*,
|
||||||
|
force: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""总结单篇论文的完整流程。"""
|
||||||
|
arxiv_id = paper.arxiv_id
|
||||||
|
|
||||||
|
# 获取或创建 summary_status
|
||||||
|
if not paper.summary_status:
|
||||||
|
db.add(SummaryStatus(paper_id=paper.id, status="pending"))
|
||||||
|
db.commit()
|
||||||
|
db.refresh(paper)
|
||||||
|
|
||||||
|
status = paper.summary_status
|
||||||
|
|
||||||
|
# 跳过已完成的(除非 force)
|
||||||
|
if status.status == "done" and not force:
|
||||||
|
return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"}
|
||||||
|
|
||||||
|
# 跳过 permanent_failure(除非 force)
|
||||||
|
if status.status == "permanent_failure" and not force:
|
||||||
|
return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "permanent_failure"}
|
||||||
|
|
||||||
|
if semaphore:
|
||||||
|
await semaphore.acquire()
|
||||||
|
try:
|
||||||
|
return await _do_summarize_one(db, paper)
|
||||||
|
finally:
|
||||||
|
if semaphore:
|
||||||
|
semaphore.release()
|
||||||
|
|
||||||
|
|
||||||
|
async def _do_summarize_one(db: Session, paper: Paper) -> dict:
|
||||||
|
"""实际的单篇总结执行(在 semaphore 保护下)。"""
|
||||||
|
arxiv_id = paper.arxiv_id
|
||||||
|
status = paper.summary_status
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# 状态 → processing
|
||||||
|
status.status = "processing"
|
||||||
|
status.started_at = now
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
raw_output = ""
|
||||||
|
try:
|
||||||
|
# 写 meta.json
|
||||||
|
meta_path = _write_meta_json(paper)
|
||||||
|
|
||||||
|
# 下载 PDF
|
||||||
|
await _download_pdf(arxiv_id, paper.pdf_url)
|
||||||
|
|
||||||
|
# 调用 pi
|
||||||
|
raw_output = await _call_pi(meta_path, _tmp_dir(arxiv_id) / "paper.pdf")
|
||||||
|
|
||||||
|
# 提取 JSON
|
||||||
|
json_data = _extract_json(raw_output)
|
||||||
|
|
||||||
|
# Pydantic 校验
|
||||||
|
schema = SummarySchema.model_validate(json_data)
|
||||||
|
|
||||||
|
# 质量评估
|
||||||
|
quality = assess_quality(schema)
|
||||||
|
|
||||||
|
# 保存文件
|
||||||
|
_save_files(arxiv_id, schema, raw_output)
|
||||||
|
|
||||||
|
# 更新 DB
|
||||||
|
_update_summary_in_db(db, paper, schema, quality, raw_output)
|
||||||
|
|
||||||
|
# 状态 → done
|
||||||
|
status.status = "done"
|
||||||
|
status.quality = quality
|
||||||
|
status.completed_at = datetime.now(timezone.utc)
|
||||||
|
status.raw_output_saved = True
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info("Summarize done: %s quality=%s", arxiv_id, quality)
|
||||||
|
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
error_type = _classify_error(exc)
|
||||||
|
logger.error(
|
||||||
|
"Summarize failed: %s error_type=%s %s",
|
||||||
|
arxiv_id,
|
||||||
|
error_type,
|
||||||
|
str(exc)[:200],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存 raw_output(如果有)
|
||||||
|
if raw_output:
|
||||||
|
_save_raw_output_only(arxiv_id, raw_output)
|
||||||
|
status.raw_output_saved = True
|
||||||
|
|
||||||
|
# 重试逻辑
|
||||||
|
status.retry_count = (status.retry_count or 0) + 1
|
||||||
|
status.error_type = error_type
|
||||||
|
status.error = str(exc)[:2000]
|
||||||
|
|
||||||
|
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
|
||||||
|
status.status = "permanent_failure"
|
||||||
|
else:
|
||||||
|
status.status = "pending"
|
||||||
|
|
||||||
|
status.completed_at = datetime.now(timezone.utc)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"arxiv_id": arxiv_id,
|
||||||
|
"status": "failed",
|
||||||
|
"error_type": error_type,
|
||||||
|
"error": str(exc)[:200],
|
||||||
|
"retry_count": status.retry_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
_cleanup_tmp(arxiv_id)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 单篇入口 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def summarize_single(
|
||||||
|
db: Session,
|
||||||
|
arxiv_id: str,
|
||||||
|
*,
|
||||||
|
force: bool = True,
|
||||||
|
_session_factory=None,
|
||||||
|
) -> dict:
|
||||||
|
"""单篇总结入口(供 admin 路由和 CLI 调用)。
|
||||||
|
|
||||||
|
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
|
||||||
|
"""
|
||||||
|
paper = (
|
||||||
|
db.query(Paper)
|
||||||
|
.filter(Paper.arxiv_id == arxiv_id)
|
||||||
|
.options(
|
||||||
|
joinedload(Paper.authors),
|
||||||
|
joinedload(Paper.tags),
|
||||||
|
joinedload(Paper.summary_status),
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not paper:
|
||||||
|
return {"status": "not_found", "arxiv_id": arxiv_id}
|
||||||
|
|
||||||
|
make_session = _session_factory or SessionLocal
|
||||||
|
|
||||||
|
# 每篇用独立 session 避免并发问题
|
||||||
|
paper_db = make_session()
|
||||||
|
try:
|
||||||
|
paper_in_new_session = (
|
||||||
|
paper_db.query(Paper)
|
||||||
|
.filter(Paper.arxiv_id == arxiv_id)
|
||||||
|
.options(
|
||||||
|
joinedload(Paper.authors),
|
||||||
|
joinedload(Paper.tags),
|
||||||
|
joinedload(Paper.summary_status),
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
result = await summarize_one(paper_db, paper_in_new_session, force=force)
|
||||||
|
finally:
|
||||||
|
paper_db.close()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ── 批量总结 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def summarize_batch(
|
||||||
|
db: Session,
|
||||||
|
arxiv_ids: list[str] | None = None,
|
||||||
|
*,
|
||||||
|
_session_factory=None,
|
||||||
|
) -> dict:
|
||||||
|
"""批量总结入口。arxiv_ids=None 时处理所有 pending 论文。
|
||||||
|
|
||||||
|
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
|
||||||
|
"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# TaskLock 防重入
|
||||||
|
lock = TaskLock(
|
||||||
|
task="summarize",
|
||||||
|
lock_key="batch",
|
||||||
|
status="running",
|
||||||
|
owner="summarize_batch",
|
||||||
|
acquired_at=now,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
db.add(lock)
|
||||||
|
db.commit()
|
||||||
|
except Exception:
|
||||||
|
db.rollback()
|
||||||
|
logger.warning("Summarize batch already running (lock conflict)")
|
||||||
|
return {"status": "conflict", "error": "summarize batch already running"}
|
||||||
|
|
||||||
|
# CrawlLog
|
||||||
|
log_entry = CrawlLog(
|
||||||
|
task="summarize",
|
||||||
|
status="running",
|
||||||
|
started_at=now,
|
||||||
|
)
|
||||||
|
db.add(log_entry)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 查询待总结论文
|
||||||
|
query = db.query(Paper).options(
|
||||||
|
joinedload(Paper.authors),
|
||||||
|
joinedload(Paper.tags),
|
||||||
|
joinedload(Paper.summary_status),
|
||||||
|
)
|
||||||
|
if arxiv_ids:
|
||||||
|
query = query.filter(Paper.arxiv_id.in_(arxiv_ids))
|
||||||
|
else:
|
||||||
|
# 只处理 pending 或 failed(可重试的)
|
||||||
|
query = query.join(SummaryStatus).filter(
|
||||||
|
SummaryStatus.status.in_(["pending", "failed"])
|
||||||
|
)
|
||||||
|
|
||||||
|
papers = query.all()
|
||||||
|
total = len(papers)
|
||||||
|
logger.info("Summarize batch: %d papers to process", total)
|
||||||
|
|
||||||
|
if total == 0:
|
||||||
|
log_entry.status = "success"
|
||||||
|
log_entry.papers_found = 0
|
||||||
|
log_entry.papers_new = 0
|
||||||
|
log_entry.completed_at = datetime.now(timezone.utc)
|
||||||
|
_release_lock(db, lock)
|
||||||
|
return {"status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0}
|
||||||
|
|
||||||
|
# 并发控制
|
||||||
|
semaphore = asyncio.Semaphore(settings.SUMMARY_CONCURRENCY)
|
||||||
|
make_session = _session_factory or SessionLocal
|
||||||
|
|
||||||
|
async def _process_paper(paper: Paper) -> dict:
|
||||||
|
paper_db = make_session()
|
||||||
|
try:
|
||||||
|
p = (
|
||||||
|
paper_db.query(Paper)
|
||||||
|
.filter(Paper.id == paper.id)
|
||||||
|
.options(
|
||||||
|
joinedload(Paper.authors),
|
||||||
|
joinedload(Paper.tags),
|
||||||
|
joinedload(Paper.summary_status),
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return await summarize_one(paper_db, p, semaphore)
|
||||||
|
finally:
|
||||||
|
paper_db.close()
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[_process_paper(p) for p in papers],
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 统计结果
|
||||||
|
done = 0
|
||||||
|
failed = 0
|
||||||
|
skipped = 0
|
||||||
|
for r in results:
|
||||||
|
if isinstance(r, Exception):
|
||||||
|
logger.error("Unexpected error in batch: %s", r)
|
||||||
|
failed += 1
|
||||||
|
elif isinstance(r, dict):
|
||||||
|
if r.get("status") == "done":
|
||||||
|
done += 1
|
||||||
|
elif r.get("status") == "skipped":
|
||||||
|
skipped += 1
|
||||||
|
else:
|
||||||
|
failed += 1
|
||||||
|
|
||||||
|
log_entry.status = "success" if failed == 0 else "failed"
|
||||||
|
log_entry.papers_found = total
|
||||||
|
log_entry.papers_new = done
|
||||||
|
log_entry.completed_at = datetime.now(timezone.utc)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Summarize batch done: total=%d done=%d failed=%d skipped=%d",
|
||||||
|
total, done, failed, skipped,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "success" if failed == 0 else "partial",
|
||||||
|
"total": total,
|
||||||
|
"done": done,
|
||||||
|
"failed": failed,
|
||||||
|
"skipped": skipped,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Summarize batch failed")
|
||||||
|
log_entry.status = "failed"
|
||||||
|
log_entry.error = str(exc)[:2000]
|
||||||
|
log_entry.completed_at = datetime.now(timezone.utc)
|
||||||
|
db.commit()
|
||||||
|
return {"status": "failed", "error": str(exc)}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
_release_lock(db, lock)
|
||||||
|
|
||||||
|
|
||||||
|
def _release_lock(db: Session, lock: TaskLock) -> None:
|
||||||
|
"""释放 TaskLock。"""
|
||||||
|
try:
|
||||||
|
lock.status = "finished"
|
||||||
|
lock.released_at = datetime.now(timezone.utc)
|
||||||
|
db.commit()
|
||||||
|
except Exception:
|
||||||
|
db.rollback()
|
||||||
|
logger.warning("Failed to release summarize lock", exc_info=True)
|
||||||
@@ -0,0 +1,209 @@
|
|||||||
|
"""测试 fixtures — 内存 SQLite、TestClient、样例数据。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import date, datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy import create_engine, event
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||||||
|
|
||||||
|
from app.database import get_db
|
||||||
|
from app.main import create_app
|
||||||
|
from app.models import (
|
||||||
|
Paper,
|
||||||
|
PaperAuthor,
|
||||||
|
PaperSummary,
|
||||||
|
PaperTag,
|
||||||
|
SummaryStatus,
|
||||||
|
init_db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 内存数据库 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _TestBase(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# 复用 app.models 的 Base metadata
|
||||||
|
from app.database import Base as _AppBase # noqa: E402
|
||||||
|
|
||||||
|
_TestBase.metadata = _AppBase.metadata
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_engine():
|
||||||
|
"""创建内存 SQLite 引擎 + FTS5。"""
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
@event.listens_for(engine, "connect")
|
||||||
|
def _pragma(dbapi_connection, _record):
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
init_db(engine)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_session(db_engine):
|
||||||
|
"""提供事务隔离的数据库 session。"""
|
||||||
|
Session = sessionmaker(bind=db_engine, autoflush=False, autocommit=False)
|
||||||
|
session = Session()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(db_engine, db_session):
|
||||||
|
"""FastAPI TestClient,override get_db。"""
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
def _override_get_db():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = _override_get_db
|
||||||
|
|
||||||
|
with TestClient(app, raise_server_exceptions=False) as c:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# ── 样例数据 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
SAMPLE_ARXIV_ID = "2401.12345"
|
||||||
|
ADMIN_TOKEN = "test-admin-token-12345"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_paper(db_session):
|
||||||
|
"""插入一篇测试论文 + 作者 + 标签 + summary_status(pending)。"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
paper = Paper(
|
||||||
|
arxiv_id=SAMPLE_ARXIV_ID,
|
||||||
|
title_en="Test Paper Title",
|
||||||
|
abstract="This is a test abstract for the paper.",
|
||||||
|
published_at=date(2024, 1, 15),
|
||||||
|
paper_date=date(2024, 1, 15),
|
||||||
|
crawled_at=now,
|
||||||
|
upvotes=42,
|
||||||
|
hf_url=f"https://huggingface.co/papers/{SAMPLE_ARXIV_ID}",
|
||||||
|
arxiv_url=f"https://arxiv.org/abs/{SAMPLE_ARXIV_ID}",
|
||||||
|
pdf_url=f"https://arxiv.org/pdf/{SAMPLE_ARXIV_ID}.pdf",
|
||||||
|
)
|
||||||
|
db_session.add(paper)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
db_session.add(PaperAuthor(paper_id=paper.id, name="Alice Smith", position=0))
|
||||||
|
db_session.add(PaperAuthor(paper_id=paper.id, name="Bob Jones", position=1))
|
||||||
|
db_session.add(PaperTag(paper_id=paper.id, tag="NLP", source="hf"))
|
||||||
|
db_session.add(PaperTag(paper_id=paper.id, tag="LLM", source="hf"))
|
||||||
|
|
||||||
|
db_session.add(SummaryStatus(paper_id=paper.id, status="pending"))
|
||||||
|
|
||||||
|
# FTS5 初始行(与 crawler 一致)
|
||||||
|
db_session.execute(
|
||||||
|
__import__("sqlalchemy").text(
|
||||||
|
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
|
||||||
|
"VALUES (:id, :title, :abstract, :authors, :tags)"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": paper.id,
|
||||||
|
"title": paper.title_en,
|
||||||
|
"abstract": paper.abstract or "",
|
||||||
|
"authors": "Alice Smith, Bob Jones",
|
||||||
|
"tags": "NLP, LLM",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
db_session.commit()
|
||||||
|
return paper
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_summary_dict() -> dict:
|
||||||
|
"""完整合法的 summary dict。"""
|
||||||
|
return {
|
||||||
|
"title_zh": "测试论文中文标题",
|
||||||
|
"one_line": "这是一篇关于自然语言处理的测试论文的一句话总结。",
|
||||||
|
"tags": ["自然语言处理", "大语言模型", "Transformer"],
|
||||||
|
"difficulty": "中级",
|
||||||
|
"prerequisites": {
|
||||||
|
"concepts": ["Transformer", "注意力机制"],
|
||||||
|
"level": "中级",
|
||||||
|
},
|
||||||
|
"motivation": {
|
||||||
|
"problem": "现有模型在长文本理解上存在不足。",
|
||||||
|
"goal": "提出一种新的注意力机制来提升长文本建模能力。",
|
||||||
|
"gap": "当前方法计算复杂度过高。",
|
||||||
|
},
|
||||||
|
"method": {
|
||||||
|
"overview": "提出了一种高效的稀疏注意力机制。",
|
||||||
|
"key_idea": "使用局部-全局混合的注意力模式来降低计算复杂度。",
|
||||||
|
"steps": [
|
||||||
|
"分析现有注意力机制的瓶颈",
|
||||||
|
"设计稀疏注意力模式",
|
||||||
|
"在多个基准上验证效果",
|
||||||
|
],
|
||||||
|
"novelty": "首次将局部-全局注意力模式结合应用于长文本建模。",
|
||||||
|
},
|
||||||
|
"results": {
|
||||||
|
"main_findings": [
|
||||||
|
"在长文本基准上取得了 SOTA 结果",
|
||||||
|
"推理速度提升了 2 倍",
|
||||||
|
],
|
||||||
|
"benchmarks": [
|
||||||
|
{"dataset": "LongBench", "score": 85.3},
|
||||||
|
],
|
||||||
|
"limitations": [
|
||||||
|
"在超长文本(>100k tokens)上效果有所下降",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"improvements": {
|
||||||
|
"weaknesses": ["仅验证了英文数据"],
|
||||||
|
"future_work": ["扩展到多语言场景"],
|
||||||
|
"reproducibility": "代码已开源,模型权重可下载。",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_summary_json(sample_summary_dict) -> str:
|
||||||
|
"""合法 summary 的 JSON 字符串。"""
|
||||||
|
return json.dumps(sample_summary_dict, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_pi_output(sample_summary_json) -> str:
|
||||||
|
"""模拟 pi CLI 的完整输出(包含 JSON)。"""
|
||||||
|
return f"""以下是论文的深度解读:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{sample_summary_json}
|
||||||
|
```
|
||||||
|
|
||||||
|
希望这个总结对你有帮助!"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_token():
|
||||||
|
"""返回测试用的 ADMIN_TOKEN(需要配合 monkeypatch 使用)。"""
|
||||||
|
return ADMIN_TOKEN
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_headers(admin_token):
|
||||||
|
"""带 Bearer token 的请求头。"""
|
||||||
|
return {"Authorization": f"Bearer {admin_token}"}
|
||||||
@@ -0,0 +1,725 @@
|
|||||||
|
"""AI 总结服务测试 — Mock 全链路,不调用真实 pi。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from datetime import date, datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.models import (
|
||||||
|
CrawlLog,
|
||||||
|
Paper,
|
||||||
|
PaperSummary,
|
||||||
|
PaperTag,
|
||||||
|
SummaryStatus,
|
||||||
|
TaskLock,
|
||||||
|
)
|
||||||
|
from app.services.schemas import (
|
||||||
|
SummarySchema,
|
||||||
|
assess_quality,
|
||||||
|
classify_validation_error,
|
||||||
|
flatten_for_db,
|
||||||
|
)
|
||||||
|
from app.services.summarizer import (
|
||||||
|
JsonNotFoundError,
|
||||||
|
PdfDownloadError,
|
||||||
|
PiProcessError,
|
||||||
|
PiTimeoutError,
|
||||||
|
_call_pi,
|
||||||
|
_classify_error,
|
||||||
|
_cleanup_tmp,
|
||||||
|
_extract_json,
|
||||||
|
_save_files,
|
||||||
|
_save_raw_output_only,
|
||||||
|
_update_summary_in_db,
|
||||||
|
summarize_batch,
|
||||||
|
summarize_one,
|
||||||
|
summarize_single,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Schema 校验测试
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestSummarySchema:
|
||||||
|
"""Pydantic schema 校验。"""
|
||||||
|
|
||||||
|
def test_valid_summary(self, sample_summary_dict):
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
assert schema.title_zh == "测试论文中文标题"
|
||||||
|
assert len(schema.tags) == 3
|
||||||
|
assert schema.motivation.problem
|
||||||
|
|
||||||
|
def test_missing_title_zh(self, sample_summary_dict):
|
||||||
|
del sample_summary_dict["title_zh"]
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
assert classify_validation_error(exc_info.value) == "field_missing"
|
||||||
|
|
||||||
|
def test_empty_one_line(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["one_line"] = ""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
|
||||||
|
def test_empty_tags(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["tags"] = []
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
|
||||||
|
def test_empty_motivation_problem(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["motivation"]["problem"] = ""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
|
||||||
|
def test_empty_method_key_idea(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["method"]["key_idea"] = ""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
|
||||||
|
def test_extra_fields_ignored(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["figures"] = ["fig1.png"]
|
||||||
|
sample_summary_dict["takeaway"] = "important paper"
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
assert not hasattr(schema, "figures")
|
||||||
|
assert schema.title_zh # 正常解析
|
||||||
|
|
||||||
|
def test_flatten_for_db(self, sample_summary_dict):
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
flat = flatten_for_db(schema)
|
||||||
|
assert flat["one_line"] == schema.one_line
|
||||||
|
assert flat["motivation_problem"] == schema.motivation.problem
|
||||||
|
assert flat["method_key_idea"] == schema.method.key_idea
|
||||||
|
assert "full_json" in flat
|
||||||
|
assert "updated_at" in flat
|
||||||
|
# JSON 字段可解析
|
||||||
|
assert isinstance(json.loads(flat["prerequisites_json"]), dict)
|
||||||
|
assert isinstance(json.loads(flat["method_steps_json"]), list)
|
||||||
|
|
||||||
|
|
||||||
|
class TestQualityAssessment:
|
||||||
|
"""质量分级测试。"""
|
||||||
|
|
||||||
|
def test_quality_normal(self, sample_summary_dict):
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
assert assess_quality(schema) == "normal"
|
||||||
|
|
||||||
|
def test_quality_degraded_missing_goal(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["motivation"]["goal"] = ""
|
||||||
|
sample_summary_dict["motivation"]["gap"] = ""
|
||||||
|
sample_summary_dict["method"]["overview"] = ""
|
||||||
|
sample_summary_dict["results"]["main_findings"] = []
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
assert assess_quality(schema) == "degraded"
|
||||||
|
|
||||||
|
def test_quality_low_short_one_line(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["one_line"] = "短"
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
assert assess_quality(schema) == "low"
|
||||||
|
|
||||||
|
def test_quality_low_short_key_idea(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["method"]["key_idea"] = "短"
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
assert assess_quality(schema) == "low"
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# JSON 提取测试
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestJsonExtraction:
|
||||||
|
"""pi 输出的 JSON 提取。"""
|
||||||
|
|
||||||
|
def test_direct_json(self, sample_summary_json):
|
||||||
|
result = _extract_json(sample_summary_json)
|
||||||
|
assert result["title_zh"] == "测试论文中文标题"
|
||||||
|
|
||||||
|
def test_fenced_code_block(self, sample_summary_json):
|
||||||
|
raw = f"一些文字\n```json\n{sample_summary_json}\n```\n更多文字"
|
||||||
|
result = _extract_json(raw)
|
||||||
|
assert result["title_zh"] == "测试论文中文标题"
|
||||||
|
|
||||||
|
def test_fenced_without_lang(self, sample_summary_json):
|
||||||
|
raw = f"文字\n```\n{sample_summary_json}\n```"
|
||||||
|
result = _extract_json(raw)
|
||||||
|
assert result["title_zh"] == "测试论文中文标题"
|
||||||
|
|
||||||
|
def test_embedded_braces(self, sample_summary_dict):
|
||||||
|
json_str = json.dumps(sample_summary_dict, ensure_ascii=False)
|
||||||
|
raw = f"Here is the summary:\n{json_str}\nEnd."
|
||||||
|
result = _extract_json(raw)
|
||||||
|
assert result["title_zh"] == "测试论文中文标题"
|
||||||
|
|
||||||
|
def test_no_json_raises(self):
|
||||||
|
with pytest.raises(JsonNotFoundError):
|
||||||
|
_extract_json("No JSON here at all.")
|
||||||
|
|
||||||
|
def test_json_without_title_zh_falls_through(self):
|
||||||
|
"""不含 title_zh 的 JSON 不是我们要的。"""
|
||||||
|
raw = json.dumps({"other": "data"})
|
||||||
|
# 如果有其他合法 JSON 块也能返回,但没有就直接找最大块
|
||||||
|
# 此场景 raw 本身就是一个 JSON dict,但没有 title_zh
|
||||||
|
# 策略 1 会跳过(无 title_zh),策略 2 无代码块,策略 3 找到最大块
|
||||||
|
result = _extract_json(raw)
|
||||||
|
assert result == {"other": "data"} # 最大块兜底
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 错误分类测试
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorClassification:
|
||||||
|
"""异常 → error_type 映射。"""
|
||||||
|
|
||||||
|
def test_pdf_download_error(self):
|
||||||
|
assert _classify_error(PdfDownloadError("fail")) == "pdf_download_failed"
|
||||||
|
|
||||||
|
def test_timeout_error(self):
|
||||||
|
assert _classify_error(PiTimeoutError("timeout")) == "timeout"
|
||||||
|
|
||||||
|
def test_process_error(self):
|
||||||
|
assert _classify_error(PiProcessError(1, "stderr")) == "process_error"
|
||||||
|
|
||||||
|
def test_json_not_found(self):
|
||||||
|
assert _classify_error(JsonNotFoundError("not found")) == "json_not_found"
|
||||||
|
|
||||||
|
def test_json_invalid(self):
|
||||||
|
assert _classify_error(json.JSONDecodeError("bad", "", 0)) == "json_invalid"
|
||||||
|
|
||||||
|
def test_field_missing(self):
|
||||||
|
try:
|
||||||
|
SummarySchema.model_validate({"title_zh": ""}) # type: ignore
|
||||||
|
except ValidationError as exc:
|
||||||
|
assert _classify_error(exc) == "field_missing"
|
||||||
|
|
||||||
|
def test_unknown_error(self):
|
||||||
|
assert _classify_error(RuntimeError("boom")) == "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# DB 更新测试
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestDbUpdate:
|
||||||
|
"""_update_summary_in_db 验证。"""
|
||||||
|
|
||||||
|
def test_summary_written(self, db_session, sample_paper, sample_summary_dict):
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||||||
|
|
||||||
|
summary = db_session.get(PaperSummary, sample_paper.id)
|
||||||
|
assert summary is not None
|
||||||
|
assert summary.one_line == schema.one_line
|
||||||
|
assert summary.motivation_problem == schema.motivation.problem
|
||||||
|
assert json.loads(summary.full_json)["title_zh"] == schema.title_zh
|
||||||
|
|
||||||
|
def test_paper_title_zh_updated(self, db_session, sample_paper, sample_summary_dict):
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||||||
|
|
||||||
|
db_session.refresh(sample_paper)
|
||||||
|
assert sample_paper.title_zh == "测试论文中文标题"
|
||||||
|
assert sample_paper.summary_quality == "normal"
|
||||||
|
|
||||||
|
def test_fts_updated(self, db_session, sample_paper, sample_summary_dict):
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||||||
|
|
||||||
|
row = db_session.execute(
|
||||||
|
text("SELECT title_zh, summary_text FROM papers_fts WHERE rowid = :id"),
|
||||||
|
{"id": sample_paper.id},
|
||||||
|
).fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row[0] == "测试论文中文标题"
|
||||||
|
assert schema.one_line in row[1]
|
||||||
|
|
||||||
|
def test_ai_tags_added(self, db_session, sample_paper, sample_summary_dict):
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||||||
|
|
||||||
|
tags = (
|
||||||
|
db_session.query(PaperTag)
|
||||||
|
.filter(PaperTag.paper_id == sample_paper.id, PaperTag.source == "ai")
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
tag_names = {t.tag for t in tags}
|
||||||
|
# AI tags 来自 schema.tags
|
||||||
|
assert "自然语言处理" in tag_names
|
||||||
|
assert "大语言模型" in tag_names
|
||||||
|
|
||||||
|
def test_existing_tags_not_duplicated(self, db_session, sample_paper, sample_summary_dict):
|
||||||
|
"""已存在的标签名(同 name)不会被 AI source 重复插入。"""
|
||||||
|
# sample_paper 已有 NLP (hf)、LLM (hf)
|
||||||
|
# 让 AI 输出包含 NLP(与 HF 重复)和 "新标签"(新的)
|
||||||
|
sample_summary_dict["tags"] = ["NLP", "新标签"]
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||||||
|
|
||||||
|
all_tags = (
|
||||||
|
db_session.query(PaperTag)
|
||||||
|
.filter(PaperTag.paper_id == sample_paper.id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
tag_names = [t.tag for t in all_tags]
|
||||||
|
# NLP 只出现一次(HF 原有的),AI 不会重复加
|
||||||
|
assert tag_names.count("NLP") == 1
|
||||||
|
# "新标签" 是 AI 新加的
|
||||||
|
assert "新标签" in tag_names
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 文件操作测试
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileOperations:
|
||||||
|
"""文件保存和清理。"""
|
||||||
|
|
||||||
|
def test_save_files(self, tmp_path, sample_summary_dict):
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
with patch("app.services.summarizer._PAPERS_DIR", tmp_path):
|
||||||
|
_save_files("2401.12345", schema, "raw output text")
|
||||||
|
|
||||||
|
paper_dir = tmp_path / "2401.12345"
|
||||||
|
assert (paper_dir / "summary.json").exists()
|
||||||
|
assert (paper_dir / "raw_output.txt").exists()
|
||||||
|
saved = json.loads((paper_dir / "summary.json").read_text())
|
||||||
|
assert saved["title_zh"] == "测试论文中文标题"
|
||||||
|
|
||||||
|
def test_save_raw_output_only(self, tmp_path):
|
||||||
|
with patch("app.services.summarizer._PAPERS_DIR", tmp_path):
|
||||||
|
_save_raw_output_only("2401.12345", "raw output")
|
||||||
|
paper_dir = tmp_path / "2401.12345"
|
||||||
|
assert (paper_dir / "raw_output.txt").exists()
|
||||||
|
assert not (paper_dir / "summary.json").exists()
|
||||||
|
|
||||||
|
def test_cleanup_tmp(self, tmp_path):
|
||||||
|
tmp_paper = tmp_path / "2401.12345"
|
||||||
|
tmp_paper.mkdir()
|
||||||
|
(tmp_paper / "paper.pdf").write_bytes(b"%PDF-fake")
|
||||||
|
with patch("app.services.summarizer._TMP_DIR", tmp_path):
|
||||||
|
_cleanup_tmp("2401.12345")
|
||||||
|
assert not tmp_paper.exists()
|
||||||
|
|
||||||
|
def test_cleanup_tmp_nonexistent(self, tmp_path):
|
||||||
|
"""清理不存在的目录不报错。"""
|
||||||
|
with patch("app.services.summarizer._TMP_DIR", tmp_path):
|
||||||
|
_cleanup_tmp("nonexistent") # 不抛异常
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 全流程状态流转测试
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestSummarizeOneFlow:
|
||||||
|
"""summarize_one 的状态流转(mock pi 和 PDF)。"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def _patch_paths(self, tmp_path):
|
||||||
|
"""将 data 目录重定向到 tmp_path。"""
|
||||||
|
with (
|
||||||
|
patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"),
|
||||||
|
patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"),
|
||||||
|
patch("app.services.summarizer._DATA_DIR", tmp_path),
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_success_path(
|
||||||
|
self, db_session, sample_paper, mock_pi_output, _patch_paths
|
||||||
|
):
|
||||||
|
"""pending → processing → done 全流程。"""
|
||||||
|
with (
|
||||||
|
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||||||
|
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||||||
|
):
|
||||||
|
result = await summarize_one(db_session, sample_paper)
|
||||||
|
|
||||||
|
assert result["status"] == "done"
|
||||||
|
assert result["quality"] == "normal"
|
||||||
|
|
||||||
|
# 验证 DB 状态
|
||||||
|
db_session.refresh(sample_paper)
|
||||||
|
assert sample_paper.summary_status.status == "done"
|
||||||
|
assert sample_paper.summary_status.quality == "normal"
|
||||||
|
assert sample_paper.title_zh == "测试论文中文标题"
|
||||||
|
|
||||||
|
# 验证 summary 已写入
|
||||||
|
summary = db_session.get(PaperSummary, sample_paper.id)
|
||||||
|
assert summary is not None
|
||||||
|
assert summary.one_line
|
||||||
|
|
||||||
|
# 验证 FTS 已更新
|
||||||
|
fts_row = db_session.execute(
|
||||||
|
text("SELECT title_zh FROM papers_fts WHERE rowid = :id"),
|
||||||
|
{"id": sample_paper.id},
|
||||||
|
).fetchone()
|
||||||
|
assert fts_row[0] == "测试论文中文标题"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pdf_download_failure(
|
||||||
|
self, db_session, sample_paper, _patch_paths
|
||||||
|
):
|
||||||
|
"""PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。"""
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.services.summarizer._download_pdf",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=PdfDownloadError("network error"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await summarize_one(db_session, sample_paper)
|
||||||
|
|
||||||
|
assert result["status"] == "failed"
|
||||||
|
assert result["error_type"] == "pdf_download_failed"
|
||||||
|
|
||||||
|
db_session.refresh(sample_paper)
|
||||||
|
status = sample_paper.summary_status
|
||||||
|
assert status.error_type == "pdf_download_failed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pi_timeout(self, db_session, sample_paper, _patch_paths):
|
||||||
|
"""pi 超时 → timeout 错误,retry_count 递增。"""
|
||||||
|
with (
|
||||||
|
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||||||
|
patch(
|
||||||
|
"app.services.summarizer._call_pi",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=PiTimeoutError("timeout after 300s"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await summarize_one(db_session, sample_paper)
|
||||||
|
|
||||||
|
assert result["status"] == "failed"
|
||||||
|
assert result["error_type"] == "timeout"
|
||||||
|
assert result["retry_count"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_json_not_found(self, db_session, sample_paper, _patch_paths):
|
||||||
|
"""pi 输出无 JSON → json_not_found。"""
|
||||||
|
with (
|
||||||
|
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||||||
|
patch(
|
||||||
|
"app.services.summarizer._call_pi",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value="No JSON in this output at all.",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await summarize_one(db_session, sample_paper)
|
||||||
|
|
||||||
|
assert result["status"] == "failed"
|
||||||
|
assert result["error_type"] == "json_not_found"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_field_missing_and_retry(
|
||||||
|
self, db_session, sample_paper, _patch_paths
|
||||||
|
):
|
||||||
|
"""必填字段缺失 → field_missing → retry → permanent_failure。"""
|
||||||
|
bad_json = json.dumps({
|
||||||
|
"title_zh": "", # 空的必填字段
|
||||||
|
"one_line": "valid line",
|
||||||
|
"tags": ["tag1"],
|
||||||
|
"motivation": {"problem": "valid problem"},
|
||||||
|
"method": {"key_idea": "valid idea"},
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
bad_output = f"```json\n{bad_json}\n```"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||||||
|
patch(
|
||||||
|
"app.services.summarizer._call_pi",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=bad_output,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
# 第一次失败 → pending (retry)
|
||||||
|
result1 = await summarize_one(db_session, sample_paper)
|
||||||
|
assert result1["status"] == "failed"
|
||||||
|
assert result1["error_type"] == "field_missing"
|
||||||
|
assert result1["retry_count"] == 1
|
||||||
|
|
||||||
|
# 第二次失败 → permanent_failure (SUMMARY_MAX_RETRIES=1, 所以 2 次 > 1+1)
|
||||||
|
db_session.refresh(sample_paper)
|
||||||
|
result2 = await summarize_one(db_session, sample_paper)
|
||||||
|
assert result2["status"] == "failed"
|
||||||
|
assert result2["retry_count"] == 2
|
||||||
|
|
||||||
|
db_session.refresh(sample_paper)
|
||||||
|
assert sample_paper.summary_status.status == "permanent_failure"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raw_output_saved_on_failure(
|
||||||
|
self, db_session, sample_paper, tmp_path, _patch_paths
|
||||||
|
):
|
||||||
|
"""失败时仍保存 raw_output.txt。"""
|
||||||
|
with (
|
||||||
|
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||||||
|
patch(
|
||||||
|
"app.services.summarizer._call_pi",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value="Some output without JSON",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await summarize_one(db_session, sample_paper)
|
||||||
|
|
||||||
|
raw_file = tmp_path / "papers" / sample_paper.arxiv_id / "raw_output.txt"
|
||||||
|
assert raw_file.exists()
|
||||||
|
assert "Some output without JSON" in raw_file.read_text()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tmp_cleaned_on_success(
|
||||||
|
self, db_session, sample_paper, mock_pi_output, tmp_path, _patch_paths
|
||||||
|
):
|
||||||
|
"""成功后清理 tmp 目录。"""
|
||||||
|
with (
|
||||||
|
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||||||
|
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||||||
|
):
|
||||||
|
await summarize_one(db_session, sample_paper)
|
||||||
|
|
||||||
|
tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id
|
||||||
|
assert not tmp_paper.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tmp_cleaned_on_failure(
|
||||||
|
self, db_session, sample_paper, tmp_path, _patch_paths
|
||||||
|
):
|
||||||
|
"""失败后也清理 tmp 目录。"""
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.services.summarizer._download_pdf",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=PdfDownloadError("fail"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await summarize_one(db_session, sample_paper)
|
||||||
|
|
||||||
|
tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id
|
||||||
|
assert not tmp_paper.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_done_paper(self, db_session, sample_paper, _patch_paths):
|
||||||
|
"""已完成的论文跳过。"""
|
||||||
|
sample_paper.summary_status.status = "done"
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
result = await summarize_one(db_session, sample_paper)
|
||||||
|
assert result["status"] == "skipped"
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 批量操作测试
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchSummarize:
|
||||||
|
"""批量总结测试。"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def _patch_paths(self, tmp_path):
|
||||||
|
with (
|
||||||
|
patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"),
|
||||||
|
patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"),
|
||||||
|
patch("app.services.summarizer._DATA_DIR", tmp_path),
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_multiple_papers(
|
||||||
|
self, db_session, db_engine, mock_pi_output, _patch_paths
|
||||||
|
):
|
||||||
|
"""批量处理多篇论文。"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
for i in range(3):
|
||||||
|
p = Paper(
|
||||||
|
arxiv_id=f"2401.1234{i}",
|
||||||
|
title_en=f"Test Paper {i}",
|
||||||
|
abstract=f"Abstract {i}",
|
||||||
|
paper_date=date(2024, 1, 15),
|
||||||
|
crawled_at=now,
|
||||||
|
pdf_url=f"https://arxiv.org/pdf/2401.1234{i}.pdf",
|
||||||
|
)
|
||||||
|
db_session.add(p)
|
||||||
|
db_session.flush()
|
||||||
|
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
|
||||||
|
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# 每个 worker 用独立 session(同一个内存引擎)
|
||||||
|
from sqlalchemy.orm import sessionmaker as _sm
|
||||||
|
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||||||
|
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||||||
|
):
|
||||||
|
result = await summarize_batch(
|
||||||
|
db_session, _session_factory=_TestSession
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "success"
|
||||||
|
assert result["done"] == 3
|
||||||
|
assert result["failed"] == 0
|
||||||
|
|
||||||
|
# 验证 CrawlLog
|
||||||
|
log = db_session.query(CrawlLog).filter(CrawlLog.task == "summarize").first()
|
||||||
|
assert log is not None
|
||||||
|
assert log.status == "success"
|
||||||
|
assert log.papers_found == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_failure_no_block(
|
||||||
|
self, db_session, db_engine, mock_pi_output, _patch_paths
|
||||||
|
):
|
||||||
|
"""一篇失败不阻塞其他。"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
for i in range(2):
|
||||||
|
p = Paper(
|
||||||
|
arxiv_id=f"2401.5678{i}",
|
||||||
|
title_en=f"Paper {i}",
|
||||||
|
abstract=f"Abstract {i}",
|
||||||
|
paper_date=date(2024, 1, 15),
|
||||||
|
crawled_at=now,
|
||||||
|
pdf_url=f"https://arxiv.org/pdf/2401.5678{i}.pdf",
|
||||||
|
)
|
||||||
|
db_session.add(p)
|
||||||
|
db_session.flush()
|
||||||
|
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
|
||||||
|
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
from sqlalchemy.orm import sessionmaker as _sm
|
||||||
|
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def _mock_call_pi(meta_path, pdf_path):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
raise PiTimeoutError("timeout")
|
||||||
|
return mock_pi_output
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||||||
|
patch("app.services.summarizer._call_pi", side_effect=_mock_call_pi),
|
||||||
|
):
|
||||||
|
result = await summarize_batch(
|
||||||
|
db_session, _session_factory=_TestSession
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["done"] == 1
|
||||||
|
assert result["failed"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_task_lock_conflict(self, db_session, _patch_paths):
|
||||||
|
"""TaskLock 防止并发 batch。"""
|
||||||
|
# 先插入一个 running 锁
|
||||||
|
db_session.add(
|
||||||
|
TaskLock(
|
||||||
|
task="summarize",
|
||||||
|
lock_key="batch",
|
||||||
|
status="running",
|
||||||
|
acquired_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
result = await summarize_batch(db_session)
|
||||||
|
assert result["status"] == "conflict"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_task_lock_released(self, db_session, db_engine, mock_pi_output, _patch_paths):
|
||||||
|
"""完成后释放 TaskLock。"""
|
||||||
|
from sqlalchemy.orm import sessionmaker as _sm
|
||||||
|
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||||||
|
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||||||
|
):
|
||||||
|
await summarize_batch(
|
||||||
|
db_session, _session_factory=_TestSession
|
||||||
|
)
|
||||||
|
|
||||||
|
locks = db_session.query(TaskLock).filter(
|
||||||
|
TaskLock.task == "summarize",
|
||||||
|
TaskLock.lock_key == "batch",
|
||||||
|
).all()
|
||||||
|
for lock in locks:
|
||||||
|
assert lock.status == "finished"
|
||||||
|
assert lock.released_at is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_empty(self, db_session, _patch_paths):
|
||||||
|
"""无 pending 论文时返回空结果。"""
|
||||||
|
result = await summarize_batch(db_session)
|
||||||
|
assert result["status"] == "success"
|
||||||
|
assert result["total"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Admin 路由鉴权测试
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdminAuth:
|
||||||
|
"""管理接口鉴权 — 只测 HTTP 层,mock 掉实际服务调用。"""
|
||||||
|
|
||||||
|
def test_no_token_returns_401(self, client):
|
||||||
|
"""无 Bearer token 返回 401。"""
|
||||||
|
resp = client.post("/admin/summarize")
|
||||||
|
assert resp.status_code in (401, 403)
|
||||||
|
|
||||||
|
def test_wrong_token_returns_401(self, client):
|
||||||
|
resp = client.post(
|
||||||
|
"/admin/summarize",
|
||||||
|
headers={"Authorization": "Bearer wrong-token"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_correct_token_batch(self, client, admin_headers):
|
||||||
|
"""正确 token 调用 batch summarize,mock 掉服务层。"""
|
||||||
|
import app.config as config_mod
|
||||||
|
|
||||||
|
original = config_mod.settings.ADMIN_TOKEN
|
||||||
|
config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345"
|
||||||
|
try:
|
||||||
|
with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock:
|
||||||
|
mock.return_value = {"status": "success", "done": 0, "failed": 0, "total": 0}
|
||||||
|
resp = client.post("/admin/summarize", headers=admin_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "success"
|
||||||
|
finally:
|
||||||
|
config_mod.settings.ADMIN_TOKEN = original
|
||||||
|
|
||||||
|
def test_single_paper_not_found(self, client, admin_headers):
|
||||||
|
"""单篇总结不存在的论文返回 404。"""
|
||||||
|
import app.config as config_mod
|
||||||
|
|
||||||
|
original = config_mod.settings.ADMIN_TOKEN
|
||||||
|
config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345"
|
||||||
|
try:
|
||||||
|
with patch(
|
||||||
|
"app.routes.admin.summarize_single",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"},
|
||||||
|
):
|
||||||
|
resp = client.post(
|
||||||
|
"/admin/summarize/nonexistent.99999",
|
||||||
|
headers=admin_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
finally:
|
||||||
|
config_mod.settings.ADMIN_TOKEN = original
|
||||||
Reference in New Issue
Block a user