feat: add search and user data routes, services, and tests
This commit is contained in:
@@ -11,6 +11,7 @@ import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database import get_db
|
||||
from app.main import create_app
|
||||
@@ -43,6 +44,7 @@ def db_engine():
|
||||
engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
|
||||
@@ -0,0 +1,267 @@
|
||||
"""搜索、阅读列表、RSS Feed 测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from datetime import date, datetime, timezone
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 搜索服务单元测试
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestSearchService:
|
||||
"""app/services/searcher.py 单元测试。"""
|
||||
|
||||
def test_search_by_title(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="Test Paper")
|
||||
assert result["total"] == 1
|
||||
assert result["results"][0].arxiv_id == "2401.12345"
|
||||
|
||||
def test_search_by_abstract(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="test abstract")
|
||||
assert result["total"] == 1
|
||||
|
||||
def test_search_by_author(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="Alice")
|
||||
assert result["total"] == 1
|
||||
|
||||
def test_search_by_tag_in_fts(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
# FTS5 索引中包含 tags 列,可以搜到
|
||||
result = search_papers(db_session, query="NLP")
|
||||
assert result["total"] == 1
|
||||
|
||||
def test_search_no_results(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="quantum entanglement")
|
||||
assert result["total"] == 0
|
||||
assert result["results"] == []
|
||||
|
||||
def test_search_empty_query_returns_empty(self, db_session):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="")
|
||||
assert result["total"] == 0
|
||||
assert result["results"] == []
|
||||
|
||||
def test_search_special_characters_sanitized(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
# 特殊字符被清除后,剩下 "Test" 仍然能搜到
|
||||
result = search_papers(db_session, query='Test "Paper" {test}')
|
||||
assert result["total"] >= 1
|
||||
|
||||
def test_search_with_tag_filter(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
# 关键词 + 标签筛选
|
||||
result = search_papers(db_session, query="Paper", tag="NLP")
|
||||
assert result["total"] == 1
|
||||
|
||||
# 标签不匹配 → 0
|
||||
result2 = search_papers(db_session, query="Paper", tag="nonexistent")
|
||||
assert result2["total"] == 0
|
||||
|
||||
def test_search_tag_only_no_query(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
# 只有标签,无关键词
|
||||
result = search_papers(db_session, tag="NLP")
|
||||
assert result["total"] == 1
|
||||
assert result["results"][0].arxiv_id == "2401.12345"
|
||||
|
||||
def test_search_pagination(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="Test", page=2, page_size=10)
|
||||
assert result["page"] == 2
|
||||
assert result["total_pages"] == 1 # 只有 1 条结果,1 页
|
||||
|
||||
def test_search_returns_snippets(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="test abstract")
|
||||
assert result["total"] == 1
|
||||
paper_id = result["results"][0].id
|
||||
assert paper_id in result["snippets"]
|
||||
snippet = result["snippets"][paper_id]
|
||||
assert "abstract" in snippet
|
||||
|
||||
def test_get_all_tags(self, db_session, sample_paper):
|
||||
from app.services.searcher import get_all_tags
|
||||
|
||||
tags = get_all_tags(db_session)
|
||||
assert "NLP" in tags
|
||||
assert "LLM" in tags
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 搜索路由 HTTP 测试
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestSearchRoutes:
|
||||
"""搜索页面和 JSON API 路由测试。"""
|
||||
|
||||
def test_search_page_renders(self, client):
|
||||
"""GET /search 返回 200。"""
|
||||
resp = client.get("/search")
|
||||
assert resp.status_code == 200
|
||||
assert "搜索" in resp.text
|
||||
|
||||
def test_search_page_with_query(self, client, sample_paper):
|
||||
"""GET /search?q=Test 返回搜索结果。"""
|
||||
resp = client.get("/search?q=Test")
|
||||
assert resp.status_code == 200
|
||||
assert "2401.12345" in resp.text
|
||||
|
||||
def test_search_page_with_tag(self, client, sample_paper):
|
||||
"""GET /search?tag=NLP 返回标签筛选结果。"""
|
||||
resp = client.get("/search?tag=NLP")
|
||||
assert resp.status_code == 200
|
||||
assert "2401.12345" in resp.text
|
||||
|
||||
def test_search_api_json(self, client, sample_paper):
|
||||
"""GET /api/search?q=Test 返回 JSON。"""
|
||||
resp = client.get("/api/search?q=Test")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] >= 1
|
||||
assert any(p["arxiv_id"] == "2401.12345" for p in data["results"])
|
||||
|
||||
def test_search_api_with_tag(self, client, sample_paper):
|
||||
"""GET /api/search?q=Test&tag=NLP 返回筛选结果。"""
|
||||
resp = client.get("/api/search?q=Test&tag=NLP")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] == 1
|
||||
|
||||
def test_search_api_empty(self, client, sample_paper):
|
||||
"""GET /api/search?q=nonexistent 返回空结果。"""
|
||||
resp = client.get("/api/search?q=nonexistent")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] == 0
|
||||
|
||||
def test_search_api_sort_by_date(self, client, sample_paper):
|
||||
"""GET /api/search?q=Test&sort=date 按日期排序。"""
|
||||
resp = client.get("/api/search?q=Test&sort=date")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] >= 1
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 阅读列表路由测试
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestReadingListRoute:
|
||||
"""阅读列表页面测试。"""
|
||||
|
||||
def test_reading_list_empty(self, client):
|
||||
"""无收藏时显示空状态。"""
|
||||
resp = client.get("/reading-list")
|
||||
assert resp.status_code == 200
|
||||
assert "阅读列表" in resp.text
|
||||
|
||||
def test_reading_list_with_bookmark(self, client, sample_paper):
|
||||
"""有收藏时显示论文。"""
|
||||
# 先收藏
|
||||
client.post("/api/bookmark/2401.12345")
|
||||
resp = client.get("/reading-list")
|
||||
assert resp.status_code == 200
|
||||
assert "2401.12345" in resp.text
|
||||
|
||||
def test_reading_list_filter_by_status(self, client, sample_paper):
|
||||
"""按阅读状态筛选。"""
|
||||
import json
|
||||
|
||||
# 设置阅读状态
|
||||
client.post(
|
||||
"/api/reading-status/2401.12345",
|
||||
json={"status": "read_summary"},
|
||||
)
|
||||
# 筛选 read_summary
|
||||
resp = client.get("/reading-list?filter=read_summary")
|
||||
assert resp.status_code == 200
|
||||
assert "2401.12345" in resp.text
|
||||
|
||||
# 筛选 unread(不应出现,因为状态是 read_summary)
|
||||
resp2 = client.get("/reading-list?filter=unread")
|
||||
assert resp2.status_code == 200
|
||||
assert "2401.12345" not in resp2.text
|
||||
|
||||
def test_reading_list_has_note_filter(self, client, sample_paper):
|
||||
"""筛选有笔记的论文。"""
|
||||
import json
|
||||
|
||||
# 写笔记
|
||||
client.post(
|
||||
"/api/note/2401.12345",
|
||||
json={"content": "这是一条笔记"},
|
||||
)
|
||||
resp = client.get("/reading-list?filter=has_note")
|
||||
assert resp.status_code == 200
|
||||
assert "2401.12345" in resp.text
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# RSS Feed 测试
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestRssFeed:
|
||||
"""RSS Feed 路由测试。"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _recent_paper(self, db_session, sample_paper):
|
||||
"""将 sample_paper 的 paper_date 设为今天,确保在 RSS 7 天窗口内。"""
|
||||
sample_paper.paper_date = date.today()
|
||||
db_session.commit()
|
||||
|
||||
def test_rss_xml_structure(self, client, sample_paper):
|
||||
"""GET /rss.xml 返回有效 XML。"""
|
||||
resp = client.get("/rss.xml")
|
||||
assert resp.status_code == 200
|
||||
assert "application/xml" in resp.headers["content-type"]
|
||||
assert "<?xml" in resp.text
|
||||
assert "<rss" in resp.text
|
||||
assert "<channel>" in resp.text
|
||||
assert "2401.12345" in resp.text
|
||||
|
||||
def test_rss_has_paper_item(self, client, sample_paper):
|
||||
"""RSS 包含论文条目。"""
|
||||
resp = client.get("/rss.xml")
|
||||
assert "<item>" in resp.text
|
||||
assert "<title>" in resp.text
|
||||
assert "/paper/2401.12345" in resp.text
|
||||
|
||||
def test_rss_with_tag_filter(self, client, sample_paper):
|
||||
"""GET /rss.xml?tag=NLP 按标签筛选。"""
|
||||
resp = client.get("/rss.xml?tag=NLP")
|
||||
assert resp.status_code == 200
|
||||
assert "2401.12345" in resp.text
|
||||
|
||||
resp2 = client.get("/rss.xml?tag=nonexistent")
|
||||
assert resp2.status_code == 200
|
||||
assert "2401.12345" not in resp2.text
|
||||
|
||||
def test_rss_uses_chinese_title(self, client, db_session, sample_paper):
|
||||
"""RSS 使用中文标题(如果有的话)。"""
|
||||
sample_paper.title_zh = "测试中文标题"
|
||||
db_session.commit()
|
||||
|
||||
resp = client.get("/rss.xml")
|
||||
assert resp.status_code == 200
|
||||
assert "测试中文标题" in resp.text
|
||||
@@ -0,0 +1,214 @@
|
||||
"""用户数据服务 + 路由测试 — 收藏、阅读状态、笔记。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 收藏服务测试
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestBookmarkService:
|
||||
def test_toggle_bookmark_add(self, db_session, sample_paper):
|
||||
from app.services.user_data import toggle_bookmark
|
||||
|
||||
result = toggle_bookmark(db_session, "2401.12345")
|
||||
assert result["bookmarked"] is True
|
||||
assert result["arxiv_id"] == "2401.12345"
|
||||
|
||||
def test_toggle_bookmark_remove(self, db_session, sample_paper):
|
||||
from app.services.user_data import toggle_bookmark
|
||||
|
||||
toggle_bookmark(db_session, "2401.12345") # 添加
|
||||
result = toggle_bookmark(db_session, "2401.12345") # 移除
|
||||
assert result["bookmarked"] is False
|
||||
|
||||
def test_toggle_bookmark_not_found(self, db_session):
|
||||
from app.services.user_data import toggle_bookmark
|
||||
|
||||
result = toggle_bookmark(db_session, "nonexistent")
|
||||
assert "error" in result
|
||||
assert result["error"] == "not_found"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 阅读状态服务测试
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestReadingStatusService:
|
||||
def test_set_reading_status(self, db_session, sample_paper):
|
||||
from app.services.user_data import set_reading_status
|
||||
|
||||
result = set_reading_status(db_session, "2401.12345", "read_summary")
|
||||
assert result["status"] == "read_summary"
|
||||
assert result["arxiv_id"] == "2401.12345"
|
||||
|
||||
def test_set_reading_status_invalid(self, db_session, sample_paper):
|
||||
from app.services.user_data import set_reading_status
|
||||
|
||||
result = set_reading_status(db_session, "2401.12345", "invalid_status")
|
||||
assert "error" in result
|
||||
assert result["error"] == "invalid_status"
|
||||
|
||||
def test_update_existing_status(self, db_session, sample_paper):
|
||||
from app.services.user_data import set_reading_status
|
||||
|
||||
set_reading_status(db_session, "2401.12345", "skimmed")
|
||||
result = set_reading_status(db_session, "2401.12345", "read_full")
|
||||
assert result["status"] == "read_full"
|
||||
|
||||
def test_set_reading_status_not_found(self, db_session):
|
||||
from app.services.user_data import set_reading_status
|
||||
|
||||
result = set_reading_status(db_session, "nonexistent", "unread")
|
||||
assert "error" in result
|
||||
assert result["error"] == "not_found"
|
||||
|
||||
def test_all_valid_statuses(self, db_session, sample_paper):
|
||||
from app.services.user_data import set_reading_status
|
||||
|
||||
for status in ("unread", "skimmed", "read_summary", "read_full"):
|
||||
result = set_reading_status(db_session, "2401.12345", status)
|
||||
assert result["status"] == status
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 笔记服务测试
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestNoteService:
|
||||
def test_save_and_get_note(self, db_session, sample_paper):
|
||||
from app.services.user_data import get_note, save_note
|
||||
|
||||
save_note(db_session, "2401.12345", "这是一条测试笔记")
|
||||
result = get_note(db_session, "2401.12345")
|
||||
assert result["content"] == "这是一条测试笔记"
|
||||
assert result["arxiv_id"] == "2401.12345"
|
||||
assert result["updated_at"] is not None
|
||||
|
||||
def test_update_note(self, db_session, sample_paper):
|
||||
from app.services.user_data import get_note, save_note
|
||||
|
||||
save_note(db_session, "2401.12345", "旧笔记")
|
||||
save_note(db_session, "2401.12345", "新笔记")
|
||||
result = get_note(db_session, "2401.12345")
|
||||
assert result["content"] == "新笔记"
|
||||
|
||||
def test_get_note_empty(self, db_session, sample_paper):
|
||||
from app.services.user_data import get_note
|
||||
|
||||
result = get_note(db_session, "2401.12345")
|
||||
assert result["content"] == ""
|
||||
assert result["updated_at"] is None
|
||||
|
||||
def test_get_note_paper_not_found(self, db_session):
|
||||
from app.services.user_data import get_note
|
||||
|
||||
result = get_note(db_session, "nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_save_note_paper_not_found(self, db_session):
|
||||
from app.services.user_data import save_note
|
||||
|
||||
result = save_note(db_session, "nonexistent", "内容")
|
||||
assert "error" in result
|
||||
assert result["error"] == "not_found"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 用户数据路由 HTTP 测试
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestUserDataRoutes:
|
||||
"""HTTP 级别的用户数据 API 测试。"""
|
||||
|
||||
def test_bookmark_toggle_api(self, client, sample_paper):
|
||||
"""POST /api/bookmark/{arxiv_id} 切换收藏。"""
|
||||
resp = client.post("/api/bookmark/2401.12345")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["bookmarked"] is True
|
||||
|
||||
# 再次切换 → 取消
|
||||
resp2 = client.post("/api/bookmark/2401.12345")
|
||||
assert resp2.status_code == 200
|
||||
data2 = resp2.json()
|
||||
assert data2["bookmarked"] is False
|
||||
|
||||
def test_bookmark_htmx_returns_html(self, client, sample_paper):
|
||||
"""HTMX 请求返回 HTML 片段。"""
|
||||
headers = {"HX-Request": "true"}
|
||||
resp = client.post("/api/bookmark/2401.12345", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
assert "btn-bookmark" in resp.text
|
||||
assert "★" in resp.text
|
||||
|
||||
def test_bookmark_not_found(self, client):
|
||||
"""收藏不存在的论文返回 404。"""
|
||||
resp = client.post("/api/bookmark/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_reading_status_api(self, client, sample_paper):
|
||||
"""POST /api/reading-status/{arxiv_id} 更新状态。"""
|
||||
resp = client.post(
|
||||
"/api/reading-status/2401.12345",
|
||||
json={"status": "read_summary"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "read_summary"
|
||||
|
||||
def test_reading_status_invalid(self, client, sample_paper):
|
||||
"""无效状态返回 422。"""
|
||||
resp = client.post(
|
||||
"/api/reading-status/2401.12345",
|
||||
json={"status": "invalid"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_reading_status_not_found(self, client):
|
||||
"""不存在的论文返回 404。"""
|
||||
resp = client.post(
|
||||
"/api/reading-status/nonexistent",
|
||||
json={"status": "unread"},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_note_get_api(self, client, sample_paper):
|
||||
"""GET /api/note/{arxiv_id} 返回笔记。"""
|
||||
resp = client.get("/api/note/2401.12345")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["content"] == ""
|
||||
|
||||
def test_note_save_api(self, client, sample_paper):
|
||||
"""POST /api/note/{arxiv_id} 保存笔记。"""
|
||||
resp = client.post(
|
||||
"/api/note/2401.12345",
|
||||
json={"content": "Markdown **笔记**"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["content"] == "Markdown **笔记**"
|
||||
assert data["updated_at"] is not None
|
||||
|
||||
# 读取确认
|
||||
resp2 = client.get("/api/note/2401.12345")
|
||||
assert resp2.json()["content"] == "Markdown **笔记**"
|
||||
|
||||
def test_note_not_found(self, client):
|
||||
"""不存在的论文返回 404。"""
|
||||
resp = client.get("/api/note/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
resp2 = client.post(
|
||||
"/api/note/nonexistent",
|
||||
json={"content": "test"},
|
||||
)
|
||||
assert resp2.status_code == 404
|
||||
Reference in New Issue
Block a user