refactor: split monolithic phase tests into per-module test files

- rename test_admin_phase4.py -> test_admin.py, test_search.py -> test_searcher.py
- split test_phase5.py into test_cleaner, test_embedder, test_image_extractor, test_pages
- move schema tests from test_summarizer.py into dedicated test_schemas.py
- add sample_papers_range and sample_papers_with_summary fixtures in conftest
- update .gitignore to exclude all of data/
This commit is contained in:
2026-06-06 00:34:30 +08:00
parent 85c4cfb9e8
commit f7f1a4c0cb
11 changed files with 1245 additions and 1249 deletions
+1 -4
View File
@@ -2,10 +2,7 @@
__pycache__/
*.pyc
*.pyo
data/db/*.db
data/papers/
data/tmp/
data/chroma/
data/
logs/*.log
.venv/
venv/
+119
View File
@@ -209,3 +209,122 @@ def admin_token():
def admin_headers(admin_token):
"""带 Bearer token 的请求头。"""
return {"Authorization": f"Bearer {admin_token}"}
@pytest.fixture
def wrong_admin_headers():
"""错误的 Authorization 请求头。"""
return {"Authorization": "Bearer wrong-token"}
# ── 多样例数据 ────────────────────────────────────────────────────────────
@pytest.fixture
def sample_papers_range(db_session):
"""插入 5 篇不同日期的论文(用于 admin / cleaner 测试)。"""
now = datetime.now(timezone.utc)
papers = []
for i, (arxiv_id, paper_date_str) in enumerate([
("2401.10001", "2024-01-10"),
("2401.10002", "2024-01-11"),
("2401.10003", "2024-01-12"),
("2401.10004", "2024-01-13"),
("2401.10005", "2024-01-14"),
]):
paper_date = date.fromisoformat(paper_date_str)
p = Paper(
arxiv_id=arxiv_id,
title_en=f"Test Paper {i+1}",
abstract=f"Abstract for paper {i+1}.",
paper_date=paper_date,
crawled_at=now,
upvotes=i * 10,
)
db_session.add(p)
db_session.flush()
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
# FTS5
db_session.execute(
__import__("sqlalchemy").text(
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
"VALUES (:id, :title, :abstract, :authors, :tags)"
),
{"id": p.id, "title": p.title_en, "abstract": p.abstract,
"authors": f"Author {i+1}", "tags": f"Tag{i+1}"},
)
papers.append(p)
db_session.commit()
return papers
@pytest.fixture
def sample_papers_with_summary(db_session):
"""插入 5 篇带总结的论文(用于 search / pages / trends 测试)。"""
now = datetime.now(timezone.utc)
papers = []
for i, (arxiv_id, paper_date_str) in enumerate([
("2401.20001", "2024-01-10"),
("2401.20002", "2024-01-11"),
("2401.20003", "2024-01-12"),
("2401.20004", "2024-01-13"),
("2401.20005", "2024-01-14"),
]):
paper_date = date.fromisoformat(paper_date_str)
p = Paper(
arxiv_id=arxiv_id,
title_en=f"Test Paper {i+1}",
title_zh=f"测试论文 {i+1}",
abstract=f"Abstract for paper {i+1}.",
paper_date=paper_date,
crawled_at=now,
upvotes=i * 10 + 5,
)
db_session.add(p)
db_session.flush()
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
db_session.add(PaperTag(paper_id=p.id, tag="NLP", source="hf"))
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
db_session.add(SummaryStatus(
paper_id=p.id,
status="done" if i < 4 else "pending",
quality="normal",
))
# 添加总结(前 4 篇)
if i < 4:
summary = PaperSummary(
paper_id=p.id,
one_line=f"这是论文{i+1}的一句话摘要",
difficulty="中级",
motivation_problem=f"论文{i+1}的研究问题",
motivation_goal=f"论文{i+1}的研究目标",
method_key_idea=f"论文{i+1}的关键思路",
method_overview=f"论文{i+1}的方法概述",
updated_at=now,
full_json=json.dumps({"title_zh": f"测试论文 {i+1}"}),
)
db_session.add(summary)
# FTS5
db_session.execute(
__import__("sqlalchemy").text(
"INSERT INTO papers_fts(rowid, title_en, title_zh, abstract, authors, tags) "
"VALUES (:id, :title_en, :title_zh, :abstract, :authors, :tags)"
),
{
"id": p.id,
"title_en": p.title_en,
"title_zh": p.title_zh or "",
"abstract": p.abstract or "",
"authors": f"Author {i+1}",
"tags": f"NLP, Tag{i+1}",
},
)
papers.append(p)
db_session.commit()
return papers
@@ -1,350 +1,36 @@
"""Phase 4 管理和自动化测试 — cleaner、admin routes、scheduler。"""
"""管理接口测试 — admin routes、auth、scheduler、task locks"""
from __future__ import annotations
import os
import shutil
import time
import logging
from datetime import date, datetime, timezone
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import select
from app.database import get_db
from app.config import settings
from app.models import (
CrawlLog,
DataDeleteJob,
Paper,
PaperAuthor,
PaperSummary,
PaperTag,
SummaryStatus,
TaskLock,
UserBookmark,
UserNote,
UserReadingStatus,
)
# ── Fixtures ────────────────────────────────────────────────────────────
ADMIN_TOKEN = "test-admin-token-12345"
@pytest.fixture
def admin_headers():
return {"Authorization": f"Bearer {ADMIN_TOKEN}"}
@pytest.fixture
def wrong_admin_headers():
return {"Authorization": "Bearer wrong-token"}
@pytest.fixture
def auth_client(client, monkeypatch):
"""带 admin token monkeypatch 的 TestClient。"""
monkeypatch.setattr(settings, "ADMIN_TOKEN", ADMIN_TOKEN)
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
return client
@pytest.fixture
def sample_papers(db_session):
"""插入多篇不同日期的论文。"""
now = datetime.now(timezone.utc)
papers = []
for i, (arxiv_id, paper_date_str) in enumerate([
("2401.10001", "2024-01-10"),
("2401.10002", "2024-01-11"),
("2401.10003", "2024-01-12"),
("2401.10004", "2024-01-13"),
("2401.10005", "2024-01-14"),
]):
paper_date = date.fromisoformat(paper_date_str)
p = Paper(
arxiv_id=arxiv_id,
title_en=f"Test Paper {i+1}",
abstract=f"Abstract for paper {i+1}.",
paper_date=paper_date,
crawled_at=now,
upvotes=i * 10,
)
db_session.add(p)
db_session.flush()
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
# FTS5
import sqlalchemy
db_session.execute(
sqlalchemy.text(
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
"VALUES (:id, :title, :abstract, :authors, :tags)"
),
{"id": p.id, "title": p.title_en, "abstract": p.abstract,
"authors": f"Author {i+1}", "tags": f"Tag{i+1}"},
)
papers.append(p)
db_session.commit()
return papers
@pytest.fixture
def sample_paper_with_user_data(db_session, sample_papers):
"""给第一篇论文添加用户数据(收藏、阅读状态、笔记)。"""
paper = sample_papers[0]
now = datetime.now(timezone.utc)
db_session.add(UserBookmark(paper_id=paper.id, created_at=now))
db_session.add(UserReadingStatus(paper_id=paper.id, status="read_summary", updated_at=now))
db_session.add(UserNote(
paper_id=paper.id,
content="My notes on this paper",
created_at=now,
updated_at=now,
))
db_session.commit()
return paper
@pytest.fixture
def tmp_data_dir(tmp_path):
"""创建临时 data 目录结构。"""
tmp_dir = tmp_path / "data" / "tmp"
papers_dir = tmp_path / "data" / "papers"
tmp_dir.mkdir(parents=True)
papers_dir.mkdir(parents=True)
return tmp_path / "data"
# ═══════════════════════════════════════════════════════════════════════
# Cleaner 服务测试
# ═══════════════════════════════════════════════════════════════════════
class TestCleanupTmp:
"""app/services/cleaner.py — cleanup_tmp 测试。"""
def test_cleanup_removes_old_dirs(self, tmp_path, monkeypatch):
"""超过 24 小时的临时目录应被删除。"""
tmp_dir = tmp_path / "tmp"
tmp_dir.mkdir()
# 创建一个旧目录
old_dir = tmp_dir / "2401.00001"
old_dir.mkdir()
(old_dir / "paper.pdf").write_text("fake pdf")
# 修改目录时间为 25 小时前
old_mtime = time.time() - 25 * 3600
os.utime(old_dir, (old_mtime, old_mtime))
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 1
assert result["removed"] == 1
assert not old_dir.exists()
def test_cleanup_keeps_recent_dirs(self, tmp_path, monkeypatch):
"""24 小时内的临时目录应保留。"""
tmp_dir = tmp_path / "tmp"
tmp_dir.mkdir()
recent_dir = tmp_dir / "2401.00002"
recent_dir.mkdir()
(recent_dir / "paper.pdf").write_text("fake pdf")
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 1
assert result["removed"] == 0
assert recent_dir.exists()
def test_cleanup_empty_dir(self, tmp_path, monkeypatch):
"""data/tmp/ 不存在时安全返回。"""
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_path / "nonexistent")
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 0
assert result["removed"] == 0
def test_cleanup_mixed_ages(self, tmp_path, monkeypatch):
"""混合新旧目录时只删除旧的。"""
tmp_dir = tmp_path / "tmp"
tmp_dir.mkdir()
old_dir = tmp_dir / "2401.old"
old_dir.mkdir()
old_mtime = time.time() - 30 * 3600
os.utime(old_dir, (old_mtime, old_mtime))
recent_dir = tmp_dir / "2401.new"
recent_dir.mkdir()
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 2
assert result["removed"] == 1
assert not old_dir.exists()
assert recent_dir.exists()
class TestDeletePapersByDateRange:
"""app/services/cleaner.py — delete_papers_by_date_range 测试。"""
@pytest.mark.asyncio
async def test_delete_by_date_range(self, db_session, sample_papers):
"""删除指定日期范围的论文。"""
from app.services.cleaner import delete_papers_by_date_range
# 删除 1月11日 ~ 1月13日(3篇)
result = await delete_papers_by_date_range(
db_session,
date(2024, 1, 11),
date(2024, 1, 13),
)
assert result["deleted"] == 3
assert result["total"] == 3
assert result["status"] == "success"
# 确认数据库中只剩 2 篇
remaining = db_session.execute(select(Paper)).scalars().all()
assert len(remaining) == 2
dates = {p.paper_date for p in remaining}
assert dates == {date(2024, 1, 10), date(2024, 1, 14)}
@pytest.mark.asyncio
async def test_delete_creates_job_record(self, db_session, sample_papers):
"""删除操作应创建 data_delete_jobs 记录。"""
from app.services.cleaner import delete_papers_by_date_range
await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 14),
)
jobs = db_session.execute(select(DataDeleteJob)).scalars().all()
assert len(jobs) == 1
assert jobs[0].status == "success"
assert jobs[0].date_start == date(2024, 1, 10)
assert jobs[0].date_end == date(2024, 1, 14)
assert jobs[0].paper_count == 5
assert jobs[0].completed_at is not None
@pytest.mark.asyncio
async def test_delete_creates_crawl_log(self, db_session, sample_papers):
"""删除操作应写入 crawl_logs。"""
from app.services.cleaner import delete_papers_by_date_range
await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 14),
)
logs = db_session.execute(
select(CrawlLog).where(CrawlLog.task == "delete")
).scalars().all()
assert len(logs) == 1
assert logs[0].status == "success"
@pytest.mark.asyncio
async def test_delete_cascade_user_data(self, db_session, sample_paper_with_user_data):
"""删除论文时应 cascade 删除关联的用户数据。"""
from app.services.cleaner import delete_papers_by_date_range
paper = sample_paper_with_user_data
# 确认用户数据存在
assert db_session.get(UserBookmark, db_session.execute(
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
).scalar_one_or_none().id if db_session.execute(
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
).scalar_one_or_none() else None) is not None or True
# 删除
result = await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 10),
)
assert result["deleted"] == 1
# 确认用户数据被 cascade 删除
assert db_session.execute(
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
).scalar_one_or_none() is None
assert db_session.execute(
select(UserReadingStatus).where(UserReadingStatus.paper_id == paper.id)
).scalar_one_or_none() is None
assert db_session.execute(
select(UserNote).where(UserNote.paper_id == paper.id)
).scalar_one_or_none() is None
@pytest.mark.asyncio
async def test_delete_removes_fts(self, db_session, sample_papers):
"""删除论文时应同步删除 FTS5 索引。"""
import sqlalchemy
from app.services.cleaner import delete_papers_by_date_range
await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 14),
)
# FTS5 应为空
rows = db_session.execute(
sqlalchemy.text("SELECT count(*) FROM papers_fts")
).scalar()
assert rows == 0
@pytest.mark.asyncio
async def test_delete_removes_local_files(self, db_session, sample_papers, tmp_path, monkeypatch):
"""删除论文时应删除本地文件目录。"""
from app.services.cleaner import delete_papers_by_date_range
papers_dir = tmp_path / "papers"
papers_dir.mkdir()
(papers_dir / "2401.10001").mkdir()
(papers_dir / "2401.10001" / "meta.json").write_text("{}")
monkeypatch.setattr("app.services.cleaner.PAPERS_DIR", papers_dir)
result = await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 10),
)
assert result["deleted"] == 1
assert not (papers_dir / "2401.10001").exists()
@pytest.mark.asyncio
async def test_delete_empty_range(self, db_session, sample_papers):
"""日期范围内无论文时返回 0。"""
from app.services.cleaner import delete_papers_by_date_range
result = await delete_papers_by_date_range(
db_session,
date(2025, 1, 1),
date(2025, 1, 31),
)
assert result["total"] == 0
assert result["deleted"] == 0
assert result["status"] == "success"
# ═══════════════════════════════════════════════════════════════════════
# Admin Routes 测试
# Admin Routes — 鉴权测试
# ═══════════════════════════════════════════════════════════════════════
@@ -363,12 +49,65 @@ class TestAdminAuth:
def test_correct_token_accepted(self, auth_client, admin_headers):
"""正确 token 应被接受(crawl 可能会失败但不是 401)。"""
# mock crawl_daily 避免 API 调用
with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl:
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
resp = auth_client.post("/admin/crawl", headers=admin_headers)
assert resp.status_code != 401
# ── summarize route auth ────────────────────────────────────────
def test_no_token_returns_401_for_summarize(self, client):
"""无 Bearer token 返回 401。"""
resp = client.post("/admin/summarize")
assert resp.status_code in (401, 403)
def test_wrong_token_returns_401_for_summarize(self, client):
resp = client.post(
"/admin/summarize",
headers={"Authorization": "Bearer wrong-token"},
)
assert resp.status_code == 401
def test_correct_token_batch_summarize(self, client, admin_headers):
"""正确 token 调用 batch summarizemock 掉服务层。"""
import app.config as config_mod
original = config_mod.settings.ADMIN_TOKEN
config_mod.settings.ADMIN_TOKEN = ADMIN_TOKEN
try:
with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock:
mock.return_value = {"status": "success", "done": 0, "failed": 0, "total": 0}
resp = client.post("/admin/summarize", headers=admin_headers)
assert resp.status_code == 200
assert resp.json()["status"] == "success"
finally:
config_mod.settings.ADMIN_TOKEN = original
def test_single_paper_not_found(self, client, admin_headers):
"""单篇总结不存在的论文返回 404。"""
import app.config as config_mod
original = config_mod.settings.ADMIN_TOKEN
config_mod.settings.ADMIN_TOKEN = ADMIN_TOKEN
try:
with patch(
"app.routes.admin.summarize_single",
new_callable=AsyncMock,
return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"},
):
resp = client.post(
"/admin/summarize/nonexistent.99999",
headers=admin_headers,
)
assert resp.status_code == 404
finally:
config_mod.settings.ADMIN_TOKEN = original
# ═══════════════════════════════════════════════════════════════════════
# Admin Routes — Crawl
# ═══════════════════════════════════════════════════════════════════════
class TestAdminCrawl:
"""POST /admin/crawl 测试。"""
@@ -381,7 +120,6 @@ class TestAdminCrawl:
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "success"
# 验证调用了 crawl_daily
mock_crawl.assert_called_once()
def test_crawl_specific_date(self, auth_client, admin_headers):
@@ -395,6 +133,11 @@ class TestAdminCrawl:
assert call_args[0][1] == "2024-01-15"
# ═══════════════════════════════════════════════════════════════════════
# Admin Routes — Cleanup
# ═══════════════════════════════════════════════════════════════════════
class TestAdminCleanup:
"""POST /admin/cleanup 测试。"""
@@ -421,6 +164,11 @@ class TestAdminCleanup:
assert logs[-1].status == "success"
# ═══════════════════════════════════════════════════════════════════════
# Admin Routes — Delete
# ═══════════════════════════════════════════════════════════════════════
class TestAdminDelete:
"""POST /admin/delete 测试。"""
@@ -438,7 +186,7 @@ class TestAdminDelete:
)
assert resp.status_code == 422
def test_delete_with_confirm(self, auth_client, admin_headers, db_session, sample_papers):
def test_delete_with_confirm(self, auth_client, admin_headers, db_session, sample_papers_range):
"""confirm='DELETE' 时应执行删除。"""
resp = auth_client.post(
"/admin/delete",
@@ -480,6 +228,11 @@ class TestAdminDelete:
assert resp.status_code == 422
# ═══════════════════════════════════════════════════════════════════════
# Admin Routes — Logs
# ═══════════════════════════════════════════════════════════════════════
class TestAdminLogs:
"""GET /admin/logs 测试。"""
@@ -494,7 +247,7 @@ class TestAdminLogs:
resp = auth_client.get("/admin/logs")
assert resp.status_code in (403, 401)
def test_logs_contains_data(self, auth_client, admin_headers, db_session, sample_papers):
def test_logs_contains_data(self, auth_client, admin_headers, db_session, sample_papers_range):
"""日志页面应包含日志数据。"""
# 先创建一条日志
now = datetime.now(timezone.utc)
@@ -519,11 +272,10 @@ class TestScheduler:
def test_scheduler_disabled_by_default(self, monkeypatch):
"""SCHEDULER_ENABLED=false 时不应启动调度器。"""
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", False)
from app.services.scheduler import start_scheduler
# 重置模块级变量
import app.services.scheduler as sched_mod
sched_mod._scheduler = None
from app.services.scheduler import start_scheduler
result = start_scheduler()
assert result is None
@@ -550,7 +302,6 @@ class TestScheduler:
@pytest.mark.asyncio
async def test_scheduler_warns_multi_worker(self, monkeypatch, caplog):
"""APP_WORKERS > 1 时应打印警告。"""
import logging
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", True)
monkeypatch.setattr(settings, "APP_WORKERS", 4)
import app.services.scheduler as sched_mod
+279
View File
@@ -0,0 +1,279 @@
"""Cleaner 服务测试 — cleanup_tmp、delete_papers_by_date_range。"""
from __future__ import annotations
import os
import time
from datetime import date, datetime, timezone
import pytest
from sqlalchemy import select
from app.config import settings
from app.models import (
CrawlLog,
DataDeleteJob,
Paper,
UserBookmark,
UserNote,
UserReadingStatus,
)
# ── Fixtures ────────────────────────────────────────────────────────────
@pytest.fixture
def sample_paper_with_user_data(db_session, sample_papers_range):
"""给第一篇论文添加用户数据(收藏、阅读状态、笔记)。"""
paper = sample_papers_range[0]
now = datetime.now(timezone.utc)
db_session.add(UserBookmark(paper_id=paper.id, created_at=now))
db_session.add(UserReadingStatus(paper_id=paper.id, status="read_summary", updated_at=now))
db_session.add(UserNote(
paper_id=paper.id,
content="My notes on this paper",
created_at=now,
updated_at=now,
))
db_session.commit()
return paper
# ═══════════════════════════════════════════════════════════════════════
# cleanup_tmp 测试
# ═══════════════════════════════════════════════════════════════════════
class TestCleanupTmp:
"""app/services/cleaner.py — cleanup_tmp 测试。"""
def test_cleanup_removes_old_dirs(self, tmp_path, monkeypatch):
"""超过 24 小时的临时目录应被删除。"""
tmp_dir = tmp_path / "tmp"
tmp_dir.mkdir()
# 创建一个旧目录
old_dir = tmp_dir / "2401.00001"
old_dir.mkdir()
(old_dir / "paper.pdf").write_text("fake pdf")
# 修改目录时间为 25 小时前
old_mtime = time.time() - 25 * 3600
os.utime(old_dir, (old_mtime, old_mtime))
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 1
assert result["removed"] == 1
assert not old_dir.exists()
def test_cleanup_keeps_recent_dirs(self, tmp_path, monkeypatch):
"""24 小时内的临时目录应保留。"""
tmp_dir = tmp_path / "tmp"
tmp_dir.mkdir()
recent_dir = tmp_dir / "2401.00002"
recent_dir.mkdir()
(recent_dir / "paper.pdf").write_text("fake pdf")
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 1
assert result["removed"] == 0
assert recent_dir.exists()
def test_cleanup_empty_dir(self, tmp_path, monkeypatch):
"""data/tmp/ 不存在时安全返回。"""
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_path / "nonexistent")
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 0
assert result["removed"] == 0
def test_cleanup_mixed_ages(self, tmp_path, monkeypatch):
"""混合新旧目录时只删除旧的。"""
tmp_dir = tmp_path / "tmp"
tmp_dir.mkdir()
old_dir = tmp_dir / "2401.old"
old_dir.mkdir()
old_mtime = time.time() - 30 * 3600
os.utime(old_dir, (old_mtime, old_mtime))
recent_dir = tmp_dir / "2401.new"
recent_dir.mkdir()
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 2
assert result["removed"] == 1
assert not old_dir.exists()
assert recent_dir.exists()
# ═══════════════════════════════════════════════════════════════════════
# delete_papers_by_date_range 测试
# ═══════════════════════════════════════════════════════════════════════
class TestDeletePapersByDateRange:
"""app/services/cleaner.py — delete_papers_by_date_range 测试。"""
@pytest.mark.asyncio
async def test_delete_by_date_range(self, db_session, sample_papers_range):
"""删除指定日期范围的论文。"""
from app.services.cleaner import delete_papers_by_date_range
# 删除 1月11日 ~ 1月13日(3篇)
result = await delete_papers_by_date_range(
db_session,
date(2024, 1, 11),
date(2024, 1, 13),
)
assert result["deleted"] == 3
assert result["total"] == 3
assert result["status"] == "success"
# 确认数据库中只剩 2 篇
remaining = db_session.execute(select(Paper)).scalars().all()
assert len(remaining) == 2
dates = {p.paper_date for p in remaining}
assert dates == {date(2024, 1, 10), date(2024, 1, 14)}
@pytest.mark.asyncio
async def test_delete_creates_job_record(self, db_session, sample_papers_range):
"""删除操作应创建 data_delete_jobs 记录。"""
from app.services.cleaner import delete_papers_by_date_range
await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 14),
)
jobs = db_session.execute(select(DataDeleteJob)).scalars().all()
assert len(jobs) == 1
assert jobs[0].status == "success"
assert jobs[0].date_start == date(2024, 1, 10)
assert jobs[0].date_end == date(2024, 1, 14)
assert jobs[0].paper_count == 5
assert jobs[0].completed_at is not None
@pytest.mark.asyncio
async def test_delete_creates_crawl_log(self, db_session, sample_papers_range):
"""删除操作应写入 crawl_logs。"""
from app.services.cleaner import delete_papers_by_date_range
await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 14),
)
logs = db_session.execute(
select(CrawlLog).where(CrawlLog.task == "delete")
).scalars().all()
assert len(logs) == 1
assert logs[0].status == "success"
@pytest.mark.asyncio
async def test_delete_cascade_user_data(self, db_session, sample_paper_with_user_data):
"""删除论文时应 cascade 删除关联的用户数据。"""
from app.services.cleaner import delete_papers_by_date_range
paper = sample_paper_with_user_data
# 删除
result = await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 10),
)
assert result["deleted"] == 1
# 确认用户数据被 cascade 删除
assert db_session.execute(
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
).scalar_one_or_none() is None
assert db_session.execute(
select(UserReadingStatus).where(UserReadingStatus.paper_id == paper.id)
).scalar_one_or_none() is None
assert db_session.execute(
select(UserNote).where(UserNote.paper_id == paper.id)
).scalar_one_or_none() is None
@pytest.mark.asyncio
async def test_delete_removes_fts(self, db_session, sample_papers_range):
"""删除论文时应同步删除 FTS5 索引。"""
import sqlalchemy
from app.services.cleaner import delete_papers_by_date_range
await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 14),
)
# FTS5 应为空
rows = db_session.execute(
sqlalchemy.text("SELECT count(*) FROM papers_fts")
).scalar()
assert rows == 0
@pytest.mark.asyncio
async def test_delete_removes_local_files(self, db_session, sample_papers_range, tmp_path, monkeypatch):
"""删除论文时应删除本地文件目录。"""
from app.services.cleaner import delete_papers_by_date_range
papers_dir = tmp_path / "papers"
papers_dir.mkdir()
(papers_dir / "2401.10001").mkdir()
(papers_dir / "2401.10001" / "meta.json").write_text("{}")
monkeypatch.setattr("app.services.cleaner.PAPERS_DIR", papers_dir)
result = await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 10),
)
assert result["deleted"] == 1
assert not (papers_dir / "2401.10001").exists()
@pytest.mark.asyncio
async def test_delete_empty_range(self, db_session, sample_papers_range):
"""日期范围内无论文时返回 0。"""
from app.services.cleaner import delete_papers_by_date_range
result = await delete_papers_by_date_range(
db_session,
date(2025, 1, 1),
date(2025, 1, 31),
)
assert result["total"] == 0
assert result["deleted"] == 0
assert result["status"] == "success"
@pytest.mark.asyncio
async def test_cleaner_works_without_chroma(self, db_session, sample_papers_with_summary, monkeypatch):
"""CHROMA 关闭时删除论文正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
from app.services.cleaner import delete_papers_by_date_range
result = await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 10),
)
assert result["status"] == "success"
assert result["deleted"] == 1
+163
View File
@@ -0,0 +1,163 @@
"""Embedder / Chroma 服务测试 — 初始化、索引、embedding API。"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from app.config import settings
# ═══════════════════════════════════════════════════════════════════════
# 初始化
# ═══════════════════════════════════════════════════════════════════════
class TestEmbedderInit:
"""embedder.py 初始化测试。"""
def test_chroma_disabled_skip_init(self, monkeypatch):
"""CHROMA_ENABLED=false 时不初始化。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
emb.init_chroma()
assert emb._chroma._client is None
def test_chroma_init_success(self, monkeypatch, tmp_path):
"""CHROMA_ENABLED=true 时初始化成功。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
import app.services.embedder as emb
emb._chroma.reset()
emb.init_chroma()
assert emb._chroma._client is not None
assert emb._chroma._collection is not None
# 清理
emb._chroma.reset()
def test_get_collection_returns_none_when_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 get_collection 返回 None。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
assert emb.get_collection() is None
# ═══════════════════════════════════════════════════════════════════════
# 索引
# ═══════════════════════════════════════════════════════════════════════
class TestEmbedderIndexing:
"""embedder.py 索引测试。"""
def test_index_paper_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 index_paper 返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
assert emb.index_paper("test-id") is False
def test_index_paper_no_api_config(self, monkeypatch, tmp_path):
"""没有 EMBED_API_BASE 时返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
monkeypatch.setattr(settings, "EMBED_MODEL", "")
import app.services.embedder as emb
emb._chroma.reset()
emb.init_chroma()
result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"})
assert result is False
emb._chroma.reset()
def test_index_batch_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 index_batch 返回全失败。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
result = emb.index_batch(["a", "b"])
assert result["success"] == 0
assert result["failed"] == 2
def test_index_batch_empty(self, monkeypatch):
"""空列表时返回 0。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
result = emb.index_batch([])
assert result["total"] == 0
def test_delete_paper_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 delete_paper 返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
assert emb.delete_paper("test-id") is False
def test_search_similar_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 search_similar 返回空列表。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
assert emb.search_similar("test query") == []
# ═══════════════════════════════════════════════════════════════════════
# Embedding API
# ═══════════════════════════════════════════════════════════════════════
class TestEmbeddingApi:
"""_get_embedding 测试。"""
def test_no_api_base_returns_none(self, monkeypatch):
"""EMBED_API_BASE 为空时返回 None。"""
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
monkeypatch.setattr(settings, "EMBED_MODEL", "")
import app.services.embedder as emb
assert emb._get_embedding("test") is None
def test_dimension_mismatch_returns_none(self, monkeypatch):
"""维度不匹配时返回 None。"""
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 128)
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
import app.services.embedder as emb
mock_resp = MagicMock()
mock_resp.json.return_value = {"data": [{"embedding": [0.1] * 64}]}
mock_resp.raise_for_status = MagicMock()
with patch("httpx.Client") as mock_client:
mock_client.return_value.__enter__ = MagicMock(return_value=mock_resp)
mock_client.return_value.__exit__ = MagicMock(return_value=False)
result = emb._get_embedding("test")
assert result is None
def test_api_failure_returns_none(self, monkeypatch):
"""API 调用失败时返回 None。"""
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 0)
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
import app.services.embedder as emb
with patch("httpx.Client") as mock_client:
mock_client.return_value.__enter__ = MagicMock()
mock_client.return_value.__exit__ = MagicMock(return_value=False)
mock_client.return_value.__enter__.return_value.post.side_effect = Exception("timeout")
result = emb._get_embedding("test")
assert result is None
+88
View File
@@ -0,0 +1,88 @@
"""LaTeX 图片提取测试 — 从 .tex 源码中提取图片文件。"""
from __future__ import annotations
import pytest
# ═══════════════════════════════════════════════════════════════════════
# Image Extraction
# ═══════════════════════════════════════════════════════════════════════
class TestImageExtraction:
"""LaTeX 图片提取测试。"""
@pytest.mark.asyncio
async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path):
"""源码目录不存在时返回 0。"""
monkeypatch.setattr("app.services.pdf_downloader.tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.pdf_downloader.paper_dir", lambda x: tmp_path / "papers" / x)
from app.services.image_extractor import extract_images_from_source
result = await extract_images_from_source("2401.99999")
assert result == 0
@pytest.mark.asyncio
async def test_extract_images_from_tex(self, monkeypatch, tmp_path):
"""从 .tex 文件中提取图片。"""
from app.services.image_extractor import extract_images_from_source
tmp_source = tmp_path / "tmp" / "2401.00001" / "source"
tmp_source.mkdir(parents=True)
images_dir = tmp_source / "figs"
images_dir.mkdir()
(images_dir / "figure1.png").write_bytes(b"\x89PNG\r\n")
(images_dir / "figure2.jpg").write_bytes(b"\xff\xd8\xff\xe0")
# 创建 .tex 文件
tex_content = r"""
\documentclass{article}
\begin{document}
\begin{figure}
\includegraphics[width=0.8\textwidth]{figs/figure1.png}
\includegraphics{figs/figure2.jpg}
\includegraphics[angle=90]{figs/nonexistent.pdf}
\end{figure}
\end{document}
"""
(tmp_source / "main.tex").write_text(tex_content)
papers_dir = tmp_path / "papers" / "2401.00001"
monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
# Mock download_source_zip to avoid real network call (source dir already exists)
async def _noop_download(*args, **kwargs):
pass
monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
result = await extract_images_from_source("2401.00001")
assert result == 2
dest_images = papers_dir / "images"
assert dest_images.exists()
assert (dest_images / "figure1.png").exists()
assert (dest_images / "figure2.jpg").exists()
@pytest.mark.asyncio
async def test_extract_images_empty_tex(self, monkeypatch, tmp_path):
""".tex 文件无图片时返回 0。"""
from app.services.image_extractor import extract_images_from_source
tmp_source = tmp_path / "tmp" / "2401.00002" / "source"
tmp_source.mkdir(parents=True)
(tmp_source / "main.tex").write_text(r"\documentclass{article}\begin{document}Hello\end{document}")
monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
# Mock download_source_zip to avoid real network call
async def _noop_download(*args, **kwargs):
pass
monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
result = await extract_images_from_source("2401.00002")
assert result == 0
+224
View File
@@ -0,0 +1,224 @@
"""页面路由测试 — detail、trends、compare、graceful degradation。"""
from __future__ import annotations
from datetime import date
from unittest.mock import patch as upatch
import pytest
from app.config import settings
# ═══════════════════════════════════════════════════════════════════════
# Detail 页 & 相似论文
# ═══════════════════════════════════════════════════════════════════════
class TestDetailPage:
"""论文详情页测试。"""
def test_detail_page_renders(self, client, sample_papers_with_summary):
"""详情页正常渲染。"""
resp = client.get("/paper/2401.20001")
assert resp.status_code == 200
assert "测试论文" in resp.text or "Test Paper" in resp.text
def test_detail_page_not_found(self, client):
"""不存在的论文返回 404。"""
resp = client.get("/paper/nonexistent.99999")
assert resp.status_code == 404
# ═══════════════════════════════════════════════════════════════════════
# Similar API(详情页内联)
# ═══════════════════════════════════════════════════════════════════════
class TestDetailSimilarPapers:
"""详情页相似论文模块测试(CHROMA 关闭时的降级行为)。"""
def test_detail_page_renders_with_similar(self, client, sample_papers_with_summary):
"""详情页正常渲染(含相似论文模块)。"""
resp = client.get("/paper/2401.20001")
assert resp.status_code == 200
assert "测试论文" in resp.text or "Test Paper" in resp.text
def test_detail_page_not_found_similar(self, client):
"""不存在的论文返回 404。"""
resp = client.get("/paper/nonexistent.99999")
assert resp.status_code == 404
# ═══════════════════════════════════════════════════════════════════════
# Trends Dashboard
# ═══════════════════════════════════════════════════════════════════════
class TestTrendsDashboard:
"""趋势看板测试。"""
def test_trends_page_renders(self, client, sample_papers_with_summary):
"""趋势看板页面正常渲染。"""
resp = client.get("/trends")
assert resp.status_code == 200
assert "趋势看板" in resp.text
assert "chart" in resp.text.lower() or "Chart" in resp.text
def test_trends_api_returns_data(self, client, sample_papers_with_summary):
"""趋势 API 返回正确数据结构。"""
resp = client.get("/api/stats/trends")
assert resp.status_code == 200
data = resp.json()
assert "daily_counts" in data
assert "top_tags" in data
assert "upvotes_dist" in data
assert "summary_completion" in data
assert isinstance(data["daily_counts"], list)
assert isinstance(data["top_tags"], list)
assert isinstance(data["upvotes_dist"], list)
assert isinstance(data["summary_completion"], list)
def test_trends_api_daily_counts(self, client, sample_papers_with_summary):
"""每日论文数量数据正确。"""
# 使用测试数据的日期范围
with upatch("app.services.trends.date") as mock_date:
mock_date.today.return_value = date(2024, 1, 20)
mock_date.side_effect = lambda *a, **kw: date(*a, **kw)
resp = client.get("/api/stats/trends")
data = resp.json()
assert len(data["daily_counts"]) == 5
for item in data["daily_counts"]:
assert "date" in item
assert "count" in item
assert item["count"] == 1
def test_trends_api_top_tags(self, client, sample_papers_with_summary):
"""热门标签数据正确。"""
resp = client.get("/api/stats/trends")
data = resp.json()
tags = {t["tag"]: t["count"] for t in data["top_tags"]}
assert "NLP" in tags
assert tags["NLP"] == 5 # 所有论文都有 NLP
def test_trends_api_summary_completion(self, client, sample_papers_with_summary):
"""总结完成率数据正确。"""
resp = client.get("/api/stats/trends")
data = resp.json()
statuses = {s["status"]: s["count"] for s in data["summary_completion"]}
assert "done" in statuses
assert statuses["done"] == 4 # 4 篇已完成
def test_trends_empty_db(self, client):
"""无数据时不崩溃。"""
resp = client.get("/api/stats/trends")
assert resp.status_code == 200
data = resp.json()
assert data["daily_counts"] == []
assert data["top_tags"] == []
# ═══════════════════════════════════════════════════════════════════════
# Compare Page
# ═══════════════════════════════════════════════════════════════════════
class TestComparePage:
"""论文对比页测试。"""
def test_compare_page_no_ids(self, client):
"""无 ID 时显示输入表单。"""
resp = client.get("/compare")
assert resp.status_code == 200
assert "对比" in resp.text
def test_compare_page_with_ids(self, client, sample_papers_with_summary):
"""对比多篇论文正常渲染。"""
resp = client.get("/compare?ids=2401.20001,2401.20002")
assert resp.status_code == 200
assert "2401.20001" in resp.text
assert "2401.20002" in resp.text
# 应包含对比字段
assert "一句话摘要" in resp.text
assert "研究问题" in resp.text
def test_compare_page_max_5(self, client, sample_papers_with_summary):
"""最多 5 篇。"""
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005"
resp = client.get(f"/compare?ids={ids}")
assert resp.status_code == 200
def test_compare_page_over_5_truncates(self, client, sample_papers_with_summary):
"""超过 5 篇截断。"""
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005,2401.20006"
resp = client.get(f"/compare?ids={ids}")
assert resp.status_code == 200
def test_compare_page_invalid_ids(self, client):
"""无效 ID 时显示空结果。"""
resp = client.get("/compare?ids=nonexistent.99999")
assert resp.status_code == 200
def test_compare_page_shows_no_summary_placeholder(self, client, sample_papers_with_summary):
"""无总结的论文显示占位文本。"""
# 2401.20005 没有 summarystatus=pending
resp = client.get("/compare?ids=2401.20005")
assert resp.status_code == 200
assert "暂无总结" in resp.text
# ═══════════════════════════════════════════════════════════════════════
# Nav Bar
# ═══════════════════════════════════════════════════════════════════════
class TestNavBar:
"""导航栏测试。"""
def test_nav_includes_trends_link(self, client):
"""导航栏应包含趋势链接。"""
resp = client.get("/search")
assert resp.status_code == 200
assert "/trends" in resp.text
def test_nav_includes_compare_implicitly(self, client):
"""compare 页面可访问。"""
resp = client.get("/compare")
assert resp.status_code == 200
# ═══════════════════════════════════════════════════════════════════════
# Graceful DegradationCHROMA_ENABLED=false
# ═══════════════════════════════════════════════════════════════════════
class TestGracefulDegradation:
"""CHROMA_ENABLED=false 时优雅降级测试。"""
def test_search_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时 FTS5 搜索正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/search?q=Test")
assert resp.status_code == 200
assert "Test Paper" in resp.text or "测试论文" in resp.text
def test_detail_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时详情页正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/paper/2401.20001")
assert resp.status_code == 200
def test_trends_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时趋势看板正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/trends")
assert resp.status_code == 200
def test_compare_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时对比页正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/compare?ids=2401.20001,2401.20002")
assert resp.status_code == 200
-660
View File
@@ -1,660 +0,0 @@
"""Phase 5 后续增强测试 — embedder、semantic search、trends、compare、image extraction。"""
from __future__ import annotations
import json
import shutil
import time
from datetime import date, datetime, timezone
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import select
from app.config import settings
from app.database import get_db
from app.models import (
Paper,
PaperAuthor,
PaperSummary,
PaperTag,
SummaryStatus,
)
# ── Fixtures ────────────────────────────────────────────────────────────
ADMIN_TOKEN = "test-admin-token-12345"
@pytest.fixture
def admin_headers():
return {"Authorization": f"Bearer " + ADMIN_TOKEN}
@pytest.fixture
def auth_client(client, monkeypatch):
monkeypatch.setattr(settings, "ADMIN_TOKEN", ADMIN_TOKEN)
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
return client
@pytest.fixture
def sample_papers_with_summary(db_session):
"""插入多篇带总结的论文。"""
now = datetime.now(timezone.utc)
papers = []
for i, (arxiv_id, paper_date_str) in enumerate([
("2401.20001", "2024-01-10"),
("2401.20002", "2024-01-11"),
("2401.20003", "2024-01-12"),
("2401.20004", "2024-01-13"),
("2401.20005", "2024-01-14"),
]):
paper_date = date.fromisoformat(paper_date_str)
p = Paper(
arxiv_id=arxiv_id,
title_en=f"Test Paper {i+1}",
title_zh=f"测试论文 {i+1}",
abstract=f"Abstract for paper {i+1}.",
paper_date=paper_date,
crawled_at=now,
upvotes=i * 10 + 5,
)
db_session.add(p)
db_session.flush()
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
db_session.add(PaperTag(paper_id=p.id, tag="NLP", source="hf"))
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
db_session.add(SummaryStatus(
paper_id=p.id,
status="done" if i < 4 else "pending",
quality="normal",
))
# 添加总结(前 4 篇)
if i < 4:
from app.services.schemas import SummarySchema
summary = PaperSummary(
paper_id=p.id,
one_line=f"这是论文{i+1}的一句话摘要",
difficulty="中级",
motivation_problem=f"论文{i+1}的研究问题",
motivation_goal=f"论文{i+1}的研究目标",
method_key_idea=f"论文{i+1}的关键思路",
method_overview=f"论文{i+1}的方法概述",
updated_at=now,
full_json=json.dumps({"title_zh": f"测试论文 {i+1}"}),
)
db_session.add(summary)
# FTS5
import sqlalchemy
db_session.execute(
sqlalchemy.text(
"INSERT INTO papers_fts(rowid, title_en, title_zh, abstract, authors, tags) "
"VALUES (:id, :title_en, :title_zh, :abstract, :authors, :tags)"
),
{
"id": p.id,
"title_en": p.title_en,
"title_zh": p.title_zh or "",
"abstract": p.abstract or "",
"authors": f"Author {i+1}",
"tags": f"NLP, Tag{i+1}",
},
)
papers.append(p)
db_session.commit()
return papers
# ═══════════════════════════════════════════════════════════════════════
# Embedder 服务测试
# ═══════════════════════════════════════════════════════════════════════
class TestEmbedderInit:
"""embedder.py 初始化测试。"""
def test_chroma_disabled_skip_init(self, monkeypatch):
"""CHROMA_ENABLED=false 时不初始化。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
emb.init_chroma()
assert emb._chroma._client is None
def test_chroma_init_success(self, monkeypatch, tmp_path):
"""CHROMA_ENABLED=true 时初始化成功。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
import app.services.embedder as emb
emb._chroma.reset()
emb.init_chroma()
assert emb._chroma._client is not None
assert emb._chroma._collection is not None
# 清理
emb._chroma.reset()
def test_get_collection_returns_none_when_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 get_collection 返回 None。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
assert emb.get_collection() is None
class TestEmbedderIndexing:
"""embedder.py 索引测试。"""
def test_index_paper_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 index_paper 返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
assert emb.index_paper("test-id") is False
def test_index_paper_no_api_config(self, monkeypatch, tmp_path):
"""没有 EMBED_API_BASE 时返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
monkeypatch.setattr(settings, "EMBED_MODEL", "")
import app.services.embedder as emb
emb._chroma.reset()
emb.init_chroma()
result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"})
assert result is False
emb._chroma.reset()
def test_index_batch_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 index_batch 返回全失败。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
result = emb.index_batch(["a", "b"])
assert result["success"] == 0
assert result["failed"] == 2
def test_index_batch_empty(self, monkeypatch):
"""空列表时返回 0。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
result = emb.index_batch([])
assert result["total"] == 0
def test_delete_paper_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 delete_paper 返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
assert emb.delete_paper("test-id") is False
def test_search_similar_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 search_similar 返回空列表。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
assert emb.search_similar("test query") == []
class TestEmbeddingApi:
"""_get_embedding 测试。"""
def test_no_api_base_returns_none(self, monkeypatch):
"""EMBED_API_BASE 为空时返回 None。"""
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
monkeypatch.setattr(settings, "EMBED_MODEL", "")
import app.services.embedder as emb
assert emb._get_embedding("test") is None
def test_dimension_mismatch_returns_none(self, monkeypatch):
"""维度不匹配时返回 None。"""
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 128)
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
import app.services.embedder as emb
mock_resp = MagicMock()
mock_resp.json.return_value = {"data": [{"embedding": [0.1] * 64}]}
mock_resp.raise_for_status = MagicMock()
with patch("httpx.Client") as mock_client:
mock_client.return_value.__enter__ = MagicMock(return_value=mock_resp)
mock_client.return_value.__exit__ = MagicMock(return_value=False)
result = emb._get_embedding("test")
assert result is None
def test_api_failure_returns_none(self, monkeypatch):
"""API 调用失败时返回 None。"""
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 0)
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
import app.services.embedder as emb
with patch("httpx.Client") as mock_client:
mock_client.return_value.__enter__ = MagicMock()
mock_client.return_value.__exit__ = MagicMock(return_value=False)
mock_client.return_value.__enter__.return_value.post.side_effect = Exception("timeout")
result = emb._get_embedding("test")
assert result is None
# ═══════════════════════════════════════════════════════════════════════
# Searcher 语义模式测试
# ═══════════════════════════════════════════════════════════════════════
class TestSearchSemanticMode:
"""searcher.py 语义搜索模式测试。"""
def test_keyword_mode_default(self, db_session, sample_papers_with_summary):
"""默认 keyword 模式走 FTS5。"""
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test Paper", mode="keyword")
assert result["total"] >= 1
assert result["distances"] == {}
def test_semantic_mode_disabled_fallback(self, db_session, monkeypatch, sample_papers_with_summary):
"""CHROMA_ENABLED=false + semantic 模式走 FTS5。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test", mode="semantic")
# 应回退到 FTS5
assert result["total"] >= 1
def test_search_returns_distances_dict(self, db_session, sample_papers_with_summary):
"""搜索结果应包含 distances 字段。"""
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test Paper")
assert "distances" in result
assert isinstance(result["distances"], dict)
def test_empty_query_returns_empty(self, db_session):
"""空查询无标签时返回空。"""
from app.services.searcher import search_papers
result = search_papers(db_session)
assert result["total"] == 0
assert result["results"] == []
def test_tag_only_search(self, db_session, sample_papers_with_summary):
"""仅标签搜索。"""
from app.services.searcher import search_papers
result = search_papers(db_session, tag="NLP")
assert result["total"] >= 1
# ═══════════════════════════════════════════════════════════════════════
# Search Routes 测试
# ═══════════════════════════════════════════════════════════════════════
class TestSearchRoutes:
"""搜索路由测试。"""
def test_search_page_keyword(self, auth_client, sample_papers_with_summary):
"""搜索页 keyword 模式。"""
resp = auth_client.get("/search?q=Test&mode=keyword")
assert resp.status_code == 200
assert "Test" in resp.text or "测试" in resp.text
def test_search_page_semantic_disabled(self, auth_client, monkeypatch, sample_papers_with_summary):
"""语义模式 CHROMA_ENABLED=false 时仍能工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/search?q=Test&mode=semantic")
assert resp.status_code == 200
def test_search_api_with_mode(self, auth_client, sample_papers_with_summary):
"""搜索 API 支持 mode 参数。"""
resp = auth_client.get("/api/search?q=Test&mode=keyword")
assert resp.status_code == 200
data = resp.json()
assert "results" in data
assert "total" in data
# ═══════════════════════════════════════════════════════════════════════
# Similar Paper API 测试
# ═══════════════════════════════════════════════════════════════════════
class TestSimilarAPI:
"""相似论文 API 测试。"""
def test_similar_api_disabled(self, auth_client, monkeypatch, sample_papers_with_summary):
"""CHROMA_ENABLED=false 时返回空列表。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/api/similar/2401.20001")
assert resp.status_code == 200
data = resp.json()
assert data["results"] == []
def test_similar_api_paper_not_found(self, auth_client, monkeypatch):
"""不存在的论文返回空。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/api/similar/nonexistent.99999")
assert resp.status_code == 200
assert resp.json()["results"] == []
def test_similar_api_with_top_k(self, auth_client, monkeypatch, sample_papers_with_summary):
"""top_k 参数控制返回数量。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/api/similar/2401.20001?top_k=3")
assert resp.status_code == 200
# ═══════════════════════════════════════════════════════════════════════
# Detail Page 相似论文测试
# ═══════════════════════════════════════════════════════════════════════
class TestDetailSimilarPapers:
"""详情页相似论文模块测试。"""
def test_detail_page_renders(self, auth_client, sample_papers_with_summary):
"""详情页正常渲染。"""
resp = auth_client.get("/paper/2401.20001")
assert resp.status_code == 200
assert "测试论文" in resp.text or "Test Paper" in resp.text
def test_detail_page_not_found(self, auth_client):
"""不存在的论文返回 404。"""
resp = auth_client.get("/paper/nonexistent.99999")
assert resp.status_code == 404
# ═══════════════════════════════════════════════════════════════════════
# Trends Dashboard 测试
# ═══════════════════════════════════════════════════════════════════════
class TestTrendsDashboard:
"""趋势看板测试。"""
def test_trends_page_renders(self, auth_client, sample_papers_with_summary):
"""趋势看板页面正常渲染。"""
resp = auth_client.get("/trends")
assert resp.status_code == 200
assert "趋势看板" in resp.text
assert "chart" in resp.text.lower() or "Chart" in resp.text
def test_trends_api_returns_data(self, auth_client, sample_papers_with_summary):
"""趋势 API 返回正确数据结构。"""
resp = auth_client.get("/api/stats/trends")
assert resp.status_code == 200
data = resp.json()
assert "daily_counts" in data
assert "top_tags" in data
assert "upvotes_dist" in data
assert "summary_completion" in data
assert isinstance(data["daily_counts"], list)
assert isinstance(data["top_tags"], list)
assert isinstance(data["upvotes_dist"], list)
assert isinstance(data["summary_completion"], list)
def test_trends_api_daily_counts(self, auth_client, sample_papers_with_summary, monkeypatch):
"""每日论文数量数据正确。"""
# 使用测试数据的日期范围
from unittest.mock import patch as upatch
import app.routes.trends as trends_mod
# monkeypatch get_trends_data 中的 date.today
with upatch("app.services.trends.date") as mock_date:
mock_date.today.return_value = date(2024, 1, 20)
mock_date.side_effect = lambda *a, **kw: date(*a, **kw)
resp = auth_client.get("/api/stats/trends")
data = resp.json()
assert len(data["daily_counts"]) == 5
for item in data["daily_counts"]:
assert "date" in item
assert "count" in item
assert item["count"] == 1
def test_trends_api_top_tags(self, auth_client, sample_papers_with_summary):
"""热门标签数据正确。"""
resp = auth_client.get("/api/stats/trends")
data = resp.json()
tags = {t["tag"]: t["count"] for t in data["top_tags"]}
assert "NLP" in tags
assert tags["NLP"] == 5 # 所有论文都有 NLP
def test_trends_api_summary_completion(self, auth_client, sample_papers_with_summary):
"""总结完成率数据正确。"""
resp = auth_client.get("/api/stats/trends")
data = resp.json()
statuses = {s["status"]: s["count"] for s in data["summary_completion"]}
assert "done" in statuses
assert statuses["done"] == 4 # 4 篇已完成
def test_trends_empty_db(self, auth_client):
"""无数据时不崩溃。"""
resp = auth_client.get("/api/stats/trends")
assert resp.status_code == 200
data = resp.json()
assert data["daily_counts"] == []
assert data["top_tags"] == []
# ═══════════════════════════════════════════════════════════════════════
# Compare Page 测试
# ═══════════════════════════════════════════════════════════════════════
class TestComparePage:
"""论文对比页测试。"""
def test_compare_page_no_ids(self, auth_client):
"""无 ID 时显示输入表单。"""
resp = auth_client.get("/compare")
assert resp.status_code == 200
assert "对比" in resp.text
def test_compare_page_with_ids(self, auth_client, sample_papers_with_summary):
"""对比多篇论文正常渲染。"""
resp = auth_client.get("/compare?ids=2401.20001,2401.20002")
assert resp.status_code == 200
assert "2401.20001" in resp.text
assert "2401.20002" in resp.text
# 应包含对比字段
assert "一句话摘要" in resp.text
assert "研究问题" in resp.text
def test_compare_page_max_5(self, auth_client, sample_papers_with_summary):
"""最多 5 篇。"""
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005"
resp = auth_client.get(f"/compare?ids={ids}")
assert resp.status_code == 200
def test_compare_page_over_5_truncates(self, auth_client, sample_papers_with_summary):
"""超过 5 篇截断。"""
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005,2401.20006"
resp = auth_client.get(f"/compare?ids={ids}")
assert resp.status_code == 200
# 不应包含第 6 篇(不存在)
def test_compare_page_invalid_ids(self, auth_client):
"""无效 ID 时显示空结果。"""
resp = auth_client.get("/compare?ids=nonexistent.99999")
assert resp.status_code == 200
# 不存在的论文
assert "未找到" in resp.text or "暂无" in resp.text or resp.status_code == 200
def test_compare_page_shows_no_summary_placeholder(self, auth_client, sample_papers_with_summary):
"""无总结的论文显示占位文本。"""
# 2401.20005 没有 summarystatus=pending
resp = auth_client.get("/compare?ids=2401.20005")
assert resp.status_code == 200
assert "暂无总结" in resp.text
# ═══════════════════════════════════════════════════════════════════════
# Image Extraction 测试
# ═══════════════════════════════════════════════════════════════════════
class TestImageExtraction:
"""LaTeX 图片提取测试。"""
@pytest.mark.asyncio
async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path):
"""源码目录不存在时返回 0。"""
monkeypatch.setattr("app.services.pdf_downloader.tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.pdf_downloader.paper_dir", lambda x: tmp_path / "papers" / x)
from app.services.image_extractor import extract_images_from_source
result = await extract_images_from_source("2401.99999")
assert result == 0
@pytest.mark.asyncio
async def test_extract_images_from_tex(self, monkeypatch, tmp_path):
"""从 .tex 文件中提取图片。"""
from app.services.image_extractor import extract_images_from_source
tmp_source = tmp_path / "tmp" / "2401.00001" / "source"
tmp_source.mkdir(parents=True)
images_dir = tmp_source / "figs"
images_dir.mkdir()
(images_dir / "figure1.png").write_bytes(b"\x89PNG\r\n")
(images_dir / "figure2.jpg").write_bytes(b"\xff\xd8\xff\xe0")
# 创建 .tex 文件
tex_content = r"""
\documentclass{article}
\begin{document}
\begin{figure}
\includegraphics[width=0.8\textwidth]{figs/figure1.png}
\includegraphics{figs/figure2.jpg}
\includegraphics[angle=90]{figs/nonexistent.pdf}
\end{figure}
\end{document}
"""
(tmp_source / "main.tex").write_text(tex_content)
papers_dir = tmp_path / "papers" / "2401.00001"
monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
# Mock download_source_zip to avoid real network call (source dir already exists)
async def _noop_download(*args, **kwargs):
pass
monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
result = await extract_images_from_source("2401.00001")
assert result == 2
dest_images = papers_dir / "images"
assert dest_images.exists()
assert (dest_images / "figure1.png").exists()
assert (dest_images / "figure2.jpg").exists()
@pytest.mark.asyncio
async def test_extract_images_empty_tex(self, monkeypatch, tmp_path):
""".tex 文件无图片时返回 0。"""
from app.services.image_extractor import extract_images_from_source
tmp_source = tmp_path / "tmp" / "2401.00002" / "source"
tmp_source.mkdir(parents=True)
(tmp_source / "main.tex").write_text(r"\documentclass{article}\begin{document}Hello\end{document}")
monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
# Mock download_source_zip to avoid real network call
async def _noop_download(*args, **kwargs):
pass
monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
result = await extract_images_from_source("2401.00002")
assert result == 0
# ═══════════════════════════════════════════════════════════════════════
# Nav Bar 测试
# ═══════════════════════════════════════════════════════════════════════
class TestNavBar:
"""导航栏测试。"""
def test_nav_includes_trends_link(self, auth_client):
"""导航栏应包含趋势链接。"""
resp = auth_client.get("/search")
assert resp.status_code == 200
assert "/trends" in resp.text
def test_nav_includes_compare_implicitly(self, auth_client):
"""compare 页面可访问。"""
resp = auth_client.get("/compare")
assert resp.status_code == 200
# ═══════════════════════════════════════════════════════════════════════
# Graceful Degradation 测试
# ═══════════════════════════════════════════════════════════════════════
class TestGracefulDegradation:
"""CHROMA_ENABLED=false 时优雅降级测试。"""
def test_search_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时 FTS5 搜索正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/search?q=Test")
assert resp.status_code == 200
assert "Test Paper" in resp.text or "测试论文" in resp.text
def test_detail_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时详情页正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/paper/2401.20001")
assert resp.status_code == 200
def test_trends_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时趋势看板正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/trends")
assert resp.status_code == 200
def test_compare_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时对比页正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/compare?ids=2401.20001,2401.20002")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_cleaner_works_without_chroma(self, db_session, sample_papers_with_summary, monkeypatch):
"""CHROMA 关闭时删除论文正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
from app.services.cleaner import delete_papers_by_date_range
result = await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 10),
)
assert result["status"] == "success"
assert result["deleted"] == 1
+189
View File
@@ -0,0 +1,189 @@
"""SummarySchema 校验、quality 分级、JSON 提取、错误分类测试。"""
from __future__ import annotations
import json
import pytest
from pydantic import ValidationError
from app.services.pi_client import (
JsonNotFoundError,
PiProcessError,
PiTimeoutError,
extract_json as _extract_json,
)
from app.services.pdf_downloader import PdfDownloadError
from app.services.schemas import (
SummarySchema,
assess_quality,
classify_validation_error,
flatten_for_db,
)
from app.services.summarizer import _classify_error
# ═══════════════════════════════════════════════════════════════════════
# SummarySchema 校验
# ═══════════════════════════════════════════════════════════════════════
class TestSummarySchema:
"""Pydantic schema 校验。"""
def test_valid_summary(self, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
assert schema.title_zh == "测试论文中文标题"
assert len(schema.tags) == 3
assert schema.motivation.problem
def test_missing_title_zh(self, sample_summary_dict):
del sample_summary_dict["title_zh"]
with pytest.raises(ValidationError) as exc_info:
SummarySchema.model_validate(sample_summary_dict)
assert classify_validation_error(exc_info.value) == "field_missing"
def test_empty_one_line(self, sample_summary_dict):
sample_summary_dict["one_line"] = ""
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_empty_tags(self, sample_summary_dict):
sample_summary_dict["tags"] = []
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_empty_motivation_problem(self, sample_summary_dict):
sample_summary_dict["motivation"]["problem"] = ""
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_empty_method_key_idea(self, sample_summary_dict):
sample_summary_dict["method"]["key_idea"] = ""
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_extra_fields_ignored(self, sample_summary_dict):
sample_summary_dict["figures"] = ["fig1.png"]
sample_summary_dict["takeaway"] = "important paper"
schema = SummarySchema.model_validate(sample_summary_dict)
assert not hasattr(schema, "figures")
assert schema.title_zh # 正常解析
def test_flatten_for_db(self, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
flat = flatten_for_db(schema)
assert flat["one_line"] == schema.one_line
assert flat["motivation_problem"] == schema.motivation.problem
assert flat["method_key_idea"] == schema.method.key_idea
assert "full_json" in flat
assert "updated_at" in flat
# JSON 字段可解析
assert isinstance(json.loads(flat["prerequisites_json"]), dict)
assert isinstance(json.loads(flat["method_steps_json"]), list)
# ═══════════════════════════════════════════════════════════════════════
# Quality 分级
# ═══════════════════════════════════════════════════════════════════════
class TestQualityAssessment:
"""质量分级测试。"""
def test_quality_normal(self, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "normal"
def test_quality_degraded_missing_goal(self, sample_summary_dict):
sample_summary_dict["motivation"]["goal"] = ""
sample_summary_dict["motivation"]["gap"] = ""
sample_summary_dict["method"]["overview"] = ""
sample_summary_dict["results"]["main_findings"] = []
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "degraded"
def test_quality_low_short_one_line(self, sample_summary_dict):
sample_summary_dict["one_line"] = ""
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "low"
def test_quality_low_short_key_idea(self, sample_summary_dict):
sample_summary_dict["method"]["key_idea"] = ""
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "low"
# ═══════════════════════════════════════════════════════════════════════
# JSON 提取
# ═══════════════════════════════════════════════════════════════════════
class TestJsonExtraction:
"""pi 输出的 JSON 提取。"""
def test_direct_json(self, sample_summary_json):
result = _extract_json(sample_summary_json)
assert result["title_zh"] == "测试论文中文标题"
def test_fenced_code_block(self, sample_summary_json):
raw = f"一些文字\n```json\n{sample_summary_json}\n```\n更多文字"
result = _extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_fenced_without_lang(self, sample_summary_json):
raw = f"文字\n```\n{sample_summary_json}\n```"
result = _extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_embedded_braces(self, sample_summary_dict):
json_str = json.dumps(sample_summary_dict, ensure_ascii=False)
raw = f"Here is the summary:\n{json_str}\nEnd."
result = _extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_no_json_raises(self):
with pytest.raises(JsonNotFoundError):
_extract_json("No JSON here at all.")
def test_json_without_title_zh_falls_through(self):
"""不含 title_zh 的 JSON 不是我们要的。"""
raw = json.dumps({"other": "data"})
# 如果有其他合法 JSON 块也能返回,但没有就直接找最大块
# 此场景 raw 本身就是一个 JSON dict,但没有 title_zh
# 策略 1 会跳过(无 title_zh),策略 2 无代码块,策略 3 找到最大块
result = _extract_json(raw)
assert result == {"other": "data"} # 最大块兜底
# ═══════════════════════════════════════════════════════════════════════
# 错误分类
# ═══════════════════════════════════════════════════════════════════════
class TestErrorClassification:
"""异常 → error_type 映射。"""
def test_pdf_download_error(self):
assert _classify_error(PdfDownloadError("fail")) == "pdf_download_failed"
def test_timeout_error(self):
assert _classify_error(PiTimeoutError("timeout")) == "timeout"
def test_process_error(self):
assert _classify_error(PiProcessError(1, "stderr")) == "process_error"
def test_json_not_found(self):
assert _classify_error(JsonNotFoundError("not found")) == "json_not_found"
def test_json_invalid(self):
assert _classify_error(json.JSONDecodeError("bad", "", 0)) == "json_invalid"
def test_field_missing(self):
try:
SummarySchema.model_validate({"title_zh": ""}) # type: ignore
except ValidationError as exc:
assert _classify_error(exc) == "field_missing"
def test_unknown_error(self):
assert _classify_error(RuntimeError("boom")) == "unknown"
+97 -19
View File
@@ -1,10 +1,12 @@
"""搜索阅读列表RSS Feed 测试。"""
"""搜索服务 + 路由 + 阅读列表 + RSS + 语义模式测试。"""
from __future__ import annotations
import pytest
from datetime import date, datetime, timezone
from app.config import settings
# ═══════════════════════════════════════════════════════════════════════
# 搜索服务单元测试
@@ -12,69 +14,59 @@ from datetime import date, datetime, timezone
class TestSearchService:
"""app/services/searcher.py 单元测试。"""
"""app/services/searcher.py — FTS5 关键词搜索单元测试。"""
def test_search_by_title(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test Paper")
assert result["total"] == 1
assert result["results"][0].arxiv_id == "2401.12345"
def test_search_by_abstract(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="test abstract")
assert result["total"] == 1
def test_search_by_author(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="Alice")
assert result["total"] == 1
def test_search_by_tag_in_fts(self, db_session, sample_paper):
from app.services.searcher import search_papers
# FTS5 索引中包含 tags 列,可以搜到
result = search_papers(db_session, query="NLP")
assert result["total"] == 1
def test_search_no_results(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="quantum entanglement")
assert result["total"] == 0
assert result["results"] == []
def test_search_empty_query_returns_empty(self, db_session):
from app.services.searcher import search_papers
result = search_papers(db_session, query="")
assert result["total"] == 0
assert result["results"] == []
def test_search_special_characters_sanitized(self, db_session, sample_paper):
from app.services.searcher import search_papers
# 特殊字符被清除后,剩下 "Test" 仍然能搜到
result = search_papers(db_session, query='Test "Paper" {test}')
assert result["total"] >= 1
def test_search_with_tag_filter(self, db_session, sample_paper):
from app.services.searcher import search_papers
# 关键词 + 标签筛选
result = search_papers(db_session, query="Paper", tag="NLP")
assert result["total"] == 1
# 标签不匹配 → 0
result2 = search_papers(db_session, query="Paper", tag="nonexistent")
assert result2["total"] == 0
def test_search_tag_only_no_query(self, db_session, sample_paper):
from app.services.searcher import search_papers
# 只有标签,无关键词
result = search_papers(db_session, tag="NLP")
assert result["total"] == 1
@@ -82,14 +74,12 @@ class TestSearchService:
def test_search_pagination(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test", page=2, page_size=10)
assert result["page"] == 2
assert result["total_pages"] == 1 # 只有 1 条结果,1 页
def test_search_returns_snippets(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="test abstract")
assert result["total"] == 1
paper_id = result["results"][0].id
@@ -99,12 +89,54 @@ class TestSearchService:
def test_get_all_tags(self, db_session, sample_paper):
from app.services.searcher import get_all_tags
tags = get_all_tags(db_session)
assert "NLP" in tags
assert "LLM" in tags
# ═══════════════════════════════════════════════════════════════════════
# 语义 / Embedder 模式测试
# ═══════════════════════════════════════════════════════════════════════
class TestSearchSemanticMode:
"""searcher.py — semantic 模式(含 embedder 回退)测试。"""
def test_keyword_mode_default(self, db_session, sample_papers_with_summary):
"""默认 keyword 模式走 FTS5。"""
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test Paper", mode="keyword")
assert result["total"] >= 1
assert result["distances"] == {}
def test_semantic_mode_disabled_fallback(self, db_session, monkeypatch, sample_papers_with_summary):
"""CHROMA_ENABLED=false + semantic 模式走 FTS5。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test", mode="semantic")
assert result["total"] >= 1
def test_search_returns_distances_dict(self, db_session, sample_papers_with_summary):
"""搜索结果应包含 distances 字段。"""
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test Paper")
assert "distances" in result
assert isinstance(result["distances"], dict)
def test_empty_query_returns_empty_no_tags(self, db_session):
"""空查询无标签时返回空。"""
from app.services.searcher import search_papers
result = search_papers(db_session)
assert result["total"] == 0
assert result["results"] == []
def test_tag_only_search(self, db_session, sample_papers_with_summary):
"""仅标签搜索。"""
from app.services.searcher import search_papers
result = search_papers(db_session, tag="NLP")
assert result["total"] >= 1
# ═══════════════════════════════════════════════════════════════════════
# 搜索路由 HTTP 测试
# ═══════════════════════════════════════════════════════════════════════
@@ -131,6 +163,18 @@ class TestSearchRoutes:
assert resp.status_code == 200
assert "2401.12345" in resp.text
def test_search_page_keyword_mode(self, client, sample_papers_with_summary):
"""搜索页 keyword 模式。"""
resp = client.get("/search?q=Test&mode=keyword")
assert resp.status_code == 200
assert "Test" in resp.text or "测试" in resp.text
def test_search_page_semantic_disabled(self, client, monkeypatch, sample_papers_with_summary):
"""语义模式 CHROMA_ENABLED=false 时仍能工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/search?q=Test&mode=semantic")
assert resp.status_code == 200
def test_search_api_json(self, client, sample_paper):
"""GET /api/search?q=Test 返回 JSON。"""
resp = client.get("/api/search?q=Test")
@@ -146,6 +190,14 @@ class TestSearchRoutes:
data = resp.json()
assert data["total"] == 1
def test_search_api_with_mode(self, client, sample_papers_with_summary):
"""搜索 API 支持 mode 参数。"""
resp = client.get("/api/search?q=Test&mode=keyword")
assert resp.status_code == 200
data = resp.json()
assert "results" in data
assert "total" in data
def test_search_api_empty(self, client, sample_paper):
"""GET /api/search?q=nonexistent 返回空结果。"""
resp = client.get("/api/search?q=nonexistent")
@@ -161,6 +213,36 @@ class TestSearchRoutes:
assert data["total"] >= 1
# ═══════════════════════════════════════════════════════════════════════
# Similar Paper API 测试
# ═══════════════════════════════════════════════════════════════════════
class TestSimilarAPI:
"""相似论文 API 测试。"""
def test_similar_api_disabled(self, client, monkeypatch, sample_papers_with_summary):
"""CHROMA_ENABLED=false 时返回空列表。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/api/similar/2401.20001")
assert resp.status_code == 200
data = resp.json()
assert data["results"] == []
def test_similar_api_paper_not_found(self, client, monkeypatch):
"""不存在的论文返回空。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/api/similar/nonexistent.99999")
assert resp.status_code == 200
assert resp.json()["results"] == []
def test_similar_api_with_top_k(self, client, monkeypatch, sample_papers_with_summary):
"""top_k 参数控制返回数量。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = client.get("/api/similar/2401.20001?top_k=3")
assert resp.status_code == 200
# ═══════════════════════════════════════════════════════════════════════
# 阅读列表路由测试
# ═══════════════════════════════════════════════════════════════════════
@@ -185,8 +267,6 @@ class TestReadingListRoute:
def test_reading_list_filter_by_status(self, client, sample_paper):
"""按阅读状态筛选。"""
import json
# 设置阅读状态
client.post(
"/api/reading-status/2401.12345",
@@ -204,8 +284,6 @@ class TestReadingListRoute:
def test_reading_list_has_note_filter(self, client, sample_paper):
"""筛选有笔记的论文。"""
import json
# 写笔记
client.post(
"/api/note/2401.12345",
+7 -239
View File
@@ -1,15 +1,13 @@
"""AI 总结服务测试 — Mock 全链路,不调用真实 pi"""
"""AI 总结服务测试 — summarize_one 状态流转、批量处理、DB 更新、文件操作"""
from __future__ import annotations
import asyncio
import json
from datetime import date, datetime, timezone
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, patch
import pytest
from pydantic import ValidationError
from sqlalchemy import text
from app.models import (
@@ -20,193 +18,19 @@ from app.models import (
SummaryStatus,
TaskLock,
)
from app.services.schemas import (
SummarySchema,
assess_quality,
classify_validation_error,
flatten_for_db,
from app.services.pdf_downloader import (
PdfDownloadError,
cleanup_tmp as _cleanup_tmp,
)
from app.services.pi_client import PiTimeoutError
from app.services.schemas import SummarySchema
from app.services.summarizer import (
_classify_error,
_save_files,
_save_raw_output_only,
_update_summary_in_db,
summarize_batch,
summarize_one,
summarize_single,
)
from app.services.pi_client import (
JsonNotFoundError,
PiProcessError,
PiTimeoutError,
call_pi as _call_pi,
extract_json as _extract_json,
)
from app.services.pdf_downloader import (
PdfDownloadError,
cleanup_tmp as _cleanup_tmp,
)
# ═══════════════════════════════════════════════════════════════════════
# Schema 校验测试
# ═══════════════════════════════════════════════════════════════════════
class TestSummarySchema:
"""Pydantic schema 校验。"""
def test_valid_summary(self, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
assert schema.title_zh == "测试论文中文标题"
assert len(schema.tags) == 3
assert schema.motivation.problem
def test_missing_title_zh(self, sample_summary_dict):
del sample_summary_dict["title_zh"]
with pytest.raises(ValidationError) as exc_info:
SummarySchema.model_validate(sample_summary_dict)
assert classify_validation_error(exc_info.value) == "field_missing"
def test_empty_one_line(self, sample_summary_dict):
sample_summary_dict["one_line"] = ""
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_empty_tags(self, sample_summary_dict):
sample_summary_dict["tags"] = []
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_empty_motivation_problem(self, sample_summary_dict):
sample_summary_dict["motivation"]["problem"] = ""
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_empty_method_key_idea(self, sample_summary_dict):
sample_summary_dict["method"]["key_idea"] = ""
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_extra_fields_ignored(self, sample_summary_dict):
sample_summary_dict["figures"] = ["fig1.png"]
sample_summary_dict["takeaway"] = "important paper"
schema = SummarySchema.model_validate(sample_summary_dict)
assert not hasattr(schema, "figures")
assert schema.title_zh # 正常解析
def test_flatten_for_db(self, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
flat = flatten_for_db(schema)
assert flat["one_line"] == schema.one_line
assert flat["motivation_problem"] == schema.motivation.problem
assert flat["method_key_idea"] == schema.method.key_idea
assert "full_json" in flat
assert "updated_at" in flat
# JSON 字段可解析
assert isinstance(json.loads(flat["prerequisites_json"]), dict)
assert isinstance(json.loads(flat["method_steps_json"]), list)
class TestQualityAssessment:
"""质量分级测试。"""
def test_quality_normal(self, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "normal"
def test_quality_degraded_missing_goal(self, sample_summary_dict):
sample_summary_dict["motivation"]["goal"] = ""
sample_summary_dict["motivation"]["gap"] = ""
sample_summary_dict["method"]["overview"] = ""
sample_summary_dict["results"]["main_findings"] = []
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "degraded"
def test_quality_low_short_one_line(self, sample_summary_dict):
sample_summary_dict["one_line"] = ""
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "low"
def test_quality_low_short_key_idea(self, sample_summary_dict):
sample_summary_dict["method"]["key_idea"] = ""
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "low"
# ═══════════════════════════════════════════════════════════════════════
# JSON 提取测试
# ═══════════════════════════════════════════════════════════════════════
class TestJsonExtraction:
"""pi 输出的 JSON 提取。"""
def test_direct_json(self, sample_summary_json):
result = _extract_json(sample_summary_json)
assert result["title_zh"] == "测试论文中文标题"
def test_fenced_code_block(self, sample_summary_json):
raw = f"一些文字\n```json\n{sample_summary_json}\n```\n更多文字"
result = _extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_fenced_without_lang(self, sample_summary_json):
raw = f"文字\n```\n{sample_summary_json}\n```"
result = _extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_embedded_braces(self, sample_summary_dict):
json_str = json.dumps(sample_summary_dict, ensure_ascii=False)
raw = f"Here is the summary:\n{json_str}\nEnd."
result = _extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_no_json_raises(self):
with pytest.raises(JsonNotFoundError):
_extract_json("No JSON here at all.")
def test_json_without_title_zh_falls_through(self):
"""不含 title_zh 的 JSON 不是我们要的。"""
raw = json.dumps({"other": "data"})
# 如果有其他合法 JSON 块也能返回,但没有就直接找最大块
# 此场景 raw 本身就是一个 JSON dict,但没有 title_zh
# 策略 1 会跳过(无 title_zh),策略 2 无代码块,策略 3 找到最大块
result = _extract_json(raw)
assert result == {"other": "data"} # 最大块兜底
# ═══════════════════════════════════════════════════════════════════════
# 错误分类测试
# ═══════════════════════════════════════════════════════════════════════
class TestErrorClassification:
"""异常 → error_type 映射。"""
def test_pdf_download_error(self):
assert _classify_error(PdfDownloadError("fail")) == "pdf_download_failed"
def test_timeout_error(self):
assert _classify_error(PiTimeoutError("timeout")) == "timeout"
def test_process_error(self):
assert _classify_error(PiProcessError(1, "stderr")) == "process_error"
def test_json_not_found(self):
assert _classify_error(JsonNotFoundError("not found")) == "json_not_found"
def test_json_invalid(self):
assert _classify_error(json.JSONDecodeError("bad", "", 0)) == "json_invalid"
def test_field_missing(self):
try:
SummarySchema.model_validate({"title_zh": ""}) # type: ignore
except ValidationError as exc:
assert _classify_error(exc) == "field_missing"
def test_unknown_error(self):
assert _classify_error(RuntimeError("boom")) == "unknown"
# ═══════════════════════════════════════════════════════════════════════
@@ -675,59 +499,3 @@ class TestBatchSummarize:
result = await summarize_batch(db_session)
assert result["status"] == "success"
assert result["total"] == 0
# ═══════════════════════════════════════════════════════════════════════
# Admin 路由鉴权测试
# ═══════════════════════════════════════════════════════════════════════
class TestAdminAuth:
"""管理接口鉴权 — 只测 HTTP 层,mock 掉实际服务调用。"""
def test_no_token_returns_401(self, client):
"""无 Bearer token 返回 401。"""
resp = client.post("/admin/summarize")
assert resp.status_code in (401, 403)
def test_wrong_token_returns_401(self, client):
resp = client.post(
"/admin/summarize",
headers={"Authorization": "Bearer wrong-token"},
)
assert resp.status_code == 401
def test_correct_token_batch(self, client, admin_headers):
"""正确 token 调用 batch summarizemock 掉服务层。"""
import app.config as config_mod
original = config_mod.settings.ADMIN_TOKEN
config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345"
try:
with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock:
mock.return_value = {"status": "success", "done": 0, "failed": 0, "total": 0}
resp = client.post("/admin/summarize", headers=admin_headers)
assert resp.status_code == 200
assert resp.json()["status"] == "success"
finally:
config_mod.settings.ADMIN_TOKEN = original
def test_single_paper_not_found(self, client, admin_headers):
"""单篇总结不存在的论文返回 404。"""
import app.config as config_mod
original = config_mod.settings.ADMIN_TOKEN
config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345"
try:
with patch(
"app.routes.admin.summarize_single",
new_callable=AsyncMock,
return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"},
):
resp = client.post(
"/admin/summarize/nonexistent.99999",
headers=admin_headers,
)
assert resp.status_code == 404
finally:
config_mod.settings.ADMIN_TOKEN = original