refactor: extract admin business logic to services, introduce job queue, add derived index helpers

- Move DB operations from routes/admin.py to services/admin.py (get_logs_context, query_summary_statuses, retry_failed, delete/reset operations)
- Add services/jobs.py with Job/JobEvent-based async job queue (create_job, run_job, enqueue_job)
- Add services/derived.py with FTS5 reindex and paper index deletion helpers
- Refactor scheduler to use job queue instead of direct pipeline calls
- Add heartbeat_at/expires_at to TaskLock for lock health tracking
- Remove DESIGN_REVIEW.md
- Update tests: remove redundant integration tests, add unit tests for new services
This commit is contained in:
2026-06-13 18:31:43 +08:00
parent 21f16e6756
commit 743d69efd0
20 changed files with 1391 additions and 1063 deletions
+1 -1
View File
@@ -42,7 +42,7 @@ DATABASE_URL=sqlite:///data/db/papers.db
# ─── 语义搜索 ───────────────────────────── # ─── 语义搜索 ─────────────────────────────
CHROMA_ENABLED=false CHROMA_ENABLED=false
CHROMA_DIR=data/chroma CHROMA_DIR=data/chroma
EMBED_API_BASE=https://api.siliconflow.cn/v1/embeddings EMBED_API_BASE=https://api.siliconflow.cn
EMBED_API_KEY=your_api_key_here EMBED_API_KEY=your_api_key_here
EMBED_MODEL=Qwen/Qwen3-Embedding-4B EMBED_MODEL=Qwen/Qwen3-Embedding-4B
EMBED_DIMENSIONS=2560 EMBED_DIMENSIONS=2560
-468
View File
@@ -1,468 +0,0 @@
# 项目设计审查与优化建议
本文档汇总对当前项目的系统设计、流程设计和代码结构的审查结论。重点不在局部代码风格,而在后续稳定运行、失败恢复、数据一致性、可维护性和扩展性。
## 总体评价
项目当前结构整体清晰:FastAPI 路由层较薄,主要业务放在 `app/services/`;抓取、总结、搜索、个人数据、管理后台等能力已有模块化拆分;SQLite + FTS5 对本地个人论文导览站是合理选择。
主要风险集中在以下几类:
- 长任务生命周期没有统一抽象。
- Pipeline 不是显式状态机,阶段结果难以恢复和追踪。
- 数据库、文件系统、FTS5、ChromaDB 之间存在多处手工同步。
- Web、CLI、Scheduler 三个入口的业务策略存在分叉风险。
- 失败恢复和可观测性不足。
- 部分同步阻塞任务混入 async Web 流程。
## 高优先级问题
### 1. 缺少统一的任务系统
当前有三套入口都能触发重任务:
- Web 管理后台:`/admin/trigger-pipeline``/admin/summarize`
- APScheduler:每日自动 pipeline
- CLI`app.cli crawl/summarize`
它们共享部分 service,但没有统一的“任务创建、排队、运行、重试、取消、状态查询、失败恢复”模型。
影响:
- Web 请求会长时间挂住。
- CLI、Web、Scheduler 的行为可能逐渐分叉。
- 任务卡死后只能看到粗粒度的 running lock。
- 前端难以展示阶段进度。
- 后续增加重跑、取消、恢复、并发限制会越来越复杂。
建议:
引入统一 Job 模型:
```text
入口层 -> 创建 Job -> Worker 执行 Job -> JobEvent 记录进度 -> 页面/API 查询状态
```
建议任务类型:
- `crawl_daily`
- `summarize_one`
- `summarize_batch`
- `pipeline_daily`
- `refresh_upvotes`
- `delete_range`
- `reindex_fts`
- `reindex_chroma`
- `extract_images`
建议表结构方向:
- `jobs`: 任务主记录,包含 `id/type/status/owner/created_at/started_at/completed_at/error`
- `job_events`: 阶段事件,包含 `job_id/stage/status/message/payload_json/created_at`
- `paper_processing_runs`: 单篇论文处理记录,适合跟踪总结、图片提取、索引状态
### 2. Pipeline 不是显式状态机
当前 pipeline 基本是线性函数:
```text
crawl -> summarize -> cleanup
```
阶段输出没有被建模成可查询、可重放、可补偿的状态。
影响:
- crawl 成功但 summarize 部分失败时,整体状态表达不够精确。
- cleanup 失败是否影响 pipeline 成败语义不清。
- “今天无数据回退昨天”是隐式逻辑,不是可配置策略。
- 图片提取、ChromaDB 索引失败被吞掉,用户无法知道哪些增强能力缺失。
建议:
将 pipeline 改为显式阶段状态:
```text
created
-> crawling
-> crawled
-> summarizing
-> summarized | summarized_partial
-> postprocessing
-> completed | failed | cancelled
```
每个阶段记录:
- 输入参数
- 输出统计
- 错误类型
- 错误摘要
- 开始/结束时间
- 是否可重试
这样可以支持只重跑失败阶段,例如只重跑 ChromaDB 索引、只重跑图片提取、只重跑失败论文总结。
### 3. 数据生命周期没有清晰分层
项目中至少有四类数据:
- 主数据:`papers``authors``tags``summaries`、用户笔记/收藏/阅读状态
- 派生索引:FTS5、ChromaDB
- 长期资产:`data/papers/{arxiv_id}` 下的 summary、raw output、images
- 临时资产:`data/tmp/{arxiv_id}` 下的 PDF、源码、中间文件
现在这些数据由不同流程手动维护。删除、总结、重建索引、清理临时目录都分散在不同模块中。
影响:
- DB 与文件系统可能不一致。
- DB 与 FTS5/ChromaDB 可能不一致。
- 删除流程中一部分资源删除成功、一部分失败时难以恢复。
- 很难判断某篇论文是否“完整可用”。
建议:
明确数据权威源:
- DB 是权威源。
- FTS5、ChromaDB、summary.json、raw_output.txt、images 都应视为可重建派生物。
- 删除主数据优先保证 DB 状态正确,派生物清理可以异步补偿。
- 提供派生数据重建任务。
建议增加命令或任务:
```text
rebuild-derived --fts
rebuild-derived --chroma
rebuild-derived --files
rebuild-derived --images
```
### 4. 数据库与文件系统存在双写一致性风险
总结持久化流程会写文件、写 DB、更新状态、提取图片、写 ChromaDB。任一步失败都可能留下不一致状态。
典型风险:
- 文件已写入,但 DB 未提交。
- DB 标记 done,但 summary.json 缺失。
- DB done,但图片未提取或 ChromaDB 未索引。
- 重新总结时旧图片或旧 figures 残留。
建议:
- 将结构化总结以 DB 为准,JSON 文件作为导出/缓存。
- 文件写入使用临时文件加原子 rename。
- DB 中记录派生物状态,例如:
- `summary_persisted_at`
- `fts_indexed_at`
- `chroma_indexed_at`
- `images_extracted_at`
- `derived_error`
- 提供一致性检查任务,扫描 DB 与文件/索引差异并修复。
### 5. 失败恢复设计偏弱
当前失败主要依赖:
- `summary_status.status`
- `retry_count`
- `task_locks`
- `crawl_logs`
但对以下情况支持不足:
- 进程崩溃后 `processing` 永久残留。
- `task_locks` 中 running lock 永久残留。
- 任务执行到一半,部分文件或索引已经写入。
- 某些后处理失败但主任务显示成功。
建议:
- `task_locks` 增加 `heartbeat_at``expires_at`
- 应用启动时扫描 stale running task,标记为 `failed/stale`
- `summary_status` 增加更细阶段字段,例如:
- `downloading_pdf`
- `generating`
- `validating`
- `persisting`
- `extracting_images`
- `indexing`
- 区分主流程失败和派生增强失败。
### 6. 长任务直接跑在 Web 请求里
管理端触发 pipeline、批量总结、刷新 upvotes、删除数据时,会在请求生命周期内直接执行重任务。
影响:
- HTTP 请求容易超时。
- 用户关闭页面后任务状态不清晰。
- Web 进程被重任务占用。
- 后续很难做取消、进度展示和失败恢复。
建议:
- Web 只创建 job 并返回 `task_id`
- 后台 worker 执行任务。
- 管理页面轮询 `/admin/jobs/{id}` 或使用 SSE 展示进度。
## 中优先级问题
### 7. FTS5 索引手工维护,容易漂移
`papers_fts` 是独立虚拟表。插入论文、更新总结、删除论文时手动同步。
影响:
- 已有论文的 title、abstract、authors、tags 更新时可能不同步。
- AI 标签新增后,如果没有统一重建,搜索结果可能缺字段。
- 删除或回滚失败后索引可能残留。
建议:
- 封装统一的 `reindex_paper(db, paper_id)`
- 所有修改论文、标签、总结的路径最后调用它。
- 提供全量 `reindex_fts` 任务。
- 或用 SQLite trigger 维护基础字段,但总结字段仍建议通过统一函数更新。
### 8. ChromaDB 索引缺少系统级一致性策略
ChromaDB 当前是可选增强能力,索引失败会被日志吞掉,不影响总结主流程。
影响:
- 语义搜索可用性不透明。
- 某些论文 DB done,但未进入 ChromaDB。
- 删除或重建时可能出现残留索引。
建议:
- 将 ChromaDB 明确视为派生索引。
- DB 中记录 `chroma_indexed_at``chroma_error`
- 提供 `reindex_chroma` 任务。
- 搜索页或管理后台展示语义索引健康状态。
### 9. 同步阻塞操作混入 async 主流程
部分 async 函数内部调用同步阻塞操作,例如 PDF 下载使用 `requests`,图片提取和 PDF 解析也都是同步重任务。
影响:
- FastAPI event loop 可能被阻塞。
- 多篇并发总结时吞吐不可控。
- Web 服务响应受后台任务影响。
建议:
- 最佳方案:重任务移到 worker 进程。
- 过渡方案:对同步阻塞段使用 `asyncio.to_thread()`
- PDF 下载统一使用 `httpx.AsyncClient` 或明确放入同步 worker。
### 10. 调度策略和业务流程耦合
Scheduler 固定每日 pipeline 加 30 分钟 upvote refreshpipeline 内部固定“今天无数据回退昨天”。
影响:
- 策略不易测试和调整。
- CLI、Web、Scheduler 可能各自实现不同 fallback。
- 后续支持日期范围、补抓、只总结新增论文会变复杂。
建议将策略配置化:
- `crawl_target_policy`: `today` / `today_then_yesterday` / `date_range`
- `summarize_scope`: `new_only` / `pending_and_failed` / `date_range`
- `cleanup_policy`: `after_success` / `always` / `manual`
- `upvote_refresh_window_days`
- `pipeline_on_partial_failure`: `continue` / `stop`
### 11. 内嵌 APScheduler 对部署形态敏感
当前调度器嵌入 Web 进程,且多 worker 时只是打印警告。
影响:
- 多 worker、多进程、reload 模式下可能重复或漏跑任务。
- Web 进程重启会影响调度可靠性。
- 调度状态和任务执行状态耦合。
建议:
- 本地单机可以保留内嵌调度器。
- 长期运行建议拆为独立 scheduler 进程。
- scheduler 只负责创建 job,不直接执行重任务。
- Web 管理后台只展示 scheduler/job 状态。
### 12. 运行时迁移能力不足
当前 `_migrate()` 只支持补列。
影响:
- 无法可靠处理索引、约束、字段改名、数据回填。
- 数据库结构演进不可追踪。
- 生产或长期本地数据升级风险高。
建议:
- 引入 Alembic。
- 将 FTS5、索引、部分约束和数据回填纳入迁移脚本。
- 启动时不要静默做复杂迁移,改为显式执行迁移命令。
## 低到中优先级问题
### 13. 配置校验不足
当前配置字段主要是裸类型,缺少合法值和组合校验。
风险示例:
- `SUMMARY_BACKEND` 填错。
- `SUMMARY_PDF_MODE` 填错。
- `SCHEDULE_MINUTE + 30` 超过 59。
- `APP_WORKERS > 1``SCHEDULER_ENABLED=true`
- `CHROMA_ENABLED=true` 但 embedding 配置缺失。
建议:
- 使用 Pydantic validator 校验枚举值和组合条件。
- 对危险组合启动时报错,而不是只 warning。
### 14. 私有函数跨模块调用说明模块边界不稳定
例如 `summarizer.py` 调用 `_generate_with_retry``_persist_summary` 等下划线函数。
影响:
- 模块公开边界不清晰。
- 后续重构容易误伤。
- 测试和替换后端不方便。
建议:
- 为 summary 生成、持久化、后处理定义明确公开 API。
- 下划线函数保留为模块内部实现细节。
- 引入更明确的 service 类或函数边界,例如:
- `SummaryGenerator.generate()`
- `SummaryRepository.save()`
- `DerivedAssetBuilder.extract_images()`
### 15. 外部依赖缺少统一 adapter
项目依赖多个外部系统:
- HuggingFace API
- arXiv PDF/source 下载
- `pi` CLI
- `claude` CLI
- Embedding API
- ChromaDB
- ONNX layout detector
建议:
- 为每类外部依赖定义 adapter。
- 统一 timeout、retry、error classification、metrics。
- 测试中替换 adapter,而不是 patch 深层函数。
### 16. 删除流程事务边界不理想
删除流程中同时删除 FTS5、ChromaDB、本地文件、临时文件、ORM 数据。
影响:
- 文件删除成功后 DB rollback,会造成数据仍在但文件丢失。
- DB 删除成功后 ChromaDB 删除失败,会造成残留索引。
- 单篇失败时 rollback 可能影响之前 flush 的状态,逻辑不直观。
建议:
- 删除主数据与删除派生物分两阶段。
- DB 删除成功后创建派生清理 job。
- 派生物清理失败可重试,不影响主数据一致性。
## 建议目标架构
如果项目定位是个人本地工具,可以采用轻量目标架构:
```text
FastAPI Web
- 页面和 API
- 管理后台
- 创建 job
- 查询 job 状态
Scheduler
- 按策略创建 job
- 不直接跑重任务
Worker
- crawl
- summarize
- PDF download
- image extraction
- FTS/Chroma reindex
- cleanup/delete
Storage/Repository
- DB 权威数据
- 文件资产管理
- 派生索引重建
- 一致性检查
```
如果暂时不想引入独立 worker,也可以先在单进程内实现 Job 表和后台任务执行器。这样至少能统一任务状态和恢复机制。
## 推荐实施顺序
### 第一阶段:先解决任务与状态
目标:降低长任务不可控风险。
- 新增 `jobs``job_events` 表。
- Web 管理动作改为创建 job 并返回 task_id。
- Scheduler 改为创建 job。
- CLI 复用同一套 job runner。
- 加 stale running job 恢复逻辑。
### 第二阶段:统一派生数据管理
目标:降低 DB、文件、FTS、Chroma 不一致风险。
- 明确 DB 为权威源。
- 封装 `reindex_paper()`
- 增加 `reindex_fts``reindex_chroma` 任务。
- 增加派生数据状态字段或健康检查。
- 删除流程改为 DB 主删除 + 派生清理 job。
### 第三阶段:改造 pipeline 为状态机
目标:让任务可恢复、可补偿、可观测。
- 拆分 pipeline stages。
- 每个 stage 记录输入/输出/错误/耗时。
- 支持从失败阶段重跑。
- 将 fallback 策略配置化。
### 第四阶段:提升运行可靠性
目标:长期运行更稳。
- 引入 Alembic。
- 配置加 validator。
- 同步阻塞操作移出 Web event loop。
- 外部依赖 adapter 化。
- 管理后台展示任务、派生索引、失败类型、耗时统计。
## 最小可行改造方案
如果只做最小但收益最大的改动,建议优先做这三项:
1. 增加统一 Job 表和 job runner。
2. DB 作为权威源,FTS/Chroma/文件全部按派生物处理。
3. 增加 stale task 恢复和派生数据重建命令。
这三项能显著降低后续复杂度,也不会强迫项目马上拆成多个服务。
+66 -6
View File
@@ -26,7 +26,7 @@ def crawl(
from app.database import SessionLocal, engine from app.database import SessionLocal, engine
from app.database import init_db as _init from app.database import init_db as _init
from app.models import Paper from app.models import Paper
from app.services.crawler import crawl_daily from app.services.jobs import create_job, run_job
from app.utils import today_str, yesterday_str from app.utils import today_str, yesterday_str
from sqlalchemy import func, select from sqlalchemy import func, select
@@ -55,7 +55,13 @@ def crawl(
return return
typer.echo(f"📡 开始抓取 {target} ...") typer.echo(f"📡 开始抓取 {target} ...")
result = asyncio.run(crawl_daily(db, target, top_n)) job = create_job(
db,
"crawl_daily",
owner="cli_crawl",
payload={"target_date": target, "top_n": top_n},
)
result = asyncio.run(run_job(db, job.id))
# 未指定日期且今天失败或无数据时,自动回退到昨天 # 未指定日期且今天失败或无数据时,自动回退到昨天
need_fallback = not date_str and ( need_fallback = not date_str and (
@@ -76,7 +82,13 @@ def crawl(
else: else:
typer.echo(f"🔄 {target} 无数据,尝试 {fallback} ...") typer.echo(f"🔄 {target} 无数据,尝试 {fallback} ...")
target = fallback target = fallback
result = asyncio.run(crawl_daily(db, target, top_n)) job = create_job(
db,
"crawl_daily",
owner="cli_crawl",
payload={"target_date": target, "top_n": top_n},
)
result = asyncio.run(run_job(db, job.id))
if result["status"] == "success": if result["status"] == "success":
typer.echo( typer.echo(
@@ -110,7 +122,7 @@ def summarize(
from app.config import settings from app.config import settings
from app.database import SessionLocal, engine from app.database import SessionLocal, engine
from app.database import init_db as _init from app.database import init_db as _init
from app.services.summarizer import summarize_batch, summarize_single from app.services.jobs import create_job, run_job
import os import os
@@ -142,11 +154,25 @@ def summarize(
try: try:
if arxiv_id: if arxiv_id:
typer.echo(f"🤖 开始总结 {arxiv_id} (mode={pdf_mode}) ...") typer.echo(f"🤖 开始总结 {arxiv_id} (mode={pdf_mode}) ...")
result = asyncio.run(summarize_single(db, arxiv_id, pdf_mode=pdf_mode)) job = create_job(
db,
"summarize_one",
owner="cli_summarize",
payload={"arxiv_id": arxiv_id, "pdf_mode": pdf_mode, "force": False},
)
else: else:
typer.echo(f"🤖 开始批量总结 pending 论文 (mode={pdf_mode}) ...") typer.echo(f"🤖 开始批量总结 pending 论文 (mode={pdf_mode}) ...")
result = asyncio.run(summarize_batch(db, pdf_mode=pdf_mode)) job = create_job(
db,
"summarize_batch",
owner="cli_summarize",
payload={"pdf_mode": pdf_mode},
)
result = asyncio.run(run_job(db, job.id))
if result.get("status") == "failed":
typer.echo(f"❌ 总结失败:{result.get('error')}", err=True)
raise typer.Exit(code=1)
typer.echo(f"✅ 总结完成:{result}") typer.echo(f"✅ 总结完成:{result}")
except NotFoundError as exc: except NotFoundError as exc:
typer.echo(f"{exc.message}", err=True) typer.echo(f"{exc.message}", err=True)
@@ -172,5 +198,39 @@ def init_db():
typer.echo(f"✅ 数据库已初始化:{settings.db_path}") typer.echo(f"✅ 数据库已初始化:{settings.db_path}")
@cli_app.command("rebuild-derived")
def rebuild_derived(
fts: bool = typer.Option(False, "--fts", help="重建 FTS5 全文索引"),
chroma: bool = typer.Option(False, "--chroma", help="重建 ChromaDB 语义索引"),
):
"""重建可派生数据索引。"""
from app.config import settings
from app.database import SessionLocal, engine
from app.database import init_db as _init
from app.services.jobs import create_job, run_job
import os
if not fts and not chroma:
fts = True
os.makedirs(settings.db_path.parent, exist_ok=True)
_init(engine)
db = SessionLocal()
try:
for job_type in [
*(["reindex_fts"] if fts else []),
*(["reindex_chroma"] if chroma else []),
]:
job = create_job(db, job_type, owner="cli_rebuild_derived", payload={})
result = asyncio.run(run_job(db, job.id))
typer.echo(f"{job_type}: {result}")
if result.get("status") == "failed":
raise typer.Exit(code=1)
finally:
db.close()
if __name__ == "__main__": if __name__ == "__main__":
cli_app() cli_app()
+16
View File
@@ -76,10 +76,26 @@ def _migrate(engine) -> None:
"crawl_logs": [ "crawl_logs": [
("details_json", "TEXT"), ("details_json", "TEXT"),
], ],
"task_locks": [
("heartbeat_at", "DATETIME"),
("expires_at", "DATETIME"),
],
"jobs": [
("heartbeat_at", "DATETIME"),
],
} }
with engine.connect() as conn: with engine.connect() as conn:
for table, columns in _MIGRATIONS.items(): for table, columns in _MIGRATIONS.items():
table_exists = conn.execute(
text(
"SELECT name FROM sqlite_master "
"WHERE type IN ('table', 'virtual table') AND name = :name"
),
{"name": table},
).fetchone()
if not table_exists:
continue
# 获取已有列名 # 获取已有列名
existing = { existing = {
row[1] for row in conn.execute(text(f"PRAGMA table_info({table})")) row[1] for row in conn.execute(text(f"PRAGMA table_info({table})"))
+7
View File
@@ -32,7 +32,14 @@ async def lifespan(app: FastAPI):
# ── startup ── # ── startup ──
from app.services.scheduler import start_scheduler from app.services.scheduler import start_scheduler
from app.services.embedder import init_chroma from app.services.embedder import init_chroma
from app.services.jobs import recover_stale_jobs
from app.database import SessionLocal
db = SessionLocal()
try:
recover_stale_jobs(db)
finally:
db.close()
start_scheduler() start_scheduler()
init_chroma() init_chroma()
+59
View File
@@ -32,6 +32,26 @@ class SummaryState(StrEnum):
PERMANENT_FAILURE = "permanent_failure" PERMANENT_FAILURE = "permanent_failure"
class JobStatus(StrEnum):
"""后台任务状态枚举 — 对应 jobs.status 列。"""
QUEUED = "queued"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
STALE = "stale"
CANCELLED = "cancelled"
class JobEventStatus(StrEnum):
"""任务阶段事件状态枚举 — 对应 job_events.status 列。"""
STARTED = "started"
SUCCESS = "success"
FAILED = "failed"
INFO = "info"
# ── papers ────────────────────────────────────────────────────────────── # ── papers ──────────────────────────────────────────────────────────────
class Paper(Base): class Paper(Base):
__tablename__ = "papers" __tablename__ = "papers"
@@ -194,9 +214,48 @@ class TaskLock(Base):
status = Column(String, nullable=False) status = Column(String, nullable=False)
owner = Column(String) owner = Column(String)
acquired_at = Column(DateTime, nullable=False) acquired_at = Column(DateTime, nullable=False)
heartbeat_at = Column(DateTime)
expires_at = Column(DateTime)
released_at = Column(DateTime) released_at = Column(DateTime)
# ── jobs / job_events ──────────────────────────────────────────────────
class Job(Base):
__tablename__ = "jobs"
id = Column(Integer, primary_key=True, autoincrement=True)
type = Column(String, nullable=False, index=True)
status = Column(String, nullable=False, default=JobStatus.QUEUED, index=True)
owner = Column(String)
payload_json = Column(Text)
result_json = Column(Text)
error = Column(Text)
created_at = Column(DateTime, nullable=False)
started_at = Column(DateTime)
heartbeat_at = Column(DateTime)
completed_at = Column(DateTime)
events = relationship(
"JobEvent", back_populates="job", cascade="all, delete-orphan"
)
class JobEvent(Base):
__tablename__ = "job_events"
id = Column(Integer, primary_key=True, autoincrement=True)
job_id = Column(
Integer, ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False, index=True
)
stage = Column(String, nullable=False)
status = Column(String, nullable=False)
message = Column(Text)
payload_json = Column(Text)
created_at = Column(DateTime, nullable=False)
job = relationship("Job", back_populates="events")
# ── user data ────────────────────────────────────────────────────────── # ── user data ──────────────────────────────────────────────────────────
class UserBookmark(Base): class UserBookmark(Base):
__tablename__ = "user_bookmarks" __tablename__ = "user_bookmarks"
+106 -277
View File
@@ -4,36 +4,20 @@ from __future__ import annotations
import hashlib import hashlib
import hmac import hmac
import json
import logging
from datetime import date from datetime import date
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request from fastapi import APIRouter, BackgroundTasks, Depends, Form, HTTPException, Query, Request
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from sqlalchemy import bindparam, func, select, text
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.config import settings from app.config import settings
from app.database import get_db from app.database import get_db
from app.models import (
CrawlLog,
DataDeleteJob,
Paper,
PaperTag,
SummaryState,
SummaryStatus,
)
from app.services import admin as admin_svc from app.services import admin as admin_svc
from app.services.admin import get_admin_stats from app.services.admin import get_admin_stats
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range from app.services.cleaner import cleanup_tmp
from app.services.crawler import refresh_upvotes from app.services.jobs import create_job, enqueue_job
from app.services.pipeline import run_crawl, run_pipeline from app.utils import templates, today_str
from app.services.scheduler import get_scheduler
from app.services.summarizer import summarize_batch, summarize_single
from app.utils import templates, today_str, utc_now
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin", tags=["admin"]) router = APIRouter(prefix="/admin", tags=["admin"])
@@ -103,18 +87,7 @@ async def admin_dashboard(
): ):
"""管理仪表盘 — 系统状态总览。""" """管理仪表盘 — 系统状态总览。"""
stats = get_admin_stats(db) stats = get_admin_stats(db)
scheduler_history = admin_svc.get_scheduler_history(db)
# 调度器历史(最近 10 条 task=scheduler 日志)
scheduler_history = (
db.execute(
select(CrawlLog)
.where(CrawlLog.task == "scheduler")
.order_by(CrawlLog.started_at.desc())
.limit(10)
)
.scalars()
.all()
)
return templates.TemplateResponse( return templates.TemplateResponse(
request, request,
@@ -129,53 +102,43 @@ async def admin_dashboard(
@router.get("/scheduler-status") @router.get("/scheduler-status")
async def admin_scheduler_status(_admin: None = Depends(verify_admin)): async def admin_scheduler_status(_admin: None = Depends(verify_admin)):
"""调度器运行状态(JSON)。""" """调度器运行状态(JSON)。"""
scheduler = get_scheduler() return admin_svc.get_scheduler_status()
next_run = None
upvote_next_run = None
if scheduler:
for job in scheduler.get_jobs():
if job.id == "daily_pipeline":
next_run = job.next_run_time
elif job.id == "upvote_refresh":
upvote_next_run = job.next_run_time
return {
"enabled": scheduler is not None,
"schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}",
"timezone": settings.APP_TIMEZONE,
"next_run": next_run.isoformat() if next_run else None,
"upvote_next_run": upvote_next_run.isoformat() if upvote_next_run else None,
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
}
@router.post("/trigger-pipeline") @router.post("/trigger-pipeline")
async def admin_trigger_pipeline( async def admin_trigger_pipeline(
background_tasks: BackgroundTasks,
_admin: None = Depends(verify_admin), _admin: None = Depends(verify_admin),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""手动触发一次完整流水线(crawl → summarize → cleanup)。""" """手动触发一次完整流水线(crawl → summarize → cleanup)。"""
today = today_str() today = today_str()
try: job = create_job(
result = await run_pipeline(db, today, owner="admin_trigger") db,
except RuntimeError as exc: "pipeline_daily",
raise HTTPException(status_code=409, detail=str(exc)) owner="admin_trigger",
payload={"target_date": today},
if result["status"] == "failed": )
raise HTTPException(status_code=500, detail=result.get("error")) enqueue_job(background_tasks, job.id)
return {"status": "success", "message": "流水线执行完成"} return {"status": "queued", "job_id": job.id, "message": "流水线任务已创建"}
@router.post("/refresh-upvotes") @router.post("/refresh-upvotes")
async def admin_refresh_upvotes( async def admin_refresh_upvotes(
background_tasks: BackgroundTasks,
_admin: None = Depends(verify_admin), _admin: None = Depends(verify_admin),
db: Session = Depends(get_db), db: Session = Depends(get_db),
days: int | None = Query(None, description="刷新最近 N 天,默认使用配置值"), days: int | None = Query(None, description="刷新最近 N 天,默认使用配置值"),
): ):
"""手动刷新最近 N 天论文的 upvotes。""" """手动刷新最近 N 天论文的 upvotes。"""
result = await refresh_upvotes(db, days=days) job = create_job(
if result["status"] == "failed": db,
raise HTTPException(status_code=500, detail=result.get("error")) "refresh_upvotes",
return result owner="admin_refresh",
payload={"days": days},
)
enqueue_job(background_tasks, job.id)
return {"status": "queued", "job_id": job.id}
# ── 请求模型 ────────────────────────────────────────────────────────── # ── 请求模型 ──────────────────────────────────────────────────────────
@@ -200,18 +163,21 @@ class DeleteRequest(BaseModel):
@router.post("/crawl") @router.post("/crawl")
async def admin_crawl( async def admin_crawl(
background_tasks: BackgroundTasks,
_admin: None = Depends(verify_admin), _admin: None = Depends(verify_admin),
db: Session = Depends(get_db), db: Session = Depends(get_db),
date: str | None = Query(None, description="YYYY-MM-DD,默认今天"), date: str | None = Query(None, description="YYYY-MM-DD,默认今天"),
): ):
"""手动抓取指定日期,默认今天。""" """手动抓取指定日期,默认今天。"""
target_date = date or today_str() target_date = date or today_str()
try: job = create_job(
return await run_crawl(db, target_date, owner="admin_crawl") db,
except RuntimeError as exc: "crawl_daily",
raise HTTPException(status_code=409, detail=str(exc)) owner="admin_crawl",
except Exception as exc: payload={"target_date": target_date},
raise HTTPException(status_code=500, detail=str(exc)) )
enqueue_job(background_tasks, job.id)
return {"status": "queued", "job_id": job.id, "target_date": target_date}
# ── 总结 ────────────────────────────────────────────────────────────── # ── 总结 ──────────────────────────────────────────────────────────────
@@ -219,23 +185,41 @@ async def admin_crawl(
@router.post("/summarize") @router.post("/summarize")
async def admin_summarize_batch( async def admin_summarize_batch(
background_tasks: BackgroundTasks,
_admin: None = Depends(verify_admin), _admin: None = Depends(verify_admin),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""批量总结所有 pending 论文。""" """批量总结所有 pending 论文。"""
return await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE) job = create_job(
db,
"summarize_batch",
owner="admin_summarize",
payload={"pdf_mode": settings.SUMMARY_PDF_MODE},
)
enqueue_job(background_tasks, job.id)
return {"status": "queued", "job_id": job.id}
@router.post("/summarize/{arxiv_id}") @router.post("/summarize/{arxiv_id}")
async def admin_summarize_single( async def admin_summarize_single(
arxiv_id: str, arxiv_id: str,
background_tasks: BackgroundTasks,
_admin: None = Depends(verify_admin), _admin: None = Depends(verify_admin),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""总结或重跑单篇论文。""" """总结或重跑单篇论文。"""
return await summarize_single( job = create_job(
db, arxiv_id, force=True, pdf_mode=settings.SUMMARY_PDF_MODE db,
"summarize_one",
owner="admin_summarize",
payload={
"arxiv_id": arxiv_id,
"force": True,
"pdf_mode": settings.SUMMARY_PDF_MODE,
},
) )
enqueue_job(background_tasks, job.id)
return {"status": "queued", "job_id": job.id, "arxiv_id": arxiv_id}
# ── 清理 ────────────────────────────────────────────────────────────── # ── 清理 ──────────────────────────────────────────────────────────────
@@ -243,39 +227,25 @@ async def admin_summarize_single(
@router.post("/cleanup") @router.post("/cleanup")
async def admin_cleanup( async def admin_cleanup(
background_tasks: BackgroundTasks,
_admin: None = Depends(verify_admin), _admin: None = Depends(verify_admin),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""清理 data/tmp/ 中超过 24 小时的临时文件。""" """清理 data/tmp/ 中超过 24 小时的临时文件。"""
now = utc_now() job = create_job(db, "cleanup_tmp", owner="admin_cleanup", payload={})
log_entry = CrawlLog( enqueue_job(background_tasks, job.id)
task="cleanup", return {"status": "queued", "job_id": job.id}
status="running",
started_at=now,
)
db.add(log_entry)
db.commit()
@router.post("/cleanup-now")
async def admin_cleanup_now(
_admin: None = Depends(verify_admin),
db: Session = Depends(get_db),
):
"""同步清理临时文件,保留给测试和本地排障使用。"""
try: try:
result = cleanup_tmp() return admin_svc.run_cleanup_now(db, cleanup_tmp)
log_entry.status = "success"
log_entry.completed_at = utc_now()
log_entry.details_json = json.dumps(
{
"scanned": result.get("scanned", 0),
"removed": result.get("removed", 0),
},
ensure_ascii=False,
)
if result.get("errors"):
log_entry.error = "; ".join(result["errors"])[:2000]
db.commit()
return result
except Exception as exc: except Exception as exc:
log_entry.status = "failed"
log_entry.error = str(exc)[:2000]
log_entry.completed_at = utc_now()
db.commit()
raise HTTPException(status_code=500, detail=str(exc)) raise HTTPException(status_code=500, detail=str(exc))
@@ -285,6 +255,7 @@ async def admin_cleanup(
@router.post("/delete") @router.post("/delete")
async def admin_delete( async def admin_delete(
body: DeleteRequest, body: DeleteRequest,
background_tasks: BackgroundTasks,
_admin: None = Depends(verify_admin), _admin: None = Depends(verify_admin),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
@@ -292,13 +263,31 @@ async def admin_delete(
if body.date_start > body.date_end: if body.date_start > body.date_end:
raise HTTPException(status_code=400, detail="date_start must be <= date_end") raise HTTPException(status_code=400, detail="date_start must be <= date_end")
result = await delete_papers_by_date_range( job = create_job(
db, db,
body.date_start, "delete_range",
body.date_end, owner="admin_delete",
include_notes=body.include_notes, payload={
"date_start": body.date_start.isoformat(),
"date_end": body.date_end.isoformat(),
"include_notes": body.include_notes,
},
) )
return result enqueue_job(background_tasks, job.id)
return {"status": "queued", "job_id": job.id}
@router.get("/jobs/{job_id}")
async def admin_job_detail(
job_id: int,
_admin: None = Depends(verify_admin),
db: Session = Depends(get_db),
):
"""查询后台任务状态和阶段事件。"""
detail = admin_svc.get_job_detail(db, job_id)
if not detail:
raise HTTPException(status_code=404, detail=f"Job not found: {job_id}")
return detail
# ── 日志 ────────────────────────────────────────────────────────────── # ── 日志 ──────────────────────────────────────────────────────────────
@@ -313,72 +302,10 @@ async def admin_logs(
per_page: int = Query(20, ge=1, le=100), per_page: int = Query(20, ge=1, le=100),
): ):
"""查看任务日志(CrawlLog + DataDeleteJob+ 总结状态统计。""" """查看任务日志(CrawlLog + DataDeleteJob+ 总结状态统计。"""
crawl_logs = (
db.execute(
select(CrawlLog)
.order_by(CrawlLog.started_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
)
.scalars()
.all()
)
delete_jobs = (
db.execute(
select(DataDeleteJob)
.order_by(DataDeleteJob.started_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
)
.scalars()
.all()
)
# 总结状态统计概要
summary_total = db.scalar(select(func.count(Paper.id))) or 0
summary_done = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status == SummaryState.DONE
)
)
or 0
)
summary_pending = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_(
[SummaryState.PENDING, SummaryState.PROCESSING]
)
)
)
or 0
)
summary_failed = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
)
or 0
)
return templates.TemplateResponse( return templates.TemplateResponse(
request, request,
"admin_logs.html", "admin_logs.html",
{ admin_svc.get_logs_context(db, page=page, per_page=per_page),
"crawl_logs": crawl_logs,
"delete_jobs": delete_jobs,
"page": page,
"per_page": per_page,
"summary_total": summary_total,
"summary_done": summary_done,
"summary_pending": summary_pending,
"summary_failed": summary_failed,
},
) )
@@ -395,22 +322,10 @@ async def admin_summary_status(
per_page: int = Query(20, ge=1, le=100), per_page: int = Query(20, ge=1, le=100),
): ):
"""总结状态列表(HTMX 片段或 JSON)。""" """总结状态列表(HTMX 片段或 JSON)。"""
results, total = admin_svc.query_summary_statuses(
query = ( db, status=status, page=page, per_page=per_page
select(Paper, SummaryStatus)
.outerjoin(SummaryStatus, SummaryStatus.paper_id == Paper.id)
.order_by(Paper.paper_date.desc())
) )
if status != "all":
if status == "none":
query = query.where(SummaryStatus.paper_id == None) # noqa: E711
else:
query = query.where(SummaryStatus.status == status)
total = db.scalar(select(func.count()).select_from(query.subquery()))
results = db.execute(query.offset((page - 1) * per_page).limit(per_page)).all()
# 判断是否 HTMX 请求 # 判断是否 HTMX 请求
is_htmx = request.headers.get("HX-Request") == "true" is_htmx = request.headers.get("HX-Request") == "true"
@@ -421,27 +336,16 @@ async def admin_summary_status(
"partials/summary_list.html", "partials/summary_list.html",
{ {
"results": results, "results": results,
"total": total or 0, "total": total,
"page": page, "page": page,
"per_page": per_page, "per_page": per_page,
"current_status": status, "current_status": status,
}, },
) )
# 非 HTMX 返回 JSON return admin_svc.serialize_summary_statuses(
items = [] results, total=total, page=page, per_page=per_page
for paper, ss in results: )
item = {
"arxiv_id": paper.arxiv_id,
"title": paper.title_zh or paper.title_en,
"paper_date": str(paper.paper_date),
"summary_status": ss.status if ss else "none",
"retry_count": ss.retry_count if ss else 0,
"error_type": ss.error_type if ss else None,
"error": ss.error if ss else None,
}
items.append(item)
return {"items": items, "total": total or 0, "page": page, "per_page": per_page}
@router.post("/summary-retry-failed") @router.post("/summary-retry-failed")
@@ -450,39 +354,14 @@ async def admin_summary_retry_failed(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""重试所有失败状态的总结任务。""" """重试所有失败状态的总结任务。"""
failed_ids = ( count = admin_svc.retry_failed_summaries(db)
db.execute( if not count:
select(Paper.arxiv_id)
.join(SummaryStatus, SummaryStatus.paper_id == Paper.id)
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
)
.scalars()
.all()
)
if not failed_ids:
return {"status": "success", "message": "没有失败的任务需要重试", "count": 0} return {"status": "success", "message": "没有失败的任务需要重试", "count": 0}
# 重置失败任务的状态为 pending
db.execute(
SummaryStatus.__table__.update()
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
.values(status=SummaryState.PENDING, error=None, error_type=None)
)
db.commit()
return { return {
"status": "success", "status": "success",
"message": f"已重置 {len(failed_ids)} 个失败任务为待总结状态", "message": f"已重置 {count} 个失败任务为待总结状态",
"count": len(failed_ids), "count": count,
} }
@@ -545,23 +424,8 @@ async def admin_paper_delete(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""删除单篇论文。""" """删除单篇论文。"""
paper = db.scalar(select(Paper).where(Paper.arxiv_id == arxiv_id)) if not admin_svc.delete_paper_by_arxiv(db, arxiv_id):
if not paper:
raise HTTPException(status_code=404, detail=f"Paper not found: {arxiv_id}") raise HTTPException(status_code=404, detail=f"Paper not found: {arxiv_id}")
# 删除相关数据(ORM cascade 自动处理关联表)
db.delete(paper)
db.commit()
# 清理 FTS 索引
try:
db.execute(
text("DELETE FROM papers_fts WHERE arxiv_id = :aid"), {"aid": arxiv_id}
)
db.commit()
except Exception:
logger.warning("Failed to clean FTS index for %s", arxiv_id, exc_info=True)
return {"status": "success", "message": f"已删除 {arxiv_id}"} return {"status": "success", "message": f"已删除 {arxiv_id}"}
@@ -588,28 +452,7 @@ async def admin_papers_batch_action(
raise HTTPException(status_code=400, detail="arxiv_ids 不能为空") raise HTTPException(status_code=400, detail="arxiv_ids 不能为空")
if body.action == "delete": if body.action == "delete":
papers = ( count = admin_svc.delete_papers_by_arxiv_ids(db, body.arxiv_ids)
db.execute(select(Paper).where(Paper.arxiv_id.in_(body.arxiv_ids)))
.scalars()
.all()
)
count = 0
for paper in papers:
db.delete(paper)
count += 1
db.commit()
# 清理 FTS 索引
try:
stmt = text("DELETE FROM papers_fts WHERE arxiv_id IN :ids").bindparams(
bindparam("ids", expanding=True)
)
db.execute(stmt, {"ids": body.arxiv_ids})
db.commit()
except Exception:
logger.warning("Failed to clean FTS index for batch delete", exc_info=True)
return { return {
"status": "success", "status": "success",
"message": f"已删除 {count} 篇论文", "message": f"已删除 {count} 篇论文",
@@ -617,24 +460,10 @@ async def admin_papers_batch_action(
} }
elif body.action == "summarize": elif body.action == "summarize":
# 将选中论文的总结状态重置为 pending count = admin_svc.reset_summaries_pending(db, body.arxiv_ids)
paper_ids = (
db.execute(select(Paper.id).where(Paper.arxiv_id.in_(body.arxiv_ids)))
.scalars()
.all()
)
if paper_ids:
# 删除旧的 status 记录让其重新进入 pipeline
db.execute(
SummaryStatus.__table__.delete().where(
SummaryStatus.paper_id.in_(paper_ids)
)
)
db.commit()
return { return {
"status": "success", "status": "success",
"message": f"已将 {len(paper_ids)} 篇论文重置为待总结", "message": f"已将 {count} 篇论文重置为待总结",
"count": len(paper_ids), "count": count,
} }
+324 -3
View File
@@ -1,17 +1,30 @@
"""管理后台服务 — 统计聚合、系统状态。""" """管理后台服务 — 统计聚合、系统状态、管理操作"""
from __future__ import annotations from __future__ import annotations
import json
from datetime import date from datetime import date
from pathlib import Path from pathlib import Path
from typing import Callable
from sqlalchemy import func, select, text from sqlalchemy import func, select, text
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.config import settings from app.config import settings
from app.models import CrawlLog, Paper, PaperTag, SummaryState, SummaryStatus, TaskLock from app.models import (
CrawlLog,
DataDeleteJob,
Job,
JobEvent,
Paper,
PaperTag,
SummaryState,
SummaryStatus,
TaskLock,
)
from app.services.derived import delete_paper_indexes
from app.services.scheduler import get_scheduler from app.services.scheduler import get_scheduler
from app.utils import PAPERS_DIR, TMP_DIR from app.utils import PAPERS_DIR, TMP_DIR, utc_now
# admin_papers 排序映射 # admin_papers 排序映射
SORT_MAP = { SORT_MAP = {
@@ -190,3 +203,311 @@ def query_papers(
statuses[paper_id_to_arxiv.get(pid, "")] = st statuses[paper_id_to_arxiv.get(pid, "")] = st
return papers, total or 0, statuses return papers, total or 0, statuses
def get_scheduler_history(db: Session, limit: int = 10) -> list[CrawlLog]:
"""最近的调度器运行日志。"""
return (
db.execute(
select(CrawlLog)
.where(CrawlLog.task == "scheduler")
.order_by(CrawlLog.started_at.desc())
.limit(limit)
)
.scalars()
.all()
)
def get_scheduler_status() -> dict:
"""调度器运行状态。"""
scheduler = get_scheduler()
next_run = None
upvote_next_run = None
if scheduler:
for job in scheduler.get_jobs():
if job.id == "daily_pipeline":
next_run = job.next_run_time
elif job.id == "upvote_refresh":
upvote_next_run = job.next_run_time
return {
"enabled": scheduler is not None,
"schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}",
"timezone": settings.APP_TIMEZONE,
"next_run": next_run.isoformat() if next_run else None,
"upvote_next_run": upvote_next_run.isoformat() if upvote_next_run else None,
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
}
def run_cleanup_now(db: Session, cleanup_func: Callable[[], dict]) -> dict:
"""同步执行临时目录清理,并写入 CrawlLog。"""
log_entry = CrawlLog(task="cleanup", status="running", started_at=utc_now())
db.add(log_entry)
db.commit()
try:
result = cleanup_func()
log_entry.status = "success"
log_entry.completed_at = utc_now()
log_entry.details_json = json.dumps(
{
"scanned": result.get("scanned", 0),
"removed": result.get("removed", 0),
},
ensure_ascii=False,
)
if result.get("errors"):
log_entry.error = "; ".join(result["errors"])[:2000]
db.commit()
return result
except Exception as exc:
log_entry.status = "failed"
log_entry.error = str(exc)[:2000]
log_entry.completed_at = utc_now()
db.commit()
raise
def get_job_detail(db: Session, job_id: int) -> dict | None:
"""后台任务详情和阶段事件,返回可 JSON 序列化 dict。"""
job = db.get(Job, job_id)
if not job:
return None
events = (
db.execute(
select(JobEvent)
.where(JobEvent.job_id == job_id)
.order_by(JobEvent.created_at.asc())
)
.scalars()
.all()
)
return {
"id": job.id,
"type": job.type,
"status": job.status,
"owner": job.owner,
"payload": json.loads(job.payload_json or "{}"),
"result": json.loads(job.result_json or "{}") if job.result_json else None,
"error": job.error,
"created_at": job.created_at.isoformat(),
"started_at": job.started_at.isoformat() if job.started_at else None,
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
"events": [
{
"stage": event.stage,
"status": event.status,
"message": event.message,
"payload": json.loads(event.payload_json or "{}")
if event.payload_json
else None,
"created_at": event.created_at.isoformat(),
}
for event in events
],
}
def get_logs_context(db: Session, *, page: int, per_page: int) -> dict:
"""管理日志页上下文。"""
crawl_logs = (
db.execute(
select(CrawlLog)
.order_by(CrawlLog.started_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
)
.scalars()
.all()
)
delete_jobs = (
db.execute(
select(DataDeleteJob)
.order_by(DataDeleteJob.started_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
)
.scalars()
.all()
)
summary_total = db.scalar(select(func.count(Paper.id))) or 0
summary_done = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status == SummaryState.DONE
)
)
or 0
)
summary_pending = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_(
[SummaryState.PENDING, SummaryState.PROCESSING]
)
)
)
or 0
)
summary_failed = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
)
or 0
)
return {
"crawl_logs": crawl_logs,
"delete_jobs": delete_jobs,
"page": page,
"per_page": per_page,
"summary_total": summary_total,
"summary_done": summary_done,
"summary_pending": summary_pending,
"summary_failed": summary_failed,
}
def query_summary_statuses(
db: Session,
*,
status: str,
page: int,
per_page: int,
) -> tuple[list[tuple[Paper, SummaryStatus | None]], int]:
"""总结状态列表查询。"""
query = (
select(Paper, SummaryStatus)
.outerjoin(SummaryStatus, SummaryStatus.paper_id == Paper.id)
.order_by(Paper.paper_date.desc())
)
if status != "all":
if status == "none":
query = query.where(SummaryStatus.paper_id == None) # noqa: E711
else:
query = query.where(SummaryStatus.status == status)
total = db.scalar(select(func.count()).select_from(query.subquery())) or 0
results = db.execute(query.offset((page - 1) * per_page).limit(per_page)).all()
return results, total
def serialize_summary_statuses(
results: list[tuple[Paper, SummaryStatus | None]],
*,
total: int,
page: int,
per_page: int,
) -> dict:
"""总结状态列表 JSON 响应。"""
items = []
for paper, ss in results:
items.append(
{
"arxiv_id": paper.arxiv_id,
"title": paper.title_zh or paper.title_en,
"paper_date": str(paper.paper_date),
"summary_status": ss.status if ss else "none",
"retry_count": ss.retry_count if ss else 0,
"error_type": ss.error_type if ss else None,
"error": ss.error if ss else None,
}
)
return {"items": items, "total": total, "page": page, "per_page": per_page}
def retry_failed_summaries(db: Session) -> int:
"""将失败/永久失败的总结任务重置为 pending。"""
failed_ids = (
db.execute(
select(Paper.arxiv_id)
.join(SummaryStatus, SummaryStatus.paper_id == Paper.id)
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
)
.scalars()
.all()
)
if not failed_ids:
return 0
db.execute(
SummaryStatus.__table__.update()
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
.values(status=SummaryState.PENDING, error=None, error_type=None)
)
db.commit()
return len(failed_ids)
def delete_paper_by_arxiv(db: Session, arxiv_id: str) -> bool:
"""删除单篇论文和派生索引。"""
paper = db.scalar(select(Paper).where(Paper.arxiv_id == arxiv_id))
if not paper:
return False
paper_id = paper.id
db.delete(paper)
db.commit()
delete_paper_indexes(db, paper_id=paper_id, arxiv_id=arxiv_id)
db.commit()
return True
def delete_papers_by_arxiv_ids(db: Session, arxiv_ids: list[str]) -> int:
"""批量删除论文和派生索引。"""
papers = (
db.execute(select(Paper).where(Paper.arxiv_id.in_(arxiv_ids))).scalars().all()
)
deleted = [(paper.id, paper.arxiv_id) for paper in papers]
for paper in papers:
db.delete(paper)
db.commit()
for paper_id, arxiv_id in deleted:
delete_paper_indexes(db, paper_id=paper_id, arxiv_id=arxiv_id)
db.commit()
return len(deleted)
def reset_summaries_pending(db: Session, arxiv_ids: list[str]) -> int:
"""将指定论文的总结状态重置为 pending,没有状态则创建。"""
paper_ids = (
db.execute(select(Paper.id).where(Paper.arxiv_id.in_(arxiv_ids)))
.scalars()
.all()
)
if not paper_ids:
return 0
existing_statuses = (
db.execute(select(SummaryStatus).where(SummaryStatus.paper_id.in_(paper_ids)))
.scalars()
.all()
)
existing_ids = {status.paper_id for status in existing_statuses}
for status in existing_statuses:
status.status = SummaryState.PENDING
status.quality = None
status.error = None
status.error_type = None
status.raw_output_saved = False
status.started_at = None
status.completed_at = None
for paper_id in paper_ids:
if paper_id not in existing_ids:
db.add(SummaryStatus(paper_id=paper_id, status=SummaryState.PENDING))
db.commit()
return len(paper_ids)
+3 -16
View File
@@ -4,7 +4,7 @@ import logging
from datetime import date as date_type, datetime, timezone from datetime import date as date_type, datetime, timezone
import httpx import httpx
from sqlalchemy import select, text from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.config import settings from app.config import settings
@@ -16,6 +16,7 @@ from app.models import (
SummaryState, SummaryState,
SummaryStatus, SummaryStatus,
) )
from app.services.derived import reindex_paper_fts
from app.utils import make_http_client, recent_date_strs, utc_now from app.utils import make_http_client, recent_date_strs, utc_now
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -143,21 +144,7 @@ def upsert_papers(db: Session, papers_raw: list[dict], paper_date: str) -> list[
db.add(SummaryStatus(paper_id=paper.id, status=SummaryState.PENDING)) db.add(SummaryStatus(paper_id=paper.id, status=SummaryState.PENDING))
authors_text = ", ".join(meta["authors"]) reindex_paper_fts(db, paper)
tags_text = ", ".join(meta["tags"])
db.execute(
text(
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
"VALUES (:id, :title, :abstract, :authors, :tags)"
),
{
"id": paper.id,
"title": meta["title_en"],
"abstract": meta["abstract"] or "",
"authors": authors_text,
"tags": tags_text,
},
)
new_papers.append(paper) new_papers.append(paper)
logger.debug("Inserted new paper: %s", arxiv_id) logger.debug("Inserted new paper: %s", arxiv_id)
+140
View File
@@ -0,0 +1,140 @@
"""派生数据维护 — FTS5 / ChromaDB 等可重建索引。"""
from __future__ import annotations
import logging
from sqlalchemy import select, text
from sqlalchemy.orm import Session
from app.models import Paper
logger = logging.getLogger(__name__)
def _summary_text(paper: Paper) -> str:
summary = paper.summary
if not summary:
return ""
parts = [
summary.one_line,
summary.motivation_problem,
summary.motivation_goal,
summary.method_overview,
summary.method_key_idea,
summary.results_main_json,
]
return " ".join(p for p in parts if p)
def delete_fts_paper(db: Session, paper_id: int) -> None:
"""删除单篇论文的 FTS5 行。FTS5 以 papers.id 作为 rowid。"""
db.execute(
text("DELETE FROM papers_fts WHERE rowid = :paper_id"),
{"paper_id": paper_id},
)
def delete_paper_indexes(db: Session, *, paper_id: int, arxiv_id: str) -> None:
"""删除单篇论文的所有派生索引。失败项记录日志但不阻断主删除。"""
try:
delete_fts_paper(db, paper_id)
except Exception:
logger.warning("Failed to clean FTS index for %s", arxiv_id, exc_info=True)
try:
from app.services.embedder import delete_paper
delete_paper(arxiv_id)
except Exception:
logger.warning("Failed to clean ChromaDB index for %s", arxiv_id, exc_info=True)
def reindex_paper_fts(db: Session, paper: Paper) -> None:
"""按 DB 权威数据重建单篇论文的 FTS5 派生索引。"""
authors_text = ", ".join(
a.name for a in sorted(paper.authors, key=lambda a: a.position or 0)
)
tags_text = ", ".join(t.tag for t in paper.tags)
delete_fts_paper(db, paper.id)
db.execute(
text(
"""
INSERT INTO papers_fts(
rowid, title_en, title_zh, abstract, authors, tags, summary_text
)
VALUES (
:id, :title_en, :title_zh, :abstract, :authors, :tags, :summary_text
)
"""
),
{
"id": paper.id,
"title_en": paper.title_en or "",
"title_zh": paper.title_zh or "",
"abstract": paper.abstract or "",
"authors": authors_text,
"tags": tags_text,
"summary_text": _summary_text(paper),
},
)
def reindex_fts(db: Session, paper_ids: list[int] | None = None) -> dict:
"""全量或局部重建 FTS5 索引。"""
query = select(Paper)
if paper_ids:
query = query.where(Paper.id.in_(paper_ids))
papers = db.execute(query).scalars().all()
if paper_ids is None:
db.execute(text("DELETE FROM papers_fts"))
count = 0
for paper in papers:
reindex_paper_fts(db, paper)
count += 1
db.commit()
logger.info("FTS reindexed: %d papers", count)
return {"status": "success", "indexed": count}
def reindex_chroma(db: Session) -> dict:
"""按 DB 权威数据重建 ChromaDB 语义索引。"""
from app.services.embedder import index_paper
papers = db.execute(select(Paper).where(Paper.summary.has())).scalars().all()
indexed = 0
errors: list[str] = []
for paper in papers:
try:
texts_dict = {
"arxiv_id": paper.arxiv_id,
"title_zh": paper.title_zh or "",
"title_en": paper.title_en or "",
"tags": " ".join(t.tag for t in paper.tags),
"one_line": paper.summary.one_line if paper.summary else "",
"motivation_problem": (
paper.summary.motivation_problem if paper.summary else ""
),
"method_key_idea": (
paper.summary.method_key_idea if paper.summary else ""
),
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
}
index_paper(paper.arxiv_id, texts_dict)
indexed += 1
except Exception as exc:
errors.append(f"{paper.arxiv_id}: {exc}")
logger.warning(
"Failed to reindex ChromaDB for %s",
paper.arxiv_id,
exc_info=True,
)
return {
"status": "success" if not errors else "partial",
"indexed": indexed,
"errors": errors or None,
}
+244
View File
@@ -0,0 +1,244 @@
"""统一后台任务系统 — 创建、运行、事件记录、失败恢复。"""
from __future__ import annotations
import json
import logging
from datetime import date, timedelta
from typing import Any
from fastapi import BackgroundTasks
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
from app.config import settings
from app.database import SessionLocal
from app.models import Job, JobEvent, JobEventStatus, JobStatus, TaskLock
from app.utils import truncate_error, utc_now
logger = logging.getLogger(__name__)
STALE_JOB_AFTER = timedelta(hours=6)
def _dumps(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, default=str)
def _loads(value: str | None) -> dict:
if not value:
return {}
try:
data = json.loads(value)
return data if isinstance(data, dict) else {}
except json.JSONDecodeError:
return {}
def create_job(
db: Session,
job_type: str,
*,
owner: str,
payload: dict | None = None,
) -> Job:
"""创建后台任务主记录。"""
job = Job(
type=job_type,
status=JobStatus.QUEUED,
owner=owner,
payload_json=_dumps(payload or {}),
created_at=utc_now(),
)
db.add(job)
db.commit()
db.refresh(job)
add_job_event(
db,
job,
stage="created",
status=JobEventStatus.INFO,
message=f"Job queued: {job_type}",
payload=payload or {},
)
return job
def add_job_event(
db: Session,
job: Job,
*,
stage: str,
status: str,
message: str | None = None,
payload: dict | None = None,
) -> None:
"""追加一条任务阶段事件。"""
db.add(
JobEvent(
job_id=job.id,
stage=stage,
status=str(status),
message=message,
payload_json=_dumps(payload) if payload is not None else None,
created_at=utc_now(),
)
)
job.heartbeat_at = utc_now()
db.commit()
def enqueue_job(background_tasks: BackgroundTasks, job_id: int) -> None:
"""把任务提交给 FastAPI BackgroundTasks。"""
background_tasks.add_task(run_job_by_id, job_id)
async def run_job_by_id(job_id: int) -> None:
"""使用独立 DB session 运行一个已创建的 job。"""
db = SessionLocal()
try:
await run_job(db, job_id)
finally:
db.close()
async def run_job(db: Session, job_id: int) -> dict:
"""运行 job,并把状态/result/error 写回 jobs/job_events。"""
job = db.get(Job, job_id)
if not job:
raise ValueError(f"Job not found: {job_id}")
if job.status == JobStatus.RUNNING:
raise RuntimeError(f"Job already running: {job_id}")
payload = _loads(job.payload_json)
job.status = JobStatus.RUNNING
job.started_at = utc_now()
job.heartbeat_at = job.started_at
db.commit()
add_job_event(db, job, stage="run", status=JobEventStatus.STARTED)
try:
result = await _dispatch_job(db, job, payload)
except Exception as exc:
logger.exception("Job failed: id=%s type=%s", job.id, job.type)
error = truncate_error(exc, limit=4000)
job.status = JobStatus.FAILED
job.error = error
job.completed_at = utc_now()
db.commit()
add_job_event(db, job, stage="run", status=JobEventStatus.FAILED, message=error)
return {"status": "failed", "error": error}
job.status = JobStatus.SUCCESS
job.result_json = _dumps(result)
job.completed_at = utc_now()
job.error = None
db.commit()
add_job_event(
db,
job,
stage="run",
status=JobEventStatus.SUCCESS,
payload=result if isinstance(result, dict) else {"result": result},
)
return result if isinstance(result, dict) else {"status": "success", "result": result}
async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
from app.services.crawler import refresh_upvotes
from app.services.derived import reindex_chroma, reindex_fts
from app.services.pipeline import run_crawl, run_pipeline
from app.services.summarizer import summarize_batch, summarize_single
if job.type == "crawl_daily":
return await run_crawl(
db,
payload["target_date"],
owner=job.owner or f"job:{job.id}",
top_n=payload.get("top_n"),
)
if job.type == "pipeline_daily":
return await run_pipeline(
db,
payload["target_date"],
owner=job.owner or f"job:{job.id}",
)
if job.type == "summarize_batch":
return await summarize_batch(
db,
pdf_mode=payload.get("pdf_mode", settings.SUMMARY_PDF_MODE),
)
if job.type == "summarize_one":
return await summarize_single(
db,
payload["arxiv_id"],
force=payload.get("force", True),
pdf_mode=payload.get("pdf_mode", settings.SUMMARY_PDF_MODE),
)
if job.type == "refresh_upvotes":
return await refresh_upvotes(db, days=payload.get("days"))
if job.type == "delete_range":
return await delete_papers_by_date_range(
db,
date.fromisoformat(payload["date_start"]),
date.fromisoformat(payload["date_end"]),
include_notes=payload.get("include_notes", True),
)
if job.type == "cleanup_tmp":
return cleanup_tmp()
if job.type == "reindex_fts":
return reindex_fts(db)
if job.type == "reindex_chroma":
return reindex_chroma(db)
raise ValueError(f"Unsupported job type: {job.type}")
def recover_stale_jobs(db: Session) -> int:
"""启动时将过期 running job/lock 标记为 stale,避免永久卡住。"""
now = utc_now()
cutoff = now - STALE_JOB_AFTER
stale_jobs = (
db.execute(
select(Job).where(
Job.status == JobStatus.RUNNING,
or_(Job.heartbeat_at == None, Job.heartbeat_at < cutoff), # noqa: E711
)
)
.scalars()
.all()
)
for job in stale_jobs:
job.status = JobStatus.STALE
job.error = "Marked stale after process restart or missed heartbeat"
job.completed_at = now
db.add(
JobEvent(
job_id=job.id,
stage="recovery",
status=JobEventStatus.FAILED,
message=job.error,
created_at=now,
)
)
stale_locks = (
db.execute(
select(TaskLock).where(
TaskLock.status == "running",
TaskLock.acquired_at < cutoff,
)
)
.scalars()
.all()
)
for lock in stale_locks:
lock.status = "stale"
lock.released_at = now
db.commit()
recovered = len(stale_jobs) + len(stale_locks)
if recovered:
logger.warning("Recovered stale runtime records: %d", recovered)
return recovered
+12 -2
View File
@@ -7,6 +7,7 @@ from __future__ import annotations
import logging import logging
from datetime import date as date_type from datetime import date as date_type
from datetime import timedelta
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -32,6 +33,8 @@ def acquire_lock(db: Session, task: str, lock_key: str, owner: str) -> TaskLock:
status="running", status="running",
owner=owner, owner=owner,
acquired_at=utc_now(), acquired_at=utc_now(),
heartbeat_at=utc_now(),
expires_at=utc_now() + timedelta(hours=6),
) )
try: try:
db.add(lock) db.add(lock)
@@ -42,7 +45,12 @@ def acquire_lock(db: Session, task: str, lock_key: str, owner: str) -> TaskLock:
return lock return lock
async def run_crawl(db: Session, target_date: str, owner: str = "admin_crawl") -> dict: async def run_crawl(
db: Session,
target_date: str,
owner: str = "admin_crawl",
top_n: int | None = None,
) -> dict:
"""执行单次抓取(带防重入锁)。 """执行单次抓取(带防重入锁)。
Args: Args:
@@ -55,7 +63,7 @@ async def run_crawl(db: Session, target_date: str, owner: str = "admin_crawl") -
""" """
lock = acquire_lock(db, "crawl", target_date, owner) lock = acquire_lock(db, "crawl", target_date, owner)
try: try:
return await crawl_daily(db, target_date) return await crawl_daily(db, target_date, top_n=top_n)
finally: finally:
release_lock(db, lock) release_lock(db, lock)
@@ -83,6 +91,8 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
status="running", status="running",
owner=owner, owner=owner,
acquired_at=now, acquired_at=now,
heartbeat_at=now,
expires_at=now + timedelta(hours=6),
) )
try: try:
db.add(lock) db.add(lock)
+10 -4
View File
@@ -11,8 +11,7 @@ from zoneinfo import ZoneInfo
from app.config import settings from app.config import settings
from app.database import SessionLocal from app.database import SessionLocal
from app.services.pipeline import run_pipeline from app.services.jobs import create_job, run_job
from app.services.crawler import refresh_upvotes
from app.utils import today_str from app.utils import today_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -112,7 +111,13 @@ async def _daily_pipeline() -> None:
db: Session = SessionLocal() db: Session = SessionLocal()
try: try:
await run_pipeline(db, today, owner="daily_pipeline") job = create_job(
db,
"pipeline_daily",
owner="daily_pipeline",
payload={"target_date": today},
)
await run_job(db, job.id)
except RuntimeError: except RuntimeError:
logger.warning("Daily pipeline already running for %s, skipping", today) logger.warning("Daily pipeline already running for %s, skipping", today)
except Exception: except Exception:
@@ -125,7 +130,8 @@ async def _upvote_refresh() -> None:
"""刷新最近 N 天论文的 upvotes。""" """刷新最近 N 天论文的 upvotes。"""
db: Session = SessionLocal() db: Session = SessionLocal()
try: try:
result = await refresh_upvotes(db) job = create_job(db, "refresh_upvotes", owner="upvote_refresh", payload={})
result = await run_job(db, job.id)
logger.info( logger.info(
"Upvote refresh completed: status=%s updated=%d", "Upvote refresh completed: status=%s updated=%d",
result.get("status"), result.get("status"),
+4 -15
View File
@@ -3,8 +3,6 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from sqlalchemy import text
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.models import ( from app.models import (
@@ -13,6 +11,7 @@ from app.models import (
PaperTag, PaperTag,
SummaryState, SummaryState,
) )
from app.services.derived import reindex_paper_fts
from app.services.pdf_downloader import paper_dir from app.services.pdf_downloader import paper_dir
from app.services.schemas import ( from app.services.schemas import (
SummarySchema, SummarySchema,
@@ -75,19 +74,9 @@ def _update_summary_in_db(
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai")) db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
existing_tag_names.add(tag_name) existing_tag_names.add(tag_name)
# 4. FTS5 更新 # 4. FTS5 派生索引
summary_text = _build_fts_summary_text(schema) db.flush()
db.execute( reindex_paper_fts(db, paper)
text(
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
"WHERE rowid=:paper_id"
),
{
"title_zh": schema.title_zh,
"summary_text": summary_text,
"paper_id": paper.id,
},
)
db.commit() db.commit()
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality) logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
+148 -91
View File
@@ -3,14 +3,17 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from unittest.mock import AsyncMock, patch from unittest.mock import patch
import pytest import pytest
from sqlalchemy import select from sqlalchemy import select, text
from app.config import settings from app.config import settings
from app.models import ( from app.models import (
CrawlLog, CrawlLog,
Job,
SummaryState,
SummaryStatus,
TaskLock, TaskLock,
) )
from app.utils import utc_now from app.utils import utc_now
@@ -64,47 +67,13 @@ class TestAdminAuth:
resp = auth_client.get("/admin/logs", follow_redirects=False) resp = auth_client.get("/admin/logs", follow_redirects=False)
assert resp.status_code == 303 assert resp.status_code == 303
def test_correct_session_accepted(self, auth_client):
"""已登录 session 应被接受(crawl 可能会失败但不是 303)。"""
with patch(
"app.routes.admin.run_crawl", new_callable=AsyncMock
) as mock_crawl:
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
resp = auth_client.post("/admin/crawl")
assert resp.status_code != 303
# ── summarize route auth ────────────────────────────────────────
def test_no_session_returns_303_for_summarize(self, client, monkeypatch):
"""无 session 返回 303。"""
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
resp = client.post("/admin/summarize", follow_redirects=False)
assert resp.status_code == 303
def test_correct_session_batch_summarize(self, auth_client): def test_correct_session_batch_summarize(self, auth_client):
"""已登录调用 batch summarizemock 掉服务层""" """已登录调用 batch summarize应创建后台任务"""
with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock: with patch("app.routes.admin.enqueue_job"):
mock.return_value = {
"status": "success",
"done": 0,
"failed": 0,
"total": 0,
}
resp = auth_client.post("/admin/summarize") resp = auth_client.post("/admin/summarize")
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["status"] == "success" assert resp.json()["status"] == "queued"
assert "job_id" in resp.json()
def test_single_paper_not_found(self, auth_client):
"""单篇总结不存在的论文返回 404。"""
from app.exceptions import NotFoundError
with patch(
"app.routes.admin.summarize_single",
new_callable=AsyncMock,
side_effect=NotFoundError("Paper not found: nonexistent.99999"),
):
resp = auth_client.post("/admin/summarize/nonexistent.99999")
assert resp.status_code == 404
# ═══════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════
@@ -115,29 +84,12 @@ class TestAdminAuth:
class TestAdminCrawl: class TestAdminCrawl:
"""POST /admin/crawl 测试。""" """POST /admin/crawl 测试。"""
def test_crawl_default_today(self, auth_client):
"""不指定日期时默认抓取今天。"""
with patch(
"app.routes.admin.run_crawl", new_callable=AsyncMock
) as mock_crawl:
mock_crawl.return_value = {"found": 5, "new": 3, "status": "success"}
resp = auth_client.post("/admin/crawl")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "success"
mock_crawl.assert_called_once()
def test_crawl_specific_date(self, auth_client): def test_crawl_specific_date(self, auth_client):
"""指定日期抓取。""" """指定日期抓取。"""
with patch( with patch("app.routes.admin.enqueue_job"):
"app.routes.admin.run_crawl", new_callable=AsyncMock
) as mock_crawl:
mock_crawl.return_value = {"found": 2, "new": 1, "status": "success"}
resp = auth_client.post("/admin/crawl?date=2024-01-15") resp = auth_client.post("/admin/crawl?date=2024-01-15")
assert resp.status_code == 200 assert resp.status_code == 200
mock_crawl.assert_called_once() assert resp.json()["target_date"] == "2024-01-15"
call_args = mock_crawl.call_args
assert call_args[0][1] == "2024-01-15"
# ═══════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════
@@ -149,20 +101,20 @@ class TestAdminCleanup:
"""POST /admin/cleanup 测试。""" """POST /admin/cleanup 测试。"""
def test_cleanup_returns_stats(self, auth_client): def test_cleanup_returns_stats(self, auth_client):
"""清理应返回统计信息。""" """同步清理排障接口应返回统计信息。"""
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup: with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
mock_cleanup.return_value = {"scanned": 3, "removed": 1, "errors": []} mock_cleanup.return_value = {"scanned": 3, "removed": 1, "errors": []}
resp = auth_client.post("/admin/cleanup") resp = auth_client.post("/admin/cleanup-now")
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert data["scanned"] == 3 assert data["scanned"] == 3
assert data["removed"] == 1 assert data["removed"] == 1
def test_cleanup_writes_log(self, auth_client, db_session): def test_cleanup_writes_log(self, auth_client, db_session):
"""清理应写入 crawl_logs。""" """同步清理排障接口应写入 crawl_logs。"""
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup: with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
mock_cleanup.return_value = {"scanned": 0, "removed": 0, "errors": []} mock_cleanup.return_value = {"scanned": 0, "removed": 0, "errors": []}
auth_client.post("/admin/cleanup") auth_client.post("/admin/cleanup-now")
logs = ( logs = (
db_session.execute(select(CrawlLog).where(CrawlLog.task == "cleanup")) db_session.execute(select(CrawlLog).where(CrawlLog.task == "cleanup"))
@@ -195,19 +147,21 @@ class TestAdminDelete:
assert resp.status_code == 422 assert resp.status_code == 422
def test_delete_with_confirm(self, auth_client, db_session, sample_papers_range): def test_delete_with_confirm(self, auth_client, db_session, sample_papers_range):
"""confirm='DELETE' 时应执行删除""" """confirm='DELETE' 时应创建后台删除 job"""
resp = auth_client.post( with patch("app.routes.admin.enqueue_job"):
"/admin/delete", resp = auth_client.post(
json={ "/admin/delete",
"date_start": "2024-01-10", json={
"date_end": "2024-01-12", "date_start": "2024-01-10",
"include_notes": True, "date_end": "2024-01-12",
"confirm": "DELETE", "include_notes": True,
}, "confirm": "DELETE",
) },
)
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert data["deleted"] == 3 assert data["status"] == "queued"
assert db_session.get(Job, data["job_id"]) is not None
def test_delete_invalid_date_range(self, auth_client): def test_delete_invalid_date_range(self, auth_client):
"""date_start > date_end 应返回 400。""" """date_start > date_end 应返回 400。"""
@@ -221,17 +175,6 @@ class TestAdminDelete:
) )
assert resp.status_code == 400 assert resp.status_code == 400
def test_delete_without_confirm_field(self, auth_client):
"""缺少 confirm 字段应返回 422。"""
resp = auth_client.post(
"/admin/delete",
json={
"date_start": "2024-01-10",
"date_end": "2024-01-12",
},
)
assert resp.status_code == 422
# ═══════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════
# Admin Routes — Logs # Admin Routes — Logs
@@ -241,12 +184,6 @@ class TestAdminDelete:
class TestAdminLogs: class TestAdminLogs:
"""GET /admin/logs 测试。""" """GET /admin/logs 测试。"""
def test_logs_returns_page(self, auth_client):
"""应返回管理日志页面。"""
resp = auth_client.get("/admin/logs")
assert resp.status_code == 200
assert "text/html" in resp.headers.get("content-type", "")
def test_logs_requires_auth(self, client, monkeypatch): def test_logs_requires_auth(self, client, monkeypatch):
"""日志页面需要鉴权。""" """日志页面需要鉴权。"""
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password") monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
@@ -272,6 +209,126 @@ class TestAdminLogs:
assert "crawl" in resp.text.lower() or "日志" in resp.text assert "crawl" in resp.text.lower() or "日志" in resp.text
class TestAdminJobs:
"""后台 job 查询接口测试。"""
def test_job_detail_returns_payload_and_events(self, auth_client, db_session):
"""GET /admin/jobs/{id} 返回 job 主记录和事件。"""
with patch("app.routes.admin.enqueue_job"):
resp = auth_client.post("/admin/crawl?date=2024-01-15")
job_id = resp.json()["job_id"]
resp = auth_client.get(f"/admin/jobs/{job_id}")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == job_id
assert data["type"] == "crawl_daily"
assert data["payload"] == {"target_date": "2024-01-15"}
assert data["events"][0]["stage"] == "created"
def test_job_detail_not_found(self, auth_client):
resp = auth_client.get("/admin/jobs/999999")
assert resp.status_code == 404
class TestAdminSummaryStatus:
"""总结状态管理接口测试。"""
def test_summary_status_json_filters_failed(
self, auth_client, db_session, sample_paper
):
sample_paper.summary_status.status = SummaryState.FAILED
sample_paper.summary_status.retry_count = 2
sample_paper.summary_status.error_type = "timeout"
db_session.commit()
resp = auth_client.get("/admin/summary-status?status=failed")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
assert data["items"][0]["arxiv_id"] == sample_paper.arxiv_id
assert data["items"][0]["retry_count"] == 2
def test_retry_failed_resets_failed_statuses(
self, auth_client, db_session, sample_paper
):
sample_paper.summary_status.status = SummaryState.PERMANENT_FAILURE
sample_paper.summary_status.error = "bad json"
sample_paper.summary_status.error_type = "json_invalid"
db_session.commit()
resp = auth_client.post("/admin/summary-retry-failed")
assert resp.status_code == 200
assert resp.json()["count"] == 1
db_session.refresh(sample_paper.summary_status)
assert sample_paper.summary_status.status == SummaryState.PENDING
assert sample_paper.summary_status.error is None
assert sample_paper.summary_status.error_type is None
class TestAdminPapers:
"""论文管理批量操作测试。"""
def test_single_delete_removes_paper_and_fts(
self, auth_client, db_session, sample_paper
):
paper_id = sample_paper.id
resp = auth_client.post(f"/admin/paper-delete/{sample_paper.arxiv_id}")
assert resp.status_code == 200
assert db_session.get(type(sample_paper), paper_id) is None
fts_row = db_session.execute(
text("SELECT rowid FROM papers_fts WHERE rowid = :id"),
{"id": paper_id},
).fetchone()
assert fts_row is None
def test_batch_delete_removes_papers_and_fts(
self, auth_client, db_session, sample_papers_range
):
target_ids = [p.id for p in sample_papers_range[:2]]
target_arxiv_ids = [p.arxiv_id for p in sample_papers_range[:2]]
resp = auth_client.post(
"/admin/papers-batch-action",
json={"action": "delete", "arxiv_ids": target_arxiv_ids},
)
assert resp.status_code == 200
assert resp.json()["count"] == 2
remaining = db_session.execute(
text(
"SELECT rowid FROM papers_fts "
"WHERE rowid IN (:id1, :id2)"
),
{"id1": target_ids[0], "id2": target_ids[1]},
).fetchall()
assert remaining == []
def test_batch_summarize_sets_pending_status(
self, auth_client, db_session, sample_papers_range
):
paper = sample_papers_range[0]
paper.summary_status.status = SummaryState.DONE
db_session.commit()
resp = auth_client.post(
"/admin/papers-batch-action",
json={"action": "summarize", "arxiv_ids": [paper.arxiv_id]},
)
assert resp.status_code == 200
status = db_session.scalar(
select(SummaryStatus).where(SummaryStatus.paper_id == paper.id)
)
assert status is not None
assert status.status == SummaryState.PENDING
# ═══════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════
# Scheduler 测试 # Scheduler 测试
# ═══════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════
+60
View File
@@ -10,6 +10,7 @@ from app.services.crawler import (
_parse_paper, _parse_paper,
crawl_daily, crawl_daily,
fetch_daily, fetch_daily,
refresh_upvotes,
upsert_papers, upsert_papers,
) )
@@ -187,3 +188,62 @@ class TestCrawlDaily:
assert result["status"] == "failed" assert result["status"] == "failed"
assert "network error" in result["error"] assert "network error" in result["error"]
class TestRefreshUpvotes:
@pytest.mark.asyncio
async def test_refresh_updates_existing_without_inserting_new(
self, db_session, sample_paper
):
sample_paper.arxiv_id = "1706.03762"
sample_paper.upvotes = 10
db_session.commit()
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
return_value=[
{
"paper": {
"id": "1706.03762",
"upvotes": 999,
"authors": [],
"tags": [],
}
},
{
"paper": {
"id": "2010.11929",
"upvotes": 123,
"authors": [],
"tags": [],
}
},
],
):
result = await refresh_upvotes(db_session, days=1)
db_session.refresh(sample_paper)
assert result["status"] == "success"
assert result["updated"] == 1
assert sample_paper.upvotes == 999
assert db_session.query(type(sample_paper)).count() == 1
@pytest.mark.asyncio
async def test_refresh_returns_partial_when_one_day_fails(self, db_session):
async def _fetch_daily(target_date):
if target_date.endswith("01"):
raise ConnectionError("hf down")
return []
with (
patch(
"app.services.crawler.recent_date_strs",
return_value=["2024-01-01", "2024-01-02"],
),
patch("app.services.crawler.fetch_daily", side_effect=_fetch_daily),
):
result = await refresh_upvotes(db_session, days=2)
assert result["status"] == "partial"
assert result["errors"] == ["2024-01-01: hf down"]
+80
View File
@@ -0,0 +1,80 @@
"""派生索引维护测试。"""
from __future__ import annotations
from unittest.mock import patch
from sqlalchemy import text
from app.services.derived import reindex_chroma, reindex_fts
class TestReindexFts:
def test_reindex_fts_rebuilds_missing_rows(self, db_session, sample_paper):
db_session.execute(
text("DELETE FROM papers_fts WHERE rowid = :id"),
{"id": sample_paper.id},
)
db_session.commit()
result = reindex_fts(db_session)
row = db_session.execute(
text("SELECT title_en, authors, tags FROM papers_fts WHERE rowid = :id"),
{"id": sample_paper.id},
).fetchone()
assert result == {"status": "success", "indexed": 1}
assert row is not None
assert row[0] == sample_paper.title_en
assert "Alice Smith" in row[1]
assert "NLP" in row[2]
def test_reindex_fts_accepts_subset(self, db_session, sample_papers_range):
keep_id = sample_papers_range[0].id
skip_id = sample_papers_range[1].id
db_session.execute(text("DELETE FROM papers_fts"))
db_session.commit()
result = reindex_fts(db_session, paper_ids=[keep_id])
keep_row = db_session.execute(
text("SELECT rowid FROM papers_fts WHERE rowid = :id"),
{"id": keep_id},
).fetchone()
skip_row = db_session.execute(
text("SELECT rowid FROM papers_fts WHERE rowid = :id"),
{"id": skip_id},
).fetchone()
assert result["indexed"] == 1
assert keep_row is not None
assert skip_row is None
class TestReindexChroma:
def test_reindex_chroma_indexes_only_summarized_papers(
self, db_session, sample_papers_with_summary
):
with patch("app.services.embedder.index_paper", return_value=True) as mock_index:
result = reindex_chroma(db_session)
assert result["status"] == "success"
assert result["indexed"] == 4
assert mock_index.call_count == 4
indexed_ids = {call.args[0] for call in mock_index.call_args_list}
assert "2401.20001" in indexed_ids
assert "2401.20005" not in indexed_ids
def test_reindex_chroma_reports_partial_failures(
self, db_session, sample_papers_with_summary
):
def _index_paper(arxiv_id, _texts):
if arxiv_id == "2401.20001":
raise RuntimeError("embedding failed")
return True
with patch("app.services.embedder.index_paper", side_effect=_index_paper):
result = reindex_chroma(db_session)
assert result["status"] == "partial"
assert result["indexed"] == 3
assert result["errors"] == ["2401.20001: embedding failed"]
+111
View File
@@ -0,0 +1,111 @@
"""后台 Job 服务测试。"""
from __future__ import annotations
from datetime import timedelta
from unittest.mock import patch
import pytest
from sqlalchemy import select
from app.models import Job, JobEvent, JobStatus, TaskLock
from app.services.jobs import create_job, recover_stale_jobs, run_job
from app.utils import utc_now
class TestJobs:
def test_create_job_writes_event(self, db_session):
job = create_job(
db_session,
"cleanup_tmp",
owner="test",
payload={"reason": "unit-test"},
)
assert job.id is not None
assert job.status == JobStatus.QUEUED
events = (
db_session.execute(select(JobEvent).where(JobEvent.job_id == job.id))
.scalars()
.all()
)
assert len(events) == 1
assert events[0].stage == "created"
@pytest.mark.asyncio
async def test_run_job_success(self, db_session):
job = create_job(db_session, "cleanup_tmp", owner="test", payload={})
with patch("app.services.cleaner.cleanup_tmp") as mock_cleanup:
mock_cleanup.return_value = {"scanned": 1, "removed": 1, "errors": []}
result = await run_job(db_session, job.id)
refreshed = db_session.get(Job, job.id)
assert result["removed"] == 1
assert refreshed.status == JobStatus.SUCCESS
assert refreshed.result_json is not None
@pytest.mark.asyncio
async def test_run_job_failure_records_error(self, db_session):
job = create_job(db_session, "missing_job_type", owner="test", payload={})
result = await run_job(db_session, job.id)
refreshed = db_session.get(Job, job.id)
assert result["status"] == "failed"
assert refreshed.status == JobStatus.FAILED
assert "Unsupported job type" in refreshed.error
@pytest.mark.asyncio
async def test_run_job_dispatches_refresh_upvotes(self, db_session):
job = create_job(
db_session,
"refresh_upvotes",
owner="test",
payload={"days": 3},
)
with patch("app.services.crawler.refresh_upvotes") as mock_refresh:
mock_refresh.return_value = {"status": "success", "updated": 2}
result = await run_job(db_session, job.id)
mock_refresh.assert_awaited_once_with(db_session, days=3)
assert result["updated"] == 2
@pytest.mark.asyncio
async def test_run_job_dispatches_reindex_fts(self, db_session):
job = create_job(db_session, "reindex_fts", owner="test", payload={})
with patch("app.services.derived.reindex_fts") as mock_reindex:
mock_reindex.return_value = {"status": "success", "indexed": 5}
result = await run_job(db_session, job.id)
mock_reindex.assert_called_once_with(db_session)
assert result["indexed"] == 5
def test_recover_stale_jobs_and_locks(self, db_session):
old = utc_now() - timedelta(hours=7)
job = Job(
type="cleanup_tmp",
status=JobStatus.RUNNING,
owner="test",
created_at=old,
started_at=old,
heartbeat_at=old,
)
lock = TaskLock(
task="cleanup",
lock_key="tmp",
status="running",
owner="test",
acquired_at=old,
)
db_session.add_all([job, lock])
db_session.commit()
recovered = recover_stale_jobs(db_session)
assert recovered == 2
assert db_session.get(Job, job.id).status == JobStatus.STALE
assert db_session.get(TaskLock, lock.id).status == "stale"
-111
View File
@@ -6,9 +6,6 @@ from datetime import date
from unittest.mock import patch as upatch from unittest.mock import patch as upatch
from app.config import settings
# ═══════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════
# Detail 页 & 相似论文 # Detail 页 & 相似论文
# ═══════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════
@@ -37,29 +34,6 @@ class TestDetailPage:
class TestTrendsDashboard: class TestTrendsDashboard:
"""趋势看板测试。""" """趋势看板测试。"""
def test_trends_page_renders(self, client, sample_papers_with_summary):
"""趋势看板页面正常渲染。"""
resp = client.get("/trends")
assert resp.status_code == 200
assert "趋势看板" in resp.text
assert "chart" in resp.text.lower() or "Chart" in resp.text
def test_trends_api_returns_data(self, client, sample_papers_with_summary):
"""趋势 API 返回正确数据结构。"""
resp = client.get("/api/stats/trends")
assert resp.status_code == 200
data = resp.json()
assert "daily_counts" in data
assert "top_tags" in data
assert "upvotes_dist" in data
assert "summary_completion" in data
assert isinstance(data["daily_counts"], list)
assert isinstance(data["top_tags"], list)
assert isinstance(data["upvotes_dist"], list)
assert isinstance(data["summary_completion"], list)
def test_trends_api_daily_counts(self, client, sample_papers_with_summary): def test_trends_api_daily_counts(self, client, sample_papers_with_summary):
"""每日论文数量数据正确。""" """每日论文数量数据正确。"""
# 使用测试数据的日期范围 # 使用测试数据的日期范围
@@ -108,12 +82,6 @@ class TestTrendsDashboard:
class TestComparePage: class TestComparePage:
"""论文对比页测试。""" """论文对比页测试。"""
def test_compare_page_no_ids(self, client):
"""无 ID 时显示输入表单。"""
resp = client.get("/compare")
assert resp.status_code == 200
assert "对比" in resp.text
def test_compare_page_with_ids(self, client, sample_papers_with_summary): def test_compare_page_with_ids(self, client, sample_papers_with_summary):
"""对比多篇论文正常渲染。""" """对比多篇论文正常渲染。"""
resp = client.get("/compare?ids=2401.20001,2401.20002") resp = client.get("/compare?ids=2401.20001,2401.20002")
@@ -124,23 +92,6 @@ class TestComparePage:
assert "一句话摘要" in resp.text assert "一句话摘要" in resp.text
assert "研究问题" in resp.text assert "研究问题" in resp.text
def test_compare_page_max_5(self, client, sample_papers_with_summary):
"""最多 5 篇。"""
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005"
resp = client.get(f"/compare?ids={ids}")
assert resp.status_code == 200
def test_compare_page_over_5_truncates(self, client, sample_papers_with_summary):
"""超过 5 篇截断。"""
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005,2401.20006"
resp = client.get(f"/compare?ids={ids}")
assert resp.status_code == 200
def test_compare_page_invalid_ids(self, client):
"""无效 ID 时显示空结果。"""
resp = client.get("/compare?ids=nonexistent.99999")
assert resp.status_code == 200
def test_compare_page_shows_no_summary_placeholder( def test_compare_page_shows_no_summary_placeholder(
self, client, sample_papers_with_summary self, client, sample_papers_with_summary
): ):
@@ -149,65 +100,3 @@ class TestComparePage:
resp = client.get("/compare?ids=2401.20005") resp = client.get("/compare?ids=2401.20005")
assert resp.status_code == 200 assert resp.status_code == 200
assert "暂无总结" in resp.text assert "暂无总结" in resp.text
# ═══════════════════════════════════════════════════════════════════════
# Nav Bar
# ═══════════════════════════════════════════════════════════════════════
class TestNavBar:
"""导航栏测试。"""
def test_nav_includes_trends_link(self, client):
"""导航栏应包含趋势链接。"""
resp = client.get("/search")
assert resp.status_code == 200
assert "/trends" in resp.text
def test_nav_includes_compare_implicitly(self, client):
"""compare 页面可访问。"""
resp = client.get("/compare")
assert resp.status_code == 200
# ═══════════════════════════════════════════════════════════════════════
# Graceful DegradationCHROMA_ENABLED=false
# ═══════════════════════════════════════════════════════════════════════
class TestGracefulDegradation:
"""CHROMA_ENABLED=false 时优雅降级测试。"""
def test_search_works_without_chroma(
self, client, monkeypatch, sample_papers_with_summary
):
"""CHROMA 关闭时 FTS5 搜索正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/search?q=Test")
assert resp.status_code == 200
assert "Test Paper" in resp.text or "测试论文" in resp.text
def test_detail_works_without_chroma(
self, client, monkeypatch, sample_papers_with_summary
):
"""CHROMA 关闭时详情页正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/paper/2401.20001")
assert resp.status_code == 200
def test_trends_works_without_chroma(
self, client, monkeypatch, sample_papers_with_summary
):
"""CHROMA 关闭时趋势看板正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/trends")
assert resp.status_code == 200
def test_compare_works_without_chroma(
self, client, monkeypatch, sample_papers_with_summary
):
"""CHROMA 关闭时对比页正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/compare?ids=2401.20001,2401.20002")
assert resp.status_code == 200
-69
View File
@@ -123,38 +123,12 @@ class TestSearchSemanticMode:
class TestSearchRoutes: class TestSearchRoutes:
"""搜索页面和 JSON API 路由测试。""" """搜索页面和 JSON API 路由测试。"""
def test_search_page_renders(self, client):
"""GET /search 返回 200。"""
resp = client.get("/search")
assert resp.status_code == 200
assert "搜索" in resp.text
def test_search_page_with_query(self, client, sample_paper): def test_search_page_with_query(self, client, sample_paper):
"""GET /search?q=Test 返回搜索结果。""" """GET /search?q=Test 返回搜索结果。"""
resp = client.get("/search?q=Test") resp = client.get("/search?q=Test")
assert resp.status_code == 200 assert resp.status_code == 200
assert "2401.12345" in resp.text assert "2401.12345" in resp.text
def test_search_page_with_tag(self, client, sample_paper):
"""GET /search?tag=NLP 返回标签筛选结果。"""
resp = client.get("/search?tag=NLP")
assert resp.status_code == 200
assert "2401.12345" in resp.text
def test_search_page_keyword_mode(self, client, sample_papers_with_summary):
"""搜索页 keyword 模式。"""
resp = client.get("/search?q=Test&mode=keyword")
assert resp.status_code == 200
assert "Test" in resp.text or "测试" in resp.text
def test_search_page_semantic_disabled(
self, client, monkeypatch, sample_papers_with_summary
):
"""语义模式 CHROMA_ENABLED=false 时仍能工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/search?q=Test&mode=semantic")
assert resp.status_code == 200
def test_search_api_json(self, client, sample_paper): def test_search_api_json(self, client, sample_paper):
"""GET /api/search?q=Test 返回 JSON。""" """GET /api/search?q=Test 返回 JSON。"""
resp = client.get("/api/search?q=Test") resp = client.get("/api/search?q=Test")
@@ -170,14 +144,6 @@ class TestSearchRoutes:
data = resp.json() data = resp.json()
assert data["total"] == 1 assert data["total"] == 1
def test_search_api_with_mode(self, client, sample_papers_with_summary):
"""搜索 API 支持 mode 参数。"""
resp = client.get("/api/search?q=Test&mode=keyword")
assert resp.status_code == 200
data = resp.json()
assert "results" in data
assert "total" in data
def test_search_api_empty(self, client, sample_paper): def test_search_api_empty(self, client, sample_paper):
"""GET /api/search?q=nonexistent 返回空结果。""" """GET /api/search?q=nonexistent 返回空结果。"""
resp = client.get("/api/search?q=nonexistent") resp = client.get("/api/search?q=nonexistent")
@@ -185,13 +151,6 @@ class TestSearchRoutes:
data = resp.json() data = resp.json()
assert data["total"] == 0 assert data["total"] == 0
def test_search_api_sort_by_date(self, client, sample_paper):
"""GET /api/search?q=Test&sort=date 按日期排序。"""
resp = client.get("/api/search?q=Test&sort=date")
assert resp.status_code == 200
data = resp.json()
assert data["total"] >= 1
# ═══════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════
# Similar Paper API 测试 # Similar Paper API 测试
@@ -211,21 +170,6 @@ class TestSimilarAPI:
data = resp.json() data = resp.json()
assert data["results"] == [] assert data["results"] == []
def test_similar_api_paper_not_found(self, client, monkeypatch):
"""不存在的论文返回空。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/api/similar/nonexistent.99999")
assert resp.status_code == 200
assert resp.json()["results"] == []
def test_similar_api_with_top_k(
self, client, monkeypatch, sample_papers_with_summary
):
"""top_k 参数控制返回数量。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/api/similar/2401.20001?top_k=3")
assert resp.status_code == 200
# ═══════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════
# 阅读列表路由测试 # 阅读列表路由测试
@@ -235,12 +179,6 @@ class TestSimilarAPI:
class TestReadingListRoute: class TestReadingListRoute:
"""阅读列表页面测试。""" """阅读列表页面测试。"""
def test_reading_list_empty(self, client):
"""无收藏时显示空状态。"""
resp = client.get("/reading-list")
assert resp.status_code == 200
assert "阅读列表" in resp.text
def test_reading_list_with_bookmark(self, client, sample_paper): def test_reading_list_with_bookmark(self, client, sample_paper):
"""有收藏时显示论文。""" """有收藏时显示论文。"""
# 先收藏 # 先收藏
@@ -302,13 +240,6 @@ class TestRssFeed:
assert "<channel>" in resp.text assert "<channel>" in resp.text
assert "2401.12345" in resp.text assert "2401.12345" in resp.text
def test_rss_has_paper_item(self, client, sample_paper):
"""RSS 包含论文条目。"""
resp = client.get("/rss.xml")
assert "<item>" in resp.text
assert "<title>" in resp.text
assert "/paper/2401.12345" in resp.text
def test_rss_with_tag_filter(self, client, sample_paper): def test_rss_with_tag_filter(self, client, sample_paper):
"""GET /rss.xml?tag=NLP 按标签筛选。""" """GET /rss.xml?tag=NLP 按标签筛选。"""
resp = client.get("/rss.xml?tag=NLP") resp = client.get("/rss.xml?tag=NLP")