Files
daily-paper/tests/test_searcher.py
T
Rain-Bus f7f1a4c0cb refactor: split monolithic phase tests into per-module test files
- rename test_admin_phase4.py -> test_admin.py, test_search.py -> test_searcher.py
- split test_phase5.py into test_cleaner, test_embedder, test_image_extractor, test_pages
- move schema tests from test_summarizer.py into dedicated test_schemas.py
- add sample_papers_range and sample_papers_with_summary fixtures in conftest
- update .gitignore to exclude all of data/
2026-06-06 00:34:30 +08:00

346 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""搜索服务 + 路由 + 阅读列表 + RSS + 语义模式测试。"""
from __future__ import annotations
import pytest
from datetime import date, datetime, timezone
from app.config import settings
# ═══════════════════════════════════════════════════════════════════════
# 搜索服务单元测试
# ═══════════════════════════════════════════════════════════════════════
class TestSearchService:
"""app/services/searcher.py — FTS5 关键词搜索单元测试。"""
def test_search_by_title(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test Paper")
assert result["total"] == 1
assert result["results"][0].arxiv_id == "2401.12345"
def test_search_by_abstract(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="test abstract")
assert result["total"] == 1
def test_search_by_author(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="Alice")
assert result["total"] == 1
def test_search_by_tag_in_fts(self, db_session, sample_paper):
from app.services.searcher import search_papers
# FTS5 索引中包含 tags 列,可以搜到
result = search_papers(db_session, query="NLP")
assert result["total"] == 1
def test_search_no_results(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="quantum entanglement")
assert result["total"] == 0
assert result["results"] == []
def test_search_empty_query_returns_empty(self, db_session):
from app.services.searcher import search_papers
result = search_papers(db_session, query="")
assert result["total"] == 0
assert result["results"] == []
def test_search_special_characters_sanitized(self, db_session, sample_paper):
from app.services.searcher import search_papers
# 特殊字符被清除后,剩下 "Test" 仍然能搜到
result = search_papers(db_session, query='Test "Paper" {test}')
assert result["total"] >= 1
def test_search_with_tag_filter(self, db_session, sample_paper):
from app.services.searcher import search_papers
# 关键词 + 标签筛选
result = search_papers(db_session, query="Paper", tag="NLP")
assert result["total"] == 1
# 标签不匹配 → 0
result2 = search_papers(db_session, query="Paper", tag="nonexistent")
assert result2["total"] == 0
def test_search_tag_only_no_query(self, db_session, sample_paper):
from app.services.searcher import search_papers
# 只有标签,无关键词
result = search_papers(db_session, tag="NLP")
assert result["total"] == 1
assert result["results"][0].arxiv_id == "2401.12345"
def test_search_pagination(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test", page=2, page_size=10)
assert result["page"] == 2
assert result["total_pages"] == 1 # 只有 1 条结果,1 页
def test_search_returns_snippets(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="test abstract")
assert result["total"] == 1
paper_id = result["results"][0].id
assert paper_id in result["snippets"]
snippet = result["snippets"][paper_id]
assert "abstract" in snippet
def test_get_all_tags(self, db_session, sample_paper):
from app.services.searcher import get_all_tags
tags = get_all_tags(db_session)
assert "NLP" in tags
assert "LLM" in tags
# ═══════════════════════════════════════════════════════════════════════
# 语义 / Embedder 模式测试
# ═══════════════════════════════════════════════════════════════════════
class TestSearchSemanticMode:
"""searcher.py — semantic 模式(含 embedder 回退)测试。"""
def test_keyword_mode_default(self, db_session, sample_papers_with_summary):
"""默认 keyword 模式走 FTS5。"""
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test Paper", mode="keyword")
assert result["total"] >= 1
assert result["distances"] == {}
def test_semantic_mode_disabled_fallback(self, db_session, monkeypatch, sample_papers_with_summary):
"""CHROMA_ENABLED=false + semantic 模式走 FTS5。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test", mode="semantic")
assert result["total"] >= 1
def test_search_returns_distances_dict(self, db_session, sample_papers_with_summary):
"""搜索结果应包含 distances 字段。"""
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test Paper")
assert "distances" in result
assert isinstance(result["distances"], dict)
def test_empty_query_returns_empty_no_tags(self, db_session):
"""空查询无标签时返回空。"""
from app.services.searcher import search_papers
result = search_papers(db_session)
assert result["total"] == 0
assert result["results"] == []
def test_tag_only_search(self, db_session, sample_papers_with_summary):
"""仅标签搜索。"""
from app.services.searcher import search_papers
result = search_papers(db_session, tag="NLP")
assert result["total"] >= 1
# ═══════════════════════════════════════════════════════════════════════
# 搜索路由 HTTP 测试
# ═══════════════════════════════════════════════════════════════════════
class TestSearchRoutes:
"""搜索页面和 JSON API 路由测试。"""
def test_search_page_renders(self, client):
"""GET /search 返回 200。"""
resp = client.get("/search")
assert resp.status_code == 200
assert "搜索" in resp.text
def test_search_page_with_query(self, client, sample_paper):
"""GET /search?q=Test 返回搜索结果。"""
resp = client.get("/search?q=Test")
assert resp.status_code == 200
assert "2401.12345" in resp.text
def test_search_page_with_tag(self, client, sample_paper):
"""GET /search?tag=NLP 返回标签筛选结果。"""
resp = client.get("/search?tag=NLP")
assert resp.status_code == 200
assert "2401.12345" in resp.text
def test_search_page_keyword_mode(self, client, sample_papers_with_summary):
"""搜索页 keyword 模式。"""
resp = client.get("/search?q=Test&mode=keyword")
assert resp.status_code == 200
assert "Test" in resp.text or "测试" in resp.text
def test_search_page_semantic_disabled(self, client, monkeypatch, sample_papers_with_summary):
"""语义模式 CHROMA_ENABLED=false 时仍能工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/search?q=Test&mode=semantic")
assert resp.status_code == 200
def test_search_api_json(self, client, sample_paper):
"""GET /api/search?q=Test 返回 JSON。"""
resp = client.get("/api/search?q=Test")
assert resp.status_code == 200
data = resp.json()
assert data["total"] >= 1
assert any(p["arxiv_id"] == "2401.12345" for p in data["results"])
def test_search_api_with_tag(self, client, sample_paper):
"""GET /api/search?q=Test&tag=NLP 返回筛选结果。"""
resp = client.get("/api/search?q=Test&tag=NLP")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
def test_search_api_with_mode(self, client, sample_papers_with_summary):
"""搜索 API 支持 mode 参数。"""
resp = client.get("/api/search?q=Test&mode=keyword")
assert resp.status_code == 200
data = resp.json()
assert "results" in data
assert "total" in data
def test_search_api_empty(self, client, sample_paper):
"""GET /api/search?q=nonexistent 返回空结果。"""
resp = client.get("/api/search?q=nonexistent")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 0
def test_search_api_sort_by_date(self, client, sample_paper):
"""GET /api/search?q=Test&sort=date 按日期排序。"""
resp = client.get("/api/search?q=Test&sort=date")
assert resp.status_code == 200
data = resp.json()
assert data["total"] >= 1
# ═══════════════════════════════════════════════════════════════════════
# Similar Paper API 测试
# ═══════════════════════════════════════════════════════════════════════
class TestSimilarAPI:
"""相似论文 API 测试。"""
def test_similar_api_disabled(self, client, monkeypatch, sample_papers_with_summary):
"""CHROMA_ENABLED=false 时返回空列表。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/api/similar/2401.20001")
assert resp.status_code == 200
data = resp.json()
assert data["results"] == []
def test_similar_api_paper_not_found(self, client, monkeypatch):
"""不存在的论文返回空。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/api/similar/nonexistent.99999")
assert resp.status_code == 200
assert resp.json()["results"] == []
def test_similar_api_with_top_k(self, client, monkeypatch, sample_papers_with_summary):
"""top_k 参数控制返回数量。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/api/similar/2401.20001?top_k=3")
assert resp.status_code == 200
# ═══════════════════════════════════════════════════════════════════════
# 阅读列表路由测试
# ═══════════════════════════════════════════════════════════════════════
class TestReadingListRoute:
"""阅读列表页面测试。"""
def test_reading_list_empty(self, client):
"""无收藏时显示空状态。"""
resp = client.get("/reading-list")
assert resp.status_code == 200
assert "阅读列表" in resp.text
def test_reading_list_with_bookmark(self, client, sample_paper):
"""有收藏时显示论文。"""
# 先收藏
client.post("/api/bookmark/2401.12345")
resp = client.get("/reading-list")
assert resp.status_code == 200
assert "2401.12345" in resp.text
def test_reading_list_filter_by_status(self, client, sample_paper):
"""按阅读状态筛选。"""
# 设置阅读状态
client.post(
"/api/reading-status/2401.12345",
json={"status": "read_summary"},
)
# 筛选 read_summary
resp = client.get("/reading-list?filter=read_summary")
assert resp.status_code == 200
assert "2401.12345" in resp.text
# 筛选 unread(不应出现,因为状态是 read_summary
resp2 = client.get("/reading-list?filter=unread")
assert resp2.status_code == 200
assert "2401.12345" not in resp2.text
def test_reading_list_has_note_filter(self, client, sample_paper):
"""筛选有笔记的论文。"""
# 写笔记
client.post(
"/api/note/2401.12345",
json={"content": "这是一条笔记"},
)
resp = client.get("/reading-list?filter=has_note")
assert resp.status_code == 200
assert "2401.12345" in resp.text
# ═══════════════════════════════════════════════════════════════════════
# RSS Feed 测试
# ═══════════════════════════════════════════════════════════════════════
class TestRssFeed:
"""RSS Feed 路由测试。"""
@pytest.fixture(autouse=True)
def _recent_paper(self, db_session, sample_paper):
"""将 sample_paper 的 paper_date 设为今天,确保在 RSS 7 天窗口内。"""
sample_paper.paper_date = date.today()
db_session.commit()
def test_rss_xml_structure(self, client, sample_paper):
"""GET /rss.xml 返回有效 XML。"""
resp = client.get("/rss.xml")
assert resp.status_code == 200
assert "application/xml" in resp.headers["content-type"]
assert "<?xml" in resp.text
assert "<rss" in resp.text
assert "<channel>" in resp.text
assert "2401.12345" in resp.text
def test_rss_has_paper_item(self, client, sample_paper):
"""RSS 包含论文条目。"""
resp = client.get("/rss.xml")
assert "<item>" in resp.text
assert "<title>" in resp.text
assert "/paper/2401.12345" in resp.text
def test_rss_with_tag_filter(self, client, sample_paper):
"""GET /rss.xml?tag=NLP 按标签筛选。"""
resp = client.get("/rss.xml?tag=NLP")
assert resp.status_code == 200
assert "2401.12345" in resp.text
resp2 = client.get("/rss.xml?tag=nonexistent")
assert resp2.status_code == 200
assert "2401.12345" not in resp2.text
def test_rss_uses_chinese_title(self, client, db_session, sample_paper):
"""RSS 使用中文标题(如果有的话)。"""
sample_paper.title_zh = "测试中文标题"
db_session.commit()
resp = client.get("/rss.xml")
assert resp.status_code == 200
assert "测试中文标题" in resp.text