Files
daily-paper/tests/test_searcher.py
T
Rain-Bus 743d69efd0 refactor: extract admin business logic to services, introduce job queue, add derived index helpers
- Move DB operations from routes/admin.py to services/admin.py (get_logs_context, query_summary_statuses, retry_failed, delete/reset operations)
- Add services/jobs.py with Job/JobEvent-based async job queue (create_job, run_job, enqueue_job)
- Add services/derived.py with FTS5 reindex and paper index deletion helpers
- Refactor scheduler to use job queue instead of direct pipeline calls
- Add heartbeat_at/expires_at to TaskLock for lock health tracking
- Remove DESIGN_REVIEW.md
- Update tests: remove redundant integration tests, add unit tests for new services
2026-06-13 18:31:43 +08:00

261 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""搜索服务 + 路由 + 阅读列表 + RSS + 语义模式测试。"""
from __future__ import annotations
from datetime import date
import pytest
from app.config import settings
from app.services.searcher import get_all_tags, search_papers
# ═══════════════════════════════════════════════════════════════════════
# 搜索服务单元测试
# ═══════════════════════════════════════════════════════════════════════
class TestSearchService:
"""app/services/searcher.py — FTS5 关键词搜索单元测试。"""
def test_search_by_title(self, db_session, sample_paper):
result = search_papers(db_session, query="Test Paper")
assert result["total"] == 1
assert result["results"][0].arxiv_id == "2401.12345"
def test_search_by_abstract(self, db_session, sample_paper):
result = search_papers(db_session, query="test abstract")
assert result["total"] == 1
def test_search_by_author(self, db_session, sample_paper):
result = search_papers(db_session, query="Alice")
assert result["total"] == 1
def test_search_by_tag_in_fts(self, db_session, sample_paper):
result = search_papers(db_session, query="NLP")
assert result["total"] == 1
def test_search_no_results(self, db_session, sample_paper):
result = search_papers(db_session, query="quantum entanglement")
assert result["total"] == 0
assert result["results"] == []
def test_search_empty_query_returns_empty(self, db_session):
result = search_papers(db_session, query="")
assert result["total"] == 0
assert result["results"] == []
def test_search_special_characters_sanitized(self, db_session, sample_paper):
result = search_papers(db_session, query='Test "Paper" {test}')
assert result["total"] >= 1
def test_search_with_tag_filter(self, db_session, sample_paper):
result = search_papers(db_session, query="Paper", tag="NLP")
assert result["total"] == 1
result2 = search_papers(db_session, query="Paper", tag="nonexistent")
assert result2["total"] == 0
def test_search_tag_only_no_query(self, db_session, sample_paper):
result = search_papers(db_session, tag="NLP")
assert result["total"] == 1
assert result["results"][0].arxiv_id == "2401.12345"
def test_search_pagination(self, db_session, sample_paper):
result = search_papers(db_session, query="Test", page=2, page_size=10)
assert result["page"] == 2
assert result["total_pages"] == 1
def test_search_returns_snippets(self, db_session, sample_paper):
result = search_papers(db_session, query="test abstract")
assert result["total"] == 1
paper_id = result["results"][0].id
assert paper_id in result["snippets"]
assert "abstract" in result["snippets"][paper_id]
def test_get_all_tags(self, db_session, sample_paper):
tags = get_all_tags(db_session)
assert "NLP" in tags
assert "LLM" in tags
# ═══════════════════════════════════════════════════════════════════════
# 语义 / Embedder 模式测试
# ═══════════════════════════════════════════════════════════════════════
class TestSearchSemanticMode:
"""searcher.py — semantic 模式(含 embedder 回退)测试。"""
def test_keyword_mode_default(self, db_session, sample_papers_with_summary):
result = search_papers(db_session, query="Test Paper", mode="keyword")
assert result["total"] >= 1
assert result["distances"] == {}
def test_semantic_mode_disabled_fallback(
self, db_session, monkeypatch, sample_papers_with_summary
):
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
result = search_papers(db_session, query="Test", mode="semantic")
assert result["total"] >= 1
def test_search_returns_distances_dict(
self, db_session, sample_papers_with_summary
):
result = search_papers(db_session, query="Test Paper")
assert "distances" in result
assert isinstance(result["distances"], dict)
def test_empty_query_returns_empty_no_tags(self, db_session):
result = search_papers(db_session)
assert result["total"] == 0
assert result["results"] == []
def test_tag_only_search(self, db_session, sample_papers_with_summary):
result = search_papers(db_session, tag="NLP")
assert result["total"] >= 1
# ═══════════════════════════════════════════════════════════════════════
# 搜索路由 HTTP 测试
# ═══════════════════════════════════════════════════════════════════════
class TestSearchRoutes:
"""搜索页面和 JSON API 路由测试。"""
def test_search_page_with_query(self, client, sample_paper):
"""GET /search?q=Test 返回搜索结果。"""
resp = client.get("/search?q=Test")
assert resp.status_code == 200
assert "2401.12345" in resp.text
def test_search_api_json(self, client, sample_paper):
"""GET /api/search?q=Test 返回 JSON。"""
resp = client.get("/api/search?q=Test")
assert resp.status_code == 200
data = resp.json()
assert data["total"] >= 1
assert any(p["arxiv_id"] == "2401.12345" for p in data["results"])
def test_search_api_with_tag(self, client, sample_paper):
"""GET /api/search?q=Test&tag=NLP 返回筛选结果。"""
resp = client.get("/api/search?q=Test&tag=NLP")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
def test_search_api_empty(self, client, sample_paper):
"""GET /api/search?q=nonexistent 返回空结果。"""
resp = client.get("/api/search?q=nonexistent")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 0
# ═══════════════════════════════════════════════════════════════════════
# Similar Paper API 测试
# ═══════════════════════════════════════════════════════════════════════
class TestSimilarAPI:
"""相似论文 API 测试。"""
def test_similar_api_disabled(
self, client, monkeypatch, sample_papers_with_summary
):
"""CHROMA_ENABLED=false 时返回空列表。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/api/similar/2401.20001")
assert resp.status_code == 200
data = resp.json()
assert data["results"] == []
# ═══════════════════════════════════════════════════════════════════════
# 阅读列表路由测试
# ═══════════════════════════════════════════════════════════════════════
class TestReadingListRoute:
"""阅读列表页面测试。"""
def test_reading_list_with_bookmark(self, client, sample_paper):
"""有收藏时显示论文。"""
# 先收藏
client.post("/api/bookmark/2401.12345")
resp = client.get("/reading-list")
assert resp.status_code == 200
assert "2401.12345" in resp.text
def test_reading_list_filter_by_status(self, client, sample_paper):
"""按阅读状态筛选。"""
# 设置阅读状态
client.post(
"/api/reading-status/2401.12345",
json={"status": "read_summary"},
)
# 筛选 read_summary
resp = client.get("/reading-list?filter=read_summary")
assert resp.status_code == 200
assert "2401.12345" in resp.text
# 筛选 unread(不应出现,因为状态是 read_summary
resp2 = client.get("/reading-list?filter=unread")
assert resp2.status_code == 200
assert "2401.12345" not in resp2.text
def test_reading_list_has_note_filter(self, client, sample_paper):
"""筛选有笔记的论文。"""
# 写笔记
client.post(
"/api/note/2401.12345",
json={"content": "这是一条笔记"},
)
resp = client.get("/reading-list?filter=has_note")
assert resp.status_code == 200
assert "2401.12345" in resp.text
# ═══════════════════════════════════════════════════════════════════════
# RSS Feed 测试
# ═══════════════════════════════════════════════════════════════════════
class TestRssFeed:
"""RSS Feed 路由测试。"""
@pytest.fixture(autouse=True)
def _recent_paper(self, db_session, sample_paper):
"""将 sample_paper 的 paper_date 设为今天,确保在 RSS 7 天窗口内。"""
sample_paper.paper_date = date.today()
db_session.commit()
def test_rss_xml_structure(self, client, sample_paper):
"""GET /rss.xml 返回有效 XML。"""
resp = client.get("/rss.xml")
assert resp.status_code == 200
assert "application/xml" in resp.headers["content-type"]
assert "<?xml" in resp.text
assert "<rss" in resp.text
assert "<channel>" in resp.text
assert "2401.12345" in resp.text
def test_rss_with_tag_filter(self, client, sample_paper):
"""GET /rss.xml?tag=NLP 按标签筛选。"""
resp = client.get("/rss.xml?tag=NLP")
assert resp.status_code == 200
assert "2401.12345" in resp.text
resp2 = client.get("/rss.xml?tag=nonexistent")
assert resp2.status_code == 200
assert "2401.12345" not in resp2.text
def test_rss_uses_chinese_title(self, client, db_session, sample_paper):
"""RSS 使用中文标题(如果有的话)。"""
sample_paper.title_zh = "测试中文标题"
db_session.commit()
resp = client.get("/rss.xml")
assert resp.status_code == 200
assert "测试中文标题" in resp.text