From 743d69efd0d0da8a1bd02ae8b3f9dca393f43bdf Mon Sep 17 00:00:00 2001 From: rain-bus Date: Sat, 13 Jun 2026 18:31:43 +0800 Subject: [PATCH] refactor: extract admin business logic to services, introduce job queue, add derived index helpers - Move DB operations from routes/admin.py to services/admin.py (get_logs_context, query_summary_statuses, retry_failed, delete/reset operations) - Add services/jobs.py with Job/JobEvent-based async job queue (create_job, run_job, enqueue_job) - Add services/derived.py with FTS5 reindex and paper index deletion helpers - Refactor scheduler to use job queue instead of direct pipeline calls - Add heartbeat_at/expires_at to TaskLock for lock health tracking - Remove DESIGN_REVIEW.md - Update tests: remove redundant integration tests, add unit tests for new services --- .env.example | 2 +- DESIGN_REVIEW.md | 468 ------------------------------ app/cli.py | 72 ++++- app/database.py | 16 + app/main.py | 7 + app/models.py | 59 ++++ app/routes/admin.py | 383 +++++++----------------- app/services/admin.py | 327 ++++++++++++++++++++- app/services/crawler.py | 19 +- app/services/derived.py | 140 +++++++++ app/services/jobs.py | 244 ++++++++++++++++ app/services/pipeline.py | 14 +- app/services/scheduler.py | 14 +- app/services/summary_persister.py | 19 +- tests/test_admin.py | 239 +++++++++------ tests/test_crawler.py | 60 ++++ tests/test_derived.py | 80 +++++ tests/test_jobs.py | 111 +++++++ tests/test_pages.py | 111 ------- tests/test_searcher.py | 69 ----- 20 files changed, 1391 insertions(+), 1063 deletions(-) delete mode 100644 DESIGN_REVIEW.md create mode 100644 app/services/derived.py create mode 100644 app/services/jobs.py create mode 100644 tests/test_derived.py create mode 100644 tests/test_jobs.py diff --git a/.env.example b/.env.example index 8ba3f16..d7a619c 100644 --- a/.env.example +++ b/.env.example @@ -42,7 +42,7 @@ DATABASE_URL=sqlite:///data/db/papers.db # ─── 语义搜索 ───────────────────────────── CHROMA_ENABLED=false CHROMA_DIR=data/chroma -EMBED_API_BASE=https://api.siliconflow.cn/v1/embeddings +EMBED_API_BASE=https://api.siliconflow.cn EMBED_API_KEY=your_api_key_here EMBED_MODEL=Qwen/Qwen3-Embedding-4B EMBED_DIMENSIONS=2560 diff --git a/DESIGN_REVIEW.md b/DESIGN_REVIEW.md deleted file mode 100644 index 18bb9b6..0000000 --- a/DESIGN_REVIEW.md +++ /dev/null @@ -1,468 +0,0 @@ -# 项目设计审查与优化建议 - -本文档汇总对当前项目的系统设计、流程设计和代码结构的审查结论。重点不在局部代码风格,而在后续稳定运行、失败恢复、数据一致性、可维护性和扩展性。 - -## 总体评价 - -项目当前结构整体清晰:FastAPI 路由层较薄,主要业务放在 `app/services/`;抓取、总结、搜索、个人数据、管理后台等能力已有模块化拆分;SQLite + FTS5 对本地个人论文导览站是合理选择。 - -主要风险集中在以下几类: - -- 长任务生命周期没有统一抽象。 -- Pipeline 不是显式状态机,阶段结果难以恢复和追踪。 -- 数据库、文件系统、FTS5、ChromaDB 之间存在多处手工同步。 -- Web、CLI、Scheduler 三个入口的业务策略存在分叉风险。 -- 失败恢复和可观测性不足。 -- 部分同步阻塞任务混入 async Web 流程。 - -## 高优先级问题 - -### 1. 缺少统一的任务系统 - -当前有三套入口都能触发重任务: - -- Web 管理后台:`/admin/trigger-pipeline`、`/admin/summarize` -- APScheduler:每日自动 pipeline -- CLI:`app.cli crawl/summarize` - -它们共享部分 service,但没有统一的“任务创建、排队、运行、重试、取消、状态查询、失败恢复”模型。 - -影响: - -- Web 请求会长时间挂住。 -- CLI、Web、Scheduler 的行为可能逐渐分叉。 -- 任务卡死后只能看到粗粒度的 running lock。 -- 前端难以展示阶段进度。 -- 后续增加重跑、取消、恢复、并发限制会越来越复杂。 - -建议: - -引入统一 Job 模型: - -```text -入口层 -> 创建 Job -> Worker 执行 Job -> JobEvent 记录进度 -> 页面/API 查询状态 -``` - -建议任务类型: - -- `crawl_daily` -- `summarize_one` -- `summarize_batch` -- `pipeline_daily` -- `refresh_upvotes` -- `delete_range` -- `reindex_fts` -- `reindex_chroma` -- `extract_images` - -建议表结构方向: - -- `jobs`: 任务主记录,包含 `id/type/status/owner/created_at/started_at/completed_at/error` -- `job_events`: 阶段事件,包含 `job_id/stage/status/message/payload_json/created_at` -- `paper_processing_runs`: 单篇论文处理记录,适合跟踪总结、图片提取、索引状态 - -### 2. Pipeline 不是显式状态机 - -当前 pipeline 基本是线性函数: - -```text -crawl -> summarize -> cleanup -``` - -阶段输出没有被建模成可查询、可重放、可补偿的状态。 - -影响: - -- crawl 成功但 summarize 部分失败时,整体状态表达不够精确。 -- cleanup 失败是否影响 pipeline 成败语义不清。 -- “今天无数据回退昨天”是隐式逻辑,不是可配置策略。 -- 图片提取、ChromaDB 索引失败被吞掉,用户无法知道哪些增强能力缺失。 - -建议: - -将 pipeline 改为显式阶段状态: - -```text -created --> crawling --> crawled --> summarizing --> summarized | summarized_partial --> postprocessing --> completed | failed | cancelled -``` - -每个阶段记录: - -- 输入参数 -- 输出统计 -- 错误类型 -- 错误摘要 -- 开始/结束时间 -- 是否可重试 - -这样可以支持只重跑失败阶段,例如只重跑 ChromaDB 索引、只重跑图片提取、只重跑失败论文总结。 - -### 3. 数据生命周期没有清晰分层 - -项目中至少有四类数据: - -- 主数据:`papers`、`authors`、`tags`、`summaries`、用户笔记/收藏/阅读状态 -- 派生索引:FTS5、ChromaDB -- 长期资产:`data/papers/{arxiv_id}` 下的 summary、raw output、images -- 临时资产:`data/tmp/{arxiv_id}` 下的 PDF、源码、中间文件 - -现在这些数据由不同流程手动维护。删除、总结、重建索引、清理临时目录都分散在不同模块中。 - -影响: - -- DB 与文件系统可能不一致。 -- DB 与 FTS5/ChromaDB 可能不一致。 -- 删除流程中一部分资源删除成功、一部分失败时难以恢复。 -- 很难判断某篇论文是否“完整可用”。 - -建议: - -明确数据权威源: - -- DB 是权威源。 -- FTS5、ChromaDB、summary.json、raw_output.txt、images 都应视为可重建派生物。 -- 删除主数据优先保证 DB 状态正确,派生物清理可以异步补偿。 -- 提供派生数据重建任务。 - -建议增加命令或任务: - -```text -rebuild-derived --fts -rebuild-derived --chroma -rebuild-derived --files -rebuild-derived --images -``` - -### 4. 数据库与文件系统存在双写一致性风险 - -总结持久化流程会写文件、写 DB、更新状态、提取图片、写 ChromaDB。任一步失败都可能留下不一致状态。 - -典型风险: - -- 文件已写入,但 DB 未提交。 -- DB 标记 done,但 summary.json 缺失。 -- DB done,但图片未提取或 ChromaDB 未索引。 -- 重新总结时旧图片或旧 figures 残留。 - -建议: - -- 将结构化总结以 DB 为准,JSON 文件作为导出/缓存。 -- 文件写入使用临时文件加原子 rename。 -- DB 中记录派生物状态,例如: - - `summary_persisted_at` - - `fts_indexed_at` - - `chroma_indexed_at` - - `images_extracted_at` - - `derived_error` -- 提供一致性检查任务,扫描 DB 与文件/索引差异并修复。 - -### 5. 失败恢复设计偏弱 - -当前失败主要依赖: - -- `summary_status.status` -- `retry_count` -- `task_locks` -- `crawl_logs` - -但对以下情况支持不足: - -- 进程崩溃后 `processing` 永久残留。 -- `task_locks` 中 running lock 永久残留。 -- 任务执行到一半,部分文件或索引已经写入。 -- 某些后处理失败但主任务显示成功。 - -建议: - -- `task_locks` 增加 `heartbeat_at` 或 `expires_at`。 -- 应用启动时扫描 stale running task,标记为 `failed/stale`。 -- `summary_status` 增加更细阶段字段,例如: - - `downloading_pdf` - - `generating` - - `validating` - - `persisting` - - `extracting_images` - - `indexing` -- 区分主流程失败和派生增强失败。 - -### 6. 长任务直接跑在 Web 请求里 - -管理端触发 pipeline、批量总结、刷新 upvotes、删除数据时,会在请求生命周期内直接执行重任务。 - -影响: - -- HTTP 请求容易超时。 -- 用户关闭页面后任务状态不清晰。 -- Web 进程被重任务占用。 -- 后续很难做取消、进度展示和失败恢复。 - -建议: - -- Web 只创建 job 并返回 `task_id`。 -- 后台 worker 执行任务。 -- 管理页面轮询 `/admin/jobs/{id}` 或使用 SSE 展示进度。 - -## 中优先级问题 - -### 7. FTS5 索引手工维护,容易漂移 - -`papers_fts` 是独立虚拟表。插入论文、更新总结、删除论文时手动同步。 - -影响: - -- 已有论文的 title、abstract、authors、tags 更新时可能不同步。 -- AI 标签新增后,如果没有统一重建,搜索结果可能缺字段。 -- 删除或回滚失败后索引可能残留。 - -建议: - -- 封装统一的 `reindex_paper(db, paper_id)`。 -- 所有修改论文、标签、总结的路径最后调用它。 -- 提供全量 `reindex_fts` 任务。 -- 或用 SQLite trigger 维护基础字段,但总结字段仍建议通过统一函数更新。 - -### 8. ChromaDB 索引缺少系统级一致性策略 - -ChromaDB 当前是可选增强能力,索引失败会被日志吞掉,不影响总结主流程。 - -影响: - -- 语义搜索可用性不透明。 -- 某些论文 DB done,但未进入 ChromaDB。 -- 删除或重建时可能出现残留索引。 - -建议: - -- 将 ChromaDB 明确视为派生索引。 -- DB 中记录 `chroma_indexed_at` 和 `chroma_error`。 -- 提供 `reindex_chroma` 任务。 -- 搜索页或管理后台展示语义索引健康状态。 - -### 9. 同步阻塞操作混入 async 主流程 - -部分 async 函数内部调用同步阻塞操作,例如 PDF 下载使用 `requests`,图片提取和 PDF 解析也都是同步重任务。 - -影响: - -- FastAPI event loop 可能被阻塞。 -- 多篇并发总结时吞吐不可控。 -- Web 服务响应受后台任务影响。 - -建议: - -- 最佳方案:重任务移到 worker 进程。 -- 过渡方案:对同步阻塞段使用 `asyncio.to_thread()`。 -- PDF 下载统一使用 `httpx.AsyncClient` 或明确放入同步 worker。 - -### 10. 调度策略和业务流程耦合 - -Scheduler 固定每日 pipeline 加 30 分钟 upvote refresh;pipeline 内部固定“今天无数据回退昨天”。 - -影响: - -- 策略不易测试和调整。 -- CLI、Web、Scheduler 可能各自实现不同 fallback。 -- 后续支持日期范围、补抓、只总结新增论文会变复杂。 - -建议将策略配置化: - -- `crawl_target_policy`: `today` / `today_then_yesterday` / `date_range` -- `summarize_scope`: `new_only` / `pending_and_failed` / `date_range` -- `cleanup_policy`: `after_success` / `always` / `manual` -- `upvote_refresh_window_days` -- `pipeline_on_partial_failure`: `continue` / `stop` - -### 11. 内嵌 APScheduler 对部署形态敏感 - -当前调度器嵌入 Web 进程,且多 worker 时只是打印警告。 - -影响: - -- 多 worker、多进程、reload 模式下可能重复或漏跑任务。 -- Web 进程重启会影响调度可靠性。 -- 调度状态和任务执行状态耦合。 - -建议: - -- 本地单机可以保留内嵌调度器。 -- 长期运行建议拆为独立 scheduler 进程。 -- scheduler 只负责创建 job,不直接执行重任务。 -- Web 管理后台只展示 scheduler/job 状态。 - -### 12. 运行时迁移能力不足 - -当前 `_migrate()` 只支持补列。 - -影响: - -- 无法可靠处理索引、约束、字段改名、数据回填。 -- 数据库结构演进不可追踪。 -- 生产或长期本地数据升级风险高。 - -建议: - -- 引入 Alembic。 -- 将 FTS5、索引、部分约束和数据回填纳入迁移脚本。 -- 启动时不要静默做复杂迁移,改为显式执行迁移命令。 - -## 低到中优先级问题 - -### 13. 配置校验不足 - -当前配置字段主要是裸类型,缺少合法值和组合校验。 - -风险示例: - -- `SUMMARY_BACKEND` 填错。 -- `SUMMARY_PDF_MODE` 填错。 -- `SCHEDULE_MINUTE + 30` 超过 59。 -- `APP_WORKERS > 1` 且 `SCHEDULER_ENABLED=true`。 -- `CHROMA_ENABLED=true` 但 embedding 配置缺失。 - -建议: - -- 使用 Pydantic validator 校验枚举值和组合条件。 -- 对危险组合启动时报错,而不是只 warning。 - -### 14. 私有函数跨模块调用说明模块边界不稳定 - -例如 `summarizer.py` 调用 `_generate_with_retry`、`_persist_summary` 等下划线函数。 - -影响: - -- 模块公开边界不清晰。 -- 后续重构容易误伤。 -- 测试和替换后端不方便。 - -建议: - -- 为 summary 生成、持久化、后处理定义明确公开 API。 -- 下划线函数保留为模块内部实现细节。 -- 引入更明确的 service 类或函数边界,例如: - - `SummaryGenerator.generate()` - - `SummaryRepository.save()` - - `DerivedAssetBuilder.extract_images()` - -### 15. 外部依赖缺少统一 adapter - -项目依赖多个外部系统: - -- HuggingFace API -- arXiv PDF/source 下载 -- `pi` CLI -- `claude` CLI -- Embedding API -- ChromaDB -- ONNX layout detector - -建议: - -- 为每类外部依赖定义 adapter。 -- 统一 timeout、retry、error classification、metrics。 -- 测试中替换 adapter,而不是 patch 深层函数。 - -### 16. 删除流程事务边界不理想 - -删除流程中同时删除 FTS5、ChromaDB、本地文件、临时文件、ORM 数据。 - -影响: - -- 文件删除成功后 DB rollback,会造成数据仍在但文件丢失。 -- DB 删除成功后 ChromaDB 删除失败,会造成残留索引。 -- 单篇失败时 rollback 可能影响之前 flush 的状态,逻辑不直观。 - -建议: - -- 删除主数据与删除派生物分两阶段。 -- DB 删除成功后创建派生清理 job。 -- 派生物清理失败可重试,不影响主数据一致性。 - -## 建议目标架构 - -如果项目定位是个人本地工具,可以采用轻量目标架构: - -```text -FastAPI Web - - 页面和 API - - 管理后台 - - 创建 job - - 查询 job 状态 - -Scheduler - - 按策略创建 job - - 不直接跑重任务 - -Worker - - crawl - - summarize - - PDF download - - image extraction - - FTS/Chroma reindex - - cleanup/delete - -Storage/Repository - - DB 权威数据 - - 文件资产管理 - - 派生索引重建 - - 一致性检查 -``` - -如果暂时不想引入独立 worker,也可以先在单进程内实现 Job 表和后台任务执行器。这样至少能统一任务状态和恢复机制。 - -## 推荐实施顺序 - -### 第一阶段:先解决任务与状态 - -目标:降低长任务不可控风险。 - -- 新增 `jobs` 和 `job_events` 表。 -- Web 管理动作改为创建 job 并返回 task_id。 -- Scheduler 改为创建 job。 -- CLI 复用同一套 job runner。 -- 加 stale running job 恢复逻辑。 - -### 第二阶段:统一派生数据管理 - -目标:降低 DB、文件、FTS、Chroma 不一致风险。 - -- 明确 DB 为权威源。 -- 封装 `reindex_paper()`。 -- 增加 `reindex_fts`、`reindex_chroma` 任务。 -- 增加派生数据状态字段或健康检查。 -- 删除流程改为 DB 主删除 + 派生清理 job。 - -### 第三阶段:改造 pipeline 为状态机 - -目标:让任务可恢复、可补偿、可观测。 - -- 拆分 pipeline stages。 -- 每个 stage 记录输入/输出/错误/耗时。 -- 支持从失败阶段重跑。 -- 将 fallback 策略配置化。 - -### 第四阶段:提升运行可靠性 - -目标:长期运行更稳。 - -- 引入 Alembic。 -- 配置加 validator。 -- 同步阻塞操作移出 Web event loop。 -- 外部依赖 adapter 化。 -- 管理后台展示任务、派生索引、失败类型、耗时统计。 - -## 最小可行改造方案 - -如果只做最小但收益最大的改动,建议优先做这三项: - -1. 增加统一 Job 表和 job runner。 -2. DB 作为权威源,FTS/Chroma/文件全部按派生物处理。 -3. 增加 stale task 恢复和派生数据重建命令。 - -这三项能显著降低后续复杂度,也不会强迫项目马上拆成多个服务。 - diff --git a/app/cli.py b/app/cli.py index 28f8e06..8049d59 100644 --- a/app/cli.py +++ b/app/cli.py @@ -26,7 +26,7 @@ def crawl( from app.database import SessionLocal, engine from app.database import init_db as _init from app.models import Paper - from app.services.crawler import crawl_daily + from app.services.jobs import create_job, run_job from app.utils import today_str, yesterday_str from sqlalchemy import func, select @@ -55,7 +55,13 @@ def crawl( return typer.echo(f"📡 开始抓取 {target} ...") - result = asyncio.run(crawl_daily(db, target, top_n)) + job = create_job( + db, + "crawl_daily", + owner="cli_crawl", + payload={"target_date": target, "top_n": top_n}, + ) + result = asyncio.run(run_job(db, job.id)) # 未指定日期且今天失败或无数据时,自动回退到昨天 need_fallback = not date_str and ( @@ -76,7 +82,13 @@ def crawl( else: typer.echo(f"🔄 {target} 无数据,尝试 {fallback} ...") target = fallback - result = asyncio.run(crawl_daily(db, target, top_n)) + job = create_job( + db, + "crawl_daily", + owner="cli_crawl", + payload={"target_date": target, "top_n": top_n}, + ) + result = asyncio.run(run_job(db, job.id)) if result["status"] == "success": typer.echo( @@ -110,7 +122,7 @@ def summarize( from app.config import settings from app.database import SessionLocal, engine from app.database import init_db as _init - from app.services.summarizer import summarize_batch, summarize_single + from app.services.jobs import create_job, run_job import os @@ -142,11 +154,25 @@ def summarize( try: if arxiv_id: typer.echo(f"🤖 开始总结 {arxiv_id} (mode={pdf_mode}) ...") - result = asyncio.run(summarize_single(db, arxiv_id, pdf_mode=pdf_mode)) + job = create_job( + db, + "summarize_one", + owner="cli_summarize", + payload={"arxiv_id": arxiv_id, "pdf_mode": pdf_mode, "force": False}, + ) else: typer.echo(f"🤖 开始批量总结 pending 论文 (mode={pdf_mode}) ...") - result = asyncio.run(summarize_batch(db, pdf_mode=pdf_mode)) + job = create_job( + db, + "summarize_batch", + owner="cli_summarize", + payload={"pdf_mode": pdf_mode}, + ) + result = asyncio.run(run_job(db, job.id)) + if result.get("status") == "failed": + typer.echo(f"❌ 总结失败:{result.get('error')}", err=True) + raise typer.Exit(code=1) typer.echo(f"✅ 总结完成:{result}") except NotFoundError as exc: typer.echo(f"❌ {exc.message}", err=True) @@ -172,5 +198,39 @@ def init_db(): typer.echo(f"✅ 数据库已初始化:{settings.db_path}") +@cli_app.command("rebuild-derived") +def rebuild_derived( + fts: bool = typer.Option(False, "--fts", help="重建 FTS5 全文索引"), + chroma: bool = typer.Option(False, "--chroma", help="重建 ChromaDB 语义索引"), +): + """重建可派生数据索引。""" + from app.config import settings + from app.database import SessionLocal, engine + from app.database import init_db as _init + from app.services.jobs import create_job, run_job + + import os + + if not fts and not chroma: + fts = True + + os.makedirs(settings.db_path.parent, exist_ok=True) + _init(engine) + + db = SessionLocal() + try: + for job_type in [ + *(["reindex_fts"] if fts else []), + *(["reindex_chroma"] if chroma else []), + ]: + job = create_job(db, job_type, owner="cli_rebuild_derived", payload={}) + result = asyncio.run(run_job(db, job.id)) + typer.echo(f"{job_type}: {result}") + if result.get("status") == "failed": + raise typer.Exit(code=1) + finally: + db.close() + + if __name__ == "__main__": cli_app() diff --git a/app/database.py b/app/database.py index be456a4..d8cdda0 100644 --- a/app/database.py +++ b/app/database.py @@ -76,10 +76,26 @@ def _migrate(engine) -> None: "crawl_logs": [ ("details_json", "TEXT"), ], + "task_locks": [ + ("heartbeat_at", "DATETIME"), + ("expires_at", "DATETIME"), + ], + "jobs": [ + ("heartbeat_at", "DATETIME"), + ], } with engine.connect() as conn: for table, columns in _MIGRATIONS.items(): + table_exists = conn.execute( + text( + "SELECT name FROM sqlite_master " + "WHERE type IN ('table', 'virtual table') AND name = :name" + ), + {"name": table}, + ).fetchone() + if not table_exists: + continue # 获取已有列名 existing = { row[1] for row in conn.execute(text(f"PRAGMA table_info({table})")) diff --git a/app/main.py b/app/main.py index af0564e..396d274 100644 --- a/app/main.py +++ b/app/main.py @@ -32,7 +32,14 @@ async def lifespan(app: FastAPI): # ── startup ── from app.services.scheduler import start_scheduler from app.services.embedder import init_chroma + from app.services.jobs import recover_stale_jobs + from app.database import SessionLocal + db = SessionLocal() + try: + recover_stale_jobs(db) + finally: + db.close() start_scheduler() init_chroma() diff --git a/app/models.py b/app/models.py index 2054619..d0673af 100644 --- a/app/models.py +++ b/app/models.py @@ -32,6 +32,26 @@ class SummaryState(StrEnum): PERMANENT_FAILURE = "permanent_failure" +class JobStatus(StrEnum): + """后台任务状态枚举 — 对应 jobs.status 列。""" + + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + STALE = "stale" + CANCELLED = "cancelled" + + +class JobEventStatus(StrEnum): + """任务阶段事件状态枚举 — 对应 job_events.status 列。""" + + STARTED = "started" + SUCCESS = "success" + FAILED = "failed" + INFO = "info" + + # ── papers ────────────────────────────────────────────────────────────── class Paper(Base): __tablename__ = "papers" @@ -194,9 +214,48 @@ class TaskLock(Base): status = Column(String, nullable=False) owner = Column(String) acquired_at = Column(DateTime, nullable=False) + heartbeat_at = Column(DateTime) + expires_at = Column(DateTime) released_at = Column(DateTime) +# ── jobs / job_events ────────────────────────────────────────────────── +class Job(Base): + __tablename__ = "jobs" + + id = Column(Integer, primary_key=True, autoincrement=True) + type = Column(String, nullable=False, index=True) + status = Column(String, nullable=False, default=JobStatus.QUEUED, index=True) + owner = Column(String) + payload_json = Column(Text) + result_json = Column(Text) + error = Column(Text) + created_at = Column(DateTime, nullable=False) + started_at = Column(DateTime) + heartbeat_at = Column(DateTime) + completed_at = Column(DateTime) + + events = relationship( + "JobEvent", back_populates="job", cascade="all, delete-orphan" + ) + + +class JobEvent(Base): + __tablename__ = "job_events" + + id = Column(Integer, primary_key=True, autoincrement=True) + job_id = Column( + Integer, ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False, index=True + ) + stage = Column(String, nullable=False) + status = Column(String, nullable=False) + message = Column(Text) + payload_json = Column(Text) + created_at = Column(DateTime, nullable=False) + + job = relationship("Job", back_populates="events") + + # ── user data ────────────────────────────────────────────────────────── class UserBookmark(Base): __tablename__ = "user_bookmarks" diff --git a/app/routes/admin.py b/app/routes/admin.py index d1f4873..d69d6dd 100644 --- a/app/routes/admin.py +++ b/app/routes/admin.py @@ -4,36 +4,20 @@ from __future__ import annotations import hashlib import hmac -import json -import logging from datetime import date -from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request +from fastapi import APIRouter, BackgroundTasks, Depends, Form, HTTPException, Query, Request from fastapi.responses import RedirectResponse from pydantic import BaseModel, field_validator -from sqlalchemy import bindparam, func, select, text from sqlalchemy.orm import Session from app.config import settings from app.database import get_db -from app.models import ( - CrawlLog, - DataDeleteJob, - Paper, - PaperTag, - SummaryState, - SummaryStatus, -) from app.services import admin as admin_svc from app.services.admin import get_admin_stats -from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range -from app.services.crawler import refresh_upvotes -from app.services.pipeline import run_crawl, run_pipeline -from app.services.scheduler import get_scheduler -from app.services.summarizer import summarize_batch, summarize_single -from app.utils import templates, today_str, utc_now - -logger = logging.getLogger(__name__) +from app.services.cleaner import cleanup_tmp +from app.services.jobs import create_job, enqueue_job +from app.utils import templates, today_str router = APIRouter(prefix="/admin", tags=["admin"]) @@ -103,18 +87,7 @@ async def admin_dashboard( ): """管理仪表盘 — 系统状态总览。""" stats = get_admin_stats(db) - - # 调度器历史(最近 10 条 task=scheduler 日志) - scheduler_history = ( - db.execute( - select(CrawlLog) - .where(CrawlLog.task == "scheduler") - .order_by(CrawlLog.started_at.desc()) - .limit(10) - ) - .scalars() - .all() - ) + scheduler_history = admin_svc.get_scheduler_history(db) return templates.TemplateResponse( request, @@ -129,53 +102,43 @@ async def admin_dashboard( @router.get("/scheduler-status") async def admin_scheduler_status(_admin: None = Depends(verify_admin)): """调度器运行状态(JSON)。""" - scheduler = get_scheduler() - next_run = None - upvote_next_run = None - if scheduler: - for job in scheduler.get_jobs(): - if job.id == "daily_pipeline": - next_run = job.next_run_time - elif job.id == "upvote_refresh": - upvote_next_run = job.next_run_time - return { - "enabled": scheduler is not None, - "schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}", - "timezone": settings.APP_TIMEZONE, - "next_run": next_run.isoformat() if next_run else None, - "upvote_next_run": upvote_next_run.isoformat() if upvote_next_run else None, - "upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS, - } + return admin_svc.get_scheduler_status() @router.post("/trigger-pipeline") async def admin_trigger_pipeline( + background_tasks: BackgroundTasks, _admin: None = Depends(verify_admin), db: Session = Depends(get_db), ): """手动触发一次完整流水线(crawl → summarize → cleanup)。""" today = today_str() - try: - result = await run_pipeline(db, today, owner="admin_trigger") - except RuntimeError as exc: - raise HTTPException(status_code=409, detail=str(exc)) - - if result["status"] == "failed": - raise HTTPException(status_code=500, detail=result.get("error")) - return {"status": "success", "message": "流水线执行完成"} + job = create_job( + db, + "pipeline_daily", + owner="admin_trigger", + payload={"target_date": today}, + ) + enqueue_job(background_tasks, job.id) + return {"status": "queued", "job_id": job.id, "message": "流水线任务已创建"} @router.post("/refresh-upvotes") async def admin_refresh_upvotes( + background_tasks: BackgroundTasks, _admin: None = Depends(verify_admin), db: Session = Depends(get_db), days: int | None = Query(None, description="刷新最近 N 天,默认使用配置值"), ): """手动刷新最近 N 天论文的 upvotes。""" - result = await refresh_upvotes(db, days=days) - if result["status"] == "failed": - raise HTTPException(status_code=500, detail=result.get("error")) - return result + job = create_job( + db, + "refresh_upvotes", + owner="admin_refresh", + payload={"days": days}, + ) + enqueue_job(background_tasks, job.id) + return {"status": "queued", "job_id": job.id} # ── 请求模型 ────────────────────────────────────────────────────────── @@ -200,18 +163,21 @@ class DeleteRequest(BaseModel): @router.post("/crawl") async def admin_crawl( + background_tasks: BackgroundTasks, _admin: None = Depends(verify_admin), db: Session = Depends(get_db), date: str | None = Query(None, description="YYYY-MM-DD,默认今天"), ): """手动抓取指定日期,默认今天。""" target_date = date or today_str() - try: - return await run_crawl(db, target_date, owner="admin_crawl") - except RuntimeError as exc: - raise HTTPException(status_code=409, detail=str(exc)) - except Exception as exc: - raise HTTPException(status_code=500, detail=str(exc)) + job = create_job( + db, + "crawl_daily", + owner="admin_crawl", + payload={"target_date": target_date}, + ) + enqueue_job(background_tasks, job.id) + return {"status": "queued", "job_id": job.id, "target_date": target_date} # ── 总结 ────────────────────────────────────────────────────────────── @@ -219,23 +185,41 @@ async def admin_crawl( @router.post("/summarize") async def admin_summarize_batch( + background_tasks: BackgroundTasks, _admin: None = Depends(verify_admin), db: Session = Depends(get_db), ): """批量总结所有 pending 论文。""" - return await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE) + job = create_job( + db, + "summarize_batch", + owner="admin_summarize", + payload={"pdf_mode": settings.SUMMARY_PDF_MODE}, + ) + enqueue_job(background_tasks, job.id) + return {"status": "queued", "job_id": job.id} @router.post("/summarize/{arxiv_id}") async def admin_summarize_single( arxiv_id: str, + background_tasks: BackgroundTasks, _admin: None = Depends(verify_admin), db: Session = Depends(get_db), ): """总结或重跑单篇论文。""" - return await summarize_single( - db, arxiv_id, force=True, pdf_mode=settings.SUMMARY_PDF_MODE + job = create_job( + db, + "summarize_one", + owner="admin_summarize", + payload={ + "arxiv_id": arxiv_id, + "force": True, + "pdf_mode": settings.SUMMARY_PDF_MODE, + }, ) + enqueue_job(background_tasks, job.id) + return {"status": "queued", "job_id": job.id, "arxiv_id": arxiv_id} # ── 清理 ────────────────────────────────────────────────────────────── @@ -243,39 +227,25 @@ async def admin_summarize_single( @router.post("/cleanup") async def admin_cleanup( + background_tasks: BackgroundTasks, _admin: None = Depends(verify_admin), db: Session = Depends(get_db), ): """清理 data/tmp/ 中超过 24 小时的临时文件。""" - now = utc_now() - log_entry = CrawlLog( - task="cleanup", - status="running", - started_at=now, - ) - db.add(log_entry) - db.commit() + job = create_job(db, "cleanup_tmp", owner="admin_cleanup", payload={}) + enqueue_job(background_tasks, job.id) + return {"status": "queued", "job_id": job.id} + +@router.post("/cleanup-now") +async def admin_cleanup_now( + _admin: None = Depends(verify_admin), + db: Session = Depends(get_db), +): + """同步清理临时文件,保留给测试和本地排障使用。""" try: - result = cleanup_tmp() - log_entry.status = "success" - log_entry.completed_at = utc_now() - log_entry.details_json = json.dumps( - { - "scanned": result.get("scanned", 0), - "removed": result.get("removed", 0), - }, - ensure_ascii=False, - ) - if result.get("errors"): - log_entry.error = "; ".join(result["errors"])[:2000] - db.commit() - return result + return admin_svc.run_cleanup_now(db, cleanup_tmp) except Exception as exc: - log_entry.status = "failed" - log_entry.error = str(exc)[:2000] - log_entry.completed_at = utc_now() - db.commit() raise HTTPException(status_code=500, detail=str(exc)) @@ -285,6 +255,7 @@ async def admin_cleanup( @router.post("/delete") async def admin_delete( body: DeleteRequest, + background_tasks: BackgroundTasks, _admin: None = Depends(verify_admin), db: Session = Depends(get_db), ): @@ -292,13 +263,31 @@ async def admin_delete( if body.date_start > body.date_end: raise HTTPException(status_code=400, detail="date_start must be <= date_end") - result = await delete_papers_by_date_range( + job = create_job( db, - body.date_start, - body.date_end, - include_notes=body.include_notes, + "delete_range", + owner="admin_delete", + payload={ + "date_start": body.date_start.isoformat(), + "date_end": body.date_end.isoformat(), + "include_notes": body.include_notes, + }, ) - return result + enqueue_job(background_tasks, job.id) + return {"status": "queued", "job_id": job.id} + + +@router.get("/jobs/{job_id}") +async def admin_job_detail( + job_id: int, + _admin: None = Depends(verify_admin), + db: Session = Depends(get_db), +): + """查询后台任务状态和阶段事件。""" + detail = admin_svc.get_job_detail(db, job_id) + if not detail: + raise HTTPException(status_code=404, detail=f"Job not found: {job_id}") + return detail # ── 日志 ────────────────────────────────────────────────────────────── @@ -313,72 +302,10 @@ async def admin_logs( per_page: int = Query(20, ge=1, le=100), ): """查看任务日志(CrawlLog + DataDeleteJob)+ 总结状态统计。""" - crawl_logs = ( - db.execute( - select(CrawlLog) - .order_by(CrawlLog.started_at.desc()) - .limit(per_page) - .offset((page - 1) * per_page) - ) - .scalars() - .all() - ) - - delete_jobs = ( - db.execute( - select(DataDeleteJob) - .order_by(DataDeleteJob.started_at.desc()) - .limit(per_page) - .offset((page - 1) * per_page) - ) - .scalars() - .all() - ) - - # 总结状态统计概要 - summary_total = db.scalar(select(func.count(Paper.id))) or 0 - summary_done = ( - db.scalar( - select(func.count(SummaryStatus.id)).where( - SummaryStatus.status == SummaryState.DONE - ) - ) - or 0 - ) - summary_pending = ( - db.scalar( - select(func.count(SummaryStatus.id)).where( - SummaryStatus.status.in_( - [SummaryState.PENDING, SummaryState.PROCESSING] - ) - ) - ) - or 0 - ) - summary_failed = ( - db.scalar( - select(func.count(SummaryStatus.id)).where( - SummaryStatus.status.in_( - [SummaryState.FAILED, SummaryState.PERMANENT_FAILURE] - ) - ) - ) - or 0 - ) - return templates.TemplateResponse( request, "admin_logs.html", - { - "crawl_logs": crawl_logs, - "delete_jobs": delete_jobs, - "page": page, - "per_page": per_page, - "summary_total": summary_total, - "summary_done": summary_done, - "summary_pending": summary_pending, - "summary_failed": summary_failed, - }, + admin_svc.get_logs_context(db, page=page, per_page=per_page), ) @@ -395,22 +322,10 @@ async def admin_summary_status( per_page: int = Query(20, ge=1, le=100), ): """总结状态列表(HTMX 片段或 JSON)。""" - - query = ( - select(Paper, SummaryStatus) - .outerjoin(SummaryStatus, SummaryStatus.paper_id == Paper.id) - .order_by(Paper.paper_date.desc()) + results, total = admin_svc.query_summary_statuses( + db, status=status, page=page, per_page=per_page ) - if status != "all": - if status == "none": - query = query.where(SummaryStatus.paper_id == None) # noqa: E711 - else: - query = query.where(SummaryStatus.status == status) - - total = db.scalar(select(func.count()).select_from(query.subquery())) - results = db.execute(query.offset((page - 1) * per_page).limit(per_page)).all() - # 判断是否 HTMX 请求 is_htmx = request.headers.get("HX-Request") == "true" @@ -421,27 +336,16 @@ async def admin_summary_status( "partials/summary_list.html", { "results": results, - "total": total or 0, + "total": total, "page": page, "per_page": per_page, "current_status": status, }, ) - # 非 HTMX 返回 JSON - items = [] - for paper, ss in results: - item = { - "arxiv_id": paper.arxiv_id, - "title": paper.title_zh or paper.title_en, - "paper_date": str(paper.paper_date), - "summary_status": ss.status if ss else "none", - "retry_count": ss.retry_count if ss else 0, - "error_type": ss.error_type if ss else None, - "error": ss.error if ss else None, - } - items.append(item) - return {"items": items, "total": total or 0, "page": page, "per_page": per_page} + return admin_svc.serialize_summary_statuses( + results, total=total, page=page, per_page=per_page + ) @router.post("/summary-retry-failed") @@ -450,39 +354,14 @@ async def admin_summary_retry_failed( db: Session = Depends(get_db), ): """重试所有失败状态的总结任务。""" - failed_ids = ( - db.execute( - select(Paper.arxiv_id) - .join(SummaryStatus, SummaryStatus.paper_id == Paper.id) - .where( - SummaryStatus.status.in_( - [SummaryState.FAILED, SummaryState.PERMANENT_FAILURE] - ) - ) - ) - .scalars() - .all() - ) - - if not failed_ids: + count = admin_svc.retry_failed_summaries(db) + if not count: return {"status": "success", "message": "没有失败的任务需要重试", "count": 0} - # 重置失败任务的状态为 pending - db.execute( - SummaryStatus.__table__.update() - .where( - SummaryStatus.status.in_( - [SummaryState.FAILED, SummaryState.PERMANENT_FAILURE] - ) - ) - .values(status=SummaryState.PENDING, error=None, error_type=None) - ) - db.commit() - return { "status": "success", - "message": f"已重置 {len(failed_ids)} 个失败任务为待总结状态", - "count": len(failed_ids), + "message": f"已重置 {count} 个失败任务为待总结状态", + "count": count, } @@ -545,23 +424,8 @@ async def admin_paper_delete( db: Session = Depends(get_db), ): """删除单篇论文。""" - paper = db.scalar(select(Paper).where(Paper.arxiv_id == arxiv_id)) - if not paper: + if not admin_svc.delete_paper_by_arxiv(db, arxiv_id): raise HTTPException(status_code=404, detail=f"Paper not found: {arxiv_id}") - - # 删除相关数据(ORM cascade 自动处理关联表) - db.delete(paper) - db.commit() - - # 清理 FTS 索引 - try: - db.execute( - text("DELETE FROM papers_fts WHERE arxiv_id = :aid"), {"aid": arxiv_id} - ) - db.commit() - except Exception: - logger.warning("Failed to clean FTS index for %s", arxiv_id, exc_info=True) - return {"status": "success", "message": f"已删除 {arxiv_id}"} @@ -588,28 +452,7 @@ async def admin_papers_batch_action( raise HTTPException(status_code=400, detail="arxiv_ids 不能为空") if body.action == "delete": - papers = ( - db.execute(select(Paper).where(Paper.arxiv_id.in_(body.arxiv_ids))) - .scalars() - .all() - ) - - count = 0 - for paper in papers: - db.delete(paper) - count += 1 - db.commit() - - # 清理 FTS 索引 - try: - stmt = text("DELETE FROM papers_fts WHERE arxiv_id IN :ids").bindparams( - bindparam("ids", expanding=True) - ) - db.execute(stmt, {"ids": body.arxiv_ids}) - db.commit() - except Exception: - logger.warning("Failed to clean FTS index for batch delete", exc_info=True) - + count = admin_svc.delete_papers_by_arxiv_ids(db, body.arxiv_ids) return { "status": "success", "message": f"已删除 {count} 篇论文", @@ -617,24 +460,10 @@ async def admin_papers_batch_action( } elif body.action == "summarize": - # 将选中论文的总结状态重置为 pending - paper_ids = ( - db.execute(select(Paper.id).where(Paper.arxiv_id.in_(body.arxiv_ids))) - .scalars() - .all() - ) - - if paper_ids: - # 删除旧的 status 记录让其重新进入 pipeline - db.execute( - SummaryStatus.__table__.delete().where( - SummaryStatus.paper_id.in_(paper_ids) - ) - ) - db.commit() + count = admin_svc.reset_summaries_pending(db, body.arxiv_ids) return { "status": "success", - "message": f"已将 {len(paper_ids)} 篇论文重置为待总结", - "count": len(paper_ids), + "message": f"已将 {count} 篇论文重置为待总结", + "count": count, } diff --git a/app/services/admin.py b/app/services/admin.py index 20d8e09..468d753 100644 --- a/app/services/admin.py +++ b/app/services/admin.py @@ -1,17 +1,30 @@ -"""管理后台服务 — 统计聚合、系统状态。""" +"""管理后台服务 — 统计聚合、系统状态、管理操作。""" from __future__ import annotations +import json from datetime import date from pathlib import Path +from typing import Callable from sqlalchemy import func, select, text from sqlalchemy.orm import Session from app.config import settings -from app.models import CrawlLog, Paper, PaperTag, SummaryState, SummaryStatus, TaskLock +from app.models import ( + CrawlLog, + DataDeleteJob, + Job, + JobEvent, + Paper, + PaperTag, + SummaryState, + SummaryStatus, + TaskLock, +) +from app.services.derived import delete_paper_indexes from app.services.scheduler import get_scheduler -from app.utils import PAPERS_DIR, TMP_DIR +from app.utils import PAPERS_DIR, TMP_DIR, utc_now # admin_papers 排序映射 SORT_MAP = { @@ -190,3 +203,311 @@ def query_papers( statuses[paper_id_to_arxiv.get(pid, "")] = st return papers, total or 0, statuses + + +def get_scheduler_history(db: Session, limit: int = 10) -> list[CrawlLog]: + """最近的调度器运行日志。""" + return ( + db.execute( + select(CrawlLog) + .where(CrawlLog.task == "scheduler") + .order_by(CrawlLog.started_at.desc()) + .limit(limit) + ) + .scalars() + .all() + ) + + +def get_scheduler_status() -> dict: + """调度器运行状态。""" + scheduler = get_scheduler() + next_run = None + upvote_next_run = None + if scheduler: + for job in scheduler.get_jobs(): + if job.id == "daily_pipeline": + next_run = job.next_run_time + elif job.id == "upvote_refresh": + upvote_next_run = job.next_run_time + return { + "enabled": scheduler is not None, + "schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}", + "timezone": settings.APP_TIMEZONE, + "next_run": next_run.isoformat() if next_run else None, + "upvote_next_run": upvote_next_run.isoformat() if upvote_next_run else None, + "upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS, + } + + +def run_cleanup_now(db: Session, cleanup_func: Callable[[], dict]) -> dict: + """同步执行临时目录清理,并写入 CrawlLog。""" + log_entry = CrawlLog(task="cleanup", status="running", started_at=utc_now()) + db.add(log_entry) + db.commit() + + try: + result = cleanup_func() + log_entry.status = "success" + log_entry.completed_at = utc_now() + log_entry.details_json = json.dumps( + { + "scanned": result.get("scanned", 0), + "removed": result.get("removed", 0), + }, + ensure_ascii=False, + ) + if result.get("errors"): + log_entry.error = "; ".join(result["errors"])[:2000] + db.commit() + return result + except Exception as exc: + log_entry.status = "failed" + log_entry.error = str(exc)[:2000] + log_entry.completed_at = utc_now() + db.commit() + raise + + +def get_job_detail(db: Session, job_id: int) -> dict | None: + """后台任务详情和阶段事件,返回可 JSON 序列化 dict。""" + job = db.get(Job, job_id) + if not job: + return None + events = ( + db.execute( + select(JobEvent) + .where(JobEvent.job_id == job_id) + .order_by(JobEvent.created_at.asc()) + ) + .scalars() + .all() + ) + return { + "id": job.id, + "type": job.type, + "status": job.status, + "owner": job.owner, + "payload": json.loads(job.payload_json or "{}"), + "result": json.loads(job.result_json or "{}") if job.result_json else None, + "error": job.error, + "created_at": job.created_at.isoformat(), + "started_at": job.started_at.isoformat() if job.started_at else None, + "completed_at": job.completed_at.isoformat() if job.completed_at else None, + "events": [ + { + "stage": event.stage, + "status": event.status, + "message": event.message, + "payload": json.loads(event.payload_json or "{}") + if event.payload_json + else None, + "created_at": event.created_at.isoformat(), + } + for event in events + ], + } + + +def get_logs_context(db: Session, *, page: int, per_page: int) -> dict: + """管理日志页上下文。""" + crawl_logs = ( + db.execute( + select(CrawlLog) + .order_by(CrawlLog.started_at.desc()) + .limit(per_page) + .offset((page - 1) * per_page) + ) + .scalars() + .all() + ) + delete_jobs = ( + db.execute( + select(DataDeleteJob) + .order_by(DataDeleteJob.started_at.desc()) + .limit(per_page) + .offset((page - 1) * per_page) + ) + .scalars() + .all() + ) + + summary_total = db.scalar(select(func.count(Paper.id))) or 0 + summary_done = ( + db.scalar( + select(func.count(SummaryStatus.id)).where( + SummaryStatus.status == SummaryState.DONE + ) + ) + or 0 + ) + summary_pending = ( + db.scalar( + select(func.count(SummaryStatus.id)).where( + SummaryStatus.status.in_( + [SummaryState.PENDING, SummaryState.PROCESSING] + ) + ) + ) + or 0 + ) + summary_failed = ( + db.scalar( + select(func.count(SummaryStatus.id)).where( + SummaryStatus.status.in_( + [SummaryState.FAILED, SummaryState.PERMANENT_FAILURE] + ) + ) + ) + or 0 + ) + return { + "crawl_logs": crawl_logs, + "delete_jobs": delete_jobs, + "page": page, + "per_page": per_page, + "summary_total": summary_total, + "summary_done": summary_done, + "summary_pending": summary_pending, + "summary_failed": summary_failed, + } + + +def query_summary_statuses( + db: Session, + *, + status: str, + page: int, + per_page: int, +) -> tuple[list[tuple[Paper, SummaryStatus | None]], int]: + """总结状态列表查询。""" + query = ( + select(Paper, SummaryStatus) + .outerjoin(SummaryStatus, SummaryStatus.paper_id == Paper.id) + .order_by(Paper.paper_date.desc()) + ) + if status != "all": + if status == "none": + query = query.where(SummaryStatus.paper_id == None) # noqa: E711 + else: + query = query.where(SummaryStatus.status == status) + + total = db.scalar(select(func.count()).select_from(query.subquery())) or 0 + results = db.execute(query.offset((page - 1) * per_page).limit(per_page)).all() + return results, total + + +def serialize_summary_statuses( + results: list[tuple[Paper, SummaryStatus | None]], + *, + total: int, + page: int, + per_page: int, +) -> dict: + """总结状态列表 JSON 响应。""" + items = [] + for paper, ss in results: + items.append( + { + "arxiv_id": paper.arxiv_id, + "title": paper.title_zh or paper.title_en, + "paper_date": str(paper.paper_date), + "summary_status": ss.status if ss else "none", + "retry_count": ss.retry_count if ss else 0, + "error_type": ss.error_type if ss else None, + "error": ss.error if ss else None, + } + ) + return {"items": items, "total": total, "page": page, "per_page": per_page} + + +def retry_failed_summaries(db: Session) -> int: + """将失败/永久失败的总结任务重置为 pending。""" + failed_ids = ( + db.execute( + select(Paper.arxiv_id) + .join(SummaryStatus, SummaryStatus.paper_id == Paper.id) + .where( + SummaryStatus.status.in_( + [SummaryState.FAILED, SummaryState.PERMANENT_FAILURE] + ) + ) + ) + .scalars() + .all() + ) + if not failed_ids: + return 0 + + db.execute( + SummaryStatus.__table__.update() + .where( + SummaryStatus.status.in_( + [SummaryState.FAILED, SummaryState.PERMANENT_FAILURE] + ) + ) + .values(status=SummaryState.PENDING, error=None, error_type=None) + ) + db.commit() + return len(failed_ids) + + +def delete_paper_by_arxiv(db: Session, arxiv_id: str) -> bool: + """删除单篇论文和派生索引。""" + paper = db.scalar(select(Paper).where(Paper.arxiv_id == arxiv_id)) + if not paper: + return False + + paper_id = paper.id + db.delete(paper) + db.commit() + delete_paper_indexes(db, paper_id=paper_id, arxiv_id=arxiv_id) + db.commit() + return True + + +def delete_papers_by_arxiv_ids(db: Session, arxiv_ids: list[str]) -> int: + """批量删除论文和派生索引。""" + papers = ( + db.execute(select(Paper).where(Paper.arxiv_id.in_(arxiv_ids))).scalars().all() + ) + deleted = [(paper.id, paper.arxiv_id) for paper in papers] + for paper in papers: + db.delete(paper) + db.commit() + + for paper_id, arxiv_id in deleted: + delete_paper_indexes(db, paper_id=paper_id, arxiv_id=arxiv_id) + db.commit() + return len(deleted) + + +def reset_summaries_pending(db: Session, arxiv_ids: list[str]) -> int: + """将指定论文的总结状态重置为 pending,没有状态则创建。""" + paper_ids = ( + db.execute(select(Paper.id).where(Paper.arxiv_id.in_(arxiv_ids))) + .scalars() + .all() + ) + if not paper_ids: + return 0 + + existing_statuses = ( + db.execute(select(SummaryStatus).where(SummaryStatus.paper_id.in_(paper_ids))) + .scalars() + .all() + ) + existing_ids = {status.paper_id for status in existing_statuses} + for status in existing_statuses: + status.status = SummaryState.PENDING + status.quality = None + status.error = None + status.error_type = None + status.raw_output_saved = False + status.started_at = None + status.completed_at = None + for paper_id in paper_ids: + if paper_id not in existing_ids: + db.add(SummaryStatus(paper_id=paper_id, status=SummaryState.PENDING)) + db.commit() + return len(paper_ids) diff --git a/app/services/crawler.py b/app/services/crawler.py index 5942373..ca8f23b 100644 --- a/app/services/crawler.py +++ b/app/services/crawler.py @@ -4,7 +4,7 @@ import logging from datetime import date as date_type, datetime, timezone import httpx -from sqlalchemy import select, text +from sqlalchemy import select from sqlalchemy.orm import Session from app.config import settings @@ -16,6 +16,7 @@ from app.models import ( SummaryState, SummaryStatus, ) +from app.services.derived import reindex_paper_fts from app.utils import make_http_client, recent_date_strs, utc_now logger = logging.getLogger(__name__) @@ -143,21 +144,7 @@ def upsert_papers(db: Session, papers_raw: list[dict], paper_date: str) -> list[ db.add(SummaryStatus(paper_id=paper.id, status=SummaryState.PENDING)) - authors_text = ", ".join(meta["authors"]) - tags_text = ", ".join(meta["tags"]) - db.execute( - text( - "INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) " - "VALUES (:id, :title, :abstract, :authors, :tags)" - ), - { - "id": paper.id, - "title": meta["title_en"], - "abstract": meta["abstract"] or "", - "authors": authors_text, - "tags": tags_text, - }, - ) + reindex_paper_fts(db, paper) new_papers.append(paper) logger.debug("Inserted new paper: %s", arxiv_id) diff --git a/app/services/derived.py b/app/services/derived.py new file mode 100644 index 0000000..1eece19 --- /dev/null +++ b/app/services/derived.py @@ -0,0 +1,140 @@ +"""派生数据维护 — FTS5 / ChromaDB 等可重建索引。""" + +from __future__ import annotations + +import logging + +from sqlalchemy import select, text +from sqlalchemy.orm import Session + +from app.models import Paper + +logger = logging.getLogger(__name__) + + +def _summary_text(paper: Paper) -> str: + summary = paper.summary + if not summary: + return "" + parts = [ + summary.one_line, + summary.motivation_problem, + summary.motivation_goal, + summary.method_overview, + summary.method_key_idea, + summary.results_main_json, + ] + return " ".join(p for p in parts if p) + + +def delete_fts_paper(db: Session, paper_id: int) -> None: + """删除单篇论文的 FTS5 行。FTS5 以 papers.id 作为 rowid。""" + db.execute( + text("DELETE FROM papers_fts WHERE rowid = :paper_id"), + {"paper_id": paper_id}, + ) + + +def delete_paper_indexes(db: Session, *, paper_id: int, arxiv_id: str) -> None: + """删除单篇论文的所有派生索引。失败项记录日志但不阻断主删除。""" + try: + delete_fts_paper(db, paper_id) + except Exception: + logger.warning("Failed to clean FTS index for %s", arxiv_id, exc_info=True) + + try: + from app.services.embedder import delete_paper + + delete_paper(arxiv_id) + except Exception: + logger.warning("Failed to clean ChromaDB index for %s", arxiv_id, exc_info=True) + + +def reindex_paper_fts(db: Session, paper: Paper) -> None: + """按 DB 权威数据重建单篇论文的 FTS5 派生索引。""" + authors_text = ", ".join( + a.name for a in sorted(paper.authors, key=lambda a: a.position or 0) + ) + tags_text = ", ".join(t.tag for t in paper.tags) + + delete_fts_paper(db, paper.id) + db.execute( + text( + """ + INSERT INTO papers_fts( + rowid, title_en, title_zh, abstract, authors, tags, summary_text + ) + VALUES ( + :id, :title_en, :title_zh, :abstract, :authors, :tags, :summary_text + ) + """ + ), + { + "id": paper.id, + "title_en": paper.title_en or "", + "title_zh": paper.title_zh or "", + "abstract": paper.abstract or "", + "authors": authors_text, + "tags": tags_text, + "summary_text": _summary_text(paper), + }, + ) + + +def reindex_fts(db: Session, paper_ids: list[int] | None = None) -> dict: + """全量或局部重建 FTS5 索引。""" + query = select(Paper) + if paper_ids: + query = query.where(Paper.id.in_(paper_ids)) + papers = db.execute(query).scalars().all() + + if paper_ids is None: + db.execute(text("DELETE FROM papers_fts")) + + count = 0 + for paper in papers: + reindex_paper_fts(db, paper) + count += 1 + db.commit() + logger.info("FTS reindexed: %d papers", count) + return {"status": "success", "indexed": count} + + +def reindex_chroma(db: Session) -> dict: + """按 DB 权威数据重建 ChromaDB 语义索引。""" + from app.services.embedder import index_paper + + papers = db.execute(select(Paper).where(Paper.summary.has())).scalars().all() + indexed = 0 + errors: list[str] = [] + for paper in papers: + try: + texts_dict = { + "arxiv_id": paper.arxiv_id, + "title_zh": paper.title_zh or "", + "title_en": paper.title_en or "", + "tags": " ".join(t.tag for t in paper.tags), + "one_line": paper.summary.one_line if paper.summary else "", + "motivation_problem": ( + paper.summary.motivation_problem if paper.summary else "" + ), + "method_key_idea": ( + paper.summary.method_key_idea if paper.summary else "" + ), + "paper_date": paper.paper_date.isoformat() if paper.paper_date else "", + } + index_paper(paper.arxiv_id, texts_dict) + indexed += 1 + except Exception as exc: + errors.append(f"{paper.arxiv_id}: {exc}") + logger.warning( + "Failed to reindex ChromaDB for %s", + paper.arxiv_id, + exc_info=True, + ) + + return { + "status": "success" if not errors else "partial", + "indexed": indexed, + "errors": errors or None, + } diff --git a/app/services/jobs.py b/app/services/jobs.py new file mode 100644 index 0000000..e4ac2e9 --- /dev/null +++ b/app/services/jobs.py @@ -0,0 +1,244 @@ +"""统一后台任务系统 — 创建、运行、事件记录、失败恢复。""" + +from __future__ import annotations + +import json +import logging +from datetime import date, timedelta +from typing import Any + +from fastapi import BackgroundTasks +from sqlalchemy import or_, select +from sqlalchemy.orm import Session + +from app.config import settings +from app.database import SessionLocal +from app.models import Job, JobEvent, JobEventStatus, JobStatus, TaskLock +from app.utils import truncate_error, utc_now + +logger = logging.getLogger(__name__) + +STALE_JOB_AFTER = timedelta(hours=6) + + +def _dumps(value: Any) -> str: + return json.dumps(value, ensure_ascii=False, default=str) + + +def _loads(value: str | None) -> dict: + if not value: + return {} + try: + data = json.loads(value) + return data if isinstance(data, dict) else {} + except json.JSONDecodeError: + return {} + + +def create_job( + db: Session, + job_type: str, + *, + owner: str, + payload: dict | None = None, +) -> Job: + """创建后台任务主记录。""" + job = Job( + type=job_type, + status=JobStatus.QUEUED, + owner=owner, + payload_json=_dumps(payload or {}), + created_at=utc_now(), + ) + db.add(job) + db.commit() + db.refresh(job) + add_job_event( + db, + job, + stage="created", + status=JobEventStatus.INFO, + message=f"Job queued: {job_type}", + payload=payload or {}, + ) + return job + + +def add_job_event( + db: Session, + job: Job, + *, + stage: str, + status: str, + message: str | None = None, + payload: dict | None = None, +) -> None: + """追加一条任务阶段事件。""" + db.add( + JobEvent( + job_id=job.id, + stage=stage, + status=str(status), + message=message, + payload_json=_dumps(payload) if payload is not None else None, + created_at=utc_now(), + ) + ) + job.heartbeat_at = utc_now() + db.commit() + + +def enqueue_job(background_tasks: BackgroundTasks, job_id: int) -> None: + """把任务提交给 FastAPI BackgroundTasks。""" + background_tasks.add_task(run_job_by_id, job_id) + + +async def run_job_by_id(job_id: int) -> None: + """使用独立 DB session 运行一个已创建的 job。""" + db = SessionLocal() + try: + await run_job(db, job_id) + finally: + db.close() + + +async def run_job(db: Session, job_id: int) -> dict: + """运行 job,并把状态/result/error 写回 jobs/job_events。""" + job = db.get(Job, job_id) + if not job: + raise ValueError(f"Job not found: {job_id}") + if job.status == JobStatus.RUNNING: + raise RuntimeError(f"Job already running: {job_id}") + + payload = _loads(job.payload_json) + job.status = JobStatus.RUNNING + job.started_at = utc_now() + job.heartbeat_at = job.started_at + db.commit() + add_job_event(db, job, stage="run", status=JobEventStatus.STARTED) + + try: + result = await _dispatch_job(db, job, payload) + except Exception as exc: + logger.exception("Job failed: id=%s type=%s", job.id, job.type) + error = truncate_error(exc, limit=4000) + job.status = JobStatus.FAILED + job.error = error + job.completed_at = utc_now() + db.commit() + add_job_event(db, job, stage="run", status=JobEventStatus.FAILED, message=error) + return {"status": "failed", "error": error} + + job.status = JobStatus.SUCCESS + job.result_json = _dumps(result) + job.completed_at = utc_now() + job.error = None + db.commit() + add_job_event( + db, + job, + stage="run", + status=JobEventStatus.SUCCESS, + payload=result if isinstance(result, dict) else {"result": result}, + ) + return result if isinstance(result, dict) else {"status": "success", "result": result} + + +async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict: + from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range + from app.services.crawler import refresh_upvotes + from app.services.derived import reindex_chroma, reindex_fts + from app.services.pipeline import run_crawl, run_pipeline + from app.services.summarizer import summarize_batch, summarize_single + + if job.type == "crawl_daily": + return await run_crawl( + db, + payload["target_date"], + owner=job.owner or f"job:{job.id}", + top_n=payload.get("top_n"), + ) + if job.type == "pipeline_daily": + return await run_pipeline( + db, + payload["target_date"], + owner=job.owner or f"job:{job.id}", + ) + if job.type == "summarize_batch": + return await summarize_batch( + db, + pdf_mode=payload.get("pdf_mode", settings.SUMMARY_PDF_MODE), + ) + if job.type == "summarize_one": + return await summarize_single( + db, + payload["arxiv_id"], + force=payload.get("force", True), + pdf_mode=payload.get("pdf_mode", settings.SUMMARY_PDF_MODE), + ) + if job.type == "refresh_upvotes": + return await refresh_upvotes(db, days=payload.get("days")) + if job.type == "delete_range": + return await delete_papers_by_date_range( + db, + date.fromisoformat(payload["date_start"]), + date.fromisoformat(payload["date_end"]), + include_notes=payload.get("include_notes", True), + ) + if job.type == "cleanup_tmp": + return cleanup_tmp() + if job.type == "reindex_fts": + return reindex_fts(db) + if job.type == "reindex_chroma": + return reindex_chroma(db) + + raise ValueError(f"Unsupported job type: {job.type}") + + +def recover_stale_jobs(db: Session) -> int: + """启动时将过期 running job/lock 标记为 stale,避免永久卡住。""" + now = utc_now() + cutoff = now - STALE_JOB_AFTER + stale_jobs = ( + db.execute( + select(Job).where( + Job.status == JobStatus.RUNNING, + or_(Job.heartbeat_at == None, Job.heartbeat_at < cutoff), # noqa: E711 + ) + ) + .scalars() + .all() + ) + for job in stale_jobs: + job.status = JobStatus.STALE + job.error = "Marked stale after process restart or missed heartbeat" + job.completed_at = now + db.add( + JobEvent( + job_id=job.id, + stage="recovery", + status=JobEventStatus.FAILED, + message=job.error, + created_at=now, + ) + ) + + stale_locks = ( + db.execute( + select(TaskLock).where( + TaskLock.status == "running", + TaskLock.acquired_at < cutoff, + ) + ) + .scalars() + .all() + ) + for lock in stale_locks: + lock.status = "stale" + lock.released_at = now + + db.commit() + recovered = len(stale_jobs) + len(stale_locks) + if recovered: + logger.warning("Recovered stale runtime records: %d", recovered) + return recovered diff --git a/app/services/pipeline.py b/app/services/pipeline.py index e731bf5..c7bee1b 100644 --- a/app/services/pipeline.py +++ b/app/services/pipeline.py @@ -7,6 +7,7 @@ from __future__ import annotations import logging from datetime import date as date_type +from datetime import timedelta from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -32,6 +33,8 @@ def acquire_lock(db: Session, task: str, lock_key: str, owner: str) -> TaskLock: status="running", owner=owner, acquired_at=utc_now(), + heartbeat_at=utc_now(), + expires_at=utc_now() + timedelta(hours=6), ) try: db.add(lock) @@ -42,7 +45,12 @@ def acquire_lock(db: Session, task: str, lock_key: str, owner: str) -> TaskLock: return lock -async def run_crawl(db: Session, target_date: str, owner: str = "admin_crawl") -> dict: +async def run_crawl( + db: Session, + target_date: str, + owner: str = "admin_crawl", + top_n: int | None = None, +) -> dict: """执行单次抓取(带防重入锁)。 Args: @@ -55,7 +63,7 @@ async def run_crawl(db: Session, target_date: str, owner: str = "admin_crawl") - """ lock = acquire_lock(db, "crawl", target_date, owner) try: - return await crawl_daily(db, target_date) + return await crawl_daily(db, target_date, top_n=top_n) finally: release_lock(db, lock) @@ -83,6 +91,8 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict: status="running", owner=owner, acquired_at=now, + heartbeat_at=now, + expires_at=now + timedelta(hours=6), ) try: db.add(lock) diff --git a/app/services/scheduler.py b/app/services/scheduler.py index 698056d..bae0e0f 100644 --- a/app/services/scheduler.py +++ b/app/services/scheduler.py @@ -11,8 +11,7 @@ from zoneinfo import ZoneInfo from app.config import settings from app.database import SessionLocal -from app.services.pipeline import run_pipeline -from app.services.crawler import refresh_upvotes +from app.services.jobs import create_job, run_job from app.utils import today_str logger = logging.getLogger(__name__) @@ -112,7 +111,13 @@ async def _daily_pipeline() -> None: db: Session = SessionLocal() try: - await run_pipeline(db, today, owner="daily_pipeline") + job = create_job( + db, + "pipeline_daily", + owner="daily_pipeline", + payload={"target_date": today}, + ) + await run_job(db, job.id) except RuntimeError: logger.warning("Daily pipeline already running for %s, skipping", today) except Exception: @@ -125,7 +130,8 @@ async def _upvote_refresh() -> None: """刷新最近 N 天论文的 upvotes。""" db: Session = SessionLocal() try: - result = await refresh_upvotes(db) + job = create_job(db, "refresh_upvotes", owner="upvote_refresh", payload={}) + result = await run_job(db, job.id) logger.info( "Upvote refresh completed: status=%s updated=%d", result.get("status"), diff --git a/app/services/summary_persister.py b/app/services/summary_persister.py index ef90b4e..1dd99d6 100644 --- a/app/services/summary_persister.py +++ b/app/services/summary_persister.py @@ -3,8 +3,6 @@ from __future__ import annotations import logging - -from sqlalchemy import text from sqlalchemy.orm import Session from app.models import ( @@ -13,6 +11,7 @@ from app.models import ( PaperTag, SummaryState, ) +from app.services.derived import reindex_paper_fts from app.services.pdf_downloader import paper_dir from app.services.schemas import ( SummarySchema, @@ -75,19 +74,9 @@ def _update_summary_in_db( db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai")) existing_tag_names.add(tag_name) - # 4. FTS5 更新 - summary_text = _build_fts_summary_text(schema) - db.execute( - text( - "UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text " - "WHERE rowid=:paper_id" - ), - { - "title_zh": schema.title_zh, - "summary_text": summary_text, - "paper_id": paper.id, - }, - ) + # 4. FTS5 派生索引 + db.flush() + reindex_paper_fts(db, paper) db.commit() logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality) diff --git a/tests/test_admin.py b/tests/test_admin.py index 4d16de3..999c3f3 100644 --- a/tests/test_admin.py +++ b/tests/test_admin.py @@ -3,14 +3,17 @@ from __future__ import annotations import logging -from unittest.mock import AsyncMock, patch +from unittest.mock import patch import pytest -from sqlalchemy import select +from sqlalchemy import select, text from app.config import settings from app.models import ( CrawlLog, + Job, + SummaryState, + SummaryStatus, TaskLock, ) from app.utils import utc_now @@ -64,47 +67,13 @@ class TestAdminAuth: resp = auth_client.get("/admin/logs", follow_redirects=False) assert resp.status_code == 303 - def test_correct_session_accepted(self, auth_client): - """已登录 session 应被接受(crawl 可能会失败但不是 303)。""" - with patch( - "app.routes.admin.run_crawl", new_callable=AsyncMock - ) as mock_crawl: - mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"} - resp = auth_client.post("/admin/crawl") - assert resp.status_code != 303 - - # ── summarize route auth ──────────────────────────────────────── - - def test_no_session_returns_303_for_summarize(self, client, monkeypatch): - """无 session 返回 303。""" - monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password") - resp = client.post("/admin/summarize", follow_redirects=False) - assert resp.status_code == 303 - def test_correct_session_batch_summarize(self, auth_client): - """已登录调用 batch summarize,mock 掉服务层。""" - with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock: - mock.return_value = { - "status": "success", - "done": 0, - "failed": 0, - "total": 0, - } + """已登录调用 batch summarize,应创建后台任务。""" + with patch("app.routes.admin.enqueue_job"): resp = auth_client.post("/admin/summarize") assert resp.status_code == 200 - assert resp.json()["status"] == "success" - - def test_single_paper_not_found(self, auth_client): - """单篇总结不存在的论文返回 404。""" - from app.exceptions import NotFoundError - - with patch( - "app.routes.admin.summarize_single", - new_callable=AsyncMock, - side_effect=NotFoundError("Paper not found: nonexistent.99999"), - ): - resp = auth_client.post("/admin/summarize/nonexistent.99999") - assert resp.status_code == 404 + assert resp.json()["status"] == "queued" + assert "job_id" in resp.json() # ═══════════════════════════════════════════════════════════════════════ @@ -115,29 +84,12 @@ class TestAdminAuth: class TestAdminCrawl: """POST /admin/crawl 测试。""" - def test_crawl_default_today(self, auth_client): - """不指定日期时默认抓取今天。""" - with patch( - "app.routes.admin.run_crawl", new_callable=AsyncMock - ) as mock_crawl: - mock_crawl.return_value = {"found": 5, "new": 3, "status": "success"} - resp = auth_client.post("/admin/crawl") - assert resp.status_code == 200 - data = resp.json() - assert data["status"] == "success" - mock_crawl.assert_called_once() - def test_crawl_specific_date(self, auth_client): """指定日期抓取。""" - with patch( - "app.routes.admin.run_crawl", new_callable=AsyncMock - ) as mock_crawl: - mock_crawl.return_value = {"found": 2, "new": 1, "status": "success"} + with patch("app.routes.admin.enqueue_job"): resp = auth_client.post("/admin/crawl?date=2024-01-15") assert resp.status_code == 200 - mock_crawl.assert_called_once() - call_args = mock_crawl.call_args - assert call_args[0][1] == "2024-01-15" + assert resp.json()["target_date"] == "2024-01-15" # ═══════════════════════════════════════════════════════════════════════ @@ -149,20 +101,20 @@ class TestAdminCleanup: """POST /admin/cleanup 测试。""" def test_cleanup_returns_stats(self, auth_client): - """清理应返回统计信息。""" + """同步清理排障接口应返回统计信息。""" with patch("app.routes.admin.cleanup_tmp") as mock_cleanup: mock_cleanup.return_value = {"scanned": 3, "removed": 1, "errors": []} - resp = auth_client.post("/admin/cleanup") + resp = auth_client.post("/admin/cleanup-now") assert resp.status_code == 200 data = resp.json() assert data["scanned"] == 3 assert data["removed"] == 1 def test_cleanup_writes_log(self, auth_client, db_session): - """清理应写入 crawl_logs。""" + """同步清理排障接口应写入 crawl_logs。""" with patch("app.routes.admin.cleanup_tmp") as mock_cleanup: mock_cleanup.return_value = {"scanned": 0, "removed": 0, "errors": []} - auth_client.post("/admin/cleanup") + auth_client.post("/admin/cleanup-now") logs = ( db_session.execute(select(CrawlLog).where(CrawlLog.task == "cleanup")) @@ -195,19 +147,21 @@ class TestAdminDelete: assert resp.status_code == 422 def test_delete_with_confirm(self, auth_client, db_session, sample_papers_range): - """confirm='DELETE' 时应执行删除。""" - resp = auth_client.post( - "/admin/delete", - json={ - "date_start": "2024-01-10", - "date_end": "2024-01-12", - "include_notes": True, - "confirm": "DELETE", - }, - ) + """confirm='DELETE' 时应创建后台删除 job。""" + with patch("app.routes.admin.enqueue_job"): + resp = auth_client.post( + "/admin/delete", + json={ + "date_start": "2024-01-10", + "date_end": "2024-01-12", + "include_notes": True, + "confirm": "DELETE", + }, + ) assert resp.status_code == 200 data = resp.json() - assert data["deleted"] == 3 + assert data["status"] == "queued" + assert db_session.get(Job, data["job_id"]) is not None def test_delete_invalid_date_range(self, auth_client): """date_start > date_end 应返回 400。""" @@ -221,17 +175,6 @@ class TestAdminDelete: ) assert resp.status_code == 400 - def test_delete_without_confirm_field(self, auth_client): - """缺少 confirm 字段应返回 422。""" - resp = auth_client.post( - "/admin/delete", - json={ - "date_start": "2024-01-10", - "date_end": "2024-01-12", - }, - ) - assert resp.status_code == 422 - # ═══════════════════════════════════════════════════════════════════════ # Admin Routes — Logs @@ -241,12 +184,6 @@ class TestAdminDelete: class TestAdminLogs: """GET /admin/logs 测试。""" - def test_logs_returns_page(self, auth_client): - """应返回管理日志页面。""" - resp = auth_client.get("/admin/logs") - assert resp.status_code == 200 - assert "text/html" in resp.headers.get("content-type", "") - def test_logs_requires_auth(self, client, monkeypatch): """日志页面需要鉴权。""" monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password") @@ -272,6 +209,126 @@ class TestAdminLogs: assert "crawl" in resp.text.lower() or "日志" in resp.text +class TestAdminJobs: + """后台 job 查询接口测试。""" + + def test_job_detail_returns_payload_and_events(self, auth_client, db_session): + """GET /admin/jobs/{id} 返回 job 主记录和事件。""" + with patch("app.routes.admin.enqueue_job"): + resp = auth_client.post("/admin/crawl?date=2024-01-15") + job_id = resp.json()["job_id"] + + resp = auth_client.get(f"/admin/jobs/{job_id}") + + assert resp.status_code == 200 + data = resp.json() + assert data["id"] == job_id + assert data["type"] == "crawl_daily" + assert data["payload"] == {"target_date": "2024-01-15"} + assert data["events"][0]["stage"] == "created" + + def test_job_detail_not_found(self, auth_client): + resp = auth_client.get("/admin/jobs/999999") + assert resp.status_code == 404 + + +class TestAdminSummaryStatus: + """总结状态管理接口测试。""" + + def test_summary_status_json_filters_failed( + self, auth_client, db_session, sample_paper + ): + sample_paper.summary_status.status = SummaryState.FAILED + sample_paper.summary_status.retry_count = 2 + sample_paper.summary_status.error_type = "timeout" + db_session.commit() + + resp = auth_client.get("/admin/summary-status?status=failed") + + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 1 + assert data["items"][0]["arxiv_id"] == sample_paper.arxiv_id + assert data["items"][0]["retry_count"] == 2 + + def test_retry_failed_resets_failed_statuses( + self, auth_client, db_session, sample_paper + ): + sample_paper.summary_status.status = SummaryState.PERMANENT_FAILURE + sample_paper.summary_status.error = "bad json" + sample_paper.summary_status.error_type = "json_invalid" + db_session.commit() + + resp = auth_client.post("/admin/summary-retry-failed") + + assert resp.status_code == 200 + assert resp.json()["count"] == 1 + db_session.refresh(sample_paper.summary_status) + assert sample_paper.summary_status.status == SummaryState.PENDING + assert sample_paper.summary_status.error is None + assert sample_paper.summary_status.error_type is None + + +class TestAdminPapers: + """论文管理批量操作测试。""" + + def test_single_delete_removes_paper_and_fts( + self, auth_client, db_session, sample_paper + ): + paper_id = sample_paper.id + + resp = auth_client.post(f"/admin/paper-delete/{sample_paper.arxiv_id}") + + assert resp.status_code == 200 + assert db_session.get(type(sample_paper), paper_id) is None + fts_row = db_session.execute( + text("SELECT rowid FROM papers_fts WHERE rowid = :id"), + {"id": paper_id}, + ).fetchone() + assert fts_row is None + + def test_batch_delete_removes_papers_and_fts( + self, auth_client, db_session, sample_papers_range + ): + target_ids = [p.id for p in sample_papers_range[:2]] + target_arxiv_ids = [p.arxiv_id for p in sample_papers_range[:2]] + + resp = auth_client.post( + "/admin/papers-batch-action", + json={"action": "delete", "arxiv_ids": target_arxiv_ids}, + ) + + assert resp.status_code == 200 + assert resp.json()["count"] == 2 + remaining = db_session.execute( + text( + "SELECT rowid FROM papers_fts " + "WHERE rowid IN (:id1, :id2)" + ), + {"id1": target_ids[0], "id2": target_ids[1]}, + ).fetchall() + assert remaining == [] + + def test_batch_summarize_sets_pending_status( + self, auth_client, db_session, sample_papers_range + ): + paper = sample_papers_range[0] + paper.summary_status.status = SummaryState.DONE + db_session.commit() + + resp = auth_client.post( + "/admin/papers-batch-action", + json={"action": "summarize", "arxiv_ids": [paper.arxiv_id]}, + ) + + assert resp.status_code == 200 + status = db_session.scalar( + select(SummaryStatus).where(SummaryStatus.paper_id == paper.id) + ) + assert status is not None + assert status.status == SummaryState.PENDING + + # ═══════════════════════════════════════════════════════════════════════ # Scheduler 测试 # ═══════════════════════════════════════════════════════════════════════ diff --git a/tests/test_crawler.py b/tests/test_crawler.py index d663fb8..941b213 100644 --- a/tests/test_crawler.py +++ b/tests/test_crawler.py @@ -10,6 +10,7 @@ from app.services.crawler import ( _parse_paper, crawl_daily, fetch_daily, + refresh_upvotes, upsert_papers, ) @@ -187,3 +188,62 @@ class TestCrawlDaily: assert result["status"] == "failed" assert "network error" in result["error"] + + +class TestRefreshUpvotes: + @pytest.mark.asyncio + async def test_refresh_updates_existing_without_inserting_new( + self, db_session, sample_paper + ): + sample_paper.arxiv_id = "1706.03762" + sample_paper.upvotes = 10 + db_session.commit() + + with patch( + "app.services.crawler.fetch_daily", + new_callable=AsyncMock, + return_value=[ + { + "paper": { + "id": "1706.03762", + "upvotes": 999, + "authors": [], + "tags": [], + } + }, + { + "paper": { + "id": "2010.11929", + "upvotes": 123, + "authors": [], + "tags": [], + } + }, + ], + ): + result = await refresh_upvotes(db_session, days=1) + + db_session.refresh(sample_paper) + assert result["status"] == "success" + assert result["updated"] == 1 + assert sample_paper.upvotes == 999 + assert db_session.query(type(sample_paper)).count() == 1 + + @pytest.mark.asyncio + async def test_refresh_returns_partial_when_one_day_fails(self, db_session): + async def _fetch_daily(target_date): + if target_date.endswith("01"): + raise ConnectionError("hf down") + return [] + + with ( + patch( + "app.services.crawler.recent_date_strs", + return_value=["2024-01-01", "2024-01-02"], + ), + patch("app.services.crawler.fetch_daily", side_effect=_fetch_daily), + ): + result = await refresh_upvotes(db_session, days=2) + + assert result["status"] == "partial" + assert result["errors"] == ["2024-01-01: hf down"] diff --git a/tests/test_derived.py b/tests/test_derived.py new file mode 100644 index 0000000..769124b --- /dev/null +++ b/tests/test_derived.py @@ -0,0 +1,80 @@ +"""派生索引维护测试。""" + +from __future__ import annotations + +from unittest.mock import patch + +from sqlalchemy import text + +from app.services.derived import reindex_chroma, reindex_fts + + +class TestReindexFts: + def test_reindex_fts_rebuilds_missing_rows(self, db_session, sample_paper): + db_session.execute( + text("DELETE FROM papers_fts WHERE rowid = :id"), + {"id": sample_paper.id}, + ) + db_session.commit() + + result = reindex_fts(db_session) + + row = db_session.execute( + text("SELECT title_en, authors, tags FROM papers_fts WHERE rowid = :id"), + {"id": sample_paper.id}, + ).fetchone() + assert result == {"status": "success", "indexed": 1} + assert row is not None + assert row[0] == sample_paper.title_en + assert "Alice Smith" in row[1] + assert "NLP" in row[2] + + def test_reindex_fts_accepts_subset(self, db_session, sample_papers_range): + keep_id = sample_papers_range[0].id + skip_id = sample_papers_range[1].id + db_session.execute(text("DELETE FROM papers_fts")) + db_session.commit() + + result = reindex_fts(db_session, paper_ids=[keep_id]) + + keep_row = db_session.execute( + text("SELECT rowid FROM papers_fts WHERE rowid = :id"), + {"id": keep_id}, + ).fetchone() + skip_row = db_session.execute( + text("SELECT rowid FROM papers_fts WHERE rowid = :id"), + {"id": skip_id}, + ).fetchone() + assert result["indexed"] == 1 + assert keep_row is not None + assert skip_row is None + + +class TestReindexChroma: + def test_reindex_chroma_indexes_only_summarized_papers( + self, db_session, sample_papers_with_summary + ): + with patch("app.services.embedder.index_paper", return_value=True) as mock_index: + result = reindex_chroma(db_session) + + assert result["status"] == "success" + assert result["indexed"] == 4 + assert mock_index.call_count == 4 + indexed_ids = {call.args[0] for call in mock_index.call_args_list} + assert "2401.20001" in indexed_ids + assert "2401.20005" not in indexed_ids + + def test_reindex_chroma_reports_partial_failures( + self, db_session, sample_papers_with_summary + ): + def _index_paper(arxiv_id, _texts): + if arxiv_id == "2401.20001": + raise RuntimeError("embedding failed") + return True + + with patch("app.services.embedder.index_paper", side_effect=_index_paper): + result = reindex_chroma(db_session) + + assert result["status"] == "partial" + assert result["indexed"] == 3 + assert result["errors"] == ["2401.20001: embedding failed"] diff --git a/tests/test_jobs.py b/tests/test_jobs.py new file mode 100644 index 0000000..5b3be65 --- /dev/null +++ b/tests/test_jobs.py @@ -0,0 +1,111 @@ +"""后台 Job 服务测试。""" + +from __future__ import annotations + +from datetime import timedelta +from unittest.mock import patch + +import pytest +from sqlalchemy import select + +from app.models import Job, JobEvent, JobStatus, TaskLock +from app.services.jobs import create_job, recover_stale_jobs, run_job +from app.utils import utc_now + + +class TestJobs: + def test_create_job_writes_event(self, db_session): + job = create_job( + db_session, + "cleanup_tmp", + owner="test", + payload={"reason": "unit-test"}, + ) + + assert job.id is not None + assert job.status == JobStatus.QUEUED + + events = ( + db_session.execute(select(JobEvent).where(JobEvent.job_id == job.id)) + .scalars() + .all() + ) + assert len(events) == 1 + assert events[0].stage == "created" + + @pytest.mark.asyncio + async def test_run_job_success(self, db_session): + job = create_job(db_session, "cleanup_tmp", owner="test", payload={}) + + with patch("app.services.cleaner.cleanup_tmp") as mock_cleanup: + mock_cleanup.return_value = {"scanned": 1, "removed": 1, "errors": []} + result = await run_job(db_session, job.id) + + refreshed = db_session.get(Job, job.id) + assert result["removed"] == 1 + assert refreshed.status == JobStatus.SUCCESS + assert refreshed.result_json is not None + + @pytest.mark.asyncio + async def test_run_job_failure_records_error(self, db_session): + job = create_job(db_session, "missing_job_type", owner="test", payload={}) + + result = await run_job(db_session, job.id) + + refreshed = db_session.get(Job, job.id) + assert result["status"] == "failed" + assert refreshed.status == JobStatus.FAILED + assert "Unsupported job type" in refreshed.error + + @pytest.mark.asyncio + async def test_run_job_dispatches_refresh_upvotes(self, db_session): + job = create_job( + db_session, + "refresh_upvotes", + owner="test", + payload={"days": 3}, + ) + + with patch("app.services.crawler.refresh_upvotes") as mock_refresh: + mock_refresh.return_value = {"status": "success", "updated": 2} + result = await run_job(db_session, job.id) + + mock_refresh.assert_awaited_once_with(db_session, days=3) + assert result["updated"] == 2 + + @pytest.mark.asyncio + async def test_run_job_dispatches_reindex_fts(self, db_session): + job = create_job(db_session, "reindex_fts", owner="test", payload={}) + + with patch("app.services.derived.reindex_fts") as mock_reindex: + mock_reindex.return_value = {"status": "success", "indexed": 5} + result = await run_job(db_session, job.id) + + mock_reindex.assert_called_once_with(db_session) + assert result["indexed"] == 5 + + def test_recover_stale_jobs_and_locks(self, db_session): + old = utc_now() - timedelta(hours=7) + job = Job( + type="cleanup_tmp", + status=JobStatus.RUNNING, + owner="test", + created_at=old, + started_at=old, + heartbeat_at=old, + ) + lock = TaskLock( + task="cleanup", + lock_key="tmp", + status="running", + owner="test", + acquired_at=old, + ) + db_session.add_all([job, lock]) + db_session.commit() + + recovered = recover_stale_jobs(db_session) + + assert recovered == 2 + assert db_session.get(Job, job.id).status == JobStatus.STALE + assert db_session.get(TaskLock, lock.id).status == "stale" diff --git a/tests/test_pages.py b/tests/test_pages.py index 2516048..482cfcf 100644 --- a/tests/test_pages.py +++ b/tests/test_pages.py @@ -6,9 +6,6 @@ from datetime import date from unittest.mock import patch as upatch -from app.config import settings - - # ═══════════════════════════════════════════════════════════════════════ # Detail 页 & 相似论文 # ═══════════════════════════════════════════════════════════════════════ @@ -37,29 +34,6 @@ class TestDetailPage: class TestTrendsDashboard: """趋势看板测试。""" - def test_trends_page_renders(self, client, sample_papers_with_summary): - """趋势看板页面正常渲染。""" - resp = client.get("/trends") - assert resp.status_code == 200 - assert "趋势看板" in resp.text - assert "chart" in resp.text.lower() or "Chart" in resp.text - - def test_trends_api_returns_data(self, client, sample_papers_with_summary): - """趋势 API 返回正确数据结构。""" - resp = client.get("/api/stats/trends") - assert resp.status_code == 200 - data = resp.json() - - assert "daily_counts" in data - assert "top_tags" in data - assert "upvotes_dist" in data - assert "summary_completion" in data - - assert isinstance(data["daily_counts"], list) - assert isinstance(data["top_tags"], list) - assert isinstance(data["upvotes_dist"], list) - assert isinstance(data["summary_completion"], list) - def test_trends_api_daily_counts(self, client, sample_papers_with_summary): """每日论文数量数据正确。""" # 使用测试数据的日期范围 @@ -108,12 +82,6 @@ class TestTrendsDashboard: class TestComparePage: """论文对比页测试。""" - def test_compare_page_no_ids(self, client): - """无 ID 时显示输入表单。""" - resp = client.get("/compare") - assert resp.status_code == 200 - assert "对比" in resp.text - def test_compare_page_with_ids(self, client, sample_papers_with_summary): """对比多篇论文正常渲染。""" resp = client.get("/compare?ids=2401.20001,2401.20002") @@ -124,23 +92,6 @@ class TestComparePage: assert "一句话摘要" in resp.text assert "研究问题" in resp.text - def test_compare_page_max_5(self, client, sample_papers_with_summary): - """最多 5 篇。""" - ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005" - resp = client.get(f"/compare?ids={ids}") - assert resp.status_code == 200 - - def test_compare_page_over_5_truncates(self, client, sample_papers_with_summary): - """超过 5 篇截断。""" - ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005,2401.20006" - resp = client.get(f"/compare?ids={ids}") - assert resp.status_code == 200 - - def test_compare_page_invalid_ids(self, client): - """无效 ID 时显示空结果。""" - resp = client.get("/compare?ids=nonexistent.99999") - assert resp.status_code == 200 - def test_compare_page_shows_no_summary_placeholder( self, client, sample_papers_with_summary ): @@ -149,65 +100,3 @@ class TestComparePage: resp = client.get("/compare?ids=2401.20005") assert resp.status_code == 200 assert "暂无总结" in resp.text - - -# ═══════════════════════════════════════════════════════════════════════ -# Nav Bar -# ═══════════════════════════════════════════════════════════════════════ - - -class TestNavBar: - """导航栏测试。""" - - def test_nav_includes_trends_link(self, client): - """导航栏应包含趋势链接。""" - resp = client.get("/search") - assert resp.status_code == 200 - assert "/trends" in resp.text - - def test_nav_includes_compare_implicitly(self, client): - """compare 页面可访问。""" - resp = client.get("/compare") - assert resp.status_code == 200 - - -# ═══════════════════════════════════════════════════════════════════════ -# Graceful Degradation(CHROMA_ENABLED=false) -# ═══════════════════════════════════════════════════════════════════════ - - -class TestGracefulDegradation: - """CHROMA_ENABLED=false 时优雅降级测试。""" - - def test_search_works_without_chroma( - self, client, monkeypatch, sample_papers_with_summary - ): - """CHROMA 关闭时 FTS5 搜索正常工作。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - resp = client.get("/search?q=Test") - assert resp.status_code == 200 - assert "Test Paper" in resp.text or "测试论文" in resp.text - - def test_detail_works_without_chroma( - self, client, monkeypatch, sample_papers_with_summary - ): - """CHROMA 关闭时详情页正常工作。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - resp = client.get("/paper/2401.20001") - assert resp.status_code == 200 - - def test_trends_works_without_chroma( - self, client, monkeypatch, sample_papers_with_summary - ): - """CHROMA 关闭时趋势看板正常工作。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - resp = client.get("/trends") - assert resp.status_code == 200 - - def test_compare_works_without_chroma( - self, client, monkeypatch, sample_papers_with_summary - ): - """CHROMA 关闭时对比页正常工作。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - resp = client.get("/compare?ids=2401.20001,2401.20002") - assert resp.status_code == 200 diff --git a/tests/test_searcher.py b/tests/test_searcher.py index 861adfa..aa2ff0a 100644 --- a/tests/test_searcher.py +++ b/tests/test_searcher.py @@ -123,38 +123,12 @@ class TestSearchSemanticMode: class TestSearchRoutes: """搜索页面和 JSON API 路由测试。""" - def test_search_page_renders(self, client): - """GET /search 返回 200。""" - resp = client.get("/search") - assert resp.status_code == 200 - assert "搜索" in resp.text - def test_search_page_with_query(self, client, sample_paper): """GET /search?q=Test 返回搜索结果。""" resp = client.get("/search?q=Test") assert resp.status_code == 200 assert "2401.12345" in resp.text - def test_search_page_with_tag(self, client, sample_paper): - """GET /search?tag=NLP 返回标签筛选结果。""" - resp = client.get("/search?tag=NLP") - assert resp.status_code == 200 - assert "2401.12345" in resp.text - - def test_search_page_keyword_mode(self, client, sample_papers_with_summary): - """搜索页 keyword 模式。""" - resp = client.get("/search?q=Test&mode=keyword") - assert resp.status_code == 200 - assert "Test" in resp.text or "测试" in resp.text - - def test_search_page_semantic_disabled( - self, client, monkeypatch, sample_papers_with_summary - ): - """语义模式 CHROMA_ENABLED=false 时仍能工作。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - resp = client.get("/search?q=Test&mode=semantic") - assert resp.status_code == 200 - def test_search_api_json(self, client, sample_paper): """GET /api/search?q=Test 返回 JSON。""" resp = client.get("/api/search?q=Test") @@ -170,14 +144,6 @@ class TestSearchRoutes: data = resp.json() assert data["total"] == 1 - def test_search_api_with_mode(self, client, sample_papers_with_summary): - """搜索 API 支持 mode 参数。""" - resp = client.get("/api/search?q=Test&mode=keyword") - assert resp.status_code == 200 - data = resp.json() - assert "results" in data - assert "total" in data - def test_search_api_empty(self, client, sample_paper): """GET /api/search?q=nonexistent 返回空结果。""" resp = client.get("/api/search?q=nonexistent") @@ -185,13 +151,6 @@ class TestSearchRoutes: data = resp.json() assert data["total"] == 0 - def test_search_api_sort_by_date(self, client, sample_paper): - """GET /api/search?q=Test&sort=date 按日期排序。""" - resp = client.get("/api/search?q=Test&sort=date") - assert resp.status_code == 200 - data = resp.json() - assert data["total"] >= 1 - # ═══════════════════════════════════════════════════════════════════════ # Similar Paper API 测试 @@ -211,21 +170,6 @@ class TestSimilarAPI: data = resp.json() assert data["results"] == [] - def test_similar_api_paper_not_found(self, client, monkeypatch): - """不存在的论文返回空。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - resp = client.get("/api/similar/nonexistent.99999") - assert resp.status_code == 200 - assert resp.json()["results"] == [] - - def test_similar_api_with_top_k( - self, client, monkeypatch, sample_papers_with_summary - ): - """top_k 参数控制返回数量。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - resp = client.get("/api/similar/2401.20001?top_k=3") - assert resp.status_code == 200 - # ═══════════════════════════════════════════════════════════════════════ # 阅读列表路由测试 @@ -235,12 +179,6 @@ class TestSimilarAPI: class TestReadingListRoute: """阅读列表页面测试。""" - def test_reading_list_empty(self, client): - """无收藏时显示空状态。""" - resp = client.get("/reading-list") - assert resp.status_code == 200 - assert "阅读列表" in resp.text - def test_reading_list_with_bookmark(self, client, sample_paper): """有收藏时显示论文。""" # 先收藏 @@ -302,13 +240,6 @@ class TestRssFeed: assert "" in resp.text assert "2401.12345" in resp.text - def test_rss_has_paper_item(self, client, sample_paper): - """RSS 包含论文条目。""" - resp = client.get("/rss.xml") - assert "" in resp.text - assert "" in resp.text - assert "/paper/2401.12345" in resp.text - def test_rss_with_tag_filter(self, client, sample_paper): """GET /rss.xml?tag=NLP 按标签筛选。""" resp = client.get("/rss.xml?tag=NLP")