feat: add compare, trends routes, embedder service, and phase5 tests

This commit is contained in:
2026-06-05 23:32:06 +08:00
parent 2cfd1a8a9f
commit ba9afa212c
17 changed files with 2122 additions and 27 deletions
+657
View File
@@ -0,0 +1,657 @@
"""Phase 5 后续增强测试 — embedder、semantic search、trends、compare、image extraction。"""
from __future__ import annotations
import json
import shutil
import time
from datetime import date, datetime, timezone
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import select
from app.config import settings
from app.database import get_db
from app.models import (
Paper,
PaperAuthor,
PaperSummary,
PaperTag,
SummaryStatus,
)
# ── Fixtures ────────────────────────────────────────────────────────────
ADMIN_TOKEN = "test-admin-token-12345"
@pytest.fixture
def admin_headers():
return {"Authorization": f"Bearer " + ADMIN_TOKEN}
@pytest.fixture
def auth_client(client, monkeypatch):
monkeypatch.setattr(settings, "ADMIN_TOKEN", ADMIN_TOKEN)
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
return client
@pytest.fixture
def sample_papers_with_summary(db_session):
"""插入多篇带总结的论文。"""
now = datetime.now(timezone.utc)
papers = []
for i, (arxiv_id, paper_date_str) in enumerate([
("2401.20001", "2024-01-10"),
("2401.20002", "2024-01-11"),
("2401.20003", "2024-01-12"),
("2401.20004", "2024-01-13"),
("2401.20005", "2024-01-14"),
]):
paper_date = date.fromisoformat(paper_date_str)
p = Paper(
arxiv_id=arxiv_id,
title_en=f"Test Paper {i+1}",
title_zh=f"测试论文 {i+1}",
abstract=f"Abstract for paper {i+1}.",
paper_date=paper_date,
crawled_at=now,
upvotes=i * 10 + 5,
)
db_session.add(p)
db_session.flush()
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
db_session.add(PaperTag(paper_id=p.id, tag="NLP", source="hf"))
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
db_session.add(SummaryStatus(
paper_id=p.id,
status="done" if i < 4 else "pending",
quality="normal",
))
# 添加总结(前 4 篇)
if i < 4:
from app.services.schemas import SummarySchema
summary = PaperSummary(
paper_id=p.id,
one_line=f"这是论文{i+1}的一句话摘要",
difficulty="中级",
motivation_problem=f"论文{i+1}的研究问题",
motivation_goal=f"论文{i+1}的研究目标",
method_key_idea=f"论文{i+1}的关键思路",
method_overview=f"论文{i+1}的方法概述",
updated_at=now,
full_json=json.dumps({"title_zh": f"测试论文 {i+1}"}),
)
db_session.add(summary)
# FTS5
import sqlalchemy
db_session.execute(
sqlalchemy.text(
"INSERT INTO papers_fts(rowid, title_en, title_zh, abstract, authors, tags) "
"VALUES (:id, :title_en, :title_zh, :abstract, :authors, :tags)"
),
{
"id": p.id,
"title_en": p.title_en,
"title_zh": p.title_zh or "",
"abstract": p.abstract or "",
"authors": f"Author {i+1}",
"tags": f"NLP, Tag{i+1}",
},
)
papers.append(p)
db_session.commit()
return papers
# ═══════════════════════════════════════════════════════════════════════
# Embedder 服务测试
# ═══════════════════════════════════════════════════════════════════════
class TestEmbedderInit:
"""embedder.py 初始化测试。"""
def test_chroma_disabled_skip_init(self, monkeypatch):
"""CHROMA_ENABLED=false 时不初始化。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
emb.init_chroma()
assert emb._client is None
def test_chroma_init_success(self, monkeypatch, tmp_path):
"""CHROMA_ENABLED=true 时初始化成功。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
import app.services.embedder as emb
emb._client = None
emb._collection = None
emb.init_chroma()
assert emb._client is not None
assert emb._collection is not None
# 清理
emb._client = None
emb._collection = None
def test_get_collection_returns_none_when_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 get_collection 返回 None。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
assert emb.get_collection() is None
class TestEmbedderIndexing:
"""embedder.py 索引测试。"""
def test_index_paper_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 index_paper 返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
assert emb.index_paper("test-id") is False
def test_index_paper_no_api_config(self, monkeypatch, tmp_path):
"""没有 EMBED_API_BASE 时返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
monkeypatch.setattr(settings, "EMBED_MODEL", "")
import app.services.embedder as emb
emb._client = None
emb._collection = None
emb.init_chroma()
result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"})
assert result is False
emb._client = None
emb._collection = None
def test_index_batch_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 index_batch 返回全失败。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
result = emb.index_batch(["a", "b"])
assert result["success"] == 0
assert result["failed"] == 2
def test_index_batch_empty(self, monkeypatch):
"""空列表时返回 0。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
result = emb.index_batch([])
assert result["total"] == 0
def test_delete_paper_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 delete_paper 返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
assert emb.delete_paper("test-id") is False
def test_search_similar_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 search_similar 返回空列表。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
assert emb.search_similar("test query") == []
class TestEmbeddingApi:
"""_get_embedding 测试。"""
def test_no_api_base_returns_none(self, monkeypatch):
"""EMBED_API_BASE 为空时返回 None。"""
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
monkeypatch.setattr(settings, "EMBED_MODEL", "")
import app.services.embedder as emb
assert emb._get_embedding("test") is None
def test_dimension_mismatch_returns_none(self, monkeypatch):
"""维度不匹配时返回 None。"""
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 128)
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
import app.services.embedder as emb
mock_resp = MagicMock()
mock_resp.json.return_value = {"data": [{"embedding": [0.1] * 64}]}
mock_resp.raise_for_status = MagicMock()
with patch("httpx.Client") as mock_client:
mock_client.return_value.__enter__ = MagicMock(return_value=mock_resp)
mock_client.return_value.__exit__ = MagicMock(return_value=False)
result = emb._get_embedding("test")
assert result is None
def test_api_failure_returns_none(self, monkeypatch):
"""API 调用失败时返回 None。"""
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 0)
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
import app.services.embedder as emb
with patch("httpx.Client") as mock_client:
mock_client.return_value.__enter__ = MagicMock()
mock_client.return_value.__exit__ = MagicMock(return_value=False)
mock_client.return_value.__enter__.return_value.post.side_effect = Exception("timeout")
result = emb._get_embedding("test")
assert result is None
# ═══════════════════════════════════════════════════════════════════════
# Searcher 语义模式测试
# ═══════════════════════════════════════════════════════════════════════
class TestSearchSemanticMode:
"""searcher.py 语义搜索模式测试。"""
def test_keyword_mode_default(self, db_session, sample_papers_with_summary):
"""默认 keyword 模式走 FTS5。"""
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test Paper", mode="keyword")
assert result["total"] >= 1
assert result["distances"] == {}
def test_semantic_mode_disabled_fallback(self, db_session, monkeypatch, sample_papers_with_summary):
"""CHROMA_ENABLED=false + semantic 模式走 FTS5。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test", mode="semantic")
# 应回退到 FTS5
assert result["total"] >= 1
def test_search_returns_distances_dict(self, db_session, sample_papers_with_summary):
"""搜索结果应包含 distances 字段。"""
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test Paper")
assert "distances" in result
assert isinstance(result["distances"], dict)
def test_empty_query_returns_empty(self, db_session):
"""空查询无标签时返回空。"""
from app.services.searcher import search_papers
result = search_papers(db_session)
assert result["total"] == 0
assert result["results"] == []
def test_tag_only_search(self, db_session, sample_papers_with_summary):
"""仅标签搜索。"""
from app.services.searcher import search_papers
result = search_papers(db_session, tag="NLP")
assert result["total"] >= 1
# ═══════════════════════════════════════════════════════════════════════
# Search Routes 测试
# ═══════════════════════════════════════════════════════════════════════
class TestSearchRoutes:
"""搜索路由测试。"""
def test_search_page_keyword(self, auth_client, sample_papers_with_summary):
"""搜索页 keyword 模式。"""
resp = auth_client.get("/search?q=Test&mode=keyword")
assert resp.status_code == 200
assert "Test" in resp.text or "测试" in resp.text
def test_search_page_semantic_disabled(self, auth_client, monkeypatch, sample_papers_with_summary):
"""语义模式 CHROMA_ENABLED=false 时仍能工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/search?q=Test&mode=semantic")
assert resp.status_code == 200
def test_search_api_with_mode(self, auth_client, sample_papers_with_summary):
"""搜索 API 支持 mode 参数。"""
resp = auth_client.get("/api/search?q=Test&mode=keyword")
assert resp.status_code == 200
data = resp.json()
assert "results" in data
assert "total" in data
# ═══════════════════════════════════════════════════════════════════════
# Similar Paper API 测试
# ═══════════════════════════════════════════════════════════════════════
class TestSimilarAPI:
"""相似论文 API 测试。"""
def test_similar_api_disabled(self, auth_client, monkeypatch, sample_papers_with_summary):
"""CHROMA_ENABLED=false 时返回空列表。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/api/similar/2401.20001")
assert resp.status_code == 200
data = resp.json()
assert data["results"] == []
def test_similar_api_paper_not_found(self, auth_client, monkeypatch):
"""不存在的论文返回空。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/api/similar/nonexistent.99999")
assert resp.status_code == 200
assert resp.json()["results"] == []
def test_similar_api_with_top_k(self, auth_client, monkeypatch, sample_papers_with_summary):
"""top_k 参数控制返回数量。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/api/similar/2401.20001?top_k=3")
assert resp.status_code == 200
# ═══════════════════════════════════════════════════════════════════════
# Detail Page 相似论文测试
# ═══════════════════════════════════════════════════════════════════════
class TestDetailSimilarPapers:
"""详情页相似论文模块测试。"""
def test_detail_page_renders(self, auth_client, sample_papers_with_summary):
"""详情页正常渲染。"""
resp = auth_client.get("/paper/2401.20001")
assert resp.status_code == 200
assert "测试论文" in resp.text or "Test Paper" in resp.text
def test_detail_page_not_found(self, auth_client):
"""不存在的论文返回 404。"""
resp = auth_client.get("/paper/nonexistent.99999")
assert resp.status_code == 404
# ═══════════════════════════════════════════════════════════════════════
# Trends Dashboard 测试
# ═══════════════════════════════════════════════════════════════════════
class TestTrendsDashboard:
"""趋势看板测试。"""
def test_trends_page_renders(self, auth_client, sample_papers_with_summary):
"""趋势看板页面正常渲染。"""
resp = auth_client.get("/trends")
assert resp.status_code == 200
assert "趋势看板" in resp.text
assert "chart" in resp.text.lower() or "Chart" in resp.text
def test_trends_api_returns_data(self, auth_client, sample_papers_with_summary):
"""趋势 API 返回正确数据结构。"""
resp = auth_client.get("/api/stats/trends")
assert resp.status_code == 200
data = resp.json()
assert "daily_counts" in data
assert "top_tags" in data
assert "upvotes_dist" in data
assert "summary_completion" in data
assert isinstance(data["daily_counts"], list)
assert isinstance(data["top_tags"], list)
assert isinstance(data["upvotes_dist"], list)
assert isinstance(data["summary_completion"], list)
def test_trends_api_daily_counts(self, auth_client, sample_papers_with_summary, monkeypatch):
"""每日论文数量数据正确。"""
# 使用测试数据的日期范围
from unittest.mock import patch as upatch
import app.routes.trends as trends_mod
# monkeypatch _get_trends_data 中的 date.today
with upatch("app.routes.trends.date") as mock_date:
mock_date.today.return_value = date(2024, 1, 20)
mock_date.side_effect = lambda *a, **kw: date(*a, **kw)
resp = auth_client.get("/api/stats/trends")
data = resp.json()
assert len(data["daily_counts"]) == 5
for item in data["daily_counts"]:
assert "date" in item
assert "count" in item
assert item["count"] == 1
def test_trends_api_top_tags(self, auth_client, sample_papers_with_summary):
"""热门标签数据正确。"""
resp = auth_client.get("/api/stats/trends")
data = resp.json()
tags = {t["tag"]: t["count"] for t in data["top_tags"]}
assert "NLP" in tags
assert tags["NLP"] == 5 # 所有论文都有 NLP
def test_trends_api_summary_completion(self, auth_client, sample_papers_with_summary):
"""总结完成率数据正确。"""
resp = auth_client.get("/api/stats/trends")
data = resp.json()
statuses = {s["status"]: s["count"] for s in data["summary_completion"]}
assert "done" in statuses
assert statuses["done"] == 4 # 4 篇已完成
def test_trends_empty_db(self, auth_client):
"""无数据时不崩溃。"""
resp = auth_client.get("/api/stats/trends")
assert resp.status_code == 200
data = resp.json()
assert data["daily_counts"] == []
assert data["top_tags"] == []
# ═══════════════════════════════════════════════════════════════════════
# Compare Page 测试
# ═══════════════════════════════════════════════════════════════════════
class TestComparePage:
"""论文对比页测试。"""
def test_compare_page_no_ids(self, auth_client):
"""无 ID 时显示输入表单。"""
resp = auth_client.get("/compare")
assert resp.status_code == 200
assert "对比" in resp.text
def test_compare_page_with_ids(self, auth_client, sample_papers_with_summary):
"""对比多篇论文正常渲染。"""
resp = auth_client.get("/compare?ids=2401.20001,2401.20002")
assert resp.status_code == 200
assert "2401.20001" in resp.text
assert "2401.20002" in resp.text
# 应包含对比字段
assert "一句话摘要" in resp.text
assert "研究问题" in resp.text
def test_compare_page_max_5(self, auth_client, sample_papers_with_summary):
"""最多 5 篇。"""
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005"
resp = auth_client.get(f"/compare?ids={ids}")
assert resp.status_code == 200
def test_compare_page_over_5_truncates(self, auth_client, sample_papers_with_summary):
"""超过 5 篇截断。"""
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005,2401.20006"
resp = auth_client.get(f"/compare?ids={ids}")
assert resp.status_code == 200
# 不应包含第 6 篇(不存在)
def test_compare_page_invalid_ids(self, auth_client):
"""无效 ID 时显示空结果。"""
resp = auth_client.get("/compare?ids=nonexistent.99999")
assert resp.status_code == 200
# 不存在的论文
assert "未找到" in resp.text or "暂无" in resp.text or resp.status_code == 200
def test_compare_page_shows_no_summary_placeholder(self, auth_client, sample_papers_with_summary):
"""无总结的论文显示占位文本。"""
# 2401.20005 没有 summarystatus=pending
resp = auth_client.get("/compare?ids=2401.20005")
assert resp.status_code == 200
assert "暂无总结" in resp.text
# ═══════════════════════════════════════════════════════════════════════
# Image Extraction 测试
# ═══════════════════════════════════════════════════════════════════════
class TestImageExtraction:
"""LaTeX 图片提取测试。"""
@pytest.mark.asyncio
async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path):
"""源码目录不存在时返回 0。"""
monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x)
from app.services.summarizer import _extract_images_from_source
result = await _extract_images_from_source("2401.99999")
assert result == 0
@pytest.mark.asyncio
async def test_extract_images_from_tex(self, monkeypatch, tmp_path):
"""从 .tex 文件中提取图片。"""
tmp_source = tmp_path / "tmp" / "2401.00001" / "source"
tmp_source.mkdir(parents=True)
images_dir = tmp_source / "figs"
images_dir.mkdir()
(images_dir / "figure1.png").write_bytes(b"\x89PNG\r\n")
(images_dir / "figure2.jpg").write_bytes(b"\xff\xd8\xff\xe0")
# 创建 .tex 文件
tex_content = r"""
\documentclass{article}
\begin{document}
\begin{figure}
\includegraphics[width=0.8\textwidth]{figs/figure1.png}
\includegraphics{figs/figure2.jpg}
\includegraphics[angle=90]{figs/nonexistent.pdf}
\end{figure}
\end{document}
"""
(tmp_source / "main.tex").write_text(tex_content)
papers_dir = tmp_path / "papers" / "2401.00001"
monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x)
from app.services.summarizer import _extract_images_from_source
result = await _extract_images_from_source("2401.00001")
assert result == 2
dest_images = papers_dir / "images"
assert dest_images.exists()
assert (dest_images / "figure1.png").exists()
assert (dest_images / "figure2.jpg").exists()
@pytest.mark.asyncio
async def test_extract_images_empty_tex(self, monkeypatch, tmp_path):
""".tex 文件无图片时返回 0。"""
tmp_source = tmp_path / "tmp" / "2401.00002" / "source"
tmp_source.mkdir(parents=True)
(tmp_source / "main.tex").write_text(r"\documentclass{article}\begin{document}Hello\end{document}")
monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x)
from app.services.summarizer import _extract_images_from_source
result = await _extract_images_from_source("2401.00002")
assert result == 0
# ═══════════════════════════════════════════════════════════════════════
# Nav Bar 测试
# ═══════════════════════════════════════════════════════════════════════
class TestNavBar:
"""导航栏测试。"""
def test_nav_includes_trends_link(self, auth_client):
"""导航栏应包含趋势链接。"""
resp = auth_client.get("/search")
assert resp.status_code == 200
assert "/trends" in resp.text
def test_nav_includes_compare_implicitly(self, auth_client):
"""compare 页面可访问。"""
resp = auth_client.get("/compare")
assert resp.status_code == 200
# ═══════════════════════════════════════════════════════════════════════
# Graceful Degradation 测试
# ═══════════════════════════════════════════════════════════════════════
class TestGracefulDegradation:
"""CHROMA_ENABLED=false 时优雅降级测试。"""
def test_search_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时 FTS5 搜索正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/search?q=Test")
assert resp.status_code == 200
assert "Test Paper" in resp.text or "测试论文" in resp.text
def test_detail_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时详情页正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/paper/2401.20001")
assert resp.status_code == 200
def test_trends_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时趋势看板正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/trends")
assert resp.status_code == 200
def test_compare_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时对比页正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/compare?ids=2401.20001,2401.20002")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_cleaner_works_without_chroma(self, db_session, sample_papers_with_summary, monkeypatch):
"""CHROMA 关闭时删除论文正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
from app.services.cleaner import delete_papers_by_date_range
result = await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 10),
)
assert result["status"] == "success"
assert result["deleted"] == 1