refactor: extract admin business logic to services, introduce job queue, add derived index helpers

- Move DB operations from routes/admin.py to services/admin.py (get_logs_context, query_summary_statuses, retry_failed, delete/reset operations)
- Add services/jobs.py with Job/JobEvent-based async job queue (create_job, run_job, enqueue_job)
- Add services/derived.py with FTS5 reindex and paper index deletion helpers
- Refactor scheduler to use job queue instead of direct pipeline calls
- Add heartbeat_at/expires_at to TaskLock for lock health tracking
- Remove DESIGN_REVIEW.md
- Update tests: remove redundant integration tests, add unit tests for new services
This commit is contained in:
2026-06-13 18:31:43 +08:00
parent 21f16e6756
commit 743d69efd0
20 changed files with 1391 additions and 1063 deletions
+148 -91
View File
@@ -3,14 +3,17 @@
from __future__ import annotations
import logging
from unittest.mock import AsyncMock, patch
from unittest.mock import patch
import pytest
from sqlalchemy import select
from sqlalchemy import select, text
from app.config import settings
from app.models import (
CrawlLog,
Job,
SummaryState,
SummaryStatus,
TaskLock,
)
from app.utils import utc_now
@@ -64,47 +67,13 @@ class TestAdminAuth:
resp = auth_client.get("/admin/logs", follow_redirects=False)
assert resp.status_code == 303
def test_correct_session_accepted(self, auth_client):
"""已登录 session 应被接受(crawl 可能会失败但不是 303)。"""
with patch(
"app.routes.admin.run_crawl", new_callable=AsyncMock
) as mock_crawl:
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
resp = auth_client.post("/admin/crawl")
assert resp.status_code != 303
# ── summarize route auth ────────────────────────────────────────
def test_no_session_returns_303_for_summarize(self, client, monkeypatch):
"""无 session 返回 303。"""
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
resp = client.post("/admin/summarize", follow_redirects=False)
assert resp.status_code == 303
def test_correct_session_batch_summarize(self, auth_client):
"""已登录调用 batch summarizemock 掉服务层"""
with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock:
mock.return_value = {
"status": "success",
"done": 0,
"failed": 0,
"total": 0,
}
"""已登录调用 batch summarize应创建后台任务"""
with patch("app.routes.admin.enqueue_job"):
resp = auth_client.post("/admin/summarize")
assert resp.status_code == 200
assert resp.json()["status"] == "success"
def test_single_paper_not_found(self, auth_client):
"""单篇总结不存在的论文返回 404。"""
from app.exceptions import NotFoundError
with patch(
"app.routes.admin.summarize_single",
new_callable=AsyncMock,
side_effect=NotFoundError("Paper not found: nonexistent.99999"),
):
resp = auth_client.post("/admin/summarize/nonexistent.99999")
assert resp.status_code == 404
assert resp.json()["status"] == "queued"
assert "job_id" in resp.json()
# ═══════════════════════════════════════════════════════════════════════
@@ -115,29 +84,12 @@ class TestAdminAuth:
class TestAdminCrawl:
"""POST /admin/crawl 测试。"""
def test_crawl_default_today(self, auth_client):
"""不指定日期时默认抓取今天。"""
with patch(
"app.routes.admin.run_crawl", new_callable=AsyncMock
) as mock_crawl:
mock_crawl.return_value = {"found": 5, "new": 3, "status": "success"}
resp = auth_client.post("/admin/crawl")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "success"
mock_crawl.assert_called_once()
def test_crawl_specific_date(self, auth_client):
"""指定日期抓取。"""
with patch(
"app.routes.admin.run_crawl", new_callable=AsyncMock
) as mock_crawl:
mock_crawl.return_value = {"found": 2, "new": 1, "status": "success"}
with patch("app.routes.admin.enqueue_job"):
resp = auth_client.post("/admin/crawl?date=2024-01-15")
assert resp.status_code == 200
mock_crawl.assert_called_once()
call_args = mock_crawl.call_args
assert call_args[0][1] == "2024-01-15"
assert resp.json()["target_date"] == "2024-01-15"
# ═══════════════════════════════════════════════════════════════════════
@@ -149,20 +101,20 @@ class TestAdminCleanup:
"""POST /admin/cleanup 测试。"""
def test_cleanup_returns_stats(self, auth_client):
"""清理应返回统计信息。"""
"""同步清理排障接口应返回统计信息。"""
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
mock_cleanup.return_value = {"scanned": 3, "removed": 1, "errors": []}
resp = auth_client.post("/admin/cleanup")
resp = auth_client.post("/admin/cleanup-now")
assert resp.status_code == 200
data = resp.json()
assert data["scanned"] == 3
assert data["removed"] == 1
def test_cleanup_writes_log(self, auth_client, db_session):
"""清理应写入 crawl_logs。"""
"""同步清理排障接口应写入 crawl_logs。"""
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
mock_cleanup.return_value = {"scanned": 0, "removed": 0, "errors": []}
auth_client.post("/admin/cleanup")
auth_client.post("/admin/cleanup-now")
logs = (
db_session.execute(select(CrawlLog).where(CrawlLog.task == "cleanup"))
@@ -195,19 +147,21 @@ class TestAdminDelete:
assert resp.status_code == 422
def test_delete_with_confirm(self, auth_client, db_session, sample_papers_range):
"""confirm='DELETE' 时应执行删除"""
resp = auth_client.post(
"/admin/delete",
json={
"date_start": "2024-01-10",
"date_end": "2024-01-12",
"include_notes": True,
"confirm": "DELETE",
},
)
"""confirm='DELETE' 时应创建后台删除 job"""
with patch("app.routes.admin.enqueue_job"):
resp = auth_client.post(
"/admin/delete",
json={
"date_start": "2024-01-10",
"date_end": "2024-01-12",
"include_notes": True,
"confirm": "DELETE",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["deleted"] == 3
assert data["status"] == "queued"
assert db_session.get(Job, data["job_id"]) is not None
def test_delete_invalid_date_range(self, auth_client):
"""date_start > date_end 应返回 400。"""
@@ -221,17 +175,6 @@ class TestAdminDelete:
)
assert resp.status_code == 400
def test_delete_without_confirm_field(self, auth_client):
"""缺少 confirm 字段应返回 422。"""
resp = auth_client.post(
"/admin/delete",
json={
"date_start": "2024-01-10",
"date_end": "2024-01-12",
},
)
assert resp.status_code == 422
# ═══════════════════════════════════════════════════════════════════════
# Admin Routes — Logs
@@ -241,12 +184,6 @@ class TestAdminDelete:
class TestAdminLogs:
"""GET /admin/logs 测试。"""
def test_logs_returns_page(self, auth_client):
"""应返回管理日志页面。"""
resp = auth_client.get("/admin/logs")
assert resp.status_code == 200
assert "text/html" in resp.headers.get("content-type", "")
def test_logs_requires_auth(self, client, monkeypatch):
"""日志页面需要鉴权。"""
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
@@ -272,6 +209,126 @@ class TestAdminLogs:
assert "crawl" in resp.text.lower() or "日志" in resp.text
class TestAdminJobs:
"""后台 job 查询接口测试。"""
def test_job_detail_returns_payload_and_events(self, auth_client, db_session):
"""GET /admin/jobs/{id} 返回 job 主记录和事件。"""
with patch("app.routes.admin.enqueue_job"):
resp = auth_client.post("/admin/crawl?date=2024-01-15")
job_id = resp.json()["job_id"]
resp = auth_client.get(f"/admin/jobs/{job_id}")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == job_id
assert data["type"] == "crawl_daily"
assert data["payload"] == {"target_date": "2024-01-15"}
assert data["events"][0]["stage"] == "created"
def test_job_detail_not_found(self, auth_client):
resp = auth_client.get("/admin/jobs/999999")
assert resp.status_code == 404
class TestAdminSummaryStatus:
"""总结状态管理接口测试。"""
def test_summary_status_json_filters_failed(
self, auth_client, db_session, sample_paper
):
sample_paper.summary_status.status = SummaryState.FAILED
sample_paper.summary_status.retry_count = 2
sample_paper.summary_status.error_type = "timeout"
db_session.commit()
resp = auth_client.get("/admin/summary-status?status=failed")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
assert data["items"][0]["arxiv_id"] == sample_paper.arxiv_id
assert data["items"][0]["retry_count"] == 2
def test_retry_failed_resets_failed_statuses(
self, auth_client, db_session, sample_paper
):
sample_paper.summary_status.status = SummaryState.PERMANENT_FAILURE
sample_paper.summary_status.error = "bad json"
sample_paper.summary_status.error_type = "json_invalid"
db_session.commit()
resp = auth_client.post("/admin/summary-retry-failed")
assert resp.status_code == 200
assert resp.json()["count"] == 1
db_session.refresh(sample_paper.summary_status)
assert sample_paper.summary_status.status == SummaryState.PENDING
assert sample_paper.summary_status.error is None
assert sample_paper.summary_status.error_type is None
class TestAdminPapers:
"""论文管理批量操作测试。"""
def test_single_delete_removes_paper_and_fts(
self, auth_client, db_session, sample_paper
):
paper_id = sample_paper.id
resp = auth_client.post(f"/admin/paper-delete/{sample_paper.arxiv_id}")
assert resp.status_code == 200
assert db_session.get(type(sample_paper), paper_id) is None
fts_row = db_session.execute(
text("SELECT rowid FROM papers_fts WHERE rowid = :id"),
{"id": paper_id},
).fetchone()
assert fts_row is None
def test_batch_delete_removes_papers_and_fts(
self, auth_client, db_session, sample_papers_range
):
target_ids = [p.id for p in sample_papers_range[:2]]
target_arxiv_ids = [p.arxiv_id for p in sample_papers_range[:2]]
resp = auth_client.post(
"/admin/papers-batch-action",
json={"action": "delete", "arxiv_ids": target_arxiv_ids},
)
assert resp.status_code == 200
assert resp.json()["count"] == 2
remaining = db_session.execute(
text(
"SELECT rowid FROM papers_fts "
"WHERE rowid IN (:id1, :id2)"
),
{"id1": target_ids[0], "id2": target_ids[1]},
).fetchall()
assert remaining == []
def test_batch_summarize_sets_pending_status(
self, auth_client, db_session, sample_papers_range
):
paper = sample_papers_range[0]
paper.summary_status.status = SummaryState.DONE
db_session.commit()
resp = auth_client.post(
"/admin/papers-batch-action",
json={"action": "summarize", "arxiv_ids": [paper.arxiv_id]},
)
assert resp.status_code == 200
status = db_session.scalar(
select(SummaryStatus).where(SummaryStatus.paper_id == paper.id)
)
assert status is not None
assert status.status == SummaryState.PENDING
# ═══════════════════════════════════════════════════════════════════════
# Scheduler 测试
# ═══════════════════════════════════════════════════════════════════════
+60
View File
@@ -10,6 +10,7 @@ from app.services.crawler import (
_parse_paper,
crawl_daily,
fetch_daily,
refresh_upvotes,
upsert_papers,
)
@@ -187,3 +188,62 @@ class TestCrawlDaily:
assert result["status"] == "failed"
assert "network error" in result["error"]
class TestRefreshUpvotes:
@pytest.mark.asyncio
async def test_refresh_updates_existing_without_inserting_new(
self, db_session, sample_paper
):
sample_paper.arxiv_id = "1706.03762"
sample_paper.upvotes = 10
db_session.commit()
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
return_value=[
{
"paper": {
"id": "1706.03762",
"upvotes": 999,
"authors": [],
"tags": [],
}
},
{
"paper": {
"id": "2010.11929",
"upvotes": 123,
"authors": [],
"tags": [],
}
},
],
):
result = await refresh_upvotes(db_session, days=1)
db_session.refresh(sample_paper)
assert result["status"] == "success"
assert result["updated"] == 1
assert sample_paper.upvotes == 999
assert db_session.query(type(sample_paper)).count() == 1
@pytest.mark.asyncio
async def test_refresh_returns_partial_when_one_day_fails(self, db_session):
async def _fetch_daily(target_date):
if target_date.endswith("01"):
raise ConnectionError("hf down")
return []
with (
patch(
"app.services.crawler.recent_date_strs",
return_value=["2024-01-01", "2024-01-02"],
),
patch("app.services.crawler.fetch_daily", side_effect=_fetch_daily),
):
result = await refresh_upvotes(db_session, days=2)
assert result["status"] == "partial"
assert result["errors"] == ["2024-01-01: hf down"]
+80
View File
@@ -0,0 +1,80 @@
"""派生索引维护测试。"""
from __future__ import annotations
from unittest.mock import patch
from sqlalchemy import text
from app.services.derived import reindex_chroma, reindex_fts
class TestReindexFts:
def test_reindex_fts_rebuilds_missing_rows(self, db_session, sample_paper):
db_session.execute(
text("DELETE FROM papers_fts WHERE rowid = :id"),
{"id": sample_paper.id},
)
db_session.commit()
result = reindex_fts(db_session)
row = db_session.execute(
text("SELECT title_en, authors, tags FROM papers_fts WHERE rowid = :id"),
{"id": sample_paper.id},
).fetchone()
assert result == {"status": "success", "indexed": 1}
assert row is not None
assert row[0] == sample_paper.title_en
assert "Alice Smith" in row[1]
assert "NLP" in row[2]
def test_reindex_fts_accepts_subset(self, db_session, sample_papers_range):
keep_id = sample_papers_range[0].id
skip_id = sample_papers_range[1].id
db_session.execute(text("DELETE FROM papers_fts"))
db_session.commit()
result = reindex_fts(db_session, paper_ids=[keep_id])
keep_row = db_session.execute(
text("SELECT rowid FROM papers_fts WHERE rowid = :id"),
{"id": keep_id},
).fetchone()
skip_row = db_session.execute(
text("SELECT rowid FROM papers_fts WHERE rowid = :id"),
{"id": skip_id},
).fetchone()
assert result["indexed"] == 1
assert keep_row is not None
assert skip_row is None
class TestReindexChroma:
def test_reindex_chroma_indexes_only_summarized_papers(
self, db_session, sample_papers_with_summary
):
with patch("app.services.embedder.index_paper", return_value=True) as mock_index:
result = reindex_chroma(db_session)
assert result["status"] == "success"
assert result["indexed"] == 4
assert mock_index.call_count == 4
indexed_ids = {call.args[0] for call in mock_index.call_args_list}
assert "2401.20001" in indexed_ids
assert "2401.20005" not in indexed_ids
def test_reindex_chroma_reports_partial_failures(
self, db_session, sample_papers_with_summary
):
def _index_paper(arxiv_id, _texts):
if arxiv_id == "2401.20001":
raise RuntimeError("embedding failed")
return True
with patch("app.services.embedder.index_paper", side_effect=_index_paper):
result = reindex_chroma(db_session)
assert result["status"] == "partial"
assert result["indexed"] == 3
assert result["errors"] == ["2401.20001: embedding failed"]
+111
View File
@@ -0,0 +1,111 @@
"""后台 Job 服务测试。"""
from __future__ import annotations
from datetime import timedelta
from unittest.mock import patch
import pytest
from sqlalchemy import select
from app.models import Job, JobEvent, JobStatus, TaskLock
from app.services.jobs import create_job, recover_stale_jobs, run_job
from app.utils import utc_now
class TestJobs:
def test_create_job_writes_event(self, db_session):
job = create_job(
db_session,
"cleanup_tmp",
owner="test",
payload={"reason": "unit-test"},
)
assert job.id is not None
assert job.status == JobStatus.QUEUED
events = (
db_session.execute(select(JobEvent).where(JobEvent.job_id == job.id))
.scalars()
.all()
)
assert len(events) == 1
assert events[0].stage == "created"
@pytest.mark.asyncio
async def test_run_job_success(self, db_session):
job = create_job(db_session, "cleanup_tmp", owner="test", payload={})
with patch("app.services.cleaner.cleanup_tmp") as mock_cleanup:
mock_cleanup.return_value = {"scanned": 1, "removed": 1, "errors": []}
result = await run_job(db_session, job.id)
refreshed = db_session.get(Job, job.id)
assert result["removed"] == 1
assert refreshed.status == JobStatus.SUCCESS
assert refreshed.result_json is not None
@pytest.mark.asyncio
async def test_run_job_failure_records_error(self, db_session):
job = create_job(db_session, "missing_job_type", owner="test", payload={})
result = await run_job(db_session, job.id)
refreshed = db_session.get(Job, job.id)
assert result["status"] == "failed"
assert refreshed.status == JobStatus.FAILED
assert "Unsupported job type" in refreshed.error
@pytest.mark.asyncio
async def test_run_job_dispatches_refresh_upvotes(self, db_session):
job = create_job(
db_session,
"refresh_upvotes",
owner="test",
payload={"days": 3},
)
with patch("app.services.crawler.refresh_upvotes") as mock_refresh:
mock_refresh.return_value = {"status": "success", "updated": 2}
result = await run_job(db_session, job.id)
mock_refresh.assert_awaited_once_with(db_session, days=3)
assert result["updated"] == 2
@pytest.mark.asyncio
async def test_run_job_dispatches_reindex_fts(self, db_session):
job = create_job(db_session, "reindex_fts", owner="test", payload={})
with patch("app.services.derived.reindex_fts") as mock_reindex:
mock_reindex.return_value = {"status": "success", "indexed": 5}
result = await run_job(db_session, job.id)
mock_reindex.assert_called_once_with(db_session)
assert result["indexed"] == 5
def test_recover_stale_jobs_and_locks(self, db_session):
old = utc_now() - timedelta(hours=7)
job = Job(
type="cleanup_tmp",
status=JobStatus.RUNNING,
owner="test",
created_at=old,
started_at=old,
heartbeat_at=old,
)
lock = TaskLock(
task="cleanup",
lock_key="tmp",
status="running",
owner="test",
acquired_at=old,
)
db_session.add_all([job, lock])
db_session.commit()
recovered = recover_stale_jobs(db_session)
assert recovered == 2
assert db_session.get(Job, job.id).status == JobStatus.STALE
assert db_session.get(TaskLock, lock.id).status == "stale"
-111
View File
@@ -6,9 +6,6 @@ from datetime import date
from unittest.mock import patch as upatch
from app.config import settings
# ═══════════════════════════════════════════════════════════════════════
# Detail 页 & 相似论文
# ═══════════════════════════════════════════════════════════════════════
@@ -37,29 +34,6 @@ class TestDetailPage:
class TestTrendsDashboard:
"""趋势看板测试。"""
def test_trends_page_renders(self, client, sample_papers_with_summary):
"""趋势看板页面正常渲染。"""
resp = client.get("/trends")
assert resp.status_code == 200
assert "趋势看板" in resp.text
assert "chart" in resp.text.lower() or "Chart" in resp.text
def test_trends_api_returns_data(self, client, sample_papers_with_summary):
"""趋势 API 返回正确数据结构。"""
resp = client.get("/api/stats/trends")
assert resp.status_code == 200
data = resp.json()
assert "daily_counts" in data
assert "top_tags" in data
assert "upvotes_dist" in data
assert "summary_completion" in data
assert isinstance(data["daily_counts"], list)
assert isinstance(data["top_tags"], list)
assert isinstance(data["upvotes_dist"], list)
assert isinstance(data["summary_completion"], list)
def test_trends_api_daily_counts(self, client, sample_papers_with_summary):
"""每日论文数量数据正确。"""
# 使用测试数据的日期范围
@@ -108,12 +82,6 @@ class TestTrendsDashboard:
class TestComparePage:
"""论文对比页测试。"""
def test_compare_page_no_ids(self, client):
"""无 ID 时显示输入表单。"""
resp = client.get("/compare")
assert resp.status_code == 200
assert "对比" in resp.text
def test_compare_page_with_ids(self, client, sample_papers_with_summary):
"""对比多篇论文正常渲染。"""
resp = client.get("/compare?ids=2401.20001,2401.20002")
@@ -124,23 +92,6 @@ class TestComparePage:
assert "一句话摘要" in resp.text
assert "研究问题" in resp.text
def test_compare_page_max_5(self, client, sample_papers_with_summary):
"""最多 5 篇。"""
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005"
resp = client.get(f"/compare?ids={ids}")
assert resp.status_code == 200
def test_compare_page_over_5_truncates(self, client, sample_papers_with_summary):
"""超过 5 篇截断。"""
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005,2401.20006"
resp = client.get(f"/compare?ids={ids}")
assert resp.status_code == 200
def test_compare_page_invalid_ids(self, client):
"""无效 ID 时显示空结果。"""
resp = client.get("/compare?ids=nonexistent.99999")
assert resp.status_code == 200
def test_compare_page_shows_no_summary_placeholder(
self, client, sample_papers_with_summary
):
@@ -149,65 +100,3 @@ class TestComparePage:
resp = client.get("/compare?ids=2401.20005")
assert resp.status_code == 200
assert "暂无总结" in resp.text
# ═══════════════════════════════════════════════════════════════════════
# Nav Bar
# ═══════════════════════════════════════════════════════════════════════
class TestNavBar:
"""导航栏测试。"""
def test_nav_includes_trends_link(self, client):
"""导航栏应包含趋势链接。"""
resp = client.get("/search")
assert resp.status_code == 200
assert "/trends" in resp.text
def test_nav_includes_compare_implicitly(self, client):
"""compare 页面可访问。"""
resp = client.get("/compare")
assert resp.status_code == 200
# ═══════════════════════════════════════════════════════════════════════
# Graceful DegradationCHROMA_ENABLED=false
# ═══════════════════════════════════════════════════════════════════════
class TestGracefulDegradation:
"""CHROMA_ENABLED=false 时优雅降级测试。"""
def test_search_works_without_chroma(
self, client, monkeypatch, sample_papers_with_summary
):
"""CHROMA 关闭时 FTS5 搜索正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/search?q=Test")
assert resp.status_code == 200
assert "Test Paper" in resp.text or "测试论文" in resp.text
def test_detail_works_without_chroma(
self, client, monkeypatch, sample_papers_with_summary
):
"""CHROMA 关闭时详情页正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/paper/2401.20001")
assert resp.status_code == 200
def test_trends_works_without_chroma(
self, client, monkeypatch, sample_papers_with_summary
):
"""CHROMA 关闭时趋势看板正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/trends")
assert resp.status_code == 200
def test_compare_works_without_chroma(
self, client, monkeypatch, sample_papers_with_summary
):
"""CHROMA 关闭时对比页正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/compare?ids=2401.20001,2401.20002")
assert resp.status_code == 200
-69
View File
@@ -123,38 +123,12 @@ class TestSearchSemanticMode:
class TestSearchRoutes:
"""搜索页面和 JSON API 路由测试。"""
def test_search_page_renders(self, client):
"""GET /search 返回 200。"""
resp = client.get("/search")
assert resp.status_code == 200
assert "搜索" in resp.text
def test_search_page_with_query(self, client, sample_paper):
"""GET /search?q=Test 返回搜索结果。"""
resp = client.get("/search?q=Test")
assert resp.status_code == 200
assert "2401.12345" in resp.text
def test_search_page_with_tag(self, client, sample_paper):
"""GET /search?tag=NLP 返回标签筛选结果。"""
resp = client.get("/search?tag=NLP")
assert resp.status_code == 200
assert "2401.12345" in resp.text
def test_search_page_keyword_mode(self, client, sample_papers_with_summary):
"""搜索页 keyword 模式。"""
resp = client.get("/search?q=Test&mode=keyword")
assert resp.status_code == 200
assert "Test" in resp.text or "测试" in resp.text
def test_search_page_semantic_disabled(
self, client, monkeypatch, sample_papers_with_summary
):
"""语义模式 CHROMA_ENABLED=false 时仍能工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/search?q=Test&mode=semantic")
assert resp.status_code == 200
def test_search_api_json(self, client, sample_paper):
"""GET /api/search?q=Test 返回 JSON。"""
resp = client.get("/api/search?q=Test")
@@ -170,14 +144,6 @@ class TestSearchRoutes:
data = resp.json()
assert data["total"] == 1
def test_search_api_with_mode(self, client, sample_papers_with_summary):
"""搜索 API 支持 mode 参数。"""
resp = client.get("/api/search?q=Test&mode=keyword")
assert resp.status_code == 200
data = resp.json()
assert "results" in data
assert "total" in data
def test_search_api_empty(self, client, sample_paper):
"""GET /api/search?q=nonexistent 返回空结果。"""
resp = client.get("/api/search?q=nonexistent")
@@ -185,13 +151,6 @@ class TestSearchRoutes:
data = resp.json()
assert data["total"] == 0
def test_search_api_sort_by_date(self, client, sample_paper):
"""GET /api/search?q=Test&sort=date 按日期排序。"""
resp = client.get("/api/search?q=Test&sort=date")
assert resp.status_code == 200
data = resp.json()
assert data["total"] >= 1
# ═══════════════════════════════════════════════════════════════════════
# Similar Paper API 测试
@@ -211,21 +170,6 @@ class TestSimilarAPI:
data = resp.json()
assert data["results"] == []
def test_similar_api_paper_not_found(self, client, monkeypatch):
"""不存在的论文返回空。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/api/similar/nonexistent.99999")
assert resp.status_code == 200
assert resp.json()["results"] == []
def test_similar_api_with_top_k(
self, client, monkeypatch, sample_papers_with_summary
):
"""top_k 参数控制返回数量。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/api/similar/2401.20001?top_k=3")
assert resp.status_code == 200
# ═══════════════════════════════════════════════════════════════════════
# 阅读列表路由测试
@@ -235,12 +179,6 @@ class TestSimilarAPI:
class TestReadingListRoute:
"""阅读列表页面测试。"""
def test_reading_list_empty(self, client):
"""无收藏时显示空状态。"""
resp = client.get("/reading-list")
assert resp.status_code == 200
assert "阅读列表" in resp.text
def test_reading_list_with_bookmark(self, client, sample_paper):
"""有收藏时显示论文。"""
# 先收藏
@@ -302,13 +240,6 @@ class TestRssFeed:
assert "<channel>" in resp.text
assert "2401.12345" in resp.text
def test_rss_has_paper_item(self, client, sample_paper):
"""RSS 包含论文条目。"""
resp = client.get("/rss.xml")
assert "<item>" in resp.text
assert "<title>" in resp.text
assert "/paper/2401.12345" in resp.text
def test_rss_with_tag_filter(self, client, sample_paper):
"""GET /rss.xml?tag=NLP 按标签筛选。"""
resp = client.get("/rss.xml?tag=NLP")