diff --git a/.gitignore b/.gitignore index 77db4a2..7c658a8 100644 --- a/.gitignore +++ b/.gitignore @@ -2,10 +2,7 @@ __pycache__/ *.pyc *.pyo -data/db/*.db -data/papers/ -data/tmp/ -data/chroma/ +data/ logs/*.log .venv/ venv/ diff --git a/tests/conftest.py b/tests/conftest.py index ef4895b..1090322 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -209,3 +209,122 @@ def admin_token(): def admin_headers(admin_token): """带 Bearer token 的请求头。""" return {"Authorization": f"Bearer {admin_token}"} + + +@pytest.fixture +def wrong_admin_headers(): + """错误的 Authorization 请求头。""" + return {"Authorization": "Bearer wrong-token"} + + +# ── 多样例数据 ──────────────────────────────────────────────────────────── + + +@pytest.fixture +def sample_papers_range(db_session): + """插入 5 篇不同日期的论文(用于 admin / cleaner 测试)。""" + now = datetime.now(timezone.utc) + papers = [] + for i, (arxiv_id, paper_date_str) in enumerate([ + ("2401.10001", "2024-01-10"), + ("2401.10002", "2024-01-11"), + ("2401.10003", "2024-01-12"), + ("2401.10004", "2024-01-13"), + ("2401.10005", "2024-01-14"), + ]): + paper_date = date.fromisoformat(paper_date_str) + p = Paper( + arxiv_id=arxiv_id, + title_en=f"Test Paper {i+1}", + abstract=f"Abstract for paper {i+1}.", + paper_date=paper_date, + crawled_at=now, + upvotes=i * 10, + ) + db_session.add(p) + db_session.flush() + db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0)) + db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf")) + db_session.add(SummaryStatus(paper_id=p.id, status="pending")) + # FTS5 + db_session.execute( + __import__("sqlalchemy").text( + "INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) " + "VALUES (:id, :title, :abstract, :authors, :tags)" + ), + {"id": p.id, "title": p.title_en, "abstract": p.abstract, + "authors": f"Author {i+1}", "tags": f"Tag{i+1}"}, + ) + papers.append(p) + db_session.commit() + return papers + + +@pytest.fixture +def sample_papers_with_summary(db_session): + """插入 5 篇带总结的论文(用于 search / pages / trends 测试)。""" + now = datetime.now(timezone.utc) + papers = [] + for i, (arxiv_id, paper_date_str) in enumerate([ + ("2401.20001", "2024-01-10"), + ("2401.20002", "2024-01-11"), + ("2401.20003", "2024-01-12"), + ("2401.20004", "2024-01-13"), + ("2401.20005", "2024-01-14"), + ]): + paper_date = date.fromisoformat(paper_date_str) + p = Paper( + arxiv_id=arxiv_id, + title_en=f"Test Paper {i+1}", + title_zh=f"测试论文 {i+1}", + abstract=f"Abstract for paper {i+1}.", + paper_date=paper_date, + crawled_at=now, + upvotes=i * 10 + 5, + ) + db_session.add(p) + db_session.flush() + + db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0)) + db_session.add(PaperTag(paper_id=p.id, tag="NLP", source="hf")) + db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf")) + + db_session.add(SummaryStatus( + paper_id=p.id, + status="done" if i < 4 else "pending", + quality="normal", + )) + + # 添加总结(前 4 篇) + if i < 4: + summary = PaperSummary( + paper_id=p.id, + one_line=f"这是论文{i+1}的一句话摘要", + difficulty="中级", + motivation_problem=f"论文{i+1}的研究问题", + motivation_goal=f"论文{i+1}的研究目标", + method_key_idea=f"论文{i+1}的关键思路", + method_overview=f"论文{i+1}的方法概述", + updated_at=now, + full_json=json.dumps({"title_zh": f"测试论文 {i+1}"}), + ) + db_session.add(summary) + + # FTS5 + db_session.execute( + __import__("sqlalchemy").text( + "INSERT INTO papers_fts(rowid, title_en, title_zh, abstract, authors, tags) " + "VALUES (:id, :title_en, :title_zh, :abstract, :authors, :tags)" + ), + { + "id": p.id, + "title_en": p.title_en, + "title_zh": p.title_zh or "", + "abstract": p.abstract or "", + "authors": f"Author {i+1}", + "tags": f"NLP, Tag{i+1}", + }, + ) + papers.append(p) + db_session.commit() + return papers diff --git a/tests/test_admin_phase4.py b/tests/test_admin.py similarity index 53% rename from tests/test_admin_phase4.py rename to tests/test_admin.py index 305e608..8c77da5 100644 --- a/tests/test_admin_phase4.py +++ b/tests/test_admin.py @@ -1,350 +1,36 @@ -"""Phase 4 管理和自动化测试 — cleaner、admin routes、scheduler。""" +"""管理接口测试 — admin routes、auth、scheduler、task locks。""" from __future__ import annotations -import os -import shutil -import time +import logging from datetime import date, datetime, timezone -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest -from fastapi.testclient import TestClient from sqlalchemy import select -from app.database import get_db from app.config import settings from app.models import ( CrawlLog, - DataDeleteJob, - Paper, - PaperAuthor, - PaperSummary, - PaperTag, - SummaryStatus, TaskLock, - UserBookmark, - UserNote, - UserReadingStatus, ) + # ── Fixtures ──────────────────────────────────────────────────────────── ADMIN_TOKEN = "test-admin-token-12345" -@pytest.fixture -def admin_headers(): - return {"Authorization": f"Bearer {ADMIN_TOKEN}"} - - -@pytest.fixture -def wrong_admin_headers(): - return {"Authorization": "Bearer wrong-token"} - - @pytest.fixture def auth_client(client, monkeypatch): """带 admin token monkeypatch 的 TestClient。""" monkeypatch.setattr(settings, "ADMIN_TOKEN", ADMIN_TOKEN) + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) return client -@pytest.fixture -def sample_papers(db_session): - """插入多篇不同日期的论文。""" - now = datetime.now(timezone.utc) - papers = [] - for i, (arxiv_id, paper_date_str) in enumerate([ - ("2401.10001", "2024-01-10"), - ("2401.10002", "2024-01-11"), - ("2401.10003", "2024-01-12"), - ("2401.10004", "2024-01-13"), - ("2401.10005", "2024-01-14"), - ]): - paper_date = date.fromisoformat(paper_date_str) - p = Paper( - arxiv_id=arxiv_id, - title_en=f"Test Paper {i+1}", - abstract=f"Abstract for paper {i+1}.", - paper_date=paper_date, - crawled_at=now, - upvotes=i * 10, - ) - db_session.add(p) - db_session.flush() - db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0)) - db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf")) - db_session.add(SummaryStatus(paper_id=p.id, status="pending")) - # FTS5 - import sqlalchemy - db_session.execute( - sqlalchemy.text( - "INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) " - "VALUES (:id, :title, :abstract, :authors, :tags)" - ), - {"id": p.id, "title": p.title_en, "abstract": p.abstract, - "authors": f"Author {i+1}", "tags": f"Tag{i+1}"}, - ) - papers.append(p) - db_session.commit() - return papers - - -@pytest.fixture -def sample_paper_with_user_data(db_session, sample_papers): - """给第一篇论文添加用户数据(收藏、阅读状态、笔记)。""" - paper = sample_papers[0] - now = datetime.now(timezone.utc) - db_session.add(UserBookmark(paper_id=paper.id, created_at=now)) - db_session.add(UserReadingStatus(paper_id=paper.id, status="read_summary", updated_at=now)) - db_session.add(UserNote( - paper_id=paper.id, - content="My notes on this paper", - created_at=now, - updated_at=now, - )) - db_session.commit() - return paper - - -@pytest.fixture -def tmp_data_dir(tmp_path): - """创建临时 data 目录结构。""" - tmp_dir = tmp_path / "data" / "tmp" - papers_dir = tmp_path / "data" / "papers" - tmp_dir.mkdir(parents=True) - papers_dir.mkdir(parents=True) - return tmp_path / "data" - - # ═══════════════════════════════════════════════════════════════════════ -# Cleaner 服务测试 -# ═══════════════════════════════════════════════════════════════════════ - - -class TestCleanupTmp: - """app/services/cleaner.py — cleanup_tmp 测试。""" - - def test_cleanup_removes_old_dirs(self, tmp_path, monkeypatch): - """超过 24 小时的临时目录应被删除。""" - tmp_dir = tmp_path / "tmp" - tmp_dir.mkdir() - - # 创建一个旧目录 - old_dir = tmp_dir / "2401.00001" - old_dir.mkdir() - (old_dir / "paper.pdf").write_text("fake pdf") - - # 修改目录时间为 25 小时前 - old_mtime = time.time() - 25 * 3600 - os.utime(old_dir, (old_mtime, old_mtime)) - - monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir) - from app.services.cleaner import cleanup_tmp - result = cleanup_tmp() - - assert result["scanned"] == 1 - assert result["removed"] == 1 - assert not old_dir.exists() - - def test_cleanup_keeps_recent_dirs(self, tmp_path, monkeypatch): - """24 小时内的临时目录应保留。""" - tmp_dir = tmp_path / "tmp" - tmp_dir.mkdir() - - recent_dir = tmp_dir / "2401.00002" - recent_dir.mkdir() - (recent_dir / "paper.pdf").write_text("fake pdf") - - monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir) - from app.services.cleaner import cleanup_tmp - result = cleanup_tmp() - - assert result["scanned"] == 1 - assert result["removed"] == 0 - assert recent_dir.exists() - - def test_cleanup_empty_dir(self, tmp_path, monkeypatch): - """data/tmp/ 不存在时安全返回。""" - monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_path / "nonexistent") - from app.services.cleaner import cleanup_tmp - result = cleanup_tmp() - assert result["scanned"] == 0 - assert result["removed"] == 0 - - def test_cleanup_mixed_ages(self, tmp_path, monkeypatch): - """混合新旧目录时只删除旧的。""" - tmp_dir = tmp_path / "tmp" - tmp_dir.mkdir() - - old_dir = tmp_dir / "2401.old" - old_dir.mkdir() - old_mtime = time.time() - 30 * 3600 - os.utime(old_dir, (old_mtime, old_mtime)) - - recent_dir = tmp_dir / "2401.new" - recent_dir.mkdir() - - monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir) - from app.services.cleaner import cleanup_tmp - result = cleanup_tmp() - - assert result["scanned"] == 2 - assert result["removed"] == 1 - assert not old_dir.exists() - assert recent_dir.exists() - - -class TestDeletePapersByDateRange: - """app/services/cleaner.py — delete_papers_by_date_range 测试。""" - - @pytest.mark.asyncio - async def test_delete_by_date_range(self, db_session, sample_papers): - """删除指定日期范围的论文。""" - from app.services.cleaner import delete_papers_by_date_range - - # 删除 1月11日 ~ 1月13日(3篇) - result = await delete_papers_by_date_range( - db_session, - date(2024, 1, 11), - date(2024, 1, 13), - ) - - assert result["deleted"] == 3 - assert result["total"] == 3 - assert result["status"] == "success" - - # 确认数据库中只剩 2 篇 - remaining = db_session.execute(select(Paper)).scalars().all() - assert len(remaining) == 2 - dates = {p.paper_date for p in remaining} - assert dates == {date(2024, 1, 10), date(2024, 1, 14)} - - @pytest.mark.asyncio - async def test_delete_creates_job_record(self, db_session, sample_papers): - """删除操作应创建 data_delete_jobs 记录。""" - from app.services.cleaner import delete_papers_by_date_range - - await delete_papers_by_date_range( - db_session, - date(2024, 1, 10), - date(2024, 1, 14), - ) - - jobs = db_session.execute(select(DataDeleteJob)).scalars().all() - assert len(jobs) == 1 - assert jobs[0].status == "success" - assert jobs[0].date_start == date(2024, 1, 10) - assert jobs[0].date_end == date(2024, 1, 14) - assert jobs[0].paper_count == 5 - assert jobs[0].completed_at is not None - - @pytest.mark.asyncio - async def test_delete_creates_crawl_log(self, db_session, sample_papers): - """删除操作应写入 crawl_logs。""" - from app.services.cleaner import delete_papers_by_date_range - - await delete_papers_by_date_range( - db_session, - date(2024, 1, 10), - date(2024, 1, 14), - ) - - logs = db_session.execute( - select(CrawlLog).where(CrawlLog.task == "delete") - ).scalars().all() - assert len(logs) == 1 - assert logs[0].status == "success" - - @pytest.mark.asyncio - async def test_delete_cascade_user_data(self, db_session, sample_paper_with_user_data): - """删除论文时应 cascade 删除关联的用户数据。""" - from app.services.cleaner import delete_papers_by_date_range - - paper = sample_paper_with_user_data - # 确认用户数据存在 - assert db_session.get(UserBookmark, db_session.execute( - select(UserBookmark).where(UserBookmark.paper_id == paper.id) - ).scalar_one_or_none().id if db_session.execute( - select(UserBookmark).where(UserBookmark.paper_id == paper.id) - ).scalar_one_or_none() else None) is not None or True - - # 删除 - result = await delete_papers_by_date_range( - db_session, - date(2024, 1, 10), - date(2024, 1, 10), - ) - assert result["deleted"] == 1 - - # 确认用户数据被 cascade 删除 - assert db_session.execute( - select(UserBookmark).where(UserBookmark.paper_id == paper.id) - ).scalar_one_or_none() is None - assert db_session.execute( - select(UserReadingStatus).where(UserReadingStatus.paper_id == paper.id) - ).scalar_one_or_none() is None - assert db_session.execute( - select(UserNote).where(UserNote.paper_id == paper.id) - ).scalar_one_or_none() is None - - @pytest.mark.asyncio - async def test_delete_removes_fts(self, db_session, sample_papers): - """删除论文时应同步删除 FTS5 索引。""" - import sqlalchemy - from app.services.cleaner import delete_papers_by_date_range - - await delete_papers_by_date_range( - db_session, - date(2024, 1, 10), - date(2024, 1, 14), - ) - - # FTS5 应为空 - rows = db_session.execute( - sqlalchemy.text("SELECT count(*) FROM papers_fts") - ).scalar() - assert rows == 0 - - @pytest.mark.asyncio - async def test_delete_removes_local_files(self, db_session, sample_papers, tmp_path, monkeypatch): - """删除论文时应删除本地文件目录。""" - from app.services.cleaner import delete_papers_by_date_range - - papers_dir = tmp_path / "papers" - papers_dir.mkdir() - (papers_dir / "2401.10001").mkdir() - (papers_dir / "2401.10001" / "meta.json").write_text("{}") - - monkeypatch.setattr("app.services.cleaner.PAPERS_DIR", papers_dir) - - result = await delete_papers_by_date_range( - db_session, - date(2024, 1, 10), - date(2024, 1, 10), - ) - assert result["deleted"] == 1 - assert not (papers_dir / "2401.10001").exists() - - @pytest.mark.asyncio - async def test_delete_empty_range(self, db_session, sample_papers): - """日期范围内无论文时返回 0。""" - from app.services.cleaner import delete_papers_by_date_range - - result = await delete_papers_by_date_range( - db_session, - date(2025, 1, 1), - date(2025, 1, 31), - ) - assert result["total"] == 0 - assert result["deleted"] == 0 - assert result["status"] == "success" - - -# ═══════════════════════════════════════════════════════════════════════ -# Admin Routes 测试 +# Admin Routes — 鉴权测试 # ═══════════════════════════════════════════════════════════════════════ @@ -363,12 +49,65 @@ class TestAdminAuth: def test_correct_token_accepted(self, auth_client, admin_headers): """正确 token 应被接受(crawl 可能会失败但不是 401)。""" - # mock crawl_daily 避免 API 调用 with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl: mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"} resp = auth_client.post("/admin/crawl", headers=admin_headers) assert resp.status_code != 401 + # ── summarize route auth ──────────────────────────────────────── + + def test_no_token_returns_401_for_summarize(self, client): + """无 Bearer token 返回 401。""" + resp = client.post("/admin/summarize") + assert resp.status_code in (401, 403) + + def test_wrong_token_returns_401_for_summarize(self, client): + resp = client.post( + "/admin/summarize", + headers={"Authorization": "Bearer wrong-token"}, + ) + assert resp.status_code == 401 + + def test_correct_token_batch_summarize(self, client, admin_headers): + """正确 token 调用 batch summarize,mock 掉服务层。""" + import app.config as config_mod + + original = config_mod.settings.ADMIN_TOKEN + config_mod.settings.ADMIN_TOKEN = ADMIN_TOKEN + try: + with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock: + mock.return_value = {"status": "success", "done": 0, "failed": 0, "total": 0} + resp = client.post("/admin/summarize", headers=admin_headers) + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + finally: + config_mod.settings.ADMIN_TOKEN = original + + def test_single_paper_not_found(self, client, admin_headers): + """单篇总结不存在的论文返回 404。""" + import app.config as config_mod + + original = config_mod.settings.ADMIN_TOKEN + config_mod.settings.ADMIN_TOKEN = ADMIN_TOKEN + try: + with patch( + "app.routes.admin.summarize_single", + new_callable=AsyncMock, + return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"}, + ): + resp = client.post( + "/admin/summarize/nonexistent.99999", + headers=admin_headers, + ) + assert resp.status_code == 404 + finally: + config_mod.settings.ADMIN_TOKEN = original + + +# ═══════════════════════════════════════════════════════════════════════ +# Admin Routes — Crawl +# ═══════════════════════════════════════════════════════════════════════ + class TestAdminCrawl: """POST /admin/crawl 测试。""" @@ -381,7 +120,6 @@ class TestAdminCrawl: assert resp.status_code == 200 data = resp.json() assert data["status"] == "success" - # 验证调用了 crawl_daily mock_crawl.assert_called_once() def test_crawl_specific_date(self, auth_client, admin_headers): @@ -395,6 +133,11 @@ class TestAdminCrawl: assert call_args[0][1] == "2024-01-15" +# ═══════════════════════════════════════════════════════════════════════ +# Admin Routes — Cleanup +# ═══════════════════════════════════════════════════════════════════════ + + class TestAdminCleanup: """POST /admin/cleanup 测试。""" @@ -421,6 +164,11 @@ class TestAdminCleanup: assert logs[-1].status == "success" +# ═══════════════════════════════════════════════════════════════════════ +# Admin Routes — Delete +# ═══════════════════════════════════════════════════════════════════════ + + class TestAdminDelete: """POST /admin/delete 测试。""" @@ -438,7 +186,7 @@ class TestAdminDelete: ) assert resp.status_code == 422 - def test_delete_with_confirm(self, auth_client, admin_headers, db_session, sample_papers): + def test_delete_with_confirm(self, auth_client, admin_headers, db_session, sample_papers_range): """confirm='DELETE' 时应执行删除。""" resp = auth_client.post( "/admin/delete", @@ -480,6 +228,11 @@ class TestAdminDelete: assert resp.status_code == 422 +# ═══════════════════════════════════════════════════════════════════════ +# Admin Routes — Logs +# ═══════════════════════════════════════════════════════════════════════ + + class TestAdminLogs: """GET /admin/logs 测试。""" @@ -494,7 +247,7 @@ class TestAdminLogs: resp = auth_client.get("/admin/logs") assert resp.status_code in (403, 401) - def test_logs_contains_data(self, auth_client, admin_headers, db_session, sample_papers): + def test_logs_contains_data(self, auth_client, admin_headers, db_session, sample_papers_range): """日志页面应包含日志数据。""" # 先创建一条日志 now = datetime.now(timezone.utc) @@ -519,11 +272,10 @@ class TestScheduler: def test_scheduler_disabled_by_default(self, monkeypatch): """SCHEDULER_ENABLED=false 时不应启动调度器。""" monkeypatch.setattr(settings, "SCHEDULER_ENABLED", False) - from app.services.scheduler import start_scheduler - # 重置模块级变量 import app.services.scheduler as sched_mod sched_mod._scheduler = None + from app.services.scheduler import start_scheduler result = start_scheduler() assert result is None @@ -550,7 +302,6 @@ class TestScheduler: @pytest.mark.asyncio async def test_scheduler_warns_multi_worker(self, monkeypatch, caplog): """APP_WORKERS > 1 时应打印警告。""" - import logging monkeypatch.setattr(settings, "SCHEDULER_ENABLED", True) monkeypatch.setattr(settings, "APP_WORKERS", 4) import app.services.scheduler as sched_mod diff --git a/tests/test_cleaner.py b/tests/test_cleaner.py new file mode 100644 index 0000000..ab00b92 --- /dev/null +++ b/tests/test_cleaner.py @@ -0,0 +1,279 @@ +"""Cleaner 服务测试 — cleanup_tmp、delete_papers_by_date_range。""" + +from __future__ import annotations + +import os +import time +from datetime import date, datetime, timezone + +import pytest +from sqlalchemy import select + +from app.config import settings +from app.models import ( + CrawlLog, + DataDeleteJob, + Paper, + UserBookmark, + UserNote, + UserReadingStatus, +) + + +# ── Fixtures ──────────────────────────────────────────────────────────── + + +@pytest.fixture +def sample_paper_with_user_data(db_session, sample_papers_range): + """给第一篇论文添加用户数据(收藏、阅读状态、笔记)。""" + paper = sample_papers_range[0] + now = datetime.now(timezone.utc) + db_session.add(UserBookmark(paper_id=paper.id, created_at=now)) + db_session.add(UserReadingStatus(paper_id=paper.id, status="read_summary", updated_at=now)) + db_session.add(UserNote( + paper_id=paper.id, + content="My notes on this paper", + created_at=now, + updated_at=now, + )) + db_session.commit() + return paper + + +# ═══════════════════════════════════════════════════════════════════════ +# cleanup_tmp 测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestCleanupTmp: + """app/services/cleaner.py — cleanup_tmp 测试。""" + + def test_cleanup_removes_old_dirs(self, tmp_path, monkeypatch): + """超过 24 小时的临时目录应被删除。""" + tmp_dir = tmp_path / "tmp" + tmp_dir.mkdir() + + # 创建一个旧目录 + old_dir = tmp_dir / "2401.00001" + old_dir.mkdir() + (old_dir / "paper.pdf").write_text("fake pdf") + + # 修改目录时间为 25 小时前 + old_mtime = time.time() - 25 * 3600 + os.utime(old_dir, (old_mtime, old_mtime)) + + monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir) + from app.services.cleaner import cleanup_tmp + result = cleanup_tmp() + + assert result["scanned"] == 1 + assert result["removed"] == 1 + assert not old_dir.exists() + + def test_cleanup_keeps_recent_dirs(self, tmp_path, monkeypatch): + """24 小时内的临时目录应保留。""" + tmp_dir = tmp_path / "tmp" + tmp_dir.mkdir() + + recent_dir = tmp_dir / "2401.00002" + recent_dir.mkdir() + (recent_dir / "paper.pdf").write_text("fake pdf") + + monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir) + from app.services.cleaner import cleanup_tmp + result = cleanup_tmp() + + assert result["scanned"] == 1 + assert result["removed"] == 0 + assert recent_dir.exists() + + def test_cleanup_empty_dir(self, tmp_path, monkeypatch): + """data/tmp/ 不存在时安全返回。""" + monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_path / "nonexistent") + from app.services.cleaner import cleanup_tmp + result = cleanup_tmp() + assert result["scanned"] == 0 + assert result["removed"] == 0 + + def test_cleanup_mixed_ages(self, tmp_path, monkeypatch): + """混合新旧目录时只删除旧的。""" + tmp_dir = tmp_path / "tmp" + tmp_dir.mkdir() + + old_dir = tmp_dir / "2401.old" + old_dir.mkdir() + old_mtime = time.time() - 30 * 3600 + os.utime(old_dir, (old_mtime, old_mtime)) + + recent_dir = tmp_dir / "2401.new" + recent_dir.mkdir() + + monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir) + from app.services.cleaner import cleanup_tmp + result = cleanup_tmp() + + assert result["scanned"] == 2 + assert result["removed"] == 1 + assert not old_dir.exists() + assert recent_dir.exists() + + +# ═══════════════════════════════════════════════════════════════════════ +# delete_papers_by_date_range 测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestDeletePapersByDateRange: + """app/services/cleaner.py — delete_papers_by_date_range 测试。""" + + @pytest.mark.asyncio + async def test_delete_by_date_range(self, db_session, sample_papers_range): + """删除指定日期范围的论文。""" + from app.services.cleaner import delete_papers_by_date_range + + # 删除 1月11日 ~ 1月13日(3篇) + result = await delete_papers_by_date_range( + db_session, + date(2024, 1, 11), + date(2024, 1, 13), + ) + + assert result["deleted"] == 3 + assert result["total"] == 3 + assert result["status"] == "success" + + # 确认数据库中只剩 2 篇 + remaining = db_session.execute(select(Paper)).scalars().all() + assert len(remaining) == 2 + dates = {p.paper_date for p in remaining} + assert dates == {date(2024, 1, 10), date(2024, 1, 14)} + + @pytest.mark.asyncio + async def test_delete_creates_job_record(self, db_session, sample_papers_range): + """删除操作应创建 data_delete_jobs 记录。""" + from app.services.cleaner import delete_papers_by_date_range + + await delete_papers_by_date_range( + db_session, + date(2024, 1, 10), + date(2024, 1, 14), + ) + + jobs = db_session.execute(select(DataDeleteJob)).scalars().all() + assert len(jobs) == 1 + assert jobs[0].status == "success" + assert jobs[0].date_start == date(2024, 1, 10) + assert jobs[0].date_end == date(2024, 1, 14) + assert jobs[0].paper_count == 5 + assert jobs[0].completed_at is not None + + @pytest.mark.asyncio + async def test_delete_creates_crawl_log(self, db_session, sample_papers_range): + """删除操作应写入 crawl_logs。""" + from app.services.cleaner import delete_papers_by_date_range + + await delete_papers_by_date_range( + db_session, + date(2024, 1, 10), + date(2024, 1, 14), + ) + + logs = db_session.execute( + select(CrawlLog).where(CrawlLog.task == "delete") + ).scalars().all() + assert len(logs) == 1 + assert logs[0].status == "success" + + @pytest.mark.asyncio + async def test_delete_cascade_user_data(self, db_session, sample_paper_with_user_data): + """删除论文时应 cascade 删除关联的用户数据。""" + from app.services.cleaner import delete_papers_by_date_range + + paper = sample_paper_with_user_data + + # 删除 + result = await delete_papers_by_date_range( + db_session, + date(2024, 1, 10), + date(2024, 1, 10), + ) + assert result["deleted"] == 1 + + # 确认用户数据被 cascade 删除 + assert db_session.execute( + select(UserBookmark).where(UserBookmark.paper_id == paper.id) + ).scalar_one_or_none() is None + assert db_session.execute( + select(UserReadingStatus).where(UserReadingStatus.paper_id == paper.id) + ).scalar_one_or_none() is None + assert db_session.execute( + select(UserNote).where(UserNote.paper_id == paper.id) + ).scalar_one_or_none() is None + + @pytest.mark.asyncio + async def test_delete_removes_fts(self, db_session, sample_papers_range): + """删除论文时应同步删除 FTS5 索引。""" + import sqlalchemy + from app.services.cleaner import delete_papers_by_date_range + + await delete_papers_by_date_range( + db_session, + date(2024, 1, 10), + date(2024, 1, 14), + ) + + # FTS5 应为空 + rows = db_session.execute( + sqlalchemy.text("SELECT count(*) FROM papers_fts") + ).scalar() + assert rows == 0 + + @pytest.mark.asyncio + async def test_delete_removes_local_files(self, db_session, sample_papers_range, tmp_path, monkeypatch): + """删除论文时应删除本地文件目录。""" + from app.services.cleaner import delete_papers_by_date_range + + papers_dir = tmp_path / "papers" + papers_dir.mkdir() + (papers_dir / "2401.10001").mkdir() + (papers_dir / "2401.10001" / "meta.json").write_text("{}") + + monkeypatch.setattr("app.services.cleaner.PAPERS_DIR", papers_dir) + + result = await delete_papers_by_date_range( + db_session, + date(2024, 1, 10), + date(2024, 1, 10), + ) + assert result["deleted"] == 1 + assert not (papers_dir / "2401.10001").exists() + + @pytest.mark.asyncio + async def test_delete_empty_range(self, db_session, sample_papers_range): + """日期范围内无论文时返回 0。""" + from app.services.cleaner import delete_papers_by_date_range + + result = await delete_papers_by_date_range( + db_session, + date(2025, 1, 1), + date(2025, 1, 31), + ) + assert result["total"] == 0 + assert result["deleted"] == 0 + assert result["status"] == "success" + + @pytest.mark.asyncio + async def test_cleaner_works_without_chroma(self, db_session, sample_papers_with_summary, monkeypatch): + """CHROMA 关闭时删除论文正常工作。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + emb._chroma.reset() + + from app.services.cleaner import delete_papers_by_date_range + result = await delete_papers_by_date_range( + db_session, + date(2024, 1, 10), + date(2024, 1, 10), + ) + assert result["status"] == "success" + assert result["deleted"] == 1 diff --git a/tests/test_embedder.py b/tests/test_embedder.py new file mode 100644 index 0000000..f9b491e --- /dev/null +++ b/tests/test_embedder.py @@ -0,0 +1,163 @@ +"""Embedder / Chroma 服务测试 — 初始化、索引、embedding API。""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from app.config import settings + + +# ═══════════════════════════════════════════════════════════════════════ +# 初始化 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestEmbedderInit: + """embedder.py 初始化测试。""" + + def test_chroma_disabled_skip_init(self, monkeypatch): + """CHROMA_ENABLED=false 时不初始化。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + emb._chroma.reset() + emb.init_chroma() + assert emb._chroma._client is None + + def test_chroma_init_success(self, monkeypatch, tmp_path): + """CHROMA_ENABLED=true 时初始化成功。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", True) + monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma")) + + import app.services.embedder as emb + emb._chroma.reset() + emb.init_chroma() + + assert emb._chroma._client is not None + assert emb._chroma._collection is not None + + # 清理 + emb._chroma.reset() + + def test_get_collection_returns_none_when_disabled(self, monkeypatch): + """CHROMA_ENABLED=false 时 get_collection 返回 None。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + emb._chroma.reset() + assert emb.get_collection() is None + + +# ═══════════════════════════════════════════════════════════════════════ +# 索引 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestEmbedderIndexing: + """embedder.py 索引测试。""" + + def test_index_paper_disabled(self, monkeypatch): + """CHROMA_ENABLED=false 时 index_paper 返回 False。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + emb._chroma.reset() + assert emb.index_paper("test-id") is False + + def test_index_paper_no_api_config(self, monkeypatch, tmp_path): + """没有 EMBED_API_BASE 时返回 False。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", True) + monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma")) + monkeypatch.setattr(settings, "EMBED_API_BASE", "") + monkeypatch.setattr(settings, "EMBED_MODEL", "") + + import app.services.embedder as emb + emb._chroma.reset() + emb.init_chroma() + + result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"}) + assert result is False + + emb._chroma.reset() + + def test_index_batch_disabled(self, monkeypatch): + """CHROMA_ENABLED=false 时 index_batch 返回全失败。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + emb._chroma.reset() + result = emb.index_batch(["a", "b"]) + assert result["success"] == 0 + assert result["failed"] == 2 + + def test_index_batch_empty(self, monkeypatch): + """空列表时返回 0。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + result = emb.index_batch([]) + assert result["total"] == 0 + + def test_delete_paper_disabled(self, monkeypatch): + """CHROMA_ENABLED=false 时 delete_paper 返回 False。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + emb._chroma.reset() + assert emb.delete_paper("test-id") is False + + def test_search_similar_disabled(self, monkeypatch): + """CHROMA_ENABLED=false 时 search_similar 返回空列表。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + emb._chroma.reset() + assert emb.search_similar("test query") == [] + + +# ═══════════════════════════════════════════════════════════════════════ +# Embedding API +# ═══════════════════════════════════════════════════════════════════════ + + +class TestEmbeddingApi: + """_get_embedding 测试。""" + + def test_no_api_base_returns_none(self, monkeypatch): + """EMBED_API_BASE 为空时返回 None。""" + monkeypatch.setattr(settings, "EMBED_API_BASE", "") + monkeypatch.setattr(settings, "EMBED_MODEL", "") + import app.services.embedder as emb + assert emb._get_embedding("test") is None + + def test_dimension_mismatch_returns_none(self, monkeypatch): + """维度不匹配时返回 None。""" + monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake") + monkeypatch.setattr(settings, "EMBED_MODEL", "test-model") + monkeypatch.setattr(settings, "EMBED_API_KEY", "") + monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 128) + monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5) + + import app.services.embedder as emb + + mock_resp = MagicMock() + mock_resp.json.return_value = {"data": [{"embedding": [0.1] * 64}]} + mock_resp.raise_for_status = MagicMock() + + with patch("httpx.Client") as mock_client: + mock_client.return_value.__enter__ = MagicMock(return_value=mock_resp) + mock_client.return_value.__exit__ = MagicMock(return_value=False) + result = emb._get_embedding("test") + assert result is None + + def test_api_failure_returns_none(self, monkeypatch): + """API 调用失败时返回 None。""" + monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake") + monkeypatch.setattr(settings, "EMBED_MODEL", "test-model") + monkeypatch.setattr(settings, "EMBED_API_KEY", "") + monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 0) + monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5) + + import app.services.embedder as emb + + with patch("httpx.Client") as mock_client: + mock_client.return_value.__enter__ = MagicMock() + mock_client.return_value.__exit__ = MagicMock(return_value=False) + mock_client.return_value.__enter__.return_value.post.side_effect = Exception("timeout") + result = emb._get_embedding("test") + assert result is None diff --git a/tests/test_image_extractor.py b/tests/test_image_extractor.py new file mode 100644 index 0000000..3622f7c --- /dev/null +++ b/tests/test_image_extractor.py @@ -0,0 +1,88 @@ +"""LaTeX 图片提取测试 — 从 .tex 源码中提取图片文件。""" + +from __future__ import annotations + +import pytest + + +# ═══════════════════════════════════════════════════════════════════════ +# Image Extraction +# ═══════════════════════════════════════════════════════════════════════ + + +class TestImageExtraction: + """LaTeX 图片提取测试。""" + + @pytest.mark.asyncio + async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path): + """源码目录不存在时返回 0。""" + monkeypatch.setattr("app.services.pdf_downloader.tmp_dir", lambda x: tmp_path / "tmp" / x) + monkeypatch.setattr("app.services.pdf_downloader.paper_dir", lambda x: tmp_path / "papers" / x) + from app.services.image_extractor import extract_images_from_source + result = await extract_images_from_source("2401.99999") + assert result == 0 + + @pytest.mark.asyncio + async def test_extract_images_from_tex(self, monkeypatch, tmp_path): + """从 .tex 文件中提取图片。""" + from app.services.image_extractor import extract_images_from_source + + tmp_source = tmp_path / "tmp" / "2401.00001" / "source" + tmp_source.mkdir(parents=True) + + images_dir = tmp_source / "figs" + images_dir.mkdir() + (images_dir / "figure1.png").write_bytes(b"\x89PNG\r\n") + (images_dir / "figure2.jpg").write_bytes(b"\xff\xd8\xff\xe0") + + # 创建 .tex 文件 + tex_content = r""" +\documentclass{article} +\begin{document} +\begin{figure} + \includegraphics[width=0.8\textwidth]{figs/figure1.png} + \includegraphics{figs/figure2.jpg} + \includegraphics[angle=90]{figs/nonexistent.pdf} +\end{figure} +\end{document} +""" + (tmp_source / "main.tex").write_text(tex_content) + + papers_dir = tmp_path / "papers" / "2401.00001" + monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x) + monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x) + + # Mock download_source_zip to avoid real network call (source dir already exists) + async def _noop_download(*args, **kwargs): + pass + + monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download) + + result = await extract_images_from_source("2401.00001") + + assert result == 2 + dest_images = papers_dir / "images" + assert dest_images.exists() + assert (dest_images / "figure1.png").exists() + assert (dest_images / "figure2.jpg").exists() + + @pytest.mark.asyncio + async def test_extract_images_empty_tex(self, monkeypatch, tmp_path): + """.tex 文件无图片时返回 0。""" + from app.services.image_extractor import extract_images_from_source + + tmp_source = tmp_path / "tmp" / "2401.00002" / "source" + tmp_source.mkdir(parents=True) + (tmp_source / "main.tex").write_text(r"\documentclass{article}\begin{document}Hello\end{document}") + + monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x) + monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x) + + # Mock download_source_zip to avoid real network call + async def _noop_download(*args, **kwargs): + pass + + monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download) + + result = await extract_images_from_source("2401.00002") + assert result == 0 diff --git a/tests/test_pages.py b/tests/test_pages.py new file mode 100644 index 0000000..c7c6e3a --- /dev/null +++ b/tests/test_pages.py @@ -0,0 +1,224 @@ +"""页面路由测试 — detail、trends、compare、graceful degradation。""" + +from __future__ import annotations + +from datetime import date +from unittest.mock import patch as upatch + +import pytest + +from app.config import settings + + +# ═══════════════════════════════════════════════════════════════════════ +# Detail 页 & 相似论文 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestDetailPage: + """论文详情页测试。""" + + def test_detail_page_renders(self, client, sample_papers_with_summary): + """详情页正常渲染。""" + resp = client.get("/paper/2401.20001") + assert resp.status_code == 200 + assert "测试论文" in resp.text or "Test Paper" in resp.text + + def test_detail_page_not_found(self, client): + """不存在的论文返回 404。""" + resp = client.get("/paper/nonexistent.99999") + assert resp.status_code == 404 + + +# ═══════════════════════════════════════════════════════════════════════ +# Similar API(详情页内联) +# ═══════════════════════════════════════════════════════════════════════ + + +class TestDetailSimilarPapers: + """详情页相似论文模块测试(CHROMA 关闭时的降级行为)。""" + + def test_detail_page_renders_with_similar(self, client, sample_papers_with_summary): + """详情页正常渲染(含相似论文模块)。""" + resp = client.get("/paper/2401.20001") + assert resp.status_code == 200 + assert "测试论文" in resp.text or "Test Paper" in resp.text + + def test_detail_page_not_found_similar(self, client): + """不存在的论文返回 404。""" + resp = client.get("/paper/nonexistent.99999") + assert resp.status_code == 404 + + +# ═══════════════════════════════════════════════════════════════════════ +# Trends Dashboard +# ═══════════════════════════════════════════════════════════════════════ + + +class TestTrendsDashboard: + """趋势看板测试。""" + + def test_trends_page_renders(self, client, sample_papers_with_summary): + """趋势看板页面正常渲染。""" + resp = client.get("/trends") + assert resp.status_code == 200 + assert "趋势看板" in resp.text + assert "chart" in resp.text.lower() or "Chart" in resp.text + + def test_trends_api_returns_data(self, client, sample_papers_with_summary): + """趋势 API 返回正确数据结构。""" + resp = client.get("/api/stats/trends") + assert resp.status_code == 200 + data = resp.json() + + assert "daily_counts" in data + assert "top_tags" in data + assert "upvotes_dist" in data + assert "summary_completion" in data + + assert isinstance(data["daily_counts"], list) + assert isinstance(data["top_tags"], list) + assert isinstance(data["upvotes_dist"], list) + assert isinstance(data["summary_completion"], list) + + def test_trends_api_daily_counts(self, client, sample_papers_with_summary): + """每日论文数量数据正确。""" + # 使用测试数据的日期范围 + with upatch("app.services.trends.date") as mock_date: + mock_date.today.return_value = date(2024, 1, 20) + mock_date.side_effect = lambda *a, **kw: date(*a, **kw) + + resp = client.get("/api/stats/trends") + data = resp.json() + assert len(data["daily_counts"]) == 5 + for item in data["daily_counts"]: + assert "date" in item + assert "count" in item + assert item["count"] == 1 + + def test_trends_api_top_tags(self, client, sample_papers_with_summary): + """热门标签数据正确。""" + resp = client.get("/api/stats/trends") + data = resp.json() + tags = {t["tag"]: t["count"] for t in data["top_tags"]} + assert "NLP" in tags + assert tags["NLP"] == 5 # 所有论文都有 NLP + + def test_trends_api_summary_completion(self, client, sample_papers_with_summary): + """总结完成率数据正确。""" + resp = client.get("/api/stats/trends") + data = resp.json() + statuses = {s["status"]: s["count"] for s in data["summary_completion"]} + assert "done" in statuses + assert statuses["done"] == 4 # 4 篇已完成 + + def test_trends_empty_db(self, client): + """无数据时不崩溃。""" + resp = client.get("/api/stats/trends") + assert resp.status_code == 200 + data = resp.json() + assert data["daily_counts"] == [] + assert data["top_tags"] == [] + + +# ═══════════════════════════════════════════════════════════════════════ +# Compare Page +# ═══════════════════════════════════════════════════════════════════════ + + +class TestComparePage: + """论文对比页测试。""" + + def test_compare_page_no_ids(self, client): + """无 ID 时显示输入表单。""" + resp = client.get("/compare") + assert resp.status_code == 200 + assert "对比" in resp.text + + def test_compare_page_with_ids(self, client, sample_papers_with_summary): + """对比多篇论文正常渲染。""" + resp = client.get("/compare?ids=2401.20001,2401.20002") + assert resp.status_code == 200 + assert "2401.20001" in resp.text + assert "2401.20002" in resp.text + # 应包含对比字段 + assert "一句话摘要" in resp.text + assert "研究问题" in resp.text + + def test_compare_page_max_5(self, client, sample_papers_with_summary): + """最多 5 篇。""" + ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005" + resp = client.get(f"/compare?ids={ids}") + assert resp.status_code == 200 + + def test_compare_page_over_5_truncates(self, client, sample_papers_with_summary): + """超过 5 篇截断。""" + ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005,2401.20006" + resp = client.get(f"/compare?ids={ids}") + assert resp.status_code == 200 + + def test_compare_page_invalid_ids(self, client): + """无效 ID 时显示空结果。""" + resp = client.get("/compare?ids=nonexistent.99999") + assert resp.status_code == 200 + + def test_compare_page_shows_no_summary_placeholder(self, client, sample_papers_with_summary): + """无总结的论文显示占位文本。""" + # 2401.20005 没有 summary(status=pending) + resp = client.get("/compare?ids=2401.20005") + assert resp.status_code == 200 + assert "暂无总结" in resp.text + + +# ═══════════════════════════════════════════════════════════════════════ +# Nav Bar +# ═══════════════════════════════════════════════════════════════════════ + + +class TestNavBar: + """导航栏测试。""" + + def test_nav_includes_trends_link(self, client): + """导航栏应包含趋势链接。""" + resp = client.get("/search") + assert resp.status_code == 200 + assert "/trends" in resp.text + + def test_nav_includes_compare_implicitly(self, client): + """compare 页面可访问。""" + resp = client.get("/compare") + assert resp.status_code == 200 + + +# ═══════════════════════════════════════════════════════════════════════ +# Graceful Degradation(CHROMA_ENABLED=false) +# ═══════════════════════════════════════════════════════════════════════ + + +class TestGracefulDegradation: + """CHROMA_ENABLED=false 时优雅降级测试。""" + + def test_search_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary): + """CHROMA 关闭时 FTS5 搜索正常工作。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = client.get("/search?q=Test") + assert resp.status_code == 200 + assert "Test Paper" in resp.text or "测试论文" in resp.text + + def test_detail_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary): + """CHROMA 关闭时详情页正常工作。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = client.get("/paper/2401.20001") + assert resp.status_code == 200 + + def test_trends_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary): + """CHROMA 关闭时趋势看板正常工作。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = client.get("/trends") + assert resp.status_code == 200 + + def test_compare_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary): + """CHROMA 关闭时对比页正常工作。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = client.get("/compare?ids=2401.20001,2401.20002") + assert resp.status_code == 200 diff --git a/tests/test_phase5.py b/tests/test_phase5.py deleted file mode 100644 index 1ed2f68..0000000 --- a/tests/test_phase5.py +++ /dev/null @@ -1,660 +0,0 @@ -"""Phase 5 后续增强测试 — embedder、semantic search、trends、compare、image extraction。""" - -from __future__ import annotations - -import json -import shutil -import time -from datetime import date, datetime, timezone -from pathlib import Path -from unittest.mock import MagicMock, patch - -import pytest -from fastapi.testclient import TestClient -from sqlalchemy import select - -from app.config import settings -from app.database import get_db -from app.models import ( - Paper, - PaperAuthor, - PaperSummary, - PaperTag, - SummaryStatus, -) - - -# ── Fixtures ──────────────────────────────────────────────────────────── - -ADMIN_TOKEN = "test-admin-token-12345" - - -@pytest.fixture -def admin_headers(): - return {"Authorization": f"Bearer " + ADMIN_TOKEN} - - -@pytest.fixture -def auth_client(client, monkeypatch): - monkeypatch.setattr(settings, "ADMIN_TOKEN", ADMIN_TOKEN) - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - return client - - -@pytest.fixture -def sample_papers_with_summary(db_session): - """插入多篇带总结的论文。""" - now = datetime.now(timezone.utc) - papers = [] - for i, (arxiv_id, paper_date_str) in enumerate([ - ("2401.20001", "2024-01-10"), - ("2401.20002", "2024-01-11"), - ("2401.20003", "2024-01-12"), - ("2401.20004", "2024-01-13"), - ("2401.20005", "2024-01-14"), - ]): - paper_date = date.fromisoformat(paper_date_str) - p = Paper( - arxiv_id=arxiv_id, - title_en=f"Test Paper {i+1}", - title_zh=f"测试论文 {i+1}", - abstract=f"Abstract for paper {i+1}.", - paper_date=paper_date, - crawled_at=now, - upvotes=i * 10 + 5, - ) - db_session.add(p) - db_session.flush() - - db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0)) - db_session.add(PaperTag(paper_id=p.id, tag="NLP", source="hf")) - db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf")) - - db_session.add(SummaryStatus( - paper_id=p.id, - status="done" if i < 4 else "pending", - quality="normal", - )) - - # 添加总结(前 4 篇) - if i < 4: - from app.services.schemas import SummarySchema - summary = PaperSummary( - paper_id=p.id, - one_line=f"这是论文{i+1}的一句话摘要", - difficulty="中级", - motivation_problem=f"论文{i+1}的研究问题", - motivation_goal=f"论文{i+1}的研究目标", - method_key_idea=f"论文{i+1}的关键思路", - method_overview=f"论文{i+1}的方法概述", - updated_at=now, - full_json=json.dumps({"title_zh": f"测试论文 {i+1}"}), - ) - db_session.add(summary) - - # FTS5 - import sqlalchemy - db_session.execute( - sqlalchemy.text( - "INSERT INTO papers_fts(rowid, title_en, title_zh, abstract, authors, tags) " - "VALUES (:id, :title_en, :title_zh, :abstract, :authors, :tags)" - ), - { - "id": p.id, - "title_en": p.title_en, - "title_zh": p.title_zh or "", - "abstract": p.abstract or "", - "authors": f"Author {i+1}", - "tags": f"NLP, Tag{i+1}", - }, - ) - papers.append(p) - db_session.commit() - return papers - - -# ═══════════════════════════════════════════════════════════════════════ -# Embedder 服务测试 -# ═══════════════════════════════════════════════════════════════════════ - - -class TestEmbedderInit: - """embedder.py 初始化测试。""" - - def test_chroma_disabled_skip_init(self, monkeypatch): - """CHROMA_ENABLED=false 时不初始化。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - import app.services.embedder as emb - emb._chroma.reset() - emb.init_chroma() - assert emb._chroma._client is None - - def test_chroma_init_success(self, monkeypatch, tmp_path): - """CHROMA_ENABLED=true 时初始化成功。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", True) - monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma")) - - import app.services.embedder as emb - emb._chroma.reset() - emb.init_chroma() - - assert emb._chroma._client is not None - assert emb._chroma._collection is not None - - # 清理 - emb._chroma.reset() - - def test_get_collection_returns_none_when_disabled(self, monkeypatch): - """CHROMA_ENABLED=false 时 get_collection 返回 None。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - import app.services.embedder as emb - emb._chroma.reset() - assert emb.get_collection() is None - - -class TestEmbedderIndexing: - """embedder.py 索引测试。""" - - def test_index_paper_disabled(self, monkeypatch): - """CHROMA_ENABLED=false 时 index_paper 返回 False。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - import app.services.embedder as emb - emb._chroma.reset() - assert emb.index_paper("test-id") is False - - def test_index_paper_no_api_config(self, monkeypatch, tmp_path): - """没有 EMBED_API_BASE 时返回 False。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", True) - monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma")) - monkeypatch.setattr(settings, "EMBED_API_BASE", "") - monkeypatch.setattr(settings, "EMBED_MODEL", "") - - import app.services.embedder as emb - emb._chroma.reset() - emb.init_chroma() - - result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"}) - assert result is False - - emb._chroma.reset() - - def test_index_batch_disabled(self, monkeypatch): - """CHROMA_ENABLED=false 时 index_batch 返回全失败。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - import app.services.embedder as emb - emb._chroma.reset() - result = emb.index_batch(["a", "b"]) - assert result["success"] == 0 - assert result["failed"] == 2 - - def test_index_batch_empty(self, monkeypatch): - """空列表时返回 0。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - import app.services.embedder as emb - result = emb.index_batch([]) - assert result["total"] == 0 - - def test_delete_paper_disabled(self, monkeypatch): - """CHROMA_ENABLED=false 时 delete_paper 返回 False。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - import app.services.embedder as emb - emb._chroma.reset() - assert emb.delete_paper("test-id") is False - - def test_search_similar_disabled(self, monkeypatch): - """CHROMA_ENABLED=false 时 search_similar 返回空列表。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - import app.services.embedder as emb - emb._chroma.reset() - assert emb.search_similar("test query") == [] - - -class TestEmbeddingApi: - """_get_embedding 测试。""" - - def test_no_api_base_returns_none(self, monkeypatch): - """EMBED_API_BASE 为空时返回 None。""" - monkeypatch.setattr(settings, "EMBED_API_BASE", "") - monkeypatch.setattr(settings, "EMBED_MODEL", "") - import app.services.embedder as emb - assert emb._get_embedding("test") is None - - def test_dimension_mismatch_returns_none(self, monkeypatch): - """维度不匹配时返回 None。""" - monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake") - monkeypatch.setattr(settings, "EMBED_MODEL", "test-model") - monkeypatch.setattr(settings, "EMBED_API_KEY", "") - monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 128) - monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5) - - import app.services.embedder as emb - - mock_resp = MagicMock() - mock_resp.json.return_value = {"data": [{"embedding": [0.1] * 64}]} - mock_resp.raise_for_status = MagicMock() - - with patch("httpx.Client") as mock_client: - mock_client.return_value.__enter__ = MagicMock(return_value=mock_resp) - mock_client.return_value.__exit__ = MagicMock(return_value=False) - result = emb._get_embedding("test") - assert result is None - - def test_api_failure_returns_none(self, monkeypatch): - """API 调用失败时返回 None。""" - monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake") - monkeypatch.setattr(settings, "EMBED_MODEL", "test-model") - monkeypatch.setattr(settings, "EMBED_API_KEY", "") - monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 0) - monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5) - - import app.services.embedder as emb - - with patch("httpx.Client") as mock_client: - mock_client.return_value.__enter__ = MagicMock() - mock_client.return_value.__exit__ = MagicMock(return_value=False) - mock_client.return_value.__enter__.return_value.post.side_effect = Exception("timeout") - result = emb._get_embedding("test") - assert result is None - - -# ═══════════════════════════════════════════════════════════════════════ -# Searcher 语义模式测试 -# ═══════════════════════════════════════════════════════════════════════ - - -class TestSearchSemanticMode: - """searcher.py 语义搜索模式测试。""" - - def test_keyword_mode_default(self, db_session, sample_papers_with_summary): - """默认 keyword 模式走 FTS5。""" - from app.services.searcher import search_papers - result = search_papers(db_session, query="Test Paper", mode="keyword") - assert result["total"] >= 1 - assert result["distances"] == {} - - def test_semantic_mode_disabled_fallback(self, db_session, monkeypatch, sample_papers_with_summary): - """CHROMA_ENABLED=false + semantic 模式走 FTS5。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - from app.services.searcher import search_papers - result = search_papers(db_session, query="Test", mode="semantic") - # 应回退到 FTS5 - assert result["total"] >= 1 - - def test_search_returns_distances_dict(self, db_session, sample_papers_with_summary): - """搜索结果应包含 distances 字段。""" - from app.services.searcher import search_papers - result = search_papers(db_session, query="Test Paper") - assert "distances" in result - assert isinstance(result["distances"], dict) - - def test_empty_query_returns_empty(self, db_session): - """空查询无标签时返回空。""" - from app.services.searcher import search_papers - result = search_papers(db_session) - assert result["total"] == 0 - assert result["results"] == [] - - def test_tag_only_search(self, db_session, sample_papers_with_summary): - """仅标签搜索。""" - from app.services.searcher import search_papers - result = search_papers(db_session, tag="NLP") - assert result["total"] >= 1 - - -# ═══════════════════════════════════════════════════════════════════════ -# Search Routes 测试 -# ═══════════════════════════════════════════════════════════════════════ - - -class TestSearchRoutes: - """搜索路由测试。""" - - def test_search_page_keyword(self, auth_client, sample_papers_with_summary): - """搜索页 keyword 模式。""" - resp = auth_client.get("/search?q=Test&mode=keyword") - assert resp.status_code == 200 - assert "Test" in resp.text or "测试" in resp.text - - def test_search_page_semantic_disabled(self, auth_client, monkeypatch, sample_papers_with_summary): - """语义模式 CHROMA_ENABLED=false 时仍能工作。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - resp = auth_client.get("/search?q=Test&mode=semantic") - assert resp.status_code == 200 - - def test_search_api_with_mode(self, auth_client, sample_papers_with_summary): - """搜索 API 支持 mode 参数。""" - resp = auth_client.get("/api/search?q=Test&mode=keyword") - assert resp.status_code == 200 - data = resp.json() - assert "results" in data - assert "total" in data - - -# ═══════════════════════════════════════════════════════════════════════ -# Similar Paper API 测试 -# ═══════════════════════════════════════════════════════════════════════ - - -class TestSimilarAPI: - """相似论文 API 测试。""" - - def test_similar_api_disabled(self, auth_client, monkeypatch, sample_papers_with_summary): - """CHROMA_ENABLED=false 时返回空列表。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - resp = auth_client.get("/api/similar/2401.20001") - assert resp.status_code == 200 - data = resp.json() - assert data["results"] == [] - - def test_similar_api_paper_not_found(self, auth_client, monkeypatch): - """不存在的论文返回空。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - resp = auth_client.get("/api/similar/nonexistent.99999") - assert resp.status_code == 200 - assert resp.json()["results"] == [] - - def test_similar_api_with_top_k(self, auth_client, monkeypatch, sample_papers_with_summary): - """top_k 参数控制返回数量。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - resp = auth_client.get("/api/similar/2401.20001?top_k=3") - assert resp.status_code == 200 - - -# ═══════════════════════════════════════════════════════════════════════ -# Detail Page 相似论文测试 -# ═══════════════════════════════════════════════════════════════════════ - - -class TestDetailSimilarPapers: - """详情页相似论文模块测试。""" - - def test_detail_page_renders(self, auth_client, sample_papers_with_summary): - """详情页正常渲染。""" - resp = auth_client.get("/paper/2401.20001") - assert resp.status_code == 200 - assert "测试论文" in resp.text or "Test Paper" in resp.text - - def test_detail_page_not_found(self, auth_client): - """不存在的论文返回 404。""" - resp = auth_client.get("/paper/nonexistent.99999") - assert resp.status_code == 404 - - -# ═══════════════════════════════════════════════════════════════════════ -# Trends Dashboard 测试 -# ═══════════════════════════════════════════════════════════════════════ - - -class TestTrendsDashboard: - """趋势看板测试。""" - - def test_trends_page_renders(self, auth_client, sample_papers_with_summary): - """趋势看板页面正常渲染。""" - resp = auth_client.get("/trends") - assert resp.status_code == 200 - assert "趋势看板" in resp.text - assert "chart" in resp.text.lower() or "Chart" in resp.text - - def test_trends_api_returns_data(self, auth_client, sample_papers_with_summary): - """趋势 API 返回正确数据结构。""" - resp = auth_client.get("/api/stats/trends") - assert resp.status_code == 200 - data = resp.json() - - assert "daily_counts" in data - assert "top_tags" in data - assert "upvotes_dist" in data - assert "summary_completion" in data - - assert isinstance(data["daily_counts"], list) - assert isinstance(data["top_tags"], list) - assert isinstance(data["upvotes_dist"], list) - assert isinstance(data["summary_completion"], list) - - def test_trends_api_daily_counts(self, auth_client, sample_papers_with_summary, monkeypatch): - """每日论文数量数据正确。""" - # 使用测试数据的日期范围 - from unittest.mock import patch as upatch - import app.routes.trends as trends_mod - - # monkeypatch get_trends_data 中的 date.today - with upatch("app.services.trends.date") as mock_date: - mock_date.today.return_value = date(2024, 1, 20) - mock_date.side_effect = lambda *a, **kw: date(*a, **kw) - - resp = auth_client.get("/api/stats/trends") - data = resp.json() - assert len(data["daily_counts"]) == 5 - for item in data["daily_counts"]: - assert "date" in item - assert "count" in item - assert item["count"] == 1 - - def test_trends_api_top_tags(self, auth_client, sample_papers_with_summary): - """热门标签数据正确。""" - resp = auth_client.get("/api/stats/trends") - data = resp.json() - tags = {t["tag"]: t["count"] for t in data["top_tags"]} - assert "NLP" in tags - assert tags["NLP"] == 5 # 所有论文都有 NLP - - def test_trends_api_summary_completion(self, auth_client, sample_papers_with_summary): - """总结完成率数据正确。""" - resp = auth_client.get("/api/stats/trends") - data = resp.json() - statuses = {s["status"]: s["count"] for s in data["summary_completion"]} - assert "done" in statuses - assert statuses["done"] == 4 # 4 篇已完成 - - def test_trends_empty_db(self, auth_client): - """无数据时不崩溃。""" - resp = auth_client.get("/api/stats/trends") - assert resp.status_code == 200 - data = resp.json() - assert data["daily_counts"] == [] - assert data["top_tags"] == [] - - -# ═══════════════════════════════════════════════════════════════════════ -# Compare Page 测试 -# ═══════════════════════════════════════════════════════════════════════ - - -class TestComparePage: - """论文对比页测试。""" - - def test_compare_page_no_ids(self, auth_client): - """无 ID 时显示输入表单。""" - resp = auth_client.get("/compare") - assert resp.status_code == 200 - assert "对比" in resp.text - - def test_compare_page_with_ids(self, auth_client, sample_papers_with_summary): - """对比多篇论文正常渲染。""" - resp = auth_client.get("/compare?ids=2401.20001,2401.20002") - assert resp.status_code == 200 - assert "2401.20001" in resp.text - assert "2401.20002" in resp.text - # 应包含对比字段 - assert "一句话摘要" in resp.text - assert "研究问题" in resp.text - - def test_compare_page_max_5(self, auth_client, sample_papers_with_summary): - """最多 5 篇。""" - ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005" - resp = auth_client.get(f"/compare?ids={ids}") - assert resp.status_code == 200 - - def test_compare_page_over_5_truncates(self, auth_client, sample_papers_with_summary): - """超过 5 篇截断。""" - ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005,2401.20006" - resp = auth_client.get(f"/compare?ids={ids}") - assert resp.status_code == 200 - # 不应包含第 6 篇(不存在) - - def test_compare_page_invalid_ids(self, auth_client): - """无效 ID 时显示空结果。""" - resp = auth_client.get("/compare?ids=nonexistent.99999") - assert resp.status_code == 200 - # 不存在的论文 - assert "未找到" in resp.text or "暂无" in resp.text or resp.status_code == 200 - - def test_compare_page_shows_no_summary_placeholder(self, auth_client, sample_papers_with_summary): - """无总结的论文显示占位文本。""" - # 2401.20005 没有 summary(status=pending) - resp = auth_client.get("/compare?ids=2401.20005") - assert resp.status_code == 200 - assert "暂无总结" in resp.text - - -# ═══════════════════════════════════════════════════════════════════════ -# Image Extraction 测试 -# ═══════════════════════════════════════════════════════════════════════ - - -class TestImageExtraction: - """LaTeX 图片提取测试。""" - - @pytest.mark.asyncio - async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path): - """源码目录不存在时返回 0。""" - monkeypatch.setattr("app.services.pdf_downloader.tmp_dir", lambda x: tmp_path / "tmp" / x) - monkeypatch.setattr("app.services.pdf_downloader.paper_dir", lambda x: tmp_path / "papers" / x) - from app.services.image_extractor import extract_images_from_source - result = await extract_images_from_source("2401.99999") - assert result == 0 - - @pytest.mark.asyncio - async def test_extract_images_from_tex(self, monkeypatch, tmp_path): - """从 .tex 文件中提取图片。""" - from app.services.image_extractor import extract_images_from_source - - tmp_source = tmp_path / "tmp" / "2401.00001" / "source" - tmp_source.mkdir(parents=True) - - images_dir = tmp_source / "figs" - images_dir.mkdir() - (images_dir / "figure1.png").write_bytes(b"\x89PNG\r\n") - (images_dir / "figure2.jpg").write_bytes(b"\xff\xd8\xff\xe0") - - # 创建 .tex 文件 - tex_content = r""" -\documentclass{article} -\begin{document} -\begin{figure} - \includegraphics[width=0.8\textwidth]{figs/figure1.png} - \includegraphics{figs/figure2.jpg} - \includegraphics[angle=90]{figs/nonexistent.pdf} -\end{figure} -\end{document} -""" - (tmp_source / "main.tex").write_text(tex_content) - - papers_dir = tmp_path / "papers" / "2401.00001" - monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x) - monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x) - - # Mock download_source_zip to avoid real network call (source dir already exists) - async def _noop_download(*args, **kwargs): - pass - - monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download) - - result = await extract_images_from_source("2401.00001") - - assert result == 2 - dest_images = papers_dir / "images" - assert dest_images.exists() - assert (dest_images / "figure1.png").exists() - assert (dest_images / "figure2.jpg").exists() - - @pytest.mark.asyncio - async def test_extract_images_empty_tex(self, monkeypatch, tmp_path): - """.tex 文件无图片时返回 0。""" - from app.services.image_extractor import extract_images_from_source - - tmp_source = tmp_path / "tmp" / "2401.00002" / "source" - tmp_source.mkdir(parents=True) - (tmp_source / "main.tex").write_text(r"\documentclass{article}\begin{document}Hello\end{document}") - - monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x) - monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x) - - # Mock download_source_zip to avoid real network call - async def _noop_download(*args, **kwargs): - pass - - monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download) - - result = await extract_images_from_source("2401.00002") - assert result == 0 - - -# ═══════════════════════════════════════════════════════════════════════ -# Nav Bar 测试 -# ═══════════════════════════════════════════════════════════════════════ - - -class TestNavBar: - """导航栏测试。""" - - def test_nav_includes_trends_link(self, auth_client): - """导航栏应包含趋势链接。""" - resp = auth_client.get("/search") - assert resp.status_code == 200 - assert "/trends" in resp.text - - def test_nav_includes_compare_implicitly(self, auth_client): - """compare 页面可访问。""" - resp = auth_client.get("/compare") - assert resp.status_code == 200 - - -# ═══════════════════════════════════════════════════════════════════════ -# Graceful Degradation 测试 -# ═══════════════════════════════════════════════════════════════════════ - - -class TestGracefulDegradation: - """CHROMA_ENABLED=false 时优雅降级测试。""" - - def test_search_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary): - """CHROMA 关闭时 FTS5 搜索正常工作。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - resp = auth_client.get("/search?q=Test") - assert resp.status_code == 200 - assert "Test Paper" in resp.text or "测试论文" in resp.text - - def test_detail_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary): - """CHROMA 关闭时详情页正常工作。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - resp = auth_client.get("/paper/2401.20001") - assert resp.status_code == 200 - - def test_trends_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary): - """CHROMA 关闭时趋势看板正常工作。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - resp = auth_client.get("/trends") - assert resp.status_code == 200 - - def test_compare_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary): - """CHROMA 关闭时对比页正常工作。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - resp = auth_client.get("/compare?ids=2401.20001,2401.20002") - assert resp.status_code == 200 - - @pytest.mark.asyncio - async def test_cleaner_works_without_chroma(self, db_session, sample_papers_with_summary, monkeypatch): - """CHROMA 关闭时删除论文正常工作。""" - monkeypatch.setattr(settings, "CHROMA_ENABLED", False) - import app.services.embedder as emb - emb._chroma.reset() - - from app.services.cleaner import delete_papers_by_date_range - result = await delete_papers_by_date_range( - db_session, - date(2024, 1, 10), - date(2024, 1, 10), - ) - assert result["status"] == "success" - assert result["deleted"] == 1 diff --git a/tests/test_schemas.py b/tests/test_schemas.py new file mode 100644 index 0000000..48794ff --- /dev/null +++ b/tests/test_schemas.py @@ -0,0 +1,189 @@ +"""SummarySchema 校验、quality 分级、JSON 提取、错误分类测试。""" + +from __future__ import annotations + +import json + +import pytest +from pydantic import ValidationError + +from app.services.pi_client import ( + JsonNotFoundError, + PiProcessError, + PiTimeoutError, + extract_json as _extract_json, +) +from app.services.pdf_downloader import PdfDownloadError +from app.services.schemas import ( + SummarySchema, + assess_quality, + classify_validation_error, + flatten_for_db, +) +from app.services.summarizer import _classify_error + + +# ═══════════════════════════════════════════════════════════════════════ +# SummarySchema 校验 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestSummarySchema: + """Pydantic schema 校验。""" + + def test_valid_summary(self, sample_summary_dict): + schema = SummarySchema.model_validate(sample_summary_dict) + assert schema.title_zh == "测试论文中文标题" + assert len(schema.tags) == 3 + assert schema.motivation.problem + + def test_missing_title_zh(self, sample_summary_dict): + del sample_summary_dict["title_zh"] + with pytest.raises(ValidationError) as exc_info: + SummarySchema.model_validate(sample_summary_dict) + assert classify_validation_error(exc_info.value) == "field_missing" + + def test_empty_one_line(self, sample_summary_dict): + sample_summary_dict["one_line"] = "" + with pytest.raises(ValidationError): + SummarySchema.model_validate(sample_summary_dict) + + def test_empty_tags(self, sample_summary_dict): + sample_summary_dict["tags"] = [] + with pytest.raises(ValidationError): + SummarySchema.model_validate(sample_summary_dict) + + def test_empty_motivation_problem(self, sample_summary_dict): + sample_summary_dict["motivation"]["problem"] = "" + with pytest.raises(ValidationError): + SummarySchema.model_validate(sample_summary_dict) + + def test_empty_method_key_idea(self, sample_summary_dict): + sample_summary_dict["method"]["key_idea"] = "" + with pytest.raises(ValidationError): + SummarySchema.model_validate(sample_summary_dict) + + def test_extra_fields_ignored(self, sample_summary_dict): + sample_summary_dict["figures"] = ["fig1.png"] + sample_summary_dict["takeaway"] = "important paper" + schema = SummarySchema.model_validate(sample_summary_dict) + assert not hasattr(schema, "figures") + assert schema.title_zh # 正常解析 + + def test_flatten_for_db(self, sample_summary_dict): + schema = SummarySchema.model_validate(sample_summary_dict) + flat = flatten_for_db(schema) + assert flat["one_line"] == schema.one_line + assert flat["motivation_problem"] == schema.motivation.problem + assert flat["method_key_idea"] == schema.method.key_idea + assert "full_json" in flat + assert "updated_at" in flat + # JSON 字段可解析 + assert isinstance(json.loads(flat["prerequisites_json"]), dict) + assert isinstance(json.loads(flat["method_steps_json"]), list) + + +# ═══════════════════════════════════════════════════════════════════════ +# Quality 分级 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestQualityAssessment: + """质量分级测试。""" + + def test_quality_normal(self, sample_summary_dict): + schema = SummarySchema.model_validate(sample_summary_dict) + assert assess_quality(schema) == "normal" + + def test_quality_degraded_missing_goal(self, sample_summary_dict): + sample_summary_dict["motivation"]["goal"] = "" + sample_summary_dict["motivation"]["gap"] = "" + sample_summary_dict["method"]["overview"] = "" + sample_summary_dict["results"]["main_findings"] = [] + schema = SummarySchema.model_validate(sample_summary_dict) + assert assess_quality(schema) == "degraded" + + def test_quality_low_short_one_line(self, sample_summary_dict): + sample_summary_dict["one_line"] = "短" + schema = SummarySchema.model_validate(sample_summary_dict) + assert assess_quality(schema) == "low" + + def test_quality_low_short_key_idea(self, sample_summary_dict): + sample_summary_dict["method"]["key_idea"] = "短" + schema = SummarySchema.model_validate(sample_summary_dict) + assert assess_quality(schema) == "low" + + +# ═══════════════════════════════════════════════════════════════════════ +# JSON 提取 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestJsonExtraction: + """pi 输出的 JSON 提取。""" + + def test_direct_json(self, sample_summary_json): + result = _extract_json(sample_summary_json) + assert result["title_zh"] == "测试论文中文标题" + + def test_fenced_code_block(self, sample_summary_json): + raw = f"一些文字\n```json\n{sample_summary_json}\n```\n更多文字" + result = _extract_json(raw) + assert result["title_zh"] == "测试论文中文标题" + + def test_fenced_without_lang(self, sample_summary_json): + raw = f"文字\n```\n{sample_summary_json}\n```" + result = _extract_json(raw) + assert result["title_zh"] == "测试论文中文标题" + + def test_embedded_braces(self, sample_summary_dict): + json_str = json.dumps(sample_summary_dict, ensure_ascii=False) + raw = f"Here is the summary:\n{json_str}\nEnd." + result = _extract_json(raw) + assert result["title_zh"] == "测试论文中文标题" + + def test_no_json_raises(self): + with pytest.raises(JsonNotFoundError): + _extract_json("No JSON here at all.") + + def test_json_without_title_zh_falls_through(self): + """不含 title_zh 的 JSON 不是我们要的。""" + raw = json.dumps({"other": "data"}) + # 如果有其他合法 JSON 块也能返回,但没有就直接找最大块 + # 此场景 raw 本身就是一个 JSON dict,但没有 title_zh + # 策略 1 会跳过(无 title_zh),策略 2 无代码块,策略 3 找到最大块 + result = _extract_json(raw) + assert result == {"other": "data"} # 最大块兜底 + + +# ═══════════════════════════════════════════════════════════════════════ +# 错误分类 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestErrorClassification: + """异常 → error_type 映射。""" + + def test_pdf_download_error(self): + assert _classify_error(PdfDownloadError("fail")) == "pdf_download_failed" + + def test_timeout_error(self): + assert _classify_error(PiTimeoutError("timeout")) == "timeout" + + def test_process_error(self): + assert _classify_error(PiProcessError(1, "stderr")) == "process_error" + + def test_json_not_found(self): + assert _classify_error(JsonNotFoundError("not found")) == "json_not_found" + + def test_json_invalid(self): + assert _classify_error(json.JSONDecodeError("bad", "", 0)) == "json_invalid" + + def test_field_missing(self): + try: + SummarySchema.model_validate({"title_zh": ""}) # type: ignore + except ValidationError as exc: + assert _classify_error(exc) == "field_missing" + + def test_unknown_error(self): + assert _classify_error(RuntimeError("boom")) == "unknown" diff --git a/tests/test_search.py b/tests/test_searcher.py similarity index 69% rename from tests/test_search.py rename to tests/test_searcher.py index 5942eef..f30ebda 100644 --- a/tests/test_search.py +++ b/tests/test_searcher.py @@ -1,10 +1,12 @@ -"""搜索、阅读列表、RSS Feed 测试。""" +"""搜索服务 + 路由 + 阅读列表 + RSS + 语义模式测试。""" from __future__ import annotations import pytest from datetime import date, datetime, timezone +from app.config import settings + # ═══════════════════════════════════════════════════════════════════════ # 搜索服务单元测试 @@ -12,69 +14,59 @@ from datetime import date, datetime, timezone class TestSearchService: - """app/services/searcher.py 单元测试。""" + """app/services/searcher.py — FTS5 关键词搜索单元测试。""" def test_search_by_title(self, db_session, sample_paper): from app.services.searcher import search_papers - result = search_papers(db_session, query="Test Paper") assert result["total"] == 1 assert result["results"][0].arxiv_id == "2401.12345" def test_search_by_abstract(self, db_session, sample_paper): from app.services.searcher import search_papers - result = search_papers(db_session, query="test abstract") assert result["total"] == 1 def test_search_by_author(self, db_session, sample_paper): from app.services.searcher import search_papers - result = search_papers(db_session, query="Alice") assert result["total"] == 1 def test_search_by_tag_in_fts(self, db_session, sample_paper): from app.services.searcher import search_papers - # FTS5 索引中包含 tags 列,可以搜到 result = search_papers(db_session, query="NLP") assert result["total"] == 1 def test_search_no_results(self, db_session, sample_paper): from app.services.searcher import search_papers - result = search_papers(db_session, query="quantum entanglement") assert result["total"] == 0 assert result["results"] == [] def test_search_empty_query_returns_empty(self, db_session): from app.services.searcher import search_papers - result = search_papers(db_session, query="") assert result["total"] == 0 assert result["results"] == [] def test_search_special_characters_sanitized(self, db_session, sample_paper): from app.services.searcher import search_papers - # 特殊字符被清除后,剩下 "Test" 仍然能搜到 result = search_papers(db_session, query='Test "Paper" {test}') assert result["total"] >= 1 def test_search_with_tag_filter(self, db_session, sample_paper): from app.services.searcher import search_papers - # 关键词 + 标签筛选 result = search_papers(db_session, query="Paper", tag="NLP") assert result["total"] == 1 - # 标签不匹配 → 0 result2 = search_papers(db_session, query="Paper", tag="nonexistent") assert result2["total"] == 0 def test_search_tag_only_no_query(self, db_session, sample_paper): from app.services.searcher import search_papers - # 只有标签,无关键词 result = search_papers(db_session, tag="NLP") assert result["total"] == 1 @@ -82,14 +74,12 @@ class TestSearchService: def test_search_pagination(self, db_session, sample_paper): from app.services.searcher import search_papers - result = search_papers(db_session, query="Test", page=2, page_size=10) assert result["page"] == 2 assert result["total_pages"] == 1 # 只有 1 条结果,1 页 def test_search_returns_snippets(self, db_session, sample_paper): from app.services.searcher import search_papers - result = search_papers(db_session, query="test abstract") assert result["total"] == 1 paper_id = result["results"][0].id @@ -99,12 +89,54 @@ class TestSearchService: def test_get_all_tags(self, db_session, sample_paper): from app.services.searcher import get_all_tags - tags = get_all_tags(db_session) assert "NLP" in tags assert "LLM" in tags +# ═══════════════════════════════════════════════════════════════════════ +# 语义 / Embedder 模式测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestSearchSemanticMode: + """searcher.py — semantic 模式(含 embedder 回退)测试。""" + + def test_keyword_mode_default(self, db_session, sample_papers_with_summary): + """默认 keyword 模式走 FTS5。""" + from app.services.searcher import search_papers + result = search_papers(db_session, query="Test Paper", mode="keyword") + assert result["total"] >= 1 + assert result["distances"] == {} + + def test_semantic_mode_disabled_fallback(self, db_session, monkeypatch, sample_papers_with_summary): + """CHROMA_ENABLED=false + semantic 模式走 FTS5。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + from app.services.searcher import search_papers + result = search_papers(db_session, query="Test", mode="semantic") + assert result["total"] >= 1 + + def test_search_returns_distances_dict(self, db_session, sample_papers_with_summary): + """搜索结果应包含 distances 字段。""" + from app.services.searcher import search_papers + result = search_papers(db_session, query="Test Paper") + assert "distances" in result + assert isinstance(result["distances"], dict) + + def test_empty_query_returns_empty_no_tags(self, db_session): + """空查询无标签时返回空。""" + from app.services.searcher import search_papers + result = search_papers(db_session) + assert result["total"] == 0 + assert result["results"] == [] + + def test_tag_only_search(self, db_session, sample_papers_with_summary): + """仅标签搜索。""" + from app.services.searcher import search_papers + result = search_papers(db_session, tag="NLP") + assert result["total"] >= 1 + + # ═══════════════════════════════════════════════════════════════════════ # 搜索路由 HTTP 测试 # ═══════════════════════════════════════════════════════════════════════ @@ -131,6 +163,18 @@ class TestSearchRoutes: assert resp.status_code == 200 assert "2401.12345" in resp.text + def test_search_page_keyword_mode(self, client, sample_papers_with_summary): + """搜索页 keyword 模式。""" + resp = client.get("/search?q=Test&mode=keyword") + assert resp.status_code == 200 + assert "Test" in resp.text or "测试" in resp.text + + def test_search_page_semantic_disabled(self, client, monkeypatch, sample_papers_with_summary): + """语义模式 CHROMA_ENABLED=false 时仍能工作。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = client.get("/search?q=Test&mode=semantic") + assert resp.status_code == 200 + def test_search_api_json(self, client, sample_paper): """GET /api/search?q=Test 返回 JSON。""" resp = client.get("/api/search?q=Test") @@ -146,6 +190,14 @@ class TestSearchRoutes: data = resp.json() assert data["total"] == 1 + def test_search_api_with_mode(self, client, sample_papers_with_summary): + """搜索 API 支持 mode 参数。""" + resp = client.get("/api/search?q=Test&mode=keyword") + assert resp.status_code == 200 + data = resp.json() + assert "results" in data + assert "total" in data + def test_search_api_empty(self, client, sample_paper): """GET /api/search?q=nonexistent 返回空结果。""" resp = client.get("/api/search?q=nonexistent") @@ -161,6 +213,36 @@ class TestSearchRoutes: assert data["total"] >= 1 +# ═══════════════════════════════════════════════════════════════════════ +# Similar Paper API 测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestSimilarAPI: + """相似论文 API 测试。""" + + def test_similar_api_disabled(self, client, monkeypatch, sample_papers_with_summary): + """CHROMA_ENABLED=false 时返回空列表。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = client.get("/api/similar/2401.20001") + assert resp.status_code == 200 + data = resp.json() + assert data["results"] == [] + + def test_similar_api_paper_not_found(self, client, monkeypatch): + """不存在的论文返回空。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = client.get("/api/similar/nonexistent.99999") + assert resp.status_code == 200 + assert resp.json()["results"] == [] + + def test_similar_api_with_top_k(self, client, monkeypatch, sample_papers_with_summary): + """top_k 参数控制返回数量。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = client.get("/api/similar/2401.20001?top_k=3") + assert resp.status_code == 200 + + # ═══════════════════════════════════════════════════════════════════════ # 阅读列表路由测试 # ═══════════════════════════════════════════════════════════════════════ @@ -185,8 +267,6 @@ class TestReadingListRoute: def test_reading_list_filter_by_status(self, client, sample_paper): """按阅读状态筛选。""" - import json - # 设置阅读状态 client.post( "/api/reading-status/2401.12345", @@ -204,8 +284,6 @@ class TestReadingListRoute: def test_reading_list_has_note_filter(self, client, sample_paper): """筛选有笔记的论文。""" - import json - # 写笔记 client.post( "/api/note/2401.12345", diff --git a/tests/test_summarizer.py b/tests/test_summarizer.py index 35cdb79..49ff237 100644 --- a/tests/test_summarizer.py +++ b/tests/test_summarizer.py @@ -1,15 +1,13 @@ -"""AI 总结服务测试 — Mock 全链路,不调用真实 pi。""" +"""AI 总结服务测试 — summarize_one 状态流转、批量处理、DB 更新、文件操作。""" from __future__ import annotations -import asyncio import json from datetime import date, datetime, timezone from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest -from pydantic import ValidationError from sqlalchemy import text from app.models import ( @@ -20,193 +18,19 @@ from app.models import ( SummaryStatus, TaskLock, ) -from app.services.schemas import ( - SummarySchema, - assess_quality, - classify_validation_error, - flatten_for_db, +from app.services.pdf_downloader import ( + PdfDownloadError, + cleanup_tmp as _cleanup_tmp, ) +from app.services.pi_client import PiTimeoutError +from app.services.schemas import SummarySchema from app.services.summarizer import ( - _classify_error, _save_files, _save_raw_output_only, _update_summary_in_db, summarize_batch, summarize_one, - summarize_single, ) -from app.services.pi_client import ( - JsonNotFoundError, - PiProcessError, - PiTimeoutError, - call_pi as _call_pi, - extract_json as _extract_json, -) -from app.services.pdf_downloader import ( - PdfDownloadError, - cleanup_tmp as _cleanup_tmp, -) - - -# ═══════════════════════════════════════════════════════════════════════ -# Schema 校验测试 -# ═══════════════════════════════════════════════════════════════════════ - - -class TestSummarySchema: - """Pydantic schema 校验。""" - - def test_valid_summary(self, sample_summary_dict): - schema = SummarySchema.model_validate(sample_summary_dict) - assert schema.title_zh == "测试论文中文标题" - assert len(schema.tags) == 3 - assert schema.motivation.problem - - def test_missing_title_zh(self, sample_summary_dict): - del sample_summary_dict["title_zh"] - with pytest.raises(ValidationError) as exc_info: - SummarySchema.model_validate(sample_summary_dict) - assert classify_validation_error(exc_info.value) == "field_missing" - - def test_empty_one_line(self, sample_summary_dict): - sample_summary_dict["one_line"] = "" - with pytest.raises(ValidationError): - SummarySchema.model_validate(sample_summary_dict) - - def test_empty_tags(self, sample_summary_dict): - sample_summary_dict["tags"] = [] - with pytest.raises(ValidationError): - SummarySchema.model_validate(sample_summary_dict) - - def test_empty_motivation_problem(self, sample_summary_dict): - sample_summary_dict["motivation"]["problem"] = "" - with pytest.raises(ValidationError): - SummarySchema.model_validate(sample_summary_dict) - - def test_empty_method_key_idea(self, sample_summary_dict): - sample_summary_dict["method"]["key_idea"] = "" - with pytest.raises(ValidationError): - SummarySchema.model_validate(sample_summary_dict) - - def test_extra_fields_ignored(self, sample_summary_dict): - sample_summary_dict["figures"] = ["fig1.png"] - sample_summary_dict["takeaway"] = "important paper" - schema = SummarySchema.model_validate(sample_summary_dict) - assert not hasattr(schema, "figures") - assert schema.title_zh # 正常解析 - - def test_flatten_for_db(self, sample_summary_dict): - schema = SummarySchema.model_validate(sample_summary_dict) - flat = flatten_for_db(schema) - assert flat["one_line"] == schema.one_line - assert flat["motivation_problem"] == schema.motivation.problem - assert flat["method_key_idea"] == schema.method.key_idea - assert "full_json" in flat - assert "updated_at" in flat - # JSON 字段可解析 - assert isinstance(json.loads(flat["prerequisites_json"]), dict) - assert isinstance(json.loads(flat["method_steps_json"]), list) - - -class TestQualityAssessment: - """质量分级测试。""" - - def test_quality_normal(self, sample_summary_dict): - schema = SummarySchema.model_validate(sample_summary_dict) - assert assess_quality(schema) == "normal" - - def test_quality_degraded_missing_goal(self, sample_summary_dict): - sample_summary_dict["motivation"]["goal"] = "" - sample_summary_dict["motivation"]["gap"] = "" - sample_summary_dict["method"]["overview"] = "" - sample_summary_dict["results"]["main_findings"] = [] - schema = SummarySchema.model_validate(sample_summary_dict) - assert assess_quality(schema) == "degraded" - - def test_quality_low_short_one_line(self, sample_summary_dict): - sample_summary_dict["one_line"] = "短" - schema = SummarySchema.model_validate(sample_summary_dict) - assert assess_quality(schema) == "low" - - def test_quality_low_short_key_idea(self, sample_summary_dict): - sample_summary_dict["method"]["key_idea"] = "短" - schema = SummarySchema.model_validate(sample_summary_dict) - assert assess_quality(schema) == "low" - - -# ═══════════════════════════════════════════════════════════════════════ -# JSON 提取测试 -# ═══════════════════════════════════════════════════════════════════════ - - -class TestJsonExtraction: - """pi 输出的 JSON 提取。""" - - def test_direct_json(self, sample_summary_json): - result = _extract_json(sample_summary_json) - assert result["title_zh"] == "测试论文中文标题" - - def test_fenced_code_block(self, sample_summary_json): - raw = f"一些文字\n```json\n{sample_summary_json}\n```\n更多文字" - result = _extract_json(raw) - assert result["title_zh"] == "测试论文中文标题" - - def test_fenced_without_lang(self, sample_summary_json): - raw = f"文字\n```\n{sample_summary_json}\n```" - result = _extract_json(raw) - assert result["title_zh"] == "测试论文中文标题" - - def test_embedded_braces(self, sample_summary_dict): - json_str = json.dumps(sample_summary_dict, ensure_ascii=False) - raw = f"Here is the summary:\n{json_str}\nEnd." - result = _extract_json(raw) - assert result["title_zh"] == "测试论文中文标题" - - def test_no_json_raises(self): - with pytest.raises(JsonNotFoundError): - _extract_json("No JSON here at all.") - - def test_json_without_title_zh_falls_through(self): - """不含 title_zh 的 JSON 不是我们要的。""" - raw = json.dumps({"other": "data"}) - # 如果有其他合法 JSON 块也能返回,但没有就直接找最大块 - # 此场景 raw 本身就是一个 JSON dict,但没有 title_zh - # 策略 1 会跳过(无 title_zh),策略 2 无代码块,策略 3 找到最大块 - result = _extract_json(raw) - assert result == {"other": "data"} # 最大块兜底 - - -# ═══════════════════════════════════════════════════════════════════════ -# 错误分类测试 -# ═══════════════════════════════════════════════════════════════════════ - - -class TestErrorClassification: - """异常 → error_type 映射。""" - - def test_pdf_download_error(self): - assert _classify_error(PdfDownloadError("fail")) == "pdf_download_failed" - - def test_timeout_error(self): - assert _classify_error(PiTimeoutError("timeout")) == "timeout" - - def test_process_error(self): - assert _classify_error(PiProcessError(1, "stderr")) == "process_error" - - def test_json_not_found(self): - assert _classify_error(JsonNotFoundError("not found")) == "json_not_found" - - def test_json_invalid(self): - assert _classify_error(json.JSONDecodeError("bad", "", 0)) == "json_invalid" - - def test_field_missing(self): - try: - SummarySchema.model_validate({"title_zh": ""}) # type: ignore - except ValidationError as exc: - assert _classify_error(exc) == "field_missing" - - def test_unknown_error(self): - assert _classify_error(RuntimeError("boom")) == "unknown" # ═══════════════════════════════════════════════════════════════════════ @@ -675,59 +499,3 @@ class TestBatchSummarize: result = await summarize_batch(db_session) assert result["status"] == "success" assert result["total"] == 0 - - -# ═══════════════════════════════════════════════════════════════════════ -# Admin 路由鉴权测试 -# ═══════════════════════════════════════════════════════════════════════ - - -class TestAdminAuth: - """管理接口鉴权 — 只测 HTTP 层,mock 掉实际服务调用。""" - - def test_no_token_returns_401(self, client): - """无 Bearer token 返回 401。""" - resp = client.post("/admin/summarize") - assert resp.status_code in (401, 403) - - def test_wrong_token_returns_401(self, client): - resp = client.post( - "/admin/summarize", - headers={"Authorization": "Bearer wrong-token"}, - ) - assert resp.status_code == 401 - - def test_correct_token_batch(self, client, admin_headers): - """正确 token 调用 batch summarize,mock 掉服务层。""" - import app.config as config_mod - - original = config_mod.settings.ADMIN_TOKEN - config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345" - try: - with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock: - mock.return_value = {"status": "success", "done": 0, "failed": 0, "total": 0} - resp = client.post("/admin/summarize", headers=admin_headers) - assert resp.status_code == 200 - assert resp.json()["status"] == "success" - finally: - config_mod.settings.ADMIN_TOKEN = original - - def test_single_paper_not_found(self, client, admin_headers): - """单篇总结不存在的论文返回 404。""" - import app.config as config_mod - - original = config_mod.settings.ADMIN_TOKEN - config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345" - try: - with patch( - "app.routes.admin.summarize_single", - new_callable=AsyncMock, - return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"}, - ): - resp = client.post( - "/admin/summarize/nonexistent.99999", - headers=admin_headers, - ) - assert resp.status_code == 404 - finally: - config_mod.settings.ADMIN_TOKEN = original