refactor: extract admin business logic to services, introduce job queue, add derived index helpers
- Move DB operations from routes/admin.py to services/admin.py (get_logs_context, query_summary_statuses, retry_failed, delete/reset operations) - Add services/jobs.py with Job/JobEvent-based async job queue (create_job, run_job, enqueue_job) - Add services/derived.py with FTS5 reindex and paper index deletion helpers - Refactor scheduler to use job queue instead of direct pipeline calls - Add heartbeat_at/expires_at to TaskLock for lock health tracking - Remove DESIGN_REVIEW.md - Update tests: remove redundant integration tests, add unit tests for new services
This commit is contained in:
+1
-1
@@ -42,7 +42,7 @@ DATABASE_URL=sqlite:///data/db/papers.db
|
||||
# ─── 语义搜索 ─────────────────────────────
|
||||
CHROMA_ENABLED=false
|
||||
CHROMA_DIR=data/chroma
|
||||
EMBED_API_BASE=https://api.siliconflow.cn/v1/embeddings
|
||||
EMBED_API_BASE=https://api.siliconflow.cn
|
||||
EMBED_API_KEY=your_api_key_here
|
||||
EMBED_MODEL=Qwen/Qwen3-Embedding-4B
|
||||
EMBED_DIMENSIONS=2560
|
||||
|
||||
@@ -1,468 +0,0 @@
|
||||
# 项目设计审查与优化建议
|
||||
|
||||
本文档汇总对当前项目的系统设计、流程设计和代码结构的审查结论。重点不在局部代码风格,而在后续稳定运行、失败恢复、数据一致性、可维护性和扩展性。
|
||||
|
||||
## 总体评价
|
||||
|
||||
项目当前结构整体清晰:FastAPI 路由层较薄,主要业务放在 `app/services/`;抓取、总结、搜索、个人数据、管理后台等能力已有模块化拆分;SQLite + FTS5 对本地个人论文导览站是合理选择。
|
||||
|
||||
主要风险集中在以下几类:
|
||||
|
||||
- 长任务生命周期没有统一抽象。
|
||||
- Pipeline 不是显式状态机,阶段结果难以恢复和追踪。
|
||||
- 数据库、文件系统、FTS5、ChromaDB 之间存在多处手工同步。
|
||||
- Web、CLI、Scheduler 三个入口的业务策略存在分叉风险。
|
||||
- 失败恢复和可观测性不足。
|
||||
- 部分同步阻塞任务混入 async Web 流程。
|
||||
|
||||
## 高优先级问题
|
||||
|
||||
### 1. 缺少统一的任务系统
|
||||
|
||||
当前有三套入口都能触发重任务:
|
||||
|
||||
- Web 管理后台:`/admin/trigger-pipeline`、`/admin/summarize`
|
||||
- APScheduler:每日自动 pipeline
|
||||
- CLI:`app.cli crawl/summarize`
|
||||
|
||||
它们共享部分 service,但没有统一的“任务创建、排队、运行、重试、取消、状态查询、失败恢复”模型。
|
||||
|
||||
影响:
|
||||
|
||||
- Web 请求会长时间挂住。
|
||||
- CLI、Web、Scheduler 的行为可能逐渐分叉。
|
||||
- 任务卡死后只能看到粗粒度的 running lock。
|
||||
- 前端难以展示阶段进度。
|
||||
- 后续增加重跑、取消、恢复、并发限制会越来越复杂。
|
||||
|
||||
建议:
|
||||
|
||||
引入统一 Job 模型:
|
||||
|
||||
```text
|
||||
入口层 -> 创建 Job -> Worker 执行 Job -> JobEvent 记录进度 -> 页面/API 查询状态
|
||||
```
|
||||
|
||||
建议任务类型:
|
||||
|
||||
- `crawl_daily`
|
||||
- `summarize_one`
|
||||
- `summarize_batch`
|
||||
- `pipeline_daily`
|
||||
- `refresh_upvotes`
|
||||
- `delete_range`
|
||||
- `reindex_fts`
|
||||
- `reindex_chroma`
|
||||
- `extract_images`
|
||||
|
||||
建议表结构方向:
|
||||
|
||||
- `jobs`: 任务主记录,包含 `id/type/status/owner/created_at/started_at/completed_at/error`
|
||||
- `job_events`: 阶段事件,包含 `job_id/stage/status/message/payload_json/created_at`
|
||||
- `paper_processing_runs`: 单篇论文处理记录,适合跟踪总结、图片提取、索引状态
|
||||
|
||||
### 2. Pipeline 不是显式状态机
|
||||
|
||||
当前 pipeline 基本是线性函数:
|
||||
|
||||
```text
|
||||
crawl -> summarize -> cleanup
|
||||
```
|
||||
|
||||
阶段输出没有被建模成可查询、可重放、可补偿的状态。
|
||||
|
||||
影响:
|
||||
|
||||
- crawl 成功但 summarize 部分失败时,整体状态表达不够精确。
|
||||
- cleanup 失败是否影响 pipeline 成败语义不清。
|
||||
- “今天无数据回退昨天”是隐式逻辑,不是可配置策略。
|
||||
- 图片提取、ChromaDB 索引失败被吞掉,用户无法知道哪些增强能力缺失。
|
||||
|
||||
建议:
|
||||
|
||||
将 pipeline 改为显式阶段状态:
|
||||
|
||||
```text
|
||||
created
|
||||
-> crawling
|
||||
-> crawled
|
||||
-> summarizing
|
||||
-> summarized | summarized_partial
|
||||
-> postprocessing
|
||||
-> completed | failed | cancelled
|
||||
```
|
||||
|
||||
每个阶段记录:
|
||||
|
||||
- 输入参数
|
||||
- 输出统计
|
||||
- 错误类型
|
||||
- 错误摘要
|
||||
- 开始/结束时间
|
||||
- 是否可重试
|
||||
|
||||
这样可以支持只重跑失败阶段,例如只重跑 ChromaDB 索引、只重跑图片提取、只重跑失败论文总结。
|
||||
|
||||
### 3. 数据生命周期没有清晰分层
|
||||
|
||||
项目中至少有四类数据:
|
||||
|
||||
- 主数据:`papers`、`authors`、`tags`、`summaries`、用户笔记/收藏/阅读状态
|
||||
- 派生索引:FTS5、ChromaDB
|
||||
- 长期资产:`data/papers/{arxiv_id}` 下的 summary、raw output、images
|
||||
- 临时资产:`data/tmp/{arxiv_id}` 下的 PDF、源码、中间文件
|
||||
|
||||
现在这些数据由不同流程手动维护。删除、总结、重建索引、清理临时目录都分散在不同模块中。
|
||||
|
||||
影响:
|
||||
|
||||
- DB 与文件系统可能不一致。
|
||||
- DB 与 FTS5/ChromaDB 可能不一致。
|
||||
- 删除流程中一部分资源删除成功、一部分失败时难以恢复。
|
||||
- 很难判断某篇论文是否“完整可用”。
|
||||
|
||||
建议:
|
||||
|
||||
明确数据权威源:
|
||||
|
||||
- DB 是权威源。
|
||||
- FTS5、ChromaDB、summary.json、raw_output.txt、images 都应视为可重建派生物。
|
||||
- 删除主数据优先保证 DB 状态正确,派生物清理可以异步补偿。
|
||||
- 提供派生数据重建任务。
|
||||
|
||||
建议增加命令或任务:
|
||||
|
||||
```text
|
||||
rebuild-derived --fts
|
||||
rebuild-derived --chroma
|
||||
rebuild-derived --files
|
||||
rebuild-derived --images
|
||||
```
|
||||
|
||||
### 4. 数据库与文件系统存在双写一致性风险
|
||||
|
||||
总结持久化流程会写文件、写 DB、更新状态、提取图片、写 ChromaDB。任一步失败都可能留下不一致状态。
|
||||
|
||||
典型风险:
|
||||
|
||||
- 文件已写入,但 DB 未提交。
|
||||
- DB 标记 done,但 summary.json 缺失。
|
||||
- DB done,但图片未提取或 ChromaDB 未索引。
|
||||
- 重新总结时旧图片或旧 figures 残留。
|
||||
|
||||
建议:
|
||||
|
||||
- 将结构化总结以 DB 为准,JSON 文件作为导出/缓存。
|
||||
- 文件写入使用临时文件加原子 rename。
|
||||
- DB 中记录派生物状态,例如:
|
||||
- `summary_persisted_at`
|
||||
- `fts_indexed_at`
|
||||
- `chroma_indexed_at`
|
||||
- `images_extracted_at`
|
||||
- `derived_error`
|
||||
- 提供一致性检查任务,扫描 DB 与文件/索引差异并修复。
|
||||
|
||||
### 5. 失败恢复设计偏弱
|
||||
|
||||
当前失败主要依赖:
|
||||
|
||||
- `summary_status.status`
|
||||
- `retry_count`
|
||||
- `task_locks`
|
||||
- `crawl_logs`
|
||||
|
||||
但对以下情况支持不足:
|
||||
|
||||
- 进程崩溃后 `processing` 永久残留。
|
||||
- `task_locks` 中 running lock 永久残留。
|
||||
- 任务执行到一半,部分文件或索引已经写入。
|
||||
- 某些后处理失败但主任务显示成功。
|
||||
|
||||
建议:
|
||||
|
||||
- `task_locks` 增加 `heartbeat_at` 或 `expires_at`。
|
||||
- 应用启动时扫描 stale running task,标记为 `failed/stale`。
|
||||
- `summary_status` 增加更细阶段字段,例如:
|
||||
- `downloading_pdf`
|
||||
- `generating`
|
||||
- `validating`
|
||||
- `persisting`
|
||||
- `extracting_images`
|
||||
- `indexing`
|
||||
- 区分主流程失败和派生增强失败。
|
||||
|
||||
### 6. 长任务直接跑在 Web 请求里
|
||||
|
||||
管理端触发 pipeline、批量总结、刷新 upvotes、删除数据时,会在请求生命周期内直接执行重任务。
|
||||
|
||||
影响:
|
||||
|
||||
- HTTP 请求容易超时。
|
||||
- 用户关闭页面后任务状态不清晰。
|
||||
- Web 进程被重任务占用。
|
||||
- 后续很难做取消、进度展示和失败恢复。
|
||||
|
||||
建议:
|
||||
|
||||
- Web 只创建 job 并返回 `task_id`。
|
||||
- 后台 worker 执行任务。
|
||||
- 管理页面轮询 `/admin/jobs/{id}` 或使用 SSE 展示进度。
|
||||
|
||||
## 中优先级问题
|
||||
|
||||
### 7. FTS5 索引手工维护,容易漂移
|
||||
|
||||
`papers_fts` 是独立虚拟表。插入论文、更新总结、删除论文时手动同步。
|
||||
|
||||
影响:
|
||||
|
||||
- 已有论文的 title、abstract、authors、tags 更新时可能不同步。
|
||||
- AI 标签新增后,如果没有统一重建,搜索结果可能缺字段。
|
||||
- 删除或回滚失败后索引可能残留。
|
||||
|
||||
建议:
|
||||
|
||||
- 封装统一的 `reindex_paper(db, paper_id)`。
|
||||
- 所有修改论文、标签、总结的路径最后调用它。
|
||||
- 提供全量 `reindex_fts` 任务。
|
||||
- 或用 SQLite trigger 维护基础字段,但总结字段仍建议通过统一函数更新。
|
||||
|
||||
### 8. ChromaDB 索引缺少系统级一致性策略
|
||||
|
||||
ChromaDB 当前是可选增强能力,索引失败会被日志吞掉,不影响总结主流程。
|
||||
|
||||
影响:
|
||||
|
||||
- 语义搜索可用性不透明。
|
||||
- 某些论文 DB done,但未进入 ChromaDB。
|
||||
- 删除或重建时可能出现残留索引。
|
||||
|
||||
建议:
|
||||
|
||||
- 将 ChromaDB 明确视为派生索引。
|
||||
- DB 中记录 `chroma_indexed_at` 和 `chroma_error`。
|
||||
- 提供 `reindex_chroma` 任务。
|
||||
- 搜索页或管理后台展示语义索引健康状态。
|
||||
|
||||
### 9. 同步阻塞操作混入 async 主流程
|
||||
|
||||
部分 async 函数内部调用同步阻塞操作,例如 PDF 下载使用 `requests`,图片提取和 PDF 解析也都是同步重任务。
|
||||
|
||||
影响:
|
||||
|
||||
- FastAPI event loop 可能被阻塞。
|
||||
- 多篇并发总结时吞吐不可控。
|
||||
- Web 服务响应受后台任务影响。
|
||||
|
||||
建议:
|
||||
|
||||
- 最佳方案:重任务移到 worker 进程。
|
||||
- 过渡方案:对同步阻塞段使用 `asyncio.to_thread()`。
|
||||
- PDF 下载统一使用 `httpx.AsyncClient` 或明确放入同步 worker。
|
||||
|
||||
### 10. 调度策略和业务流程耦合
|
||||
|
||||
Scheduler 固定每日 pipeline 加 30 分钟 upvote refresh;pipeline 内部固定“今天无数据回退昨天”。
|
||||
|
||||
影响:
|
||||
|
||||
- 策略不易测试和调整。
|
||||
- CLI、Web、Scheduler 可能各自实现不同 fallback。
|
||||
- 后续支持日期范围、补抓、只总结新增论文会变复杂。
|
||||
|
||||
建议将策略配置化:
|
||||
|
||||
- `crawl_target_policy`: `today` / `today_then_yesterday` / `date_range`
|
||||
- `summarize_scope`: `new_only` / `pending_and_failed` / `date_range`
|
||||
- `cleanup_policy`: `after_success` / `always` / `manual`
|
||||
- `upvote_refresh_window_days`
|
||||
- `pipeline_on_partial_failure`: `continue` / `stop`
|
||||
|
||||
### 11. 内嵌 APScheduler 对部署形态敏感
|
||||
|
||||
当前调度器嵌入 Web 进程,且多 worker 时只是打印警告。
|
||||
|
||||
影响:
|
||||
|
||||
- 多 worker、多进程、reload 模式下可能重复或漏跑任务。
|
||||
- Web 进程重启会影响调度可靠性。
|
||||
- 调度状态和任务执行状态耦合。
|
||||
|
||||
建议:
|
||||
|
||||
- 本地单机可以保留内嵌调度器。
|
||||
- 长期运行建议拆为独立 scheduler 进程。
|
||||
- scheduler 只负责创建 job,不直接执行重任务。
|
||||
- Web 管理后台只展示 scheduler/job 状态。
|
||||
|
||||
### 12. 运行时迁移能力不足
|
||||
|
||||
当前 `_migrate()` 只支持补列。
|
||||
|
||||
影响:
|
||||
|
||||
- 无法可靠处理索引、约束、字段改名、数据回填。
|
||||
- 数据库结构演进不可追踪。
|
||||
- 生产或长期本地数据升级风险高。
|
||||
|
||||
建议:
|
||||
|
||||
- 引入 Alembic。
|
||||
- 将 FTS5、索引、部分约束和数据回填纳入迁移脚本。
|
||||
- 启动时不要静默做复杂迁移,改为显式执行迁移命令。
|
||||
|
||||
## 低到中优先级问题
|
||||
|
||||
### 13. 配置校验不足
|
||||
|
||||
当前配置字段主要是裸类型,缺少合法值和组合校验。
|
||||
|
||||
风险示例:
|
||||
|
||||
- `SUMMARY_BACKEND` 填错。
|
||||
- `SUMMARY_PDF_MODE` 填错。
|
||||
- `SCHEDULE_MINUTE + 30` 超过 59。
|
||||
- `APP_WORKERS > 1` 且 `SCHEDULER_ENABLED=true`。
|
||||
- `CHROMA_ENABLED=true` 但 embedding 配置缺失。
|
||||
|
||||
建议:
|
||||
|
||||
- 使用 Pydantic validator 校验枚举值和组合条件。
|
||||
- 对危险组合启动时报错,而不是只 warning。
|
||||
|
||||
### 14. 私有函数跨模块调用说明模块边界不稳定
|
||||
|
||||
例如 `summarizer.py` 调用 `_generate_with_retry`、`_persist_summary` 等下划线函数。
|
||||
|
||||
影响:
|
||||
|
||||
- 模块公开边界不清晰。
|
||||
- 后续重构容易误伤。
|
||||
- 测试和替换后端不方便。
|
||||
|
||||
建议:
|
||||
|
||||
- 为 summary 生成、持久化、后处理定义明确公开 API。
|
||||
- 下划线函数保留为模块内部实现细节。
|
||||
- 引入更明确的 service 类或函数边界,例如:
|
||||
- `SummaryGenerator.generate()`
|
||||
- `SummaryRepository.save()`
|
||||
- `DerivedAssetBuilder.extract_images()`
|
||||
|
||||
### 15. 外部依赖缺少统一 adapter
|
||||
|
||||
项目依赖多个外部系统:
|
||||
|
||||
- HuggingFace API
|
||||
- arXiv PDF/source 下载
|
||||
- `pi` CLI
|
||||
- `claude` CLI
|
||||
- Embedding API
|
||||
- ChromaDB
|
||||
- ONNX layout detector
|
||||
|
||||
建议:
|
||||
|
||||
- 为每类外部依赖定义 adapter。
|
||||
- 统一 timeout、retry、error classification、metrics。
|
||||
- 测试中替换 adapter,而不是 patch 深层函数。
|
||||
|
||||
### 16. 删除流程事务边界不理想
|
||||
|
||||
删除流程中同时删除 FTS5、ChromaDB、本地文件、临时文件、ORM 数据。
|
||||
|
||||
影响:
|
||||
|
||||
- 文件删除成功后 DB rollback,会造成数据仍在但文件丢失。
|
||||
- DB 删除成功后 ChromaDB 删除失败,会造成残留索引。
|
||||
- 单篇失败时 rollback 可能影响之前 flush 的状态,逻辑不直观。
|
||||
|
||||
建议:
|
||||
|
||||
- 删除主数据与删除派生物分两阶段。
|
||||
- DB 删除成功后创建派生清理 job。
|
||||
- 派生物清理失败可重试,不影响主数据一致性。
|
||||
|
||||
## 建议目标架构
|
||||
|
||||
如果项目定位是个人本地工具,可以采用轻量目标架构:
|
||||
|
||||
```text
|
||||
FastAPI Web
|
||||
- 页面和 API
|
||||
- 管理后台
|
||||
- 创建 job
|
||||
- 查询 job 状态
|
||||
|
||||
Scheduler
|
||||
- 按策略创建 job
|
||||
- 不直接跑重任务
|
||||
|
||||
Worker
|
||||
- crawl
|
||||
- summarize
|
||||
- PDF download
|
||||
- image extraction
|
||||
- FTS/Chroma reindex
|
||||
- cleanup/delete
|
||||
|
||||
Storage/Repository
|
||||
- DB 权威数据
|
||||
- 文件资产管理
|
||||
- 派生索引重建
|
||||
- 一致性检查
|
||||
```
|
||||
|
||||
如果暂时不想引入独立 worker,也可以先在单进程内实现 Job 表和后台任务执行器。这样至少能统一任务状态和恢复机制。
|
||||
|
||||
## 推荐实施顺序
|
||||
|
||||
### 第一阶段:先解决任务与状态
|
||||
|
||||
目标:降低长任务不可控风险。
|
||||
|
||||
- 新增 `jobs` 和 `job_events` 表。
|
||||
- Web 管理动作改为创建 job 并返回 task_id。
|
||||
- Scheduler 改为创建 job。
|
||||
- CLI 复用同一套 job runner。
|
||||
- 加 stale running job 恢复逻辑。
|
||||
|
||||
### 第二阶段:统一派生数据管理
|
||||
|
||||
目标:降低 DB、文件、FTS、Chroma 不一致风险。
|
||||
|
||||
- 明确 DB 为权威源。
|
||||
- 封装 `reindex_paper()`。
|
||||
- 增加 `reindex_fts`、`reindex_chroma` 任务。
|
||||
- 增加派生数据状态字段或健康检查。
|
||||
- 删除流程改为 DB 主删除 + 派生清理 job。
|
||||
|
||||
### 第三阶段:改造 pipeline 为状态机
|
||||
|
||||
目标:让任务可恢复、可补偿、可观测。
|
||||
|
||||
- 拆分 pipeline stages。
|
||||
- 每个 stage 记录输入/输出/错误/耗时。
|
||||
- 支持从失败阶段重跑。
|
||||
- 将 fallback 策略配置化。
|
||||
|
||||
### 第四阶段:提升运行可靠性
|
||||
|
||||
目标:长期运行更稳。
|
||||
|
||||
- 引入 Alembic。
|
||||
- 配置加 validator。
|
||||
- 同步阻塞操作移出 Web event loop。
|
||||
- 外部依赖 adapter 化。
|
||||
- 管理后台展示任务、派生索引、失败类型、耗时统计。
|
||||
|
||||
## 最小可行改造方案
|
||||
|
||||
如果只做最小但收益最大的改动,建议优先做这三项:
|
||||
|
||||
1. 增加统一 Job 表和 job runner。
|
||||
2. DB 作为权威源,FTS/Chroma/文件全部按派生物处理。
|
||||
3. 增加 stale task 恢复和派生数据重建命令。
|
||||
|
||||
这三项能显著降低后续复杂度,也不会强迫项目马上拆成多个服务。
|
||||
|
||||
+66
-6
@@ -26,7 +26,7 @@ def crawl(
|
||||
from app.database import SessionLocal, engine
|
||||
from app.database import init_db as _init
|
||||
from app.models import Paper
|
||||
from app.services.crawler import crawl_daily
|
||||
from app.services.jobs import create_job, run_job
|
||||
from app.utils import today_str, yesterday_str
|
||||
from sqlalchemy import func, select
|
||||
|
||||
@@ -55,7 +55,13 @@ def crawl(
|
||||
return
|
||||
|
||||
typer.echo(f"📡 开始抓取 {target} ...")
|
||||
result = asyncio.run(crawl_daily(db, target, top_n))
|
||||
job = create_job(
|
||||
db,
|
||||
"crawl_daily",
|
||||
owner="cli_crawl",
|
||||
payload={"target_date": target, "top_n": top_n},
|
||||
)
|
||||
result = asyncio.run(run_job(db, job.id))
|
||||
|
||||
# 未指定日期且今天失败或无数据时,自动回退到昨天
|
||||
need_fallback = not date_str and (
|
||||
@@ -76,7 +82,13 @@ def crawl(
|
||||
else:
|
||||
typer.echo(f"🔄 {target} 无数据,尝试 {fallback} ...")
|
||||
target = fallback
|
||||
result = asyncio.run(crawl_daily(db, target, top_n))
|
||||
job = create_job(
|
||||
db,
|
||||
"crawl_daily",
|
||||
owner="cli_crawl",
|
||||
payload={"target_date": target, "top_n": top_n},
|
||||
)
|
||||
result = asyncio.run(run_job(db, job.id))
|
||||
|
||||
if result["status"] == "success":
|
||||
typer.echo(
|
||||
@@ -110,7 +122,7 @@ def summarize(
|
||||
from app.config import settings
|
||||
from app.database import SessionLocal, engine
|
||||
from app.database import init_db as _init
|
||||
from app.services.summarizer import summarize_batch, summarize_single
|
||||
from app.services.jobs import create_job, run_job
|
||||
|
||||
import os
|
||||
|
||||
@@ -142,11 +154,25 @@ def summarize(
|
||||
try:
|
||||
if arxiv_id:
|
||||
typer.echo(f"🤖 开始总结 {arxiv_id} (mode={pdf_mode}) ...")
|
||||
result = asyncio.run(summarize_single(db, arxiv_id, pdf_mode=pdf_mode))
|
||||
job = create_job(
|
||||
db,
|
||||
"summarize_one",
|
||||
owner="cli_summarize",
|
||||
payload={"arxiv_id": arxiv_id, "pdf_mode": pdf_mode, "force": False},
|
||||
)
|
||||
else:
|
||||
typer.echo(f"🤖 开始批量总结 pending 论文 (mode={pdf_mode}) ...")
|
||||
result = asyncio.run(summarize_batch(db, pdf_mode=pdf_mode))
|
||||
job = create_job(
|
||||
db,
|
||||
"summarize_batch",
|
||||
owner="cli_summarize",
|
||||
payload={"pdf_mode": pdf_mode},
|
||||
)
|
||||
|
||||
result = asyncio.run(run_job(db, job.id))
|
||||
if result.get("status") == "failed":
|
||||
typer.echo(f"❌ 总结失败:{result.get('error')}", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
typer.echo(f"✅ 总结完成:{result}")
|
||||
except NotFoundError as exc:
|
||||
typer.echo(f"❌ {exc.message}", err=True)
|
||||
@@ -172,5 +198,39 @@ def init_db():
|
||||
typer.echo(f"✅ 数据库已初始化:{settings.db_path}")
|
||||
|
||||
|
||||
@cli_app.command("rebuild-derived")
|
||||
def rebuild_derived(
|
||||
fts: bool = typer.Option(False, "--fts", help="重建 FTS5 全文索引"),
|
||||
chroma: bool = typer.Option(False, "--chroma", help="重建 ChromaDB 语义索引"),
|
||||
):
|
||||
"""重建可派生数据索引。"""
|
||||
from app.config import settings
|
||||
from app.database import SessionLocal, engine
|
||||
from app.database import init_db as _init
|
||||
from app.services.jobs import create_job, run_job
|
||||
|
||||
import os
|
||||
|
||||
if not fts and not chroma:
|
||||
fts = True
|
||||
|
||||
os.makedirs(settings.db_path.parent, exist_ok=True)
|
||||
_init(engine)
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for job_type in [
|
||||
*(["reindex_fts"] if fts else []),
|
||||
*(["reindex_chroma"] if chroma else []),
|
||||
]:
|
||||
job = create_job(db, job_type, owner="cli_rebuild_derived", payload={})
|
||||
result = asyncio.run(run_job(db, job.id))
|
||||
typer.echo(f"{job_type}: {result}")
|
||||
if result.get("status") == "failed":
|
||||
raise typer.Exit(code=1)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_app()
|
||||
|
||||
@@ -76,10 +76,26 @@ def _migrate(engine) -> None:
|
||||
"crawl_logs": [
|
||||
("details_json", "TEXT"),
|
||||
],
|
||||
"task_locks": [
|
||||
("heartbeat_at", "DATETIME"),
|
||||
("expires_at", "DATETIME"),
|
||||
],
|
||||
"jobs": [
|
||||
("heartbeat_at", "DATETIME"),
|
||||
],
|
||||
}
|
||||
|
||||
with engine.connect() as conn:
|
||||
for table, columns in _MIGRATIONS.items():
|
||||
table_exists = conn.execute(
|
||||
text(
|
||||
"SELECT name FROM sqlite_master "
|
||||
"WHERE type IN ('table', 'virtual table') AND name = :name"
|
||||
),
|
||||
{"name": table},
|
||||
).fetchone()
|
||||
if not table_exists:
|
||||
continue
|
||||
# 获取已有列名
|
||||
existing = {
|
||||
row[1] for row in conn.execute(text(f"PRAGMA table_info({table})"))
|
||||
|
||||
@@ -32,7 +32,14 @@ async def lifespan(app: FastAPI):
|
||||
# ── startup ──
|
||||
from app.services.scheduler import start_scheduler
|
||||
from app.services.embedder import init_chroma
|
||||
from app.services.jobs import recover_stale_jobs
|
||||
from app.database import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
recover_stale_jobs(db)
|
||||
finally:
|
||||
db.close()
|
||||
start_scheduler()
|
||||
init_chroma()
|
||||
|
||||
|
||||
@@ -32,6 +32,26 @@ class SummaryState(StrEnum):
|
||||
PERMANENT_FAILURE = "permanent_failure"
|
||||
|
||||
|
||||
class JobStatus(StrEnum):
|
||||
"""后台任务状态枚举 — 对应 jobs.status 列。"""
|
||||
|
||||
QUEUED = "queued"
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
STALE = "stale"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class JobEventStatus(StrEnum):
|
||||
"""任务阶段事件状态枚举 — 对应 job_events.status 列。"""
|
||||
|
||||
STARTED = "started"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
INFO = "info"
|
||||
|
||||
|
||||
# ── papers ──────────────────────────────────────────────────────────────
|
||||
class Paper(Base):
|
||||
__tablename__ = "papers"
|
||||
@@ -194,9 +214,48 @@ class TaskLock(Base):
|
||||
status = Column(String, nullable=False)
|
||||
owner = Column(String)
|
||||
acquired_at = Column(DateTime, nullable=False)
|
||||
heartbeat_at = Column(DateTime)
|
||||
expires_at = Column(DateTime)
|
||||
released_at = Column(DateTime)
|
||||
|
||||
|
||||
# ── jobs / job_events ──────────────────────────────────────────────────
|
||||
class Job(Base):
|
||||
__tablename__ = "jobs"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
type = Column(String, nullable=False, index=True)
|
||||
status = Column(String, nullable=False, default=JobStatus.QUEUED, index=True)
|
||||
owner = Column(String)
|
||||
payload_json = Column(Text)
|
||||
result_json = Column(Text)
|
||||
error = Column(Text)
|
||||
created_at = Column(DateTime, nullable=False)
|
||||
started_at = Column(DateTime)
|
||||
heartbeat_at = Column(DateTime)
|
||||
completed_at = Column(DateTime)
|
||||
|
||||
events = relationship(
|
||||
"JobEvent", back_populates="job", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class JobEvent(Base):
|
||||
__tablename__ = "job_events"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
job_id = Column(
|
||||
Integer, ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
stage = Column(String, nullable=False)
|
||||
status = Column(String, nullable=False)
|
||||
message = Column(Text)
|
||||
payload_json = Column(Text)
|
||||
created_at = Column(DateTime, nullable=False)
|
||||
|
||||
job = relationship("Job", back_populates="events")
|
||||
|
||||
|
||||
# ── user data ──────────────────────────────────────────────────────────
|
||||
class UserBookmark(Base):
|
||||
__tablename__ = "user_bookmarks"
|
||||
|
||||
+106
-277
@@ -4,36 +4,20 @@ from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Form, HTTPException, Query, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel, field_validator
|
||||
from sqlalchemy import bindparam, func, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.models import (
|
||||
CrawlLog,
|
||||
DataDeleteJob,
|
||||
Paper,
|
||||
PaperTag,
|
||||
SummaryState,
|
||||
SummaryStatus,
|
||||
)
|
||||
from app.services import admin as admin_svc
|
||||
from app.services.admin import get_admin_stats
|
||||
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
|
||||
from app.services.crawler import refresh_upvotes
|
||||
from app.services.pipeline import run_crawl, run_pipeline
|
||||
from app.services.scheduler import get_scheduler
|
||||
from app.services.summarizer import summarize_batch, summarize_single
|
||||
from app.utils import templates, today_str, utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from app.services.cleaner import cleanup_tmp
|
||||
from app.services.jobs import create_job, enqueue_job
|
||||
from app.utils import templates, today_str
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
@@ -103,18 +87,7 @@ async def admin_dashboard(
|
||||
):
|
||||
"""管理仪表盘 — 系统状态总览。"""
|
||||
stats = get_admin_stats(db)
|
||||
|
||||
# 调度器历史(最近 10 条 task=scheduler 日志)
|
||||
scheduler_history = (
|
||||
db.execute(
|
||||
select(CrawlLog)
|
||||
.where(CrawlLog.task == "scheduler")
|
||||
.order_by(CrawlLog.started_at.desc())
|
||||
.limit(10)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
scheduler_history = admin_svc.get_scheduler_history(db)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
@@ -129,53 +102,43 @@ async def admin_dashboard(
|
||||
@router.get("/scheduler-status")
|
||||
async def admin_scheduler_status(_admin: None = Depends(verify_admin)):
|
||||
"""调度器运行状态(JSON)。"""
|
||||
scheduler = get_scheduler()
|
||||
next_run = None
|
||||
upvote_next_run = None
|
||||
if scheduler:
|
||||
for job in scheduler.get_jobs():
|
||||
if job.id == "daily_pipeline":
|
||||
next_run = job.next_run_time
|
||||
elif job.id == "upvote_refresh":
|
||||
upvote_next_run = job.next_run_time
|
||||
return {
|
||||
"enabled": scheduler is not None,
|
||||
"schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}",
|
||||
"timezone": settings.APP_TIMEZONE,
|
||||
"next_run": next_run.isoformat() if next_run else None,
|
||||
"upvote_next_run": upvote_next_run.isoformat() if upvote_next_run else None,
|
||||
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
|
||||
}
|
||||
return admin_svc.get_scheduler_status()
|
||||
|
||||
|
||||
@router.post("/trigger-pipeline")
|
||||
async def admin_trigger_pipeline(
|
||||
background_tasks: BackgroundTasks,
|
||||
_admin: None = Depends(verify_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""手动触发一次完整流水线(crawl → summarize → cleanup)。"""
|
||||
today = today_str()
|
||||
try:
|
||||
result = await run_pipeline(db, today, owner="admin_trigger")
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc))
|
||||
|
||||
if result["status"] == "failed":
|
||||
raise HTTPException(status_code=500, detail=result.get("error"))
|
||||
return {"status": "success", "message": "流水线执行完成"}
|
||||
job = create_job(
|
||||
db,
|
||||
"pipeline_daily",
|
||||
owner="admin_trigger",
|
||||
payload={"target_date": today},
|
||||
)
|
||||
enqueue_job(background_tasks, job.id)
|
||||
return {"status": "queued", "job_id": job.id, "message": "流水线任务已创建"}
|
||||
|
||||
|
||||
@router.post("/refresh-upvotes")
|
||||
async def admin_refresh_upvotes(
|
||||
background_tasks: BackgroundTasks,
|
||||
_admin: None = Depends(verify_admin),
|
||||
db: Session = Depends(get_db),
|
||||
days: int | None = Query(None, description="刷新最近 N 天,默认使用配置值"),
|
||||
):
|
||||
"""手动刷新最近 N 天论文的 upvotes。"""
|
||||
result = await refresh_upvotes(db, days=days)
|
||||
if result["status"] == "failed":
|
||||
raise HTTPException(status_code=500, detail=result.get("error"))
|
||||
return result
|
||||
job = create_job(
|
||||
db,
|
||||
"refresh_upvotes",
|
||||
owner="admin_refresh",
|
||||
payload={"days": days},
|
||||
)
|
||||
enqueue_job(background_tasks, job.id)
|
||||
return {"status": "queued", "job_id": job.id}
|
||||
|
||||
|
||||
# ── 请求模型 ──────────────────────────────────────────────────────────
|
||||
@@ -200,18 +163,21 @@ class DeleteRequest(BaseModel):
|
||||
|
||||
@router.post("/crawl")
|
||||
async def admin_crawl(
|
||||
background_tasks: BackgroundTasks,
|
||||
_admin: None = Depends(verify_admin),
|
||||
db: Session = Depends(get_db),
|
||||
date: str | None = Query(None, description="YYYY-MM-DD,默认今天"),
|
||||
):
|
||||
"""手动抓取指定日期,默认今天。"""
|
||||
target_date = date or today_str()
|
||||
try:
|
||||
return await run_crawl(db, target_date, owner="admin_crawl")
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc))
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
job = create_job(
|
||||
db,
|
||||
"crawl_daily",
|
||||
owner="admin_crawl",
|
||||
payload={"target_date": target_date},
|
||||
)
|
||||
enqueue_job(background_tasks, job.id)
|
||||
return {"status": "queued", "job_id": job.id, "target_date": target_date}
|
||||
|
||||
|
||||
# ── 总结 ──────────────────────────────────────────────────────────────
|
||||
@@ -219,23 +185,41 @@ async def admin_crawl(
|
||||
|
||||
@router.post("/summarize")
|
||||
async def admin_summarize_batch(
|
||||
background_tasks: BackgroundTasks,
|
||||
_admin: None = Depends(verify_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""批量总结所有 pending 论文。"""
|
||||
return await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE)
|
||||
job = create_job(
|
||||
db,
|
||||
"summarize_batch",
|
||||
owner="admin_summarize",
|
||||
payload={"pdf_mode": settings.SUMMARY_PDF_MODE},
|
||||
)
|
||||
enqueue_job(background_tasks, job.id)
|
||||
return {"status": "queued", "job_id": job.id}
|
||||
|
||||
|
||||
@router.post("/summarize/{arxiv_id}")
|
||||
async def admin_summarize_single(
|
||||
arxiv_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
_admin: None = Depends(verify_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""总结或重跑单篇论文。"""
|
||||
return await summarize_single(
|
||||
db, arxiv_id, force=True, pdf_mode=settings.SUMMARY_PDF_MODE
|
||||
job = create_job(
|
||||
db,
|
||||
"summarize_one",
|
||||
owner="admin_summarize",
|
||||
payload={
|
||||
"arxiv_id": arxiv_id,
|
||||
"force": True,
|
||||
"pdf_mode": settings.SUMMARY_PDF_MODE,
|
||||
},
|
||||
)
|
||||
enqueue_job(background_tasks, job.id)
|
||||
return {"status": "queued", "job_id": job.id, "arxiv_id": arxiv_id}
|
||||
|
||||
|
||||
# ── 清理 ──────────────────────────────────────────────────────────────
|
||||
@@ -243,39 +227,25 @@ async def admin_summarize_single(
|
||||
|
||||
@router.post("/cleanup")
|
||||
async def admin_cleanup(
|
||||
background_tasks: BackgroundTasks,
|
||||
_admin: None = Depends(verify_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""清理 data/tmp/ 中超过 24 小时的临时文件。"""
|
||||
now = utc_now()
|
||||
log_entry = CrawlLog(
|
||||
task="cleanup",
|
||||
status="running",
|
||||
started_at=now,
|
||||
)
|
||||
db.add(log_entry)
|
||||
db.commit()
|
||||
job = create_job(db, "cleanup_tmp", owner="admin_cleanup", payload={})
|
||||
enqueue_job(background_tasks, job.id)
|
||||
return {"status": "queued", "job_id": job.id}
|
||||
|
||||
|
||||
@router.post("/cleanup-now")
|
||||
async def admin_cleanup_now(
|
||||
_admin: None = Depends(verify_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""同步清理临时文件,保留给测试和本地排障使用。"""
|
||||
try:
|
||||
result = cleanup_tmp()
|
||||
log_entry.status = "success"
|
||||
log_entry.completed_at = utc_now()
|
||||
log_entry.details_json = json.dumps(
|
||||
{
|
||||
"scanned": result.get("scanned", 0),
|
||||
"removed": result.get("removed", 0),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
if result.get("errors"):
|
||||
log_entry.error = "; ".join(result["errors"])[:2000]
|
||||
db.commit()
|
||||
return result
|
||||
return admin_svc.run_cleanup_now(db, cleanup_tmp)
|
||||
except Exception as exc:
|
||||
log_entry.status = "failed"
|
||||
log_entry.error = str(exc)[:2000]
|
||||
log_entry.completed_at = utc_now()
|
||||
db.commit()
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@@ -285,6 +255,7 @@ async def admin_cleanup(
|
||||
@router.post("/delete")
|
||||
async def admin_delete(
|
||||
body: DeleteRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
_admin: None = Depends(verify_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
@@ -292,13 +263,31 @@ async def admin_delete(
|
||||
if body.date_start > body.date_end:
|
||||
raise HTTPException(status_code=400, detail="date_start must be <= date_end")
|
||||
|
||||
result = await delete_papers_by_date_range(
|
||||
job = create_job(
|
||||
db,
|
||||
body.date_start,
|
||||
body.date_end,
|
||||
include_notes=body.include_notes,
|
||||
"delete_range",
|
||||
owner="admin_delete",
|
||||
payload={
|
||||
"date_start": body.date_start.isoformat(),
|
||||
"date_end": body.date_end.isoformat(),
|
||||
"include_notes": body.include_notes,
|
||||
},
|
||||
)
|
||||
return result
|
||||
enqueue_job(background_tasks, job.id)
|
||||
return {"status": "queued", "job_id": job.id}
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}")
|
||||
async def admin_job_detail(
|
||||
job_id: int,
|
||||
_admin: None = Depends(verify_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""查询后台任务状态和阶段事件。"""
|
||||
detail = admin_svc.get_job_detail(db, job_id)
|
||||
if not detail:
|
||||
raise HTTPException(status_code=404, detail=f"Job not found: {job_id}")
|
||||
return detail
|
||||
|
||||
|
||||
# ── 日志 ──────────────────────────────────────────────────────────────
|
||||
@@ -313,72 +302,10 @@ async def admin_logs(
|
||||
per_page: int = Query(20, ge=1, le=100),
|
||||
):
|
||||
"""查看任务日志(CrawlLog + DataDeleteJob)+ 总结状态统计。"""
|
||||
crawl_logs = (
|
||||
db.execute(
|
||||
select(CrawlLog)
|
||||
.order_by(CrawlLog.started_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
delete_jobs = (
|
||||
db.execute(
|
||||
select(DataDeleteJob)
|
||||
.order_by(DataDeleteJob.started_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
# 总结状态统计概要
|
||||
summary_total = db.scalar(select(func.count(Paper.id))) or 0
|
||||
summary_done = (
|
||||
db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status == SummaryState.DONE
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
summary_pending = (
|
||||
db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.PENDING, SummaryState.PROCESSING]
|
||||
)
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
summary_failed = (
|
||||
db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||
)
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"admin_logs.html",
|
||||
{
|
||||
"crawl_logs": crawl_logs,
|
||||
"delete_jobs": delete_jobs,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"summary_total": summary_total,
|
||||
"summary_done": summary_done,
|
||||
"summary_pending": summary_pending,
|
||||
"summary_failed": summary_failed,
|
||||
},
|
||||
admin_svc.get_logs_context(db, page=page, per_page=per_page),
|
||||
)
|
||||
|
||||
|
||||
@@ -395,22 +322,10 @@ async def admin_summary_status(
|
||||
per_page: int = Query(20, ge=1, le=100),
|
||||
):
|
||||
"""总结状态列表(HTMX 片段或 JSON)。"""
|
||||
|
||||
query = (
|
||||
select(Paper, SummaryStatus)
|
||||
.outerjoin(SummaryStatus, SummaryStatus.paper_id == Paper.id)
|
||||
.order_by(Paper.paper_date.desc())
|
||||
results, total = admin_svc.query_summary_statuses(
|
||||
db, status=status, page=page, per_page=per_page
|
||||
)
|
||||
|
||||
if status != "all":
|
||||
if status == "none":
|
||||
query = query.where(SummaryStatus.paper_id == None) # noqa: E711
|
||||
else:
|
||||
query = query.where(SummaryStatus.status == status)
|
||||
|
||||
total = db.scalar(select(func.count()).select_from(query.subquery()))
|
||||
results = db.execute(query.offset((page - 1) * per_page).limit(per_page)).all()
|
||||
|
||||
# 判断是否 HTMX 请求
|
||||
is_htmx = request.headers.get("HX-Request") == "true"
|
||||
|
||||
@@ -421,27 +336,16 @@ async def admin_summary_status(
|
||||
"partials/summary_list.html",
|
||||
{
|
||||
"results": results,
|
||||
"total": total or 0,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"current_status": status,
|
||||
},
|
||||
)
|
||||
|
||||
# 非 HTMX 返回 JSON
|
||||
items = []
|
||||
for paper, ss in results:
|
||||
item = {
|
||||
"arxiv_id": paper.arxiv_id,
|
||||
"title": paper.title_zh or paper.title_en,
|
||||
"paper_date": str(paper.paper_date),
|
||||
"summary_status": ss.status if ss else "none",
|
||||
"retry_count": ss.retry_count if ss else 0,
|
||||
"error_type": ss.error_type if ss else None,
|
||||
"error": ss.error if ss else None,
|
||||
}
|
||||
items.append(item)
|
||||
return {"items": items, "total": total or 0, "page": page, "per_page": per_page}
|
||||
return admin_svc.serialize_summary_statuses(
|
||||
results, total=total, page=page, per_page=per_page
|
||||
)
|
||||
|
||||
|
||||
@router.post("/summary-retry-failed")
|
||||
@@ -450,39 +354,14 @@ async def admin_summary_retry_failed(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""重试所有失败状态的总结任务。"""
|
||||
failed_ids = (
|
||||
db.execute(
|
||||
select(Paper.arxiv_id)
|
||||
.join(SummaryStatus, SummaryStatus.paper_id == Paper.id)
|
||||
.where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
if not failed_ids:
|
||||
count = admin_svc.retry_failed_summaries(db)
|
||||
if not count:
|
||||
return {"status": "success", "message": "没有失败的任务需要重试", "count": 0}
|
||||
|
||||
# 重置失败任务的状态为 pending
|
||||
db.execute(
|
||||
SummaryStatus.__table__.update()
|
||||
.where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||
)
|
||||
)
|
||||
.values(status=SummaryState.PENDING, error=None, error_type=None)
|
||||
)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"已重置 {len(failed_ids)} 个失败任务为待总结状态",
|
||||
"count": len(failed_ids),
|
||||
"message": f"已重置 {count} 个失败任务为待总结状态",
|
||||
"count": count,
|
||||
}
|
||||
|
||||
|
||||
@@ -545,23 +424,8 @@ async def admin_paper_delete(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除单篇论文。"""
|
||||
paper = db.scalar(select(Paper).where(Paper.arxiv_id == arxiv_id))
|
||||
if not paper:
|
||||
if not admin_svc.delete_paper_by_arxiv(db, arxiv_id):
|
||||
raise HTTPException(status_code=404, detail=f"Paper not found: {arxiv_id}")
|
||||
|
||||
# 删除相关数据(ORM cascade 自动处理关联表)
|
||||
db.delete(paper)
|
||||
db.commit()
|
||||
|
||||
# 清理 FTS 索引
|
||||
try:
|
||||
db.execute(
|
||||
text("DELETE FROM papers_fts WHERE arxiv_id = :aid"), {"aid": arxiv_id}
|
||||
)
|
||||
db.commit()
|
||||
except Exception:
|
||||
logger.warning("Failed to clean FTS index for %s", arxiv_id, exc_info=True)
|
||||
|
||||
return {"status": "success", "message": f"已删除 {arxiv_id}"}
|
||||
|
||||
|
||||
@@ -588,28 +452,7 @@ async def admin_papers_batch_action(
|
||||
raise HTTPException(status_code=400, detail="arxiv_ids 不能为空")
|
||||
|
||||
if body.action == "delete":
|
||||
papers = (
|
||||
db.execute(select(Paper).where(Paper.arxiv_id.in_(body.arxiv_ids)))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
count = 0
|
||||
for paper in papers:
|
||||
db.delete(paper)
|
||||
count += 1
|
||||
db.commit()
|
||||
|
||||
# 清理 FTS 索引
|
||||
try:
|
||||
stmt = text("DELETE FROM papers_fts WHERE arxiv_id IN :ids").bindparams(
|
||||
bindparam("ids", expanding=True)
|
||||
)
|
||||
db.execute(stmt, {"ids": body.arxiv_ids})
|
||||
db.commit()
|
||||
except Exception:
|
||||
logger.warning("Failed to clean FTS index for batch delete", exc_info=True)
|
||||
|
||||
count = admin_svc.delete_papers_by_arxiv_ids(db, body.arxiv_ids)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"已删除 {count} 篇论文",
|
||||
@@ -617,24 +460,10 @@ async def admin_papers_batch_action(
|
||||
}
|
||||
|
||||
elif body.action == "summarize":
|
||||
# 将选中论文的总结状态重置为 pending
|
||||
paper_ids = (
|
||||
db.execute(select(Paper.id).where(Paper.arxiv_id.in_(body.arxiv_ids)))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
if paper_ids:
|
||||
# 删除旧的 status 记录让其重新进入 pipeline
|
||||
db.execute(
|
||||
SummaryStatus.__table__.delete().where(
|
||||
SummaryStatus.paper_id.in_(paper_ids)
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
count = admin_svc.reset_summaries_pending(db, body.arxiv_ids)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"已将 {len(paper_ids)} 篇论文重置为待总结",
|
||||
"count": len(paper_ids),
|
||||
"message": f"已将 {count} 篇论文重置为待总结",
|
||||
"count": count,
|
||||
}
|
||||
|
||||
+324
-3
@@ -1,17 +1,30 @@
|
||||
"""管理后台服务 — 统计聚合、系统状态。"""
|
||||
"""管理后台服务 — 统计聚合、系统状态、管理操作。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import CrawlLog, Paper, PaperTag, SummaryState, SummaryStatus, TaskLock
|
||||
from app.models import (
|
||||
CrawlLog,
|
||||
DataDeleteJob,
|
||||
Job,
|
||||
JobEvent,
|
||||
Paper,
|
||||
PaperTag,
|
||||
SummaryState,
|
||||
SummaryStatus,
|
||||
TaskLock,
|
||||
)
|
||||
from app.services.derived import delete_paper_indexes
|
||||
from app.services.scheduler import get_scheduler
|
||||
from app.utils import PAPERS_DIR, TMP_DIR
|
||||
from app.utils import PAPERS_DIR, TMP_DIR, utc_now
|
||||
|
||||
# admin_papers 排序映射
|
||||
SORT_MAP = {
|
||||
@@ -190,3 +203,311 @@ def query_papers(
|
||||
statuses[paper_id_to_arxiv.get(pid, "")] = st
|
||||
|
||||
return papers, total or 0, statuses
|
||||
|
||||
|
||||
def get_scheduler_history(db: Session, limit: int = 10) -> list[CrawlLog]:
|
||||
"""最近的调度器运行日志。"""
|
||||
return (
|
||||
db.execute(
|
||||
select(CrawlLog)
|
||||
.where(CrawlLog.task == "scheduler")
|
||||
.order_by(CrawlLog.started_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_scheduler_status() -> dict:
|
||||
"""调度器运行状态。"""
|
||||
scheduler = get_scheduler()
|
||||
next_run = None
|
||||
upvote_next_run = None
|
||||
if scheduler:
|
||||
for job in scheduler.get_jobs():
|
||||
if job.id == "daily_pipeline":
|
||||
next_run = job.next_run_time
|
||||
elif job.id == "upvote_refresh":
|
||||
upvote_next_run = job.next_run_time
|
||||
return {
|
||||
"enabled": scheduler is not None,
|
||||
"schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}",
|
||||
"timezone": settings.APP_TIMEZONE,
|
||||
"next_run": next_run.isoformat() if next_run else None,
|
||||
"upvote_next_run": upvote_next_run.isoformat() if upvote_next_run else None,
|
||||
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
|
||||
}
|
||||
|
||||
|
||||
def run_cleanup_now(db: Session, cleanup_func: Callable[[], dict]) -> dict:
|
||||
"""同步执行临时目录清理,并写入 CrawlLog。"""
|
||||
log_entry = CrawlLog(task="cleanup", status="running", started_at=utc_now())
|
||||
db.add(log_entry)
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
result = cleanup_func()
|
||||
log_entry.status = "success"
|
||||
log_entry.completed_at = utc_now()
|
||||
log_entry.details_json = json.dumps(
|
||||
{
|
||||
"scanned": result.get("scanned", 0),
|
||||
"removed": result.get("removed", 0),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
if result.get("errors"):
|
||||
log_entry.error = "; ".join(result["errors"])[:2000]
|
||||
db.commit()
|
||||
return result
|
||||
except Exception as exc:
|
||||
log_entry.status = "failed"
|
||||
log_entry.error = str(exc)[:2000]
|
||||
log_entry.completed_at = utc_now()
|
||||
db.commit()
|
||||
raise
|
||||
|
||||
|
||||
def get_job_detail(db: Session, job_id: int) -> dict | None:
|
||||
"""后台任务详情和阶段事件,返回可 JSON 序列化 dict。"""
|
||||
job = db.get(Job, job_id)
|
||||
if not job:
|
||||
return None
|
||||
events = (
|
||||
db.execute(
|
||||
select(JobEvent)
|
||||
.where(JobEvent.job_id == job_id)
|
||||
.order_by(JobEvent.created_at.asc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
return {
|
||||
"id": job.id,
|
||||
"type": job.type,
|
||||
"status": job.status,
|
||||
"owner": job.owner,
|
||||
"payload": json.loads(job.payload_json or "{}"),
|
||||
"result": json.loads(job.result_json or "{}") if job.result_json else None,
|
||||
"error": job.error,
|
||||
"created_at": job.created_at.isoformat(),
|
||||
"started_at": job.started_at.isoformat() if job.started_at else None,
|
||||
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
|
||||
"events": [
|
||||
{
|
||||
"stage": event.stage,
|
||||
"status": event.status,
|
||||
"message": event.message,
|
||||
"payload": json.loads(event.payload_json or "{}")
|
||||
if event.payload_json
|
||||
else None,
|
||||
"created_at": event.created_at.isoformat(),
|
||||
}
|
||||
for event in events
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def get_logs_context(db: Session, *, page: int, per_page: int) -> dict:
|
||||
"""管理日志页上下文。"""
|
||||
crawl_logs = (
|
||||
db.execute(
|
||||
select(CrawlLog)
|
||||
.order_by(CrawlLog.started_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
delete_jobs = (
|
||||
db.execute(
|
||||
select(DataDeleteJob)
|
||||
.order_by(DataDeleteJob.started_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
summary_total = db.scalar(select(func.count(Paper.id))) or 0
|
||||
summary_done = (
|
||||
db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status == SummaryState.DONE
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
summary_pending = (
|
||||
db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.PENDING, SummaryState.PROCESSING]
|
||||
)
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
summary_failed = (
|
||||
db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||
)
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
return {
|
||||
"crawl_logs": crawl_logs,
|
||||
"delete_jobs": delete_jobs,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"summary_total": summary_total,
|
||||
"summary_done": summary_done,
|
||||
"summary_pending": summary_pending,
|
||||
"summary_failed": summary_failed,
|
||||
}
|
||||
|
||||
|
||||
def query_summary_statuses(
|
||||
db: Session,
|
||||
*,
|
||||
status: str,
|
||||
page: int,
|
||||
per_page: int,
|
||||
) -> tuple[list[tuple[Paper, SummaryStatus | None]], int]:
|
||||
"""总结状态列表查询。"""
|
||||
query = (
|
||||
select(Paper, SummaryStatus)
|
||||
.outerjoin(SummaryStatus, SummaryStatus.paper_id == Paper.id)
|
||||
.order_by(Paper.paper_date.desc())
|
||||
)
|
||||
if status != "all":
|
||||
if status == "none":
|
||||
query = query.where(SummaryStatus.paper_id == None) # noqa: E711
|
||||
else:
|
||||
query = query.where(SummaryStatus.status == status)
|
||||
|
||||
total = db.scalar(select(func.count()).select_from(query.subquery())) or 0
|
||||
results = db.execute(query.offset((page - 1) * per_page).limit(per_page)).all()
|
||||
return results, total
|
||||
|
||||
|
||||
def serialize_summary_statuses(
|
||||
results: list[tuple[Paper, SummaryStatus | None]],
|
||||
*,
|
||||
total: int,
|
||||
page: int,
|
||||
per_page: int,
|
||||
) -> dict:
|
||||
"""总结状态列表 JSON 响应。"""
|
||||
items = []
|
||||
for paper, ss in results:
|
||||
items.append(
|
||||
{
|
||||
"arxiv_id": paper.arxiv_id,
|
||||
"title": paper.title_zh or paper.title_en,
|
||||
"paper_date": str(paper.paper_date),
|
||||
"summary_status": ss.status if ss else "none",
|
||||
"retry_count": ss.retry_count if ss else 0,
|
||||
"error_type": ss.error_type if ss else None,
|
||||
"error": ss.error if ss else None,
|
||||
}
|
||||
)
|
||||
return {"items": items, "total": total, "page": page, "per_page": per_page}
|
||||
|
||||
|
||||
def retry_failed_summaries(db: Session) -> int:
|
||||
"""将失败/永久失败的总结任务重置为 pending。"""
|
||||
failed_ids = (
|
||||
db.execute(
|
||||
select(Paper.arxiv_id)
|
||||
.join(SummaryStatus, SummaryStatus.paper_id == Paper.id)
|
||||
.where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
if not failed_ids:
|
||||
return 0
|
||||
|
||||
db.execute(
|
||||
SummaryStatus.__table__.update()
|
||||
.where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||
)
|
||||
)
|
||||
.values(status=SummaryState.PENDING, error=None, error_type=None)
|
||||
)
|
||||
db.commit()
|
||||
return len(failed_ids)
|
||||
|
||||
|
||||
def delete_paper_by_arxiv(db: Session, arxiv_id: str) -> bool:
|
||||
"""删除单篇论文和派生索引。"""
|
||||
paper = db.scalar(select(Paper).where(Paper.arxiv_id == arxiv_id))
|
||||
if not paper:
|
||||
return False
|
||||
|
||||
paper_id = paper.id
|
||||
db.delete(paper)
|
||||
db.commit()
|
||||
delete_paper_indexes(db, paper_id=paper_id, arxiv_id=arxiv_id)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
def delete_papers_by_arxiv_ids(db: Session, arxiv_ids: list[str]) -> int:
|
||||
"""批量删除论文和派生索引。"""
|
||||
papers = (
|
||||
db.execute(select(Paper).where(Paper.arxiv_id.in_(arxiv_ids))).scalars().all()
|
||||
)
|
||||
deleted = [(paper.id, paper.arxiv_id) for paper in papers]
|
||||
for paper in papers:
|
||||
db.delete(paper)
|
||||
db.commit()
|
||||
|
||||
for paper_id, arxiv_id in deleted:
|
||||
delete_paper_indexes(db, paper_id=paper_id, arxiv_id=arxiv_id)
|
||||
db.commit()
|
||||
return len(deleted)
|
||||
|
||||
|
||||
def reset_summaries_pending(db: Session, arxiv_ids: list[str]) -> int:
|
||||
"""将指定论文的总结状态重置为 pending,没有状态则创建。"""
|
||||
paper_ids = (
|
||||
db.execute(select(Paper.id).where(Paper.arxiv_id.in_(arxiv_ids)))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
if not paper_ids:
|
||||
return 0
|
||||
|
||||
existing_statuses = (
|
||||
db.execute(select(SummaryStatus).where(SummaryStatus.paper_id.in_(paper_ids)))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
existing_ids = {status.paper_id for status in existing_statuses}
|
||||
for status in existing_statuses:
|
||||
status.status = SummaryState.PENDING
|
||||
status.quality = None
|
||||
status.error = None
|
||||
status.error_type = None
|
||||
status.raw_output_saved = False
|
||||
status.started_at = None
|
||||
status.completed_at = None
|
||||
for paper_id in paper_ids:
|
||||
if paper_id not in existing_ids:
|
||||
db.add(SummaryStatus(paper_id=paper_id, status=SummaryState.PENDING))
|
||||
db.commit()
|
||||
return len(paper_ids)
|
||||
|
||||
+3
-16
@@ -4,7 +4,7 @@ import logging
|
||||
from datetime import date as date_type, datetime, timezone
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
@@ -16,6 +16,7 @@ from app.models import (
|
||||
SummaryState,
|
||||
SummaryStatus,
|
||||
)
|
||||
from app.services.derived import reindex_paper_fts
|
||||
from app.utils import make_http_client, recent_date_strs, utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -143,21 +144,7 @@ def upsert_papers(db: Session, papers_raw: list[dict], paper_date: str) -> list[
|
||||
|
||||
db.add(SummaryStatus(paper_id=paper.id, status=SummaryState.PENDING))
|
||||
|
||||
authors_text = ", ".join(meta["authors"])
|
||||
tags_text = ", ".join(meta["tags"])
|
||||
db.execute(
|
||||
text(
|
||||
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
|
||||
"VALUES (:id, :title, :abstract, :authors, :tags)"
|
||||
),
|
||||
{
|
||||
"id": paper.id,
|
||||
"title": meta["title_en"],
|
||||
"abstract": meta["abstract"] or "",
|
||||
"authors": authors_text,
|
||||
"tags": tags_text,
|
||||
},
|
||||
)
|
||||
reindex_paper_fts(db, paper)
|
||||
|
||||
new_papers.append(paper)
|
||||
logger.debug("Inserted new paper: %s", arxiv_id)
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
"""派生数据维护 — FTS5 / ChromaDB 等可重建索引。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import Paper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _summary_text(paper: Paper) -> str:
|
||||
summary = paper.summary
|
||||
if not summary:
|
||||
return ""
|
||||
parts = [
|
||||
summary.one_line,
|
||||
summary.motivation_problem,
|
||||
summary.motivation_goal,
|
||||
summary.method_overview,
|
||||
summary.method_key_idea,
|
||||
summary.results_main_json,
|
||||
]
|
||||
return " ".join(p for p in parts if p)
|
||||
|
||||
|
||||
def delete_fts_paper(db: Session, paper_id: int) -> None:
|
||||
"""删除单篇论文的 FTS5 行。FTS5 以 papers.id 作为 rowid。"""
|
||||
db.execute(
|
||||
text("DELETE FROM papers_fts WHERE rowid = :paper_id"),
|
||||
{"paper_id": paper_id},
|
||||
)
|
||||
|
||||
|
||||
def delete_paper_indexes(db: Session, *, paper_id: int, arxiv_id: str) -> None:
|
||||
"""删除单篇论文的所有派生索引。失败项记录日志但不阻断主删除。"""
|
||||
try:
|
||||
delete_fts_paper(db, paper_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to clean FTS index for %s", arxiv_id, exc_info=True)
|
||||
|
||||
try:
|
||||
from app.services.embedder import delete_paper
|
||||
|
||||
delete_paper(arxiv_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to clean ChromaDB index for %s", arxiv_id, exc_info=True)
|
||||
|
||||
|
||||
def reindex_paper_fts(db: Session, paper: Paper) -> None:
|
||||
"""按 DB 权威数据重建单篇论文的 FTS5 派生索引。"""
|
||||
authors_text = ", ".join(
|
||||
a.name for a in sorted(paper.authors, key=lambda a: a.position or 0)
|
||||
)
|
||||
tags_text = ", ".join(t.tag for t in paper.tags)
|
||||
|
||||
delete_fts_paper(db, paper.id)
|
||||
db.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO papers_fts(
|
||||
rowid, title_en, title_zh, abstract, authors, tags, summary_text
|
||||
)
|
||||
VALUES (
|
||||
:id, :title_en, :title_zh, :abstract, :authors, :tags, :summary_text
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": paper.id,
|
||||
"title_en": paper.title_en or "",
|
||||
"title_zh": paper.title_zh or "",
|
||||
"abstract": paper.abstract or "",
|
||||
"authors": authors_text,
|
||||
"tags": tags_text,
|
||||
"summary_text": _summary_text(paper),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def reindex_fts(db: Session, paper_ids: list[int] | None = None) -> dict:
|
||||
"""全量或局部重建 FTS5 索引。"""
|
||||
query = select(Paper)
|
||||
if paper_ids:
|
||||
query = query.where(Paper.id.in_(paper_ids))
|
||||
papers = db.execute(query).scalars().all()
|
||||
|
||||
if paper_ids is None:
|
||||
db.execute(text("DELETE FROM papers_fts"))
|
||||
|
||||
count = 0
|
||||
for paper in papers:
|
||||
reindex_paper_fts(db, paper)
|
||||
count += 1
|
||||
db.commit()
|
||||
logger.info("FTS reindexed: %d papers", count)
|
||||
return {"status": "success", "indexed": count}
|
||||
|
||||
|
||||
def reindex_chroma(db: Session) -> dict:
|
||||
"""按 DB 权威数据重建 ChromaDB 语义索引。"""
|
||||
from app.services.embedder import index_paper
|
||||
|
||||
papers = db.execute(select(Paper).where(Paper.summary.has())).scalars().all()
|
||||
indexed = 0
|
||||
errors: list[str] = []
|
||||
for paper in papers:
|
||||
try:
|
||||
texts_dict = {
|
||||
"arxiv_id": paper.arxiv_id,
|
||||
"title_zh": paper.title_zh or "",
|
||||
"title_en": paper.title_en or "",
|
||||
"tags": " ".join(t.tag for t in paper.tags),
|
||||
"one_line": paper.summary.one_line if paper.summary else "",
|
||||
"motivation_problem": (
|
||||
paper.summary.motivation_problem if paper.summary else ""
|
||||
),
|
||||
"method_key_idea": (
|
||||
paper.summary.method_key_idea if paper.summary else ""
|
||||
),
|
||||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
|
||||
}
|
||||
index_paper(paper.arxiv_id, texts_dict)
|
||||
indexed += 1
|
||||
except Exception as exc:
|
||||
errors.append(f"{paper.arxiv_id}: {exc}")
|
||||
logger.warning(
|
||||
"Failed to reindex ChromaDB for %s",
|
||||
paper.arxiv_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success" if not errors else "partial",
|
||||
"indexed": indexed,
|
||||
"errors": errors or None,
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
"""统一后台任务系统 — 创建、运行、事件记录、失败恢复。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import date, timedelta
|
||||
from typing import Any
|
||||
|
||||
from fastapi import BackgroundTasks
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.database import SessionLocal
|
||||
from app.models import Job, JobEvent, JobEventStatus, JobStatus, TaskLock
|
||||
from app.utils import truncate_error, utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
STALE_JOB_AFTER = timedelta(hours=6)
|
||||
|
||||
|
||||
def _dumps(value: Any) -> str:
|
||||
return json.dumps(value, ensure_ascii=False, default=str)
|
||||
|
||||
|
||||
def _loads(value: str | None) -> dict:
|
||||
if not value:
|
||||
return {}
|
||||
try:
|
||||
data = json.loads(value)
|
||||
return data if isinstance(data, dict) else {}
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
|
||||
def create_job(
|
||||
db: Session,
|
||||
job_type: str,
|
||||
*,
|
||||
owner: str,
|
||||
payload: dict | None = None,
|
||||
) -> Job:
|
||||
"""创建后台任务主记录。"""
|
||||
job = Job(
|
||||
type=job_type,
|
||||
status=JobStatus.QUEUED,
|
||||
owner=owner,
|
||||
payload_json=_dumps(payload or {}),
|
||||
created_at=utc_now(),
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
add_job_event(
|
||||
db,
|
||||
job,
|
||||
stage="created",
|
||||
status=JobEventStatus.INFO,
|
||||
message=f"Job queued: {job_type}",
|
||||
payload=payload or {},
|
||||
)
|
||||
return job
|
||||
|
||||
|
||||
def add_job_event(
|
||||
db: Session,
|
||||
job: Job,
|
||||
*,
|
||||
stage: str,
|
||||
status: str,
|
||||
message: str | None = None,
|
||||
payload: dict | None = None,
|
||||
) -> None:
|
||||
"""追加一条任务阶段事件。"""
|
||||
db.add(
|
||||
JobEvent(
|
||||
job_id=job.id,
|
||||
stage=stage,
|
||||
status=str(status),
|
||||
message=message,
|
||||
payload_json=_dumps(payload) if payload is not None else None,
|
||||
created_at=utc_now(),
|
||||
)
|
||||
)
|
||||
job.heartbeat_at = utc_now()
|
||||
db.commit()
|
||||
|
||||
|
||||
def enqueue_job(background_tasks: BackgroundTasks, job_id: int) -> None:
|
||||
"""把任务提交给 FastAPI BackgroundTasks。"""
|
||||
background_tasks.add_task(run_job_by_id, job_id)
|
||||
|
||||
|
||||
async def run_job_by_id(job_id: int) -> None:
|
||||
"""使用独立 DB session 运行一个已创建的 job。"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
await run_job(db, job_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def run_job(db: Session, job_id: int) -> dict:
|
||||
"""运行 job,并把状态/result/error 写回 jobs/job_events。"""
|
||||
job = db.get(Job, job_id)
|
||||
if not job:
|
||||
raise ValueError(f"Job not found: {job_id}")
|
||||
if job.status == JobStatus.RUNNING:
|
||||
raise RuntimeError(f"Job already running: {job_id}")
|
||||
|
||||
payload = _loads(job.payload_json)
|
||||
job.status = JobStatus.RUNNING
|
||||
job.started_at = utc_now()
|
||||
job.heartbeat_at = job.started_at
|
||||
db.commit()
|
||||
add_job_event(db, job, stage="run", status=JobEventStatus.STARTED)
|
||||
|
||||
try:
|
||||
result = await _dispatch_job(db, job, payload)
|
||||
except Exception as exc:
|
||||
logger.exception("Job failed: id=%s type=%s", job.id, job.type)
|
||||
error = truncate_error(exc, limit=4000)
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = error
|
||||
job.completed_at = utc_now()
|
||||
db.commit()
|
||||
add_job_event(db, job, stage="run", status=JobEventStatus.FAILED, message=error)
|
||||
return {"status": "failed", "error": error}
|
||||
|
||||
job.status = JobStatus.SUCCESS
|
||||
job.result_json = _dumps(result)
|
||||
job.completed_at = utc_now()
|
||||
job.error = None
|
||||
db.commit()
|
||||
add_job_event(
|
||||
db,
|
||||
job,
|
||||
stage="run",
|
||||
status=JobEventStatus.SUCCESS,
|
||||
payload=result if isinstance(result, dict) else {"result": result},
|
||||
)
|
||||
return result if isinstance(result, dict) else {"status": "success", "result": result}
|
||||
|
||||
|
||||
async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
|
||||
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
|
||||
from app.services.crawler import refresh_upvotes
|
||||
from app.services.derived import reindex_chroma, reindex_fts
|
||||
from app.services.pipeline import run_crawl, run_pipeline
|
||||
from app.services.summarizer import summarize_batch, summarize_single
|
||||
|
||||
if job.type == "crawl_daily":
|
||||
return await run_crawl(
|
||||
db,
|
||||
payload["target_date"],
|
||||
owner=job.owner or f"job:{job.id}",
|
||||
top_n=payload.get("top_n"),
|
||||
)
|
||||
if job.type == "pipeline_daily":
|
||||
return await run_pipeline(
|
||||
db,
|
||||
payload["target_date"],
|
||||
owner=job.owner or f"job:{job.id}",
|
||||
)
|
||||
if job.type == "summarize_batch":
|
||||
return await summarize_batch(
|
||||
db,
|
||||
pdf_mode=payload.get("pdf_mode", settings.SUMMARY_PDF_MODE),
|
||||
)
|
||||
if job.type == "summarize_one":
|
||||
return await summarize_single(
|
||||
db,
|
||||
payload["arxiv_id"],
|
||||
force=payload.get("force", True),
|
||||
pdf_mode=payload.get("pdf_mode", settings.SUMMARY_PDF_MODE),
|
||||
)
|
||||
if job.type == "refresh_upvotes":
|
||||
return await refresh_upvotes(db, days=payload.get("days"))
|
||||
if job.type == "delete_range":
|
||||
return await delete_papers_by_date_range(
|
||||
db,
|
||||
date.fromisoformat(payload["date_start"]),
|
||||
date.fromisoformat(payload["date_end"]),
|
||||
include_notes=payload.get("include_notes", True),
|
||||
)
|
||||
if job.type == "cleanup_tmp":
|
||||
return cleanup_tmp()
|
||||
if job.type == "reindex_fts":
|
||||
return reindex_fts(db)
|
||||
if job.type == "reindex_chroma":
|
||||
return reindex_chroma(db)
|
||||
|
||||
raise ValueError(f"Unsupported job type: {job.type}")
|
||||
|
||||
|
||||
def recover_stale_jobs(db: Session) -> int:
|
||||
"""启动时将过期 running job/lock 标记为 stale,避免永久卡住。"""
|
||||
now = utc_now()
|
||||
cutoff = now - STALE_JOB_AFTER
|
||||
stale_jobs = (
|
||||
db.execute(
|
||||
select(Job).where(
|
||||
Job.status == JobStatus.RUNNING,
|
||||
or_(Job.heartbeat_at == None, Job.heartbeat_at < cutoff), # noqa: E711
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
for job in stale_jobs:
|
||||
job.status = JobStatus.STALE
|
||||
job.error = "Marked stale after process restart or missed heartbeat"
|
||||
job.completed_at = now
|
||||
db.add(
|
||||
JobEvent(
|
||||
job_id=job.id,
|
||||
stage="recovery",
|
||||
status=JobEventStatus.FAILED,
|
||||
message=job.error,
|
||||
created_at=now,
|
||||
)
|
||||
)
|
||||
|
||||
stale_locks = (
|
||||
db.execute(
|
||||
select(TaskLock).where(
|
||||
TaskLock.status == "running",
|
||||
TaskLock.acquired_at < cutoff,
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
for lock in stale_locks:
|
||||
lock.status = "stale"
|
||||
lock.released_at = now
|
||||
|
||||
db.commit()
|
||||
recovered = len(stale_jobs) + len(stale_locks)
|
||||
if recovered:
|
||||
logger.warning("Recovered stale runtime records: %d", recovered)
|
||||
return recovered
|
||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import date as date_type
|
||||
from datetime import timedelta
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -32,6 +33,8 @@ def acquire_lock(db: Session, task: str, lock_key: str, owner: str) -> TaskLock:
|
||||
status="running",
|
||||
owner=owner,
|
||||
acquired_at=utc_now(),
|
||||
heartbeat_at=utc_now(),
|
||||
expires_at=utc_now() + timedelta(hours=6),
|
||||
)
|
||||
try:
|
||||
db.add(lock)
|
||||
@@ -42,7 +45,12 @@ def acquire_lock(db: Session, task: str, lock_key: str, owner: str) -> TaskLock:
|
||||
return lock
|
||||
|
||||
|
||||
async def run_crawl(db: Session, target_date: str, owner: str = "admin_crawl") -> dict:
|
||||
async def run_crawl(
|
||||
db: Session,
|
||||
target_date: str,
|
||||
owner: str = "admin_crawl",
|
||||
top_n: int | None = None,
|
||||
) -> dict:
|
||||
"""执行单次抓取(带防重入锁)。
|
||||
|
||||
Args:
|
||||
@@ -55,7 +63,7 @@ async def run_crawl(db: Session, target_date: str, owner: str = "admin_crawl") -
|
||||
"""
|
||||
lock = acquire_lock(db, "crawl", target_date, owner)
|
||||
try:
|
||||
return await crawl_daily(db, target_date)
|
||||
return await crawl_daily(db, target_date, top_n=top_n)
|
||||
finally:
|
||||
release_lock(db, lock)
|
||||
|
||||
@@ -83,6 +91,8 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
|
||||
status="running",
|
||||
owner=owner,
|
||||
acquired_at=now,
|
||||
heartbeat_at=now,
|
||||
expires_at=now + timedelta(hours=6),
|
||||
)
|
||||
try:
|
||||
db.add(lock)
|
||||
|
||||
@@ -11,8 +11,7 @@ from zoneinfo import ZoneInfo
|
||||
|
||||
from app.config import settings
|
||||
from app.database import SessionLocal
|
||||
from app.services.pipeline import run_pipeline
|
||||
from app.services.crawler import refresh_upvotes
|
||||
from app.services.jobs import create_job, run_job
|
||||
from app.utils import today_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -112,7 +111,13 @@ async def _daily_pipeline() -> None:
|
||||
|
||||
db: Session = SessionLocal()
|
||||
try:
|
||||
await run_pipeline(db, today, owner="daily_pipeline")
|
||||
job = create_job(
|
||||
db,
|
||||
"pipeline_daily",
|
||||
owner="daily_pipeline",
|
||||
payload={"target_date": today},
|
||||
)
|
||||
await run_job(db, job.id)
|
||||
except RuntimeError:
|
||||
logger.warning("Daily pipeline already running for %s, skipping", today)
|
||||
except Exception:
|
||||
@@ -125,7 +130,8 @@ async def _upvote_refresh() -> None:
|
||||
"""刷新最近 N 天论文的 upvotes。"""
|
||||
db: Session = SessionLocal()
|
||||
try:
|
||||
result = await refresh_upvotes(db)
|
||||
job = create_job(db, "refresh_upvotes", owner="upvote_refresh", payload={})
|
||||
result = await run_job(db, job.id)
|
||||
logger.info(
|
||||
"Upvote refresh completed: status=%s updated=%d",
|
||||
result.get("status"),
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import (
|
||||
@@ -13,6 +11,7 @@ from app.models import (
|
||||
PaperTag,
|
||||
SummaryState,
|
||||
)
|
||||
from app.services.derived import reindex_paper_fts
|
||||
from app.services.pdf_downloader import paper_dir
|
||||
from app.services.schemas import (
|
||||
SummarySchema,
|
||||
@@ -75,19 +74,9 @@ def _update_summary_in_db(
|
||||
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
|
||||
existing_tag_names.add(tag_name)
|
||||
|
||||
# 4. FTS5 更新
|
||||
summary_text = _build_fts_summary_text(schema)
|
||||
db.execute(
|
||||
text(
|
||||
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
|
||||
"WHERE rowid=:paper_id"
|
||||
),
|
||||
{
|
||||
"title_zh": schema.title_zh,
|
||||
"summary_text": summary_text,
|
||||
"paper_id": paper.id,
|
||||
},
|
||||
)
|
||||
# 4. FTS5 派生索引
|
||||
db.flush()
|
||||
reindex_paper_fts(db, paper)
|
||||
|
||||
db.commit()
|
||||
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
|
||||
|
||||
+139
-82
@@ -3,14 +3,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, text
|
||||
|
||||
from app.config import settings
|
||||
from app.models import (
|
||||
CrawlLog,
|
||||
Job,
|
||||
SummaryState,
|
||||
SummaryStatus,
|
||||
TaskLock,
|
||||
)
|
||||
from app.utils import utc_now
|
||||
@@ -64,47 +67,13 @@ class TestAdminAuth:
|
||||
resp = auth_client.get("/admin/logs", follow_redirects=False)
|
||||
assert resp.status_code == 303
|
||||
|
||||
def test_correct_session_accepted(self, auth_client):
|
||||
"""已登录 session 应被接受(crawl 可能会失败但不是 303)。"""
|
||||
with patch(
|
||||
"app.routes.admin.run_crawl", new_callable=AsyncMock
|
||||
) as mock_crawl:
|
||||
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
|
||||
resp = auth_client.post("/admin/crawl")
|
||||
assert resp.status_code != 303
|
||||
|
||||
# ── summarize route auth ────────────────────────────────────────
|
||||
|
||||
def test_no_session_returns_303_for_summarize(self, client, monkeypatch):
|
||||
"""无 session 返回 303。"""
|
||||
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
|
||||
resp = client.post("/admin/summarize", follow_redirects=False)
|
||||
assert resp.status_code == 303
|
||||
|
||||
def test_correct_session_batch_summarize(self, auth_client):
|
||||
"""已登录调用 batch summarize,mock 掉服务层。"""
|
||||
with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = {
|
||||
"status": "success",
|
||||
"done": 0,
|
||||
"failed": 0,
|
||||
"total": 0,
|
||||
}
|
||||
"""已登录调用 batch summarize,应创建后台任务。"""
|
||||
with patch("app.routes.admin.enqueue_job"):
|
||||
resp = auth_client.post("/admin/summarize")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "success"
|
||||
|
||||
def test_single_paper_not_found(self, auth_client):
|
||||
"""单篇总结不存在的论文返回 404。"""
|
||||
from app.exceptions import NotFoundError
|
||||
|
||||
with patch(
|
||||
"app.routes.admin.summarize_single",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=NotFoundError("Paper not found: nonexistent.99999"),
|
||||
):
|
||||
resp = auth_client.post("/admin/summarize/nonexistent.99999")
|
||||
assert resp.status_code == 404
|
||||
assert resp.json()["status"] == "queued"
|
||||
assert "job_id" in resp.json()
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
@@ -115,29 +84,12 @@ class TestAdminAuth:
|
||||
class TestAdminCrawl:
|
||||
"""POST /admin/crawl 测试。"""
|
||||
|
||||
def test_crawl_default_today(self, auth_client):
|
||||
"""不指定日期时默认抓取今天。"""
|
||||
with patch(
|
||||
"app.routes.admin.run_crawl", new_callable=AsyncMock
|
||||
) as mock_crawl:
|
||||
mock_crawl.return_value = {"found": 5, "new": 3, "status": "success"}
|
||||
resp = auth_client.post("/admin/crawl")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "success"
|
||||
mock_crawl.assert_called_once()
|
||||
|
||||
def test_crawl_specific_date(self, auth_client):
|
||||
"""指定日期抓取。"""
|
||||
with patch(
|
||||
"app.routes.admin.run_crawl", new_callable=AsyncMock
|
||||
) as mock_crawl:
|
||||
mock_crawl.return_value = {"found": 2, "new": 1, "status": "success"}
|
||||
with patch("app.routes.admin.enqueue_job"):
|
||||
resp = auth_client.post("/admin/crawl?date=2024-01-15")
|
||||
assert resp.status_code == 200
|
||||
mock_crawl.assert_called_once()
|
||||
call_args = mock_crawl.call_args
|
||||
assert call_args[0][1] == "2024-01-15"
|
||||
assert resp.json()["target_date"] == "2024-01-15"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
@@ -149,20 +101,20 @@ class TestAdminCleanup:
|
||||
"""POST /admin/cleanup 测试。"""
|
||||
|
||||
def test_cleanup_returns_stats(self, auth_client):
|
||||
"""清理应返回统计信息。"""
|
||||
"""同步清理排障接口应返回统计信息。"""
|
||||
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
|
||||
mock_cleanup.return_value = {"scanned": 3, "removed": 1, "errors": []}
|
||||
resp = auth_client.post("/admin/cleanup")
|
||||
resp = auth_client.post("/admin/cleanup-now")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["scanned"] == 3
|
||||
assert data["removed"] == 1
|
||||
|
||||
def test_cleanup_writes_log(self, auth_client, db_session):
|
||||
"""清理应写入 crawl_logs。"""
|
||||
"""同步清理排障接口应写入 crawl_logs。"""
|
||||
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
|
||||
mock_cleanup.return_value = {"scanned": 0, "removed": 0, "errors": []}
|
||||
auth_client.post("/admin/cleanup")
|
||||
auth_client.post("/admin/cleanup-now")
|
||||
|
||||
logs = (
|
||||
db_session.execute(select(CrawlLog).where(CrawlLog.task == "cleanup"))
|
||||
@@ -195,7 +147,8 @@ class TestAdminDelete:
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_delete_with_confirm(self, auth_client, db_session, sample_papers_range):
|
||||
"""confirm='DELETE' 时应执行删除。"""
|
||||
"""confirm='DELETE' 时应创建后台删除 job。"""
|
||||
with patch("app.routes.admin.enqueue_job"):
|
||||
resp = auth_client.post(
|
||||
"/admin/delete",
|
||||
json={
|
||||
@@ -207,7 +160,8 @@ class TestAdminDelete:
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["deleted"] == 3
|
||||
assert data["status"] == "queued"
|
||||
assert db_session.get(Job, data["job_id"]) is not None
|
||||
|
||||
def test_delete_invalid_date_range(self, auth_client):
|
||||
"""date_start > date_end 应返回 400。"""
|
||||
@@ -221,17 +175,6 @@ class TestAdminDelete:
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_delete_without_confirm_field(self, auth_client):
|
||||
"""缺少 confirm 字段应返回 422。"""
|
||||
resp = auth_client.post(
|
||||
"/admin/delete",
|
||||
json={
|
||||
"date_start": "2024-01-10",
|
||||
"date_end": "2024-01-12",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Admin Routes — Logs
|
||||
@@ -241,12 +184,6 @@ class TestAdminDelete:
|
||||
class TestAdminLogs:
|
||||
"""GET /admin/logs 测试。"""
|
||||
|
||||
def test_logs_returns_page(self, auth_client):
|
||||
"""应返回管理日志页面。"""
|
||||
resp = auth_client.get("/admin/logs")
|
||||
assert resp.status_code == 200
|
||||
assert "text/html" in resp.headers.get("content-type", "")
|
||||
|
||||
def test_logs_requires_auth(self, client, monkeypatch):
|
||||
"""日志页面需要鉴权。"""
|
||||
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
|
||||
@@ -272,6 +209,126 @@ class TestAdminLogs:
|
||||
assert "crawl" in resp.text.lower() or "日志" in resp.text
|
||||
|
||||
|
||||
class TestAdminJobs:
|
||||
"""后台 job 查询接口测试。"""
|
||||
|
||||
def test_job_detail_returns_payload_and_events(self, auth_client, db_session):
|
||||
"""GET /admin/jobs/{id} 返回 job 主记录和事件。"""
|
||||
with patch("app.routes.admin.enqueue_job"):
|
||||
resp = auth_client.post("/admin/crawl?date=2024-01-15")
|
||||
job_id = resp.json()["job_id"]
|
||||
|
||||
resp = auth_client.get(f"/admin/jobs/{job_id}")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == job_id
|
||||
assert data["type"] == "crawl_daily"
|
||||
assert data["payload"] == {"target_date": "2024-01-15"}
|
||||
assert data["events"][0]["stage"] == "created"
|
||||
|
||||
def test_job_detail_not_found(self, auth_client):
|
||||
resp = auth_client.get("/admin/jobs/999999")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestAdminSummaryStatus:
|
||||
"""总结状态管理接口测试。"""
|
||||
|
||||
def test_summary_status_json_filters_failed(
|
||||
self, auth_client, db_session, sample_paper
|
||||
):
|
||||
sample_paper.summary_status.status = SummaryState.FAILED
|
||||
sample_paper.summary_status.retry_count = 2
|
||||
sample_paper.summary_status.error_type = "timeout"
|
||||
db_session.commit()
|
||||
|
||||
resp = auth_client.get("/admin/summary-status?status=failed")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] == 1
|
||||
assert data["items"][0]["arxiv_id"] == sample_paper.arxiv_id
|
||||
assert data["items"][0]["retry_count"] == 2
|
||||
|
||||
def test_retry_failed_resets_failed_statuses(
|
||||
self, auth_client, db_session, sample_paper
|
||||
):
|
||||
sample_paper.summary_status.status = SummaryState.PERMANENT_FAILURE
|
||||
sample_paper.summary_status.error = "bad json"
|
||||
sample_paper.summary_status.error_type = "json_invalid"
|
||||
db_session.commit()
|
||||
|
||||
resp = auth_client.post("/admin/summary-retry-failed")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["count"] == 1
|
||||
db_session.refresh(sample_paper.summary_status)
|
||||
assert sample_paper.summary_status.status == SummaryState.PENDING
|
||||
assert sample_paper.summary_status.error is None
|
||||
assert sample_paper.summary_status.error_type is None
|
||||
|
||||
|
||||
class TestAdminPapers:
|
||||
"""论文管理批量操作测试。"""
|
||||
|
||||
def test_single_delete_removes_paper_and_fts(
|
||||
self, auth_client, db_session, sample_paper
|
||||
):
|
||||
paper_id = sample_paper.id
|
||||
|
||||
resp = auth_client.post(f"/admin/paper-delete/{sample_paper.arxiv_id}")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert db_session.get(type(sample_paper), paper_id) is None
|
||||
fts_row = db_session.execute(
|
||||
text("SELECT rowid FROM papers_fts WHERE rowid = :id"),
|
||||
{"id": paper_id},
|
||||
).fetchone()
|
||||
assert fts_row is None
|
||||
|
||||
def test_batch_delete_removes_papers_and_fts(
|
||||
self, auth_client, db_session, sample_papers_range
|
||||
):
|
||||
target_ids = [p.id for p in sample_papers_range[:2]]
|
||||
target_arxiv_ids = [p.arxiv_id for p in sample_papers_range[:2]]
|
||||
|
||||
resp = auth_client.post(
|
||||
"/admin/papers-batch-action",
|
||||
json={"action": "delete", "arxiv_ids": target_arxiv_ids},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["count"] == 2
|
||||
remaining = db_session.execute(
|
||||
text(
|
||||
"SELECT rowid FROM papers_fts "
|
||||
"WHERE rowid IN (:id1, :id2)"
|
||||
),
|
||||
{"id1": target_ids[0], "id2": target_ids[1]},
|
||||
).fetchall()
|
||||
assert remaining == []
|
||||
|
||||
def test_batch_summarize_sets_pending_status(
|
||||
self, auth_client, db_session, sample_papers_range
|
||||
):
|
||||
paper = sample_papers_range[0]
|
||||
paper.summary_status.status = SummaryState.DONE
|
||||
db_session.commit()
|
||||
|
||||
resp = auth_client.post(
|
||||
"/admin/papers-batch-action",
|
||||
json={"action": "summarize", "arxiv_ids": [paper.arxiv_id]},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
status = db_session.scalar(
|
||||
select(SummaryStatus).where(SummaryStatus.paper_id == paper.id)
|
||||
)
|
||||
assert status is not None
|
||||
assert status.status == SummaryState.PENDING
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Scheduler 测试
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
@@ -10,6 +10,7 @@ from app.services.crawler import (
|
||||
_parse_paper,
|
||||
crawl_daily,
|
||||
fetch_daily,
|
||||
refresh_upvotes,
|
||||
upsert_papers,
|
||||
)
|
||||
|
||||
@@ -187,3 +188,62 @@ class TestCrawlDaily:
|
||||
|
||||
assert result["status"] == "failed"
|
||||
assert "network error" in result["error"]
|
||||
|
||||
|
||||
class TestRefreshUpvotes:
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_updates_existing_without_inserting_new(
|
||||
self, db_session, sample_paper
|
||||
):
|
||||
sample_paper.arxiv_id = "1706.03762"
|
||||
sample_paper.upvotes = 10
|
||||
db_session.commit()
|
||||
|
||||
with patch(
|
||||
"app.services.crawler.fetch_daily",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[
|
||||
{
|
||||
"paper": {
|
||||
"id": "1706.03762",
|
||||
"upvotes": 999,
|
||||
"authors": [],
|
||||
"tags": [],
|
||||
}
|
||||
},
|
||||
{
|
||||
"paper": {
|
||||
"id": "2010.11929",
|
||||
"upvotes": 123,
|
||||
"authors": [],
|
||||
"tags": [],
|
||||
}
|
||||
},
|
||||
],
|
||||
):
|
||||
result = await refresh_upvotes(db_session, days=1)
|
||||
|
||||
db_session.refresh(sample_paper)
|
||||
assert result["status"] == "success"
|
||||
assert result["updated"] == 1
|
||||
assert sample_paper.upvotes == 999
|
||||
assert db_session.query(type(sample_paper)).count() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_returns_partial_when_one_day_fails(self, db_session):
|
||||
async def _fetch_daily(target_date):
|
||||
if target_date.endswith("01"):
|
||||
raise ConnectionError("hf down")
|
||||
return []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.crawler.recent_date_strs",
|
||||
return_value=["2024-01-01", "2024-01-02"],
|
||||
),
|
||||
patch("app.services.crawler.fetch_daily", side_effect=_fetch_daily),
|
||||
):
|
||||
result = await refresh_upvotes(db_session, days=2)
|
||||
|
||||
assert result["status"] == "partial"
|
||||
assert result["errors"] == ["2024-01-01: hf down"]
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""派生索引维护测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.services.derived import reindex_chroma, reindex_fts
|
||||
|
||||
|
||||
class TestReindexFts:
|
||||
def test_reindex_fts_rebuilds_missing_rows(self, db_session, sample_paper):
|
||||
db_session.execute(
|
||||
text("DELETE FROM papers_fts WHERE rowid = :id"),
|
||||
{"id": sample_paper.id},
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
result = reindex_fts(db_session)
|
||||
|
||||
row = db_session.execute(
|
||||
text("SELECT title_en, authors, tags FROM papers_fts WHERE rowid = :id"),
|
||||
{"id": sample_paper.id},
|
||||
).fetchone()
|
||||
assert result == {"status": "success", "indexed": 1}
|
||||
assert row is not None
|
||||
assert row[0] == sample_paper.title_en
|
||||
assert "Alice Smith" in row[1]
|
||||
assert "NLP" in row[2]
|
||||
|
||||
def test_reindex_fts_accepts_subset(self, db_session, sample_papers_range):
|
||||
keep_id = sample_papers_range[0].id
|
||||
skip_id = sample_papers_range[1].id
|
||||
db_session.execute(text("DELETE FROM papers_fts"))
|
||||
db_session.commit()
|
||||
|
||||
result = reindex_fts(db_session, paper_ids=[keep_id])
|
||||
|
||||
keep_row = db_session.execute(
|
||||
text("SELECT rowid FROM papers_fts WHERE rowid = :id"),
|
||||
{"id": keep_id},
|
||||
).fetchone()
|
||||
skip_row = db_session.execute(
|
||||
text("SELECT rowid FROM papers_fts WHERE rowid = :id"),
|
||||
{"id": skip_id},
|
||||
).fetchone()
|
||||
assert result["indexed"] == 1
|
||||
assert keep_row is not None
|
||||
assert skip_row is None
|
||||
|
||||
|
||||
class TestReindexChroma:
|
||||
def test_reindex_chroma_indexes_only_summarized_papers(
|
||||
self, db_session, sample_papers_with_summary
|
||||
):
|
||||
with patch("app.services.embedder.index_paper", return_value=True) as mock_index:
|
||||
result = reindex_chroma(db_session)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["indexed"] == 4
|
||||
assert mock_index.call_count == 4
|
||||
indexed_ids = {call.args[0] for call in mock_index.call_args_list}
|
||||
assert "2401.20001" in indexed_ids
|
||||
assert "2401.20005" not in indexed_ids
|
||||
|
||||
def test_reindex_chroma_reports_partial_failures(
|
||||
self, db_session, sample_papers_with_summary
|
||||
):
|
||||
def _index_paper(arxiv_id, _texts):
|
||||
if arxiv_id == "2401.20001":
|
||||
raise RuntimeError("embedding failed")
|
||||
return True
|
||||
|
||||
with patch("app.services.embedder.index_paper", side_effect=_index_paper):
|
||||
result = reindex_chroma(db_session)
|
||||
|
||||
assert result["status"] == "partial"
|
||||
assert result["indexed"] == 3
|
||||
assert result["errors"] == ["2401.20001: embedding failed"]
|
||||
@@ -0,0 +1,111 @@
|
||||
"""后台 Job 服务测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models import Job, JobEvent, JobStatus, TaskLock
|
||||
from app.services.jobs import create_job, recover_stale_jobs, run_job
|
||||
from app.utils import utc_now
|
||||
|
||||
|
||||
class TestJobs:
|
||||
def test_create_job_writes_event(self, db_session):
|
||||
job = create_job(
|
||||
db_session,
|
||||
"cleanup_tmp",
|
||||
owner="test",
|
||||
payload={"reason": "unit-test"},
|
||||
)
|
||||
|
||||
assert job.id is not None
|
||||
assert job.status == JobStatus.QUEUED
|
||||
|
||||
events = (
|
||||
db_session.execute(select(JobEvent).where(JobEvent.job_id == job.id))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
assert len(events) == 1
|
||||
assert events[0].stage == "created"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_success(self, db_session):
|
||||
job = create_job(db_session, "cleanup_tmp", owner="test", payload={})
|
||||
|
||||
with patch("app.services.cleaner.cleanup_tmp") as mock_cleanup:
|
||||
mock_cleanup.return_value = {"scanned": 1, "removed": 1, "errors": []}
|
||||
result = await run_job(db_session, job.id)
|
||||
|
||||
refreshed = db_session.get(Job, job.id)
|
||||
assert result["removed"] == 1
|
||||
assert refreshed.status == JobStatus.SUCCESS
|
||||
assert refreshed.result_json is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_failure_records_error(self, db_session):
|
||||
job = create_job(db_session, "missing_job_type", owner="test", payload={})
|
||||
|
||||
result = await run_job(db_session, job.id)
|
||||
|
||||
refreshed = db_session.get(Job, job.id)
|
||||
assert result["status"] == "failed"
|
||||
assert refreshed.status == JobStatus.FAILED
|
||||
assert "Unsupported job type" in refreshed.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_dispatches_refresh_upvotes(self, db_session):
|
||||
job = create_job(
|
||||
db_session,
|
||||
"refresh_upvotes",
|
||||
owner="test",
|
||||
payload={"days": 3},
|
||||
)
|
||||
|
||||
with patch("app.services.crawler.refresh_upvotes") as mock_refresh:
|
||||
mock_refresh.return_value = {"status": "success", "updated": 2}
|
||||
result = await run_job(db_session, job.id)
|
||||
|
||||
mock_refresh.assert_awaited_once_with(db_session, days=3)
|
||||
assert result["updated"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_dispatches_reindex_fts(self, db_session):
|
||||
job = create_job(db_session, "reindex_fts", owner="test", payload={})
|
||||
|
||||
with patch("app.services.derived.reindex_fts") as mock_reindex:
|
||||
mock_reindex.return_value = {"status": "success", "indexed": 5}
|
||||
result = await run_job(db_session, job.id)
|
||||
|
||||
mock_reindex.assert_called_once_with(db_session)
|
||||
assert result["indexed"] == 5
|
||||
|
||||
def test_recover_stale_jobs_and_locks(self, db_session):
|
||||
old = utc_now() - timedelta(hours=7)
|
||||
job = Job(
|
||||
type="cleanup_tmp",
|
||||
status=JobStatus.RUNNING,
|
||||
owner="test",
|
||||
created_at=old,
|
||||
started_at=old,
|
||||
heartbeat_at=old,
|
||||
)
|
||||
lock = TaskLock(
|
||||
task="cleanup",
|
||||
lock_key="tmp",
|
||||
status="running",
|
||||
owner="test",
|
||||
acquired_at=old,
|
||||
)
|
||||
db_session.add_all([job, lock])
|
||||
db_session.commit()
|
||||
|
||||
recovered = recover_stale_jobs(db_session)
|
||||
|
||||
assert recovered == 2
|
||||
assert db_session.get(Job, job.id).status == JobStatus.STALE
|
||||
assert db_session.get(TaskLock, lock.id).status == "stale"
|
||||
@@ -6,9 +6,6 @@ from datetime import date
|
||||
from unittest.mock import patch as upatch
|
||||
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Detail 页 & 相似论文
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
@@ -37,29 +34,6 @@ class TestDetailPage:
|
||||
class TestTrendsDashboard:
|
||||
"""趋势看板测试。"""
|
||||
|
||||
def test_trends_page_renders(self, client, sample_papers_with_summary):
|
||||
"""趋势看板页面正常渲染。"""
|
||||
resp = client.get("/trends")
|
||||
assert resp.status_code == 200
|
||||
assert "趋势看板" in resp.text
|
||||
assert "chart" in resp.text.lower() or "Chart" in resp.text
|
||||
|
||||
def test_trends_api_returns_data(self, client, sample_papers_with_summary):
|
||||
"""趋势 API 返回正确数据结构。"""
|
||||
resp = client.get("/api/stats/trends")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
assert "daily_counts" in data
|
||||
assert "top_tags" in data
|
||||
assert "upvotes_dist" in data
|
||||
assert "summary_completion" in data
|
||||
|
||||
assert isinstance(data["daily_counts"], list)
|
||||
assert isinstance(data["top_tags"], list)
|
||||
assert isinstance(data["upvotes_dist"], list)
|
||||
assert isinstance(data["summary_completion"], list)
|
||||
|
||||
def test_trends_api_daily_counts(self, client, sample_papers_with_summary):
|
||||
"""每日论文数量数据正确。"""
|
||||
# 使用测试数据的日期范围
|
||||
@@ -108,12 +82,6 @@ class TestTrendsDashboard:
|
||||
class TestComparePage:
|
||||
"""论文对比页测试。"""
|
||||
|
||||
def test_compare_page_no_ids(self, client):
|
||||
"""无 ID 时显示输入表单。"""
|
||||
resp = client.get("/compare")
|
||||
assert resp.status_code == 200
|
||||
assert "对比" in resp.text
|
||||
|
||||
def test_compare_page_with_ids(self, client, sample_papers_with_summary):
|
||||
"""对比多篇论文正常渲染。"""
|
||||
resp = client.get("/compare?ids=2401.20001,2401.20002")
|
||||
@@ -124,23 +92,6 @@ class TestComparePage:
|
||||
assert "一句话摘要" in resp.text
|
||||
assert "研究问题" in resp.text
|
||||
|
||||
def test_compare_page_max_5(self, client, sample_papers_with_summary):
|
||||
"""最多 5 篇。"""
|
||||
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005"
|
||||
resp = client.get(f"/compare?ids={ids}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_compare_page_over_5_truncates(self, client, sample_papers_with_summary):
|
||||
"""超过 5 篇截断。"""
|
||||
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005,2401.20006"
|
||||
resp = client.get(f"/compare?ids={ids}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_compare_page_invalid_ids(self, client):
|
||||
"""无效 ID 时显示空结果。"""
|
||||
resp = client.get("/compare?ids=nonexistent.99999")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_compare_page_shows_no_summary_placeholder(
|
||||
self, client, sample_papers_with_summary
|
||||
):
|
||||
@@ -149,65 +100,3 @@ class TestComparePage:
|
||||
resp = client.get("/compare?ids=2401.20005")
|
||||
assert resp.status_code == 200
|
||||
assert "暂无总结" in resp.text
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Nav Bar
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestNavBar:
|
||||
"""导航栏测试。"""
|
||||
|
||||
def test_nav_includes_trends_link(self, client):
|
||||
"""导航栏应包含趋势链接。"""
|
||||
resp = client.get("/search")
|
||||
assert resp.status_code == 200
|
||||
assert "/trends" in resp.text
|
||||
|
||||
def test_nav_includes_compare_implicitly(self, client):
|
||||
"""compare 页面可访问。"""
|
||||
resp = client.get("/compare")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Graceful Degradation(CHROMA_ENABLED=false)
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestGracefulDegradation:
|
||||
"""CHROMA_ENABLED=false 时优雅降级测试。"""
|
||||
|
||||
def test_search_works_without_chroma(
|
||||
self, client, monkeypatch, sample_papers_with_summary
|
||||
):
|
||||
"""CHROMA 关闭时 FTS5 搜索正常工作。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
resp = client.get("/search?q=Test")
|
||||
assert resp.status_code == 200
|
||||
assert "Test Paper" in resp.text or "测试论文" in resp.text
|
||||
|
||||
def test_detail_works_without_chroma(
|
||||
self, client, monkeypatch, sample_papers_with_summary
|
||||
):
|
||||
"""CHROMA 关闭时详情页正常工作。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
resp = client.get("/paper/2401.20001")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_trends_works_without_chroma(
|
||||
self, client, monkeypatch, sample_papers_with_summary
|
||||
):
|
||||
"""CHROMA 关闭时趋势看板正常工作。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
resp = client.get("/trends")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_compare_works_without_chroma(
|
||||
self, client, monkeypatch, sample_papers_with_summary
|
||||
):
|
||||
"""CHROMA 关闭时对比页正常工作。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
resp = client.get("/compare?ids=2401.20001,2401.20002")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@@ -123,38 +123,12 @@ class TestSearchSemanticMode:
|
||||
class TestSearchRoutes:
|
||||
"""搜索页面和 JSON API 路由测试。"""
|
||||
|
||||
def test_search_page_renders(self, client):
|
||||
"""GET /search 返回 200。"""
|
||||
resp = client.get("/search")
|
||||
assert resp.status_code == 200
|
||||
assert "搜索" in resp.text
|
||||
|
||||
def test_search_page_with_query(self, client, sample_paper):
|
||||
"""GET /search?q=Test 返回搜索结果。"""
|
||||
resp = client.get("/search?q=Test")
|
||||
assert resp.status_code == 200
|
||||
assert "2401.12345" in resp.text
|
||||
|
||||
def test_search_page_with_tag(self, client, sample_paper):
|
||||
"""GET /search?tag=NLP 返回标签筛选结果。"""
|
||||
resp = client.get("/search?tag=NLP")
|
||||
assert resp.status_code == 200
|
||||
assert "2401.12345" in resp.text
|
||||
|
||||
def test_search_page_keyword_mode(self, client, sample_papers_with_summary):
|
||||
"""搜索页 keyword 模式。"""
|
||||
resp = client.get("/search?q=Test&mode=keyword")
|
||||
assert resp.status_code == 200
|
||||
assert "Test" in resp.text or "测试" in resp.text
|
||||
|
||||
def test_search_page_semantic_disabled(
|
||||
self, client, monkeypatch, sample_papers_with_summary
|
||||
):
|
||||
"""语义模式 CHROMA_ENABLED=false 时仍能工作。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
resp = client.get("/search?q=Test&mode=semantic")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_search_api_json(self, client, sample_paper):
|
||||
"""GET /api/search?q=Test 返回 JSON。"""
|
||||
resp = client.get("/api/search?q=Test")
|
||||
@@ -170,14 +144,6 @@ class TestSearchRoutes:
|
||||
data = resp.json()
|
||||
assert data["total"] == 1
|
||||
|
||||
def test_search_api_with_mode(self, client, sample_papers_with_summary):
|
||||
"""搜索 API 支持 mode 参数。"""
|
||||
resp = client.get("/api/search?q=Test&mode=keyword")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "results" in data
|
||||
assert "total" in data
|
||||
|
||||
def test_search_api_empty(self, client, sample_paper):
|
||||
"""GET /api/search?q=nonexistent 返回空结果。"""
|
||||
resp = client.get("/api/search?q=nonexistent")
|
||||
@@ -185,13 +151,6 @@ class TestSearchRoutes:
|
||||
data = resp.json()
|
||||
assert data["total"] == 0
|
||||
|
||||
def test_search_api_sort_by_date(self, client, sample_paper):
|
||||
"""GET /api/search?q=Test&sort=date 按日期排序。"""
|
||||
resp = client.get("/api/search?q=Test&sort=date")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] >= 1
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Similar Paper API 测试
|
||||
@@ -211,21 +170,6 @@ class TestSimilarAPI:
|
||||
data = resp.json()
|
||||
assert data["results"] == []
|
||||
|
||||
def test_similar_api_paper_not_found(self, client, monkeypatch):
|
||||
"""不存在的论文返回空。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
resp = client.get("/api/similar/nonexistent.99999")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["results"] == []
|
||||
|
||||
def test_similar_api_with_top_k(
|
||||
self, client, monkeypatch, sample_papers_with_summary
|
||||
):
|
||||
"""top_k 参数控制返回数量。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
resp = client.get("/api/similar/2401.20001?top_k=3")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 阅读列表路由测试
|
||||
@@ -235,12 +179,6 @@ class TestSimilarAPI:
|
||||
class TestReadingListRoute:
|
||||
"""阅读列表页面测试。"""
|
||||
|
||||
def test_reading_list_empty(self, client):
|
||||
"""无收藏时显示空状态。"""
|
||||
resp = client.get("/reading-list")
|
||||
assert resp.status_code == 200
|
||||
assert "阅读列表" in resp.text
|
||||
|
||||
def test_reading_list_with_bookmark(self, client, sample_paper):
|
||||
"""有收藏时显示论文。"""
|
||||
# 先收藏
|
||||
@@ -302,13 +240,6 @@ class TestRssFeed:
|
||||
assert "<channel>" in resp.text
|
||||
assert "2401.12345" in resp.text
|
||||
|
||||
def test_rss_has_paper_item(self, client, sample_paper):
|
||||
"""RSS 包含论文条目。"""
|
||||
resp = client.get("/rss.xml")
|
||||
assert "<item>" in resp.text
|
||||
assert "<title>" in resp.text
|
||||
assert "/paper/2401.12345" in resp.text
|
||||
|
||||
def test_rss_with_tag_filter(self, client, sample_paper):
|
||||
"""GET /rss.xml?tag=NLP 按标签筛选。"""
|
||||
resp = client.get("/rss.xml?tag=NLP")
|
||||
|
||||
Reference in New Issue
Block a user