refactor: split monolithic phase tests into per-module test files
- rename test_admin_phase4.py -> test_admin.py, test_search.py -> test_searcher.py - split test_phase5.py into test_cleaner, test_embedder, test_image_extractor, test_pages - move schema tests from test_summarizer.py into dedicated test_schemas.py - add sample_papers_range and sample_papers_with_summary fixtures in conftest - update .gitignore to exclude all of data/
This commit is contained in:
+1
-4
@@ -2,10 +2,7 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
*.pyc
|
*.pyc
|
||||||
*.pyo
|
*.pyo
|
||||||
data/db/*.db
|
data/
|
||||||
data/papers/
|
|
||||||
data/tmp/
|
|
||||||
data/chroma/
|
|
||||||
logs/*.log
|
logs/*.log
|
||||||
.venv/
|
.venv/
|
||||||
venv/
|
venv/
|
||||||
|
|||||||
@@ -209,3 +209,122 @@ def admin_token():
|
|||||||
def admin_headers(admin_token):
|
def admin_headers(admin_token):
|
||||||
"""带 Bearer token 的请求头。"""
|
"""带 Bearer token 的请求头。"""
|
||||||
return {"Authorization": f"Bearer {admin_token}"}
|
return {"Authorization": f"Bearer {admin_token}"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def wrong_admin_headers():
|
||||||
|
"""错误的 Authorization 请求头。"""
|
||||||
|
return {"Authorization": "Bearer wrong-token"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 多样例数据 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_papers_range(db_session):
|
||||||
|
"""插入 5 篇不同日期的论文(用于 admin / cleaner 测试)。"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
papers = []
|
||||||
|
for i, (arxiv_id, paper_date_str) in enumerate([
|
||||||
|
("2401.10001", "2024-01-10"),
|
||||||
|
("2401.10002", "2024-01-11"),
|
||||||
|
("2401.10003", "2024-01-12"),
|
||||||
|
("2401.10004", "2024-01-13"),
|
||||||
|
("2401.10005", "2024-01-14"),
|
||||||
|
]):
|
||||||
|
paper_date = date.fromisoformat(paper_date_str)
|
||||||
|
p = Paper(
|
||||||
|
arxiv_id=arxiv_id,
|
||||||
|
title_en=f"Test Paper {i+1}",
|
||||||
|
abstract=f"Abstract for paper {i+1}.",
|
||||||
|
paper_date=paper_date,
|
||||||
|
crawled_at=now,
|
||||||
|
upvotes=i * 10,
|
||||||
|
)
|
||||||
|
db_session.add(p)
|
||||||
|
db_session.flush()
|
||||||
|
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
|
||||||
|
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
|
||||||
|
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
|
||||||
|
# FTS5
|
||||||
|
db_session.execute(
|
||||||
|
__import__("sqlalchemy").text(
|
||||||
|
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
|
||||||
|
"VALUES (:id, :title, :abstract, :authors, :tags)"
|
||||||
|
),
|
||||||
|
{"id": p.id, "title": p.title_en, "abstract": p.abstract,
|
||||||
|
"authors": f"Author {i+1}", "tags": f"Tag{i+1}"},
|
||||||
|
)
|
||||||
|
papers.append(p)
|
||||||
|
db_session.commit()
|
||||||
|
return papers
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_papers_with_summary(db_session):
|
||||||
|
"""插入 5 篇带总结的论文(用于 search / pages / trends 测试)。"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
papers = []
|
||||||
|
for i, (arxiv_id, paper_date_str) in enumerate([
|
||||||
|
("2401.20001", "2024-01-10"),
|
||||||
|
("2401.20002", "2024-01-11"),
|
||||||
|
("2401.20003", "2024-01-12"),
|
||||||
|
("2401.20004", "2024-01-13"),
|
||||||
|
("2401.20005", "2024-01-14"),
|
||||||
|
]):
|
||||||
|
paper_date = date.fromisoformat(paper_date_str)
|
||||||
|
p = Paper(
|
||||||
|
arxiv_id=arxiv_id,
|
||||||
|
title_en=f"Test Paper {i+1}",
|
||||||
|
title_zh=f"测试论文 {i+1}",
|
||||||
|
abstract=f"Abstract for paper {i+1}.",
|
||||||
|
paper_date=paper_date,
|
||||||
|
crawled_at=now,
|
||||||
|
upvotes=i * 10 + 5,
|
||||||
|
)
|
||||||
|
db_session.add(p)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
|
||||||
|
db_session.add(PaperTag(paper_id=p.id, tag="NLP", source="hf"))
|
||||||
|
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
|
||||||
|
|
||||||
|
db_session.add(SummaryStatus(
|
||||||
|
paper_id=p.id,
|
||||||
|
status="done" if i < 4 else "pending",
|
||||||
|
quality="normal",
|
||||||
|
))
|
||||||
|
|
||||||
|
# 添加总结(前 4 篇)
|
||||||
|
if i < 4:
|
||||||
|
summary = PaperSummary(
|
||||||
|
paper_id=p.id,
|
||||||
|
one_line=f"这是论文{i+1}的一句话摘要",
|
||||||
|
difficulty="中级",
|
||||||
|
motivation_problem=f"论文{i+1}的研究问题",
|
||||||
|
motivation_goal=f"论文{i+1}的研究目标",
|
||||||
|
method_key_idea=f"论文{i+1}的关键思路",
|
||||||
|
method_overview=f"论文{i+1}的方法概述",
|
||||||
|
updated_at=now,
|
||||||
|
full_json=json.dumps({"title_zh": f"测试论文 {i+1}"}),
|
||||||
|
)
|
||||||
|
db_session.add(summary)
|
||||||
|
|
||||||
|
# FTS5
|
||||||
|
db_session.execute(
|
||||||
|
__import__("sqlalchemy").text(
|
||||||
|
"INSERT INTO papers_fts(rowid, title_en, title_zh, abstract, authors, tags) "
|
||||||
|
"VALUES (:id, :title_en, :title_zh, :abstract, :authors, :tags)"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": p.id,
|
||||||
|
"title_en": p.title_en,
|
||||||
|
"title_zh": p.title_zh or "",
|
||||||
|
"abstract": p.abstract or "",
|
||||||
|
"authors": f"Author {i+1}",
|
||||||
|
"tags": f"NLP, Tag{i+1}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
papers.append(p)
|
||||||
|
db_session.commit()
|
||||||
|
return papers
|
||||||
|
|||||||
@@ -1,350 +1,36 @@
|
|||||||
"""Phase 4 管理和自动化测试 — cleaner、admin routes、scheduler。"""
|
"""管理接口测试 — admin routes、auth、scheduler、task locks。"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import logging
|
||||||
import shutil
|
|
||||||
import time
|
|
||||||
from datetime import date, datetime, timezone
|
from datetime import date, datetime, timezone
|
||||||
from pathlib import Path
|
from unittest.mock import AsyncMock, patch
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.database import get_db
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.models import (
|
from app.models import (
|
||||||
CrawlLog,
|
CrawlLog,
|
||||||
DataDeleteJob,
|
|
||||||
Paper,
|
|
||||||
PaperAuthor,
|
|
||||||
PaperSummary,
|
|
||||||
PaperTag,
|
|
||||||
SummaryStatus,
|
|
||||||
TaskLock,
|
TaskLock,
|
||||||
UserBookmark,
|
|
||||||
UserNote,
|
|
||||||
UserReadingStatus,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Fixtures ────────────────────────────────────────────────────────────
|
# ── Fixtures ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
ADMIN_TOKEN = "test-admin-token-12345"
|
ADMIN_TOKEN = "test-admin-token-12345"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def admin_headers():
|
|
||||||
return {"Authorization": f"Bearer {ADMIN_TOKEN}"}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def wrong_admin_headers():
|
|
||||||
return {"Authorization": "Bearer wrong-token"}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def auth_client(client, monkeypatch):
|
def auth_client(client, monkeypatch):
|
||||||
"""带 admin token monkeypatch 的 TestClient。"""
|
"""带 admin token monkeypatch 的 TestClient。"""
|
||||||
monkeypatch.setattr(settings, "ADMIN_TOKEN", ADMIN_TOKEN)
|
monkeypatch.setattr(settings, "ADMIN_TOKEN", ADMIN_TOKEN)
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_papers(db_session):
|
|
||||||
"""插入多篇不同日期的论文。"""
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
papers = []
|
|
||||||
for i, (arxiv_id, paper_date_str) in enumerate([
|
|
||||||
("2401.10001", "2024-01-10"),
|
|
||||||
("2401.10002", "2024-01-11"),
|
|
||||||
("2401.10003", "2024-01-12"),
|
|
||||||
("2401.10004", "2024-01-13"),
|
|
||||||
("2401.10005", "2024-01-14"),
|
|
||||||
]):
|
|
||||||
paper_date = date.fromisoformat(paper_date_str)
|
|
||||||
p = Paper(
|
|
||||||
arxiv_id=arxiv_id,
|
|
||||||
title_en=f"Test Paper {i+1}",
|
|
||||||
abstract=f"Abstract for paper {i+1}.",
|
|
||||||
paper_date=paper_date,
|
|
||||||
crawled_at=now,
|
|
||||||
upvotes=i * 10,
|
|
||||||
)
|
|
||||||
db_session.add(p)
|
|
||||||
db_session.flush()
|
|
||||||
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
|
|
||||||
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
|
|
||||||
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
|
|
||||||
# FTS5
|
|
||||||
import sqlalchemy
|
|
||||||
db_session.execute(
|
|
||||||
sqlalchemy.text(
|
|
||||||
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
|
|
||||||
"VALUES (:id, :title, :abstract, :authors, :tags)"
|
|
||||||
),
|
|
||||||
{"id": p.id, "title": p.title_en, "abstract": p.abstract,
|
|
||||||
"authors": f"Author {i+1}", "tags": f"Tag{i+1}"},
|
|
||||||
)
|
|
||||||
papers.append(p)
|
|
||||||
db_session.commit()
|
|
||||||
return papers
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_paper_with_user_data(db_session, sample_papers):
|
|
||||||
"""给第一篇论文添加用户数据(收藏、阅读状态、笔记)。"""
|
|
||||||
paper = sample_papers[0]
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
db_session.add(UserBookmark(paper_id=paper.id, created_at=now))
|
|
||||||
db_session.add(UserReadingStatus(paper_id=paper.id, status="read_summary", updated_at=now))
|
|
||||||
db_session.add(UserNote(
|
|
||||||
paper_id=paper.id,
|
|
||||||
content="My notes on this paper",
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
))
|
|
||||||
db_session.commit()
|
|
||||||
return paper
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def tmp_data_dir(tmp_path):
|
|
||||||
"""创建临时 data 目录结构。"""
|
|
||||||
tmp_dir = tmp_path / "data" / "tmp"
|
|
||||||
papers_dir = tmp_path / "data" / "papers"
|
|
||||||
tmp_dir.mkdir(parents=True)
|
|
||||||
papers_dir.mkdir(parents=True)
|
|
||||||
return tmp_path / "data"
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
# Cleaner 服务测试
|
# Admin Routes — 鉴权测试
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
|
|
||||||
class TestCleanupTmp:
|
|
||||||
"""app/services/cleaner.py — cleanup_tmp 测试。"""
|
|
||||||
|
|
||||||
def test_cleanup_removes_old_dirs(self, tmp_path, monkeypatch):
|
|
||||||
"""超过 24 小时的临时目录应被删除。"""
|
|
||||||
tmp_dir = tmp_path / "tmp"
|
|
||||||
tmp_dir.mkdir()
|
|
||||||
|
|
||||||
# 创建一个旧目录
|
|
||||||
old_dir = tmp_dir / "2401.00001"
|
|
||||||
old_dir.mkdir()
|
|
||||||
(old_dir / "paper.pdf").write_text("fake pdf")
|
|
||||||
|
|
||||||
# 修改目录时间为 25 小时前
|
|
||||||
old_mtime = time.time() - 25 * 3600
|
|
||||||
os.utime(old_dir, (old_mtime, old_mtime))
|
|
||||||
|
|
||||||
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
|
||||||
from app.services.cleaner import cleanup_tmp
|
|
||||||
result = cleanup_tmp()
|
|
||||||
|
|
||||||
assert result["scanned"] == 1
|
|
||||||
assert result["removed"] == 1
|
|
||||||
assert not old_dir.exists()
|
|
||||||
|
|
||||||
def test_cleanup_keeps_recent_dirs(self, tmp_path, monkeypatch):
|
|
||||||
"""24 小时内的临时目录应保留。"""
|
|
||||||
tmp_dir = tmp_path / "tmp"
|
|
||||||
tmp_dir.mkdir()
|
|
||||||
|
|
||||||
recent_dir = tmp_dir / "2401.00002"
|
|
||||||
recent_dir.mkdir()
|
|
||||||
(recent_dir / "paper.pdf").write_text("fake pdf")
|
|
||||||
|
|
||||||
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
|
||||||
from app.services.cleaner import cleanup_tmp
|
|
||||||
result = cleanup_tmp()
|
|
||||||
|
|
||||||
assert result["scanned"] == 1
|
|
||||||
assert result["removed"] == 0
|
|
||||||
assert recent_dir.exists()
|
|
||||||
|
|
||||||
def test_cleanup_empty_dir(self, tmp_path, monkeypatch):
|
|
||||||
"""data/tmp/ 不存在时安全返回。"""
|
|
||||||
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_path / "nonexistent")
|
|
||||||
from app.services.cleaner import cleanup_tmp
|
|
||||||
result = cleanup_tmp()
|
|
||||||
assert result["scanned"] == 0
|
|
||||||
assert result["removed"] == 0
|
|
||||||
|
|
||||||
def test_cleanup_mixed_ages(self, tmp_path, monkeypatch):
|
|
||||||
"""混合新旧目录时只删除旧的。"""
|
|
||||||
tmp_dir = tmp_path / "tmp"
|
|
||||||
tmp_dir.mkdir()
|
|
||||||
|
|
||||||
old_dir = tmp_dir / "2401.old"
|
|
||||||
old_dir.mkdir()
|
|
||||||
old_mtime = time.time() - 30 * 3600
|
|
||||||
os.utime(old_dir, (old_mtime, old_mtime))
|
|
||||||
|
|
||||||
recent_dir = tmp_dir / "2401.new"
|
|
||||||
recent_dir.mkdir()
|
|
||||||
|
|
||||||
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
|
||||||
from app.services.cleaner import cleanup_tmp
|
|
||||||
result = cleanup_tmp()
|
|
||||||
|
|
||||||
assert result["scanned"] == 2
|
|
||||||
assert result["removed"] == 1
|
|
||||||
assert not old_dir.exists()
|
|
||||||
assert recent_dir.exists()
|
|
||||||
|
|
||||||
|
|
||||||
class TestDeletePapersByDateRange:
|
|
||||||
"""app/services/cleaner.py — delete_papers_by_date_range 测试。"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_by_date_range(self, db_session, sample_papers):
|
|
||||||
"""删除指定日期范围的论文。"""
|
|
||||||
from app.services.cleaner import delete_papers_by_date_range
|
|
||||||
|
|
||||||
# 删除 1月11日 ~ 1月13日(3篇)
|
|
||||||
result = await delete_papers_by_date_range(
|
|
||||||
db_session,
|
|
||||||
date(2024, 1, 11),
|
|
||||||
date(2024, 1, 13),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["deleted"] == 3
|
|
||||||
assert result["total"] == 3
|
|
||||||
assert result["status"] == "success"
|
|
||||||
|
|
||||||
# 确认数据库中只剩 2 篇
|
|
||||||
remaining = db_session.execute(select(Paper)).scalars().all()
|
|
||||||
assert len(remaining) == 2
|
|
||||||
dates = {p.paper_date for p in remaining}
|
|
||||||
assert dates == {date(2024, 1, 10), date(2024, 1, 14)}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_creates_job_record(self, db_session, sample_papers):
|
|
||||||
"""删除操作应创建 data_delete_jobs 记录。"""
|
|
||||||
from app.services.cleaner import delete_papers_by_date_range
|
|
||||||
|
|
||||||
await delete_papers_by_date_range(
|
|
||||||
db_session,
|
|
||||||
date(2024, 1, 10),
|
|
||||||
date(2024, 1, 14),
|
|
||||||
)
|
|
||||||
|
|
||||||
jobs = db_session.execute(select(DataDeleteJob)).scalars().all()
|
|
||||||
assert len(jobs) == 1
|
|
||||||
assert jobs[0].status == "success"
|
|
||||||
assert jobs[0].date_start == date(2024, 1, 10)
|
|
||||||
assert jobs[0].date_end == date(2024, 1, 14)
|
|
||||||
assert jobs[0].paper_count == 5
|
|
||||||
assert jobs[0].completed_at is not None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_creates_crawl_log(self, db_session, sample_papers):
|
|
||||||
"""删除操作应写入 crawl_logs。"""
|
|
||||||
from app.services.cleaner import delete_papers_by_date_range
|
|
||||||
|
|
||||||
await delete_papers_by_date_range(
|
|
||||||
db_session,
|
|
||||||
date(2024, 1, 10),
|
|
||||||
date(2024, 1, 14),
|
|
||||||
)
|
|
||||||
|
|
||||||
logs = db_session.execute(
|
|
||||||
select(CrawlLog).where(CrawlLog.task == "delete")
|
|
||||||
).scalars().all()
|
|
||||||
assert len(logs) == 1
|
|
||||||
assert logs[0].status == "success"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_cascade_user_data(self, db_session, sample_paper_with_user_data):
|
|
||||||
"""删除论文时应 cascade 删除关联的用户数据。"""
|
|
||||||
from app.services.cleaner import delete_papers_by_date_range
|
|
||||||
|
|
||||||
paper = sample_paper_with_user_data
|
|
||||||
# 确认用户数据存在
|
|
||||||
assert db_session.get(UserBookmark, db_session.execute(
|
|
||||||
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
|
|
||||||
).scalar_one_or_none().id if db_session.execute(
|
|
||||||
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
|
|
||||||
).scalar_one_or_none() else None) is not None or True
|
|
||||||
|
|
||||||
# 删除
|
|
||||||
result = await delete_papers_by_date_range(
|
|
||||||
db_session,
|
|
||||||
date(2024, 1, 10),
|
|
||||||
date(2024, 1, 10),
|
|
||||||
)
|
|
||||||
assert result["deleted"] == 1
|
|
||||||
|
|
||||||
# 确认用户数据被 cascade 删除
|
|
||||||
assert db_session.execute(
|
|
||||||
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
|
|
||||||
).scalar_one_or_none() is None
|
|
||||||
assert db_session.execute(
|
|
||||||
select(UserReadingStatus).where(UserReadingStatus.paper_id == paper.id)
|
|
||||||
).scalar_one_or_none() is None
|
|
||||||
assert db_session.execute(
|
|
||||||
select(UserNote).where(UserNote.paper_id == paper.id)
|
|
||||||
).scalar_one_or_none() is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_removes_fts(self, db_session, sample_papers):
|
|
||||||
"""删除论文时应同步删除 FTS5 索引。"""
|
|
||||||
import sqlalchemy
|
|
||||||
from app.services.cleaner import delete_papers_by_date_range
|
|
||||||
|
|
||||||
await delete_papers_by_date_range(
|
|
||||||
db_session,
|
|
||||||
date(2024, 1, 10),
|
|
||||||
date(2024, 1, 14),
|
|
||||||
)
|
|
||||||
|
|
||||||
# FTS5 应为空
|
|
||||||
rows = db_session.execute(
|
|
||||||
sqlalchemy.text("SELECT count(*) FROM papers_fts")
|
|
||||||
).scalar()
|
|
||||||
assert rows == 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_removes_local_files(self, db_session, sample_papers, tmp_path, monkeypatch):
|
|
||||||
"""删除论文时应删除本地文件目录。"""
|
|
||||||
from app.services.cleaner import delete_papers_by_date_range
|
|
||||||
|
|
||||||
papers_dir = tmp_path / "papers"
|
|
||||||
papers_dir.mkdir()
|
|
||||||
(papers_dir / "2401.10001").mkdir()
|
|
||||||
(papers_dir / "2401.10001" / "meta.json").write_text("{}")
|
|
||||||
|
|
||||||
monkeypatch.setattr("app.services.cleaner.PAPERS_DIR", papers_dir)
|
|
||||||
|
|
||||||
result = await delete_papers_by_date_range(
|
|
||||||
db_session,
|
|
||||||
date(2024, 1, 10),
|
|
||||||
date(2024, 1, 10),
|
|
||||||
)
|
|
||||||
assert result["deleted"] == 1
|
|
||||||
assert not (papers_dir / "2401.10001").exists()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_empty_range(self, db_session, sample_papers):
|
|
||||||
"""日期范围内无论文时返回 0。"""
|
|
||||||
from app.services.cleaner import delete_papers_by_date_range
|
|
||||||
|
|
||||||
result = await delete_papers_by_date_range(
|
|
||||||
db_session,
|
|
||||||
date(2025, 1, 1),
|
|
||||||
date(2025, 1, 31),
|
|
||||||
)
|
|
||||||
assert result["total"] == 0
|
|
||||||
assert result["deleted"] == 0
|
|
||||||
assert result["status"] == "success"
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
# Admin Routes 测试
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
@@ -363,12 +49,65 @@ class TestAdminAuth:
|
|||||||
|
|
||||||
def test_correct_token_accepted(self, auth_client, admin_headers):
|
def test_correct_token_accepted(self, auth_client, admin_headers):
|
||||||
"""正确 token 应被接受(crawl 可能会失败但不是 401)。"""
|
"""正确 token 应被接受(crawl 可能会失败但不是 401)。"""
|
||||||
# mock crawl_daily 避免 API 调用
|
|
||||||
with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl:
|
with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl:
|
||||||
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
|
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
|
||||||
resp = auth_client.post("/admin/crawl", headers=admin_headers)
|
resp = auth_client.post("/admin/crawl", headers=admin_headers)
|
||||||
assert resp.status_code != 401
|
assert resp.status_code != 401
|
||||||
|
|
||||||
|
# ── summarize route auth ────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_no_token_returns_401_for_summarize(self, client):
|
||||||
|
"""无 Bearer token 返回 401。"""
|
||||||
|
resp = client.post("/admin/summarize")
|
||||||
|
assert resp.status_code in (401, 403)
|
||||||
|
|
||||||
|
def test_wrong_token_returns_401_for_summarize(self, client):
|
||||||
|
resp = client.post(
|
||||||
|
"/admin/summarize",
|
||||||
|
headers={"Authorization": "Bearer wrong-token"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_correct_token_batch_summarize(self, client, admin_headers):
|
||||||
|
"""正确 token 调用 batch summarize,mock 掉服务层。"""
|
||||||
|
import app.config as config_mod
|
||||||
|
|
||||||
|
original = config_mod.settings.ADMIN_TOKEN
|
||||||
|
config_mod.settings.ADMIN_TOKEN = ADMIN_TOKEN
|
||||||
|
try:
|
||||||
|
with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock:
|
||||||
|
mock.return_value = {"status": "success", "done": 0, "failed": 0, "total": 0}
|
||||||
|
resp = client.post("/admin/summarize", headers=admin_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "success"
|
||||||
|
finally:
|
||||||
|
config_mod.settings.ADMIN_TOKEN = original
|
||||||
|
|
||||||
|
def test_single_paper_not_found(self, client, admin_headers):
|
||||||
|
"""单篇总结不存在的论文返回 404。"""
|
||||||
|
import app.config as config_mod
|
||||||
|
|
||||||
|
original = config_mod.settings.ADMIN_TOKEN
|
||||||
|
config_mod.settings.ADMIN_TOKEN = ADMIN_TOKEN
|
||||||
|
try:
|
||||||
|
with patch(
|
||||||
|
"app.routes.admin.summarize_single",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"},
|
||||||
|
):
|
||||||
|
resp = client.post(
|
||||||
|
"/admin/summarize/nonexistent.99999",
|
||||||
|
headers=admin_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
finally:
|
||||||
|
config_mod.settings.ADMIN_TOKEN = original
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Admin Routes — Crawl
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
class TestAdminCrawl:
|
class TestAdminCrawl:
|
||||||
"""POST /admin/crawl 测试。"""
|
"""POST /admin/crawl 测试。"""
|
||||||
@@ -381,7 +120,6 @@ class TestAdminCrawl:
|
|||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert data["status"] == "success"
|
assert data["status"] == "success"
|
||||||
# 验证调用了 crawl_daily
|
|
||||||
mock_crawl.assert_called_once()
|
mock_crawl.assert_called_once()
|
||||||
|
|
||||||
def test_crawl_specific_date(self, auth_client, admin_headers):
|
def test_crawl_specific_date(self, auth_client, admin_headers):
|
||||||
@@ -395,6 +133,11 @@ class TestAdminCrawl:
|
|||||||
assert call_args[0][1] == "2024-01-15"
|
assert call_args[0][1] == "2024-01-15"
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Admin Routes — Cleanup
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
class TestAdminCleanup:
|
class TestAdminCleanup:
|
||||||
"""POST /admin/cleanup 测试。"""
|
"""POST /admin/cleanup 测试。"""
|
||||||
|
|
||||||
@@ -421,6 +164,11 @@ class TestAdminCleanup:
|
|||||||
assert logs[-1].status == "success"
|
assert logs[-1].status == "success"
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Admin Routes — Delete
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
class TestAdminDelete:
|
class TestAdminDelete:
|
||||||
"""POST /admin/delete 测试。"""
|
"""POST /admin/delete 测试。"""
|
||||||
|
|
||||||
@@ -438,7 +186,7 @@ class TestAdminDelete:
|
|||||||
)
|
)
|
||||||
assert resp.status_code == 422
|
assert resp.status_code == 422
|
||||||
|
|
||||||
def test_delete_with_confirm(self, auth_client, admin_headers, db_session, sample_papers):
|
def test_delete_with_confirm(self, auth_client, admin_headers, db_session, sample_papers_range):
|
||||||
"""confirm='DELETE' 时应执行删除。"""
|
"""confirm='DELETE' 时应执行删除。"""
|
||||||
resp = auth_client.post(
|
resp = auth_client.post(
|
||||||
"/admin/delete",
|
"/admin/delete",
|
||||||
@@ -480,6 +228,11 @@ class TestAdminDelete:
|
|||||||
assert resp.status_code == 422
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Admin Routes — Logs
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
class TestAdminLogs:
|
class TestAdminLogs:
|
||||||
"""GET /admin/logs 测试。"""
|
"""GET /admin/logs 测试。"""
|
||||||
|
|
||||||
@@ -494,7 +247,7 @@ class TestAdminLogs:
|
|||||||
resp = auth_client.get("/admin/logs")
|
resp = auth_client.get("/admin/logs")
|
||||||
assert resp.status_code in (403, 401)
|
assert resp.status_code in (403, 401)
|
||||||
|
|
||||||
def test_logs_contains_data(self, auth_client, admin_headers, db_session, sample_papers):
|
def test_logs_contains_data(self, auth_client, admin_headers, db_session, sample_papers_range):
|
||||||
"""日志页面应包含日志数据。"""
|
"""日志页面应包含日志数据。"""
|
||||||
# 先创建一条日志
|
# 先创建一条日志
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
@@ -519,11 +272,10 @@ class TestScheduler:
|
|||||||
def test_scheduler_disabled_by_default(self, monkeypatch):
|
def test_scheduler_disabled_by_default(self, monkeypatch):
|
||||||
"""SCHEDULER_ENABLED=false 时不应启动调度器。"""
|
"""SCHEDULER_ENABLED=false 时不应启动调度器。"""
|
||||||
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", False)
|
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", False)
|
||||||
from app.services.scheduler import start_scheduler
|
|
||||||
# 重置模块级变量
|
|
||||||
import app.services.scheduler as sched_mod
|
import app.services.scheduler as sched_mod
|
||||||
sched_mod._scheduler = None
|
sched_mod._scheduler = None
|
||||||
|
|
||||||
|
from app.services.scheduler import start_scheduler
|
||||||
result = start_scheduler()
|
result = start_scheduler()
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
@@ -550,7 +302,6 @@ class TestScheduler:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_scheduler_warns_multi_worker(self, monkeypatch, caplog):
|
async def test_scheduler_warns_multi_worker(self, monkeypatch, caplog):
|
||||||
"""APP_WORKERS > 1 时应打印警告。"""
|
"""APP_WORKERS > 1 时应打印警告。"""
|
||||||
import logging
|
|
||||||
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", True)
|
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", True)
|
||||||
monkeypatch.setattr(settings, "APP_WORKERS", 4)
|
monkeypatch.setattr(settings, "APP_WORKERS", 4)
|
||||||
import app.services.scheduler as sched_mod
|
import app.services.scheduler as sched_mod
|
||||||
@@ -0,0 +1,279 @@
|
|||||||
|
"""Cleaner 服务测试 — cleanup_tmp、delete_papers_by_date_range。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import date, datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.models import (
|
||||||
|
CrawlLog,
|
||||||
|
DataDeleteJob,
|
||||||
|
Paper,
|
||||||
|
UserBookmark,
|
||||||
|
UserNote,
|
||||||
|
UserReadingStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_paper_with_user_data(db_session, sample_papers_range):
|
||||||
|
"""给第一篇论文添加用户数据(收藏、阅读状态、笔记)。"""
|
||||||
|
paper = sample_papers_range[0]
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
db_session.add(UserBookmark(paper_id=paper.id, created_at=now))
|
||||||
|
db_session.add(UserReadingStatus(paper_id=paper.id, status="read_summary", updated_at=now))
|
||||||
|
db_session.add(UserNote(
|
||||||
|
paper_id=paper.id,
|
||||||
|
content="My notes on this paper",
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
))
|
||||||
|
db_session.commit()
|
||||||
|
return paper
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# cleanup_tmp 测试
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanupTmp:
|
||||||
|
"""app/services/cleaner.py — cleanup_tmp 测试。"""
|
||||||
|
|
||||||
|
def test_cleanup_removes_old_dirs(self, tmp_path, monkeypatch):
|
||||||
|
"""超过 24 小时的临时目录应被删除。"""
|
||||||
|
tmp_dir = tmp_path / "tmp"
|
||||||
|
tmp_dir.mkdir()
|
||||||
|
|
||||||
|
# 创建一个旧目录
|
||||||
|
old_dir = tmp_dir / "2401.00001"
|
||||||
|
old_dir.mkdir()
|
||||||
|
(old_dir / "paper.pdf").write_text("fake pdf")
|
||||||
|
|
||||||
|
# 修改目录时间为 25 小时前
|
||||||
|
old_mtime = time.time() - 25 * 3600
|
||||||
|
os.utime(old_dir, (old_mtime, old_mtime))
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
||||||
|
from app.services.cleaner import cleanup_tmp
|
||||||
|
result = cleanup_tmp()
|
||||||
|
|
||||||
|
assert result["scanned"] == 1
|
||||||
|
assert result["removed"] == 1
|
||||||
|
assert not old_dir.exists()
|
||||||
|
|
||||||
|
def test_cleanup_keeps_recent_dirs(self, tmp_path, monkeypatch):
|
||||||
|
"""24 小时内的临时目录应保留。"""
|
||||||
|
tmp_dir = tmp_path / "tmp"
|
||||||
|
tmp_dir.mkdir()
|
||||||
|
|
||||||
|
recent_dir = tmp_dir / "2401.00002"
|
||||||
|
recent_dir.mkdir()
|
||||||
|
(recent_dir / "paper.pdf").write_text("fake pdf")
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
||||||
|
from app.services.cleaner import cleanup_tmp
|
||||||
|
result = cleanup_tmp()
|
||||||
|
|
||||||
|
assert result["scanned"] == 1
|
||||||
|
assert result["removed"] == 0
|
||||||
|
assert recent_dir.exists()
|
||||||
|
|
||||||
|
def test_cleanup_empty_dir(self, tmp_path, monkeypatch):
|
||||||
|
"""data/tmp/ 不存在时安全返回。"""
|
||||||
|
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_path / "nonexistent")
|
||||||
|
from app.services.cleaner import cleanup_tmp
|
||||||
|
result = cleanup_tmp()
|
||||||
|
assert result["scanned"] == 0
|
||||||
|
assert result["removed"] == 0
|
||||||
|
|
||||||
|
def test_cleanup_mixed_ages(self, tmp_path, monkeypatch):
|
||||||
|
"""混合新旧目录时只删除旧的。"""
|
||||||
|
tmp_dir = tmp_path / "tmp"
|
||||||
|
tmp_dir.mkdir()
|
||||||
|
|
||||||
|
old_dir = tmp_dir / "2401.old"
|
||||||
|
old_dir.mkdir()
|
||||||
|
old_mtime = time.time() - 30 * 3600
|
||||||
|
os.utime(old_dir, (old_mtime, old_mtime))
|
||||||
|
|
||||||
|
recent_dir = tmp_dir / "2401.new"
|
||||||
|
recent_dir.mkdir()
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
||||||
|
from app.services.cleaner import cleanup_tmp
|
||||||
|
result = cleanup_tmp()
|
||||||
|
|
||||||
|
assert result["scanned"] == 2
|
||||||
|
assert result["removed"] == 1
|
||||||
|
assert not old_dir.exists()
|
||||||
|
assert recent_dir.exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# delete_papers_by_date_range 测试
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeletePapersByDateRange:
|
||||||
|
"""app/services/cleaner.py — delete_papers_by_date_range 测试。"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_by_date_range(self, db_session, sample_papers_range):
|
||||||
|
"""删除指定日期范围的论文。"""
|
||||||
|
from app.services.cleaner import delete_papers_by_date_range
|
||||||
|
|
||||||
|
# 删除 1月11日 ~ 1月13日(3篇)
|
||||||
|
result = await delete_papers_by_date_range(
|
||||||
|
db_session,
|
||||||
|
date(2024, 1, 11),
|
||||||
|
date(2024, 1, 13),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["deleted"] == 3
|
||||||
|
assert result["total"] == 3
|
||||||
|
assert result["status"] == "success"
|
||||||
|
|
||||||
|
# 确认数据库中只剩 2 篇
|
||||||
|
remaining = db_session.execute(select(Paper)).scalars().all()
|
||||||
|
assert len(remaining) == 2
|
||||||
|
dates = {p.paper_date for p in remaining}
|
||||||
|
assert dates == {date(2024, 1, 10), date(2024, 1, 14)}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_creates_job_record(self, db_session, sample_papers_range):
|
||||||
|
"""删除操作应创建 data_delete_jobs 记录。"""
|
||||||
|
from app.services.cleaner import delete_papers_by_date_range
|
||||||
|
|
||||||
|
await delete_papers_by_date_range(
|
||||||
|
db_session,
|
||||||
|
date(2024, 1, 10),
|
||||||
|
date(2024, 1, 14),
|
||||||
|
)
|
||||||
|
|
||||||
|
jobs = db_session.execute(select(DataDeleteJob)).scalars().all()
|
||||||
|
assert len(jobs) == 1
|
||||||
|
assert jobs[0].status == "success"
|
||||||
|
assert jobs[0].date_start == date(2024, 1, 10)
|
||||||
|
assert jobs[0].date_end == date(2024, 1, 14)
|
||||||
|
assert jobs[0].paper_count == 5
|
||||||
|
assert jobs[0].completed_at is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_creates_crawl_log(self, db_session, sample_papers_range):
|
||||||
|
"""删除操作应写入 crawl_logs。"""
|
||||||
|
from app.services.cleaner import delete_papers_by_date_range
|
||||||
|
|
||||||
|
await delete_papers_by_date_range(
|
||||||
|
db_session,
|
||||||
|
date(2024, 1, 10),
|
||||||
|
date(2024, 1, 14),
|
||||||
|
)
|
||||||
|
|
||||||
|
logs = db_session.execute(
|
||||||
|
select(CrawlLog).where(CrawlLog.task == "delete")
|
||||||
|
).scalars().all()
|
||||||
|
assert len(logs) == 1
|
||||||
|
assert logs[0].status == "success"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_cascade_user_data(self, db_session, sample_paper_with_user_data):
|
||||||
|
"""删除论文时应 cascade 删除关联的用户数据。"""
|
||||||
|
from app.services.cleaner import delete_papers_by_date_range
|
||||||
|
|
||||||
|
paper = sample_paper_with_user_data
|
||||||
|
|
||||||
|
# 删除
|
||||||
|
result = await delete_papers_by_date_range(
|
||||||
|
db_session,
|
||||||
|
date(2024, 1, 10),
|
||||||
|
date(2024, 1, 10),
|
||||||
|
)
|
||||||
|
assert result["deleted"] == 1
|
||||||
|
|
||||||
|
# 确认用户数据被 cascade 删除
|
||||||
|
assert db_session.execute(
|
||||||
|
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
|
||||||
|
).scalar_one_or_none() is None
|
||||||
|
assert db_session.execute(
|
||||||
|
select(UserReadingStatus).where(UserReadingStatus.paper_id == paper.id)
|
||||||
|
).scalar_one_or_none() is None
|
||||||
|
assert db_session.execute(
|
||||||
|
select(UserNote).where(UserNote.paper_id == paper.id)
|
||||||
|
).scalar_one_or_none() is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_removes_fts(self, db_session, sample_papers_range):
|
||||||
|
"""删除论文时应同步删除 FTS5 索引。"""
|
||||||
|
import sqlalchemy
|
||||||
|
from app.services.cleaner import delete_papers_by_date_range
|
||||||
|
|
||||||
|
await delete_papers_by_date_range(
|
||||||
|
db_session,
|
||||||
|
date(2024, 1, 10),
|
||||||
|
date(2024, 1, 14),
|
||||||
|
)
|
||||||
|
|
||||||
|
# FTS5 应为空
|
||||||
|
rows = db_session.execute(
|
||||||
|
sqlalchemy.text("SELECT count(*) FROM papers_fts")
|
||||||
|
).scalar()
|
||||||
|
assert rows == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_removes_local_files(self, db_session, sample_papers_range, tmp_path, monkeypatch):
|
||||||
|
"""删除论文时应删除本地文件目录。"""
|
||||||
|
from app.services.cleaner import delete_papers_by_date_range
|
||||||
|
|
||||||
|
papers_dir = tmp_path / "papers"
|
||||||
|
papers_dir.mkdir()
|
||||||
|
(papers_dir / "2401.10001").mkdir()
|
||||||
|
(papers_dir / "2401.10001" / "meta.json").write_text("{}")
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.cleaner.PAPERS_DIR", papers_dir)
|
||||||
|
|
||||||
|
result = await delete_papers_by_date_range(
|
||||||
|
db_session,
|
||||||
|
date(2024, 1, 10),
|
||||||
|
date(2024, 1, 10),
|
||||||
|
)
|
||||||
|
assert result["deleted"] == 1
|
||||||
|
assert not (papers_dir / "2401.10001").exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_empty_range(self, db_session, sample_papers_range):
|
||||||
|
"""日期范围内无论文时返回 0。"""
|
||||||
|
from app.services.cleaner import delete_papers_by_date_range
|
||||||
|
|
||||||
|
result = await delete_papers_by_date_range(
|
||||||
|
db_session,
|
||||||
|
date(2025, 1, 1),
|
||||||
|
date(2025, 1, 31),
|
||||||
|
)
|
||||||
|
assert result["total"] == 0
|
||||||
|
assert result["deleted"] == 0
|
||||||
|
assert result["status"] == "success"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleaner_works_without_chroma(self, db_session, sample_papers_with_summary, monkeypatch):
|
||||||
|
"""CHROMA 关闭时删除论文正常工作。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
import app.services.embedder as emb
|
||||||
|
emb._chroma.reset()
|
||||||
|
|
||||||
|
from app.services.cleaner import delete_papers_by_date_range
|
||||||
|
result = await delete_papers_by_date_range(
|
||||||
|
db_session,
|
||||||
|
date(2024, 1, 10),
|
||||||
|
date(2024, 1, 10),
|
||||||
|
)
|
||||||
|
assert result["status"] == "success"
|
||||||
|
assert result["deleted"] == 1
|
||||||
@@ -0,0 +1,163 @@
|
|||||||
|
"""Embedder / Chroma 服务测试 — 初始化、索引、embedding API。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 初始化
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbedderInit:
|
||||||
|
"""embedder.py 初始化测试。"""
|
||||||
|
|
||||||
|
def test_chroma_disabled_skip_init(self, monkeypatch):
|
||||||
|
"""CHROMA_ENABLED=false 时不初始化。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
import app.services.embedder as emb
|
||||||
|
emb._chroma.reset()
|
||||||
|
emb.init_chroma()
|
||||||
|
assert emb._chroma._client is None
|
||||||
|
|
||||||
|
def test_chroma_init_success(self, monkeypatch, tmp_path):
|
||||||
|
"""CHROMA_ENABLED=true 时初始化成功。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||||||
|
|
||||||
|
import app.services.embedder as emb
|
||||||
|
emb._chroma.reset()
|
||||||
|
emb.init_chroma()
|
||||||
|
|
||||||
|
assert emb._chroma._client is not None
|
||||||
|
assert emb._chroma._collection is not None
|
||||||
|
|
||||||
|
# 清理
|
||||||
|
emb._chroma.reset()
|
||||||
|
|
||||||
|
def test_get_collection_returns_none_when_disabled(self, monkeypatch):
|
||||||
|
"""CHROMA_ENABLED=false 时 get_collection 返回 None。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
import app.services.embedder as emb
|
||||||
|
emb._chroma.reset()
|
||||||
|
assert emb.get_collection() is None
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 索引
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbedderIndexing:
|
||||||
|
"""embedder.py 索引测试。"""
|
||||||
|
|
||||||
|
def test_index_paper_disabled(self, monkeypatch):
|
||||||
|
"""CHROMA_ENABLED=false 时 index_paper 返回 False。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
import app.services.embedder as emb
|
||||||
|
emb._chroma.reset()
|
||||||
|
assert emb.index_paper("test-id") is False
|
||||||
|
|
||||||
|
def test_index_paper_no_api_config(self, monkeypatch, tmp_path):
|
||||||
|
"""没有 EMBED_API_BASE 时返回 False。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||||||
|
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
|
||||||
|
monkeypatch.setattr(settings, "EMBED_MODEL", "")
|
||||||
|
|
||||||
|
import app.services.embedder as emb
|
||||||
|
emb._chroma.reset()
|
||||||
|
emb.init_chroma()
|
||||||
|
|
||||||
|
result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"})
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
emb._chroma.reset()
|
||||||
|
|
||||||
|
def test_index_batch_disabled(self, monkeypatch):
|
||||||
|
"""CHROMA_ENABLED=false 时 index_batch 返回全失败。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
import app.services.embedder as emb
|
||||||
|
emb._chroma.reset()
|
||||||
|
result = emb.index_batch(["a", "b"])
|
||||||
|
assert result["success"] == 0
|
||||||
|
assert result["failed"] == 2
|
||||||
|
|
||||||
|
def test_index_batch_empty(self, monkeypatch):
|
||||||
|
"""空列表时返回 0。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
import app.services.embedder as emb
|
||||||
|
result = emb.index_batch([])
|
||||||
|
assert result["total"] == 0
|
||||||
|
|
||||||
|
def test_delete_paper_disabled(self, monkeypatch):
|
||||||
|
"""CHROMA_ENABLED=false 时 delete_paper 返回 False。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
import app.services.embedder as emb
|
||||||
|
emb._chroma.reset()
|
||||||
|
assert emb.delete_paper("test-id") is False
|
||||||
|
|
||||||
|
def test_search_similar_disabled(self, monkeypatch):
|
||||||
|
"""CHROMA_ENABLED=false 时 search_similar 返回空列表。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
import app.services.embedder as emb
|
||||||
|
emb._chroma.reset()
|
||||||
|
assert emb.search_similar("test query") == []
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Embedding API
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingApi:
|
||||||
|
"""_get_embedding 测试。"""
|
||||||
|
|
||||||
|
def test_no_api_base_returns_none(self, monkeypatch):
|
||||||
|
"""EMBED_API_BASE 为空时返回 None。"""
|
||||||
|
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
|
||||||
|
monkeypatch.setattr(settings, "EMBED_MODEL", "")
|
||||||
|
import app.services.embedder as emb
|
||||||
|
assert emb._get_embedding("test") is None
|
||||||
|
|
||||||
|
def test_dimension_mismatch_returns_none(self, monkeypatch):
|
||||||
|
"""维度不匹配时返回 None。"""
|
||||||
|
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
|
||||||
|
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
|
||||||
|
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
|
||||||
|
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 128)
|
||||||
|
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
|
||||||
|
|
||||||
|
import app.services.embedder as emb
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = {"data": [{"embedding": [0.1] * 64}]}
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("httpx.Client") as mock_client:
|
||||||
|
mock_client.return_value.__enter__ = MagicMock(return_value=mock_resp)
|
||||||
|
mock_client.return_value.__exit__ = MagicMock(return_value=False)
|
||||||
|
result = emb._get_embedding("test")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_api_failure_returns_none(self, monkeypatch):
|
||||||
|
"""API 调用失败时返回 None。"""
|
||||||
|
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
|
||||||
|
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
|
||||||
|
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
|
||||||
|
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 0)
|
||||||
|
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
|
||||||
|
|
||||||
|
import app.services.embedder as emb
|
||||||
|
|
||||||
|
with patch("httpx.Client") as mock_client:
|
||||||
|
mock_client.return_value.__enter__ = MagicMock()
|
||||||
|
mock_client.return_value.__exit__ = MagicMock(return_value=False)
|
||||||
|
mock_client.return_value.__enter__.return_value.post.side_effect = Exception("timeout")
|
||||||
|
result = emb._get_embedding("test")
|
||||||
|
assert result is None
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
"""LaTeX 图片提取测试 — 从 .tex 源码中提取图片文件。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Image Extraction
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestImageExtraction:
|
||||||
|
"""LaTeX 图片提取测试。"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path):
|
||||||
|
"""源码目录不存在时返回 0。"""
|
||||||
|
monkeypatch.setattr("app.services.pdf_downloader.tmp_dir", lambda x: tmp_path / "tmp" / x)
|
||||||
|
monkeypatch.setattr("app.services.pdf_downloader.paper_dir", lambda x: tmp_path / "papers" / x)
|
||||||
|
from app.services.image_extractor import extract_images_from_source
|
||||||
|
result = await extract_images_from_source("2401.99999")
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_images_from_tex(self, monkeypatch, tmp_path):
|
||||||
|
"""从 .tex 文件中提取图片。"""
|
||||||
|
from app.services.image_extractor import extract_images_from_source
|
||||||
|
|
||||||
|
tmp_source = tmp_path / "tmp" / "2401.00001" / "source"
|
||||||
|
tmp_source.mkdir(parents=True)
|
||||||
|
|
||||||
|
images_dir = tmp_source / "figs"
|
||||||
|
images_dir.mkdir()
|
||||||
|
(images_dir / "figure1.png").write_bytes(b"\x89PNG\r\n")
|
||||||
|
(images_dir / "figure2.jpg").write_bytes(b"\xff\xd8\xff\xe0")
|
||||||
|
|
||||||
|
# 创建 .tex 文件
|
||||||
|
tex_content = r"""
|
||||||
|
\documentclass{article}
|
||||||
|
\begin{document}
|
||||||
|
\begin{figure}
|
||||||
|
\includegraphics[width=0.8\textwidth]{figs/figure1.png}
|
||||||
|
\includegraphics{figs/figure2.jpg}
|
||||||
|
\includegraphics[angle=90]{figs/nonexistent.pdf}
|
||||||
|
\end{figure}
|
||||||
|
\end{document}
|
||||||
|
"""
|
||||||
|
(tmp_source / "main.tex").write_text(tex_content)
|
||||||
|
|
||||||
|
papers_dir = tmp_path / "papers" / "2401.00001"
|
||||||
|
monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
|
||||||
|
monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
|
||||||
|
|
||||||
|
# Mock download_source_zip to avoid real network call (source dir already exists)
|
||||||
|
async def _noop_download(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
|
||||||
|
|
||||||
|
result = await extract_images_from_source("2401.00001")
|
||||||
|
|
||||||
|
assert result == 2
|
||||||
|
dest_images = papers_dir / "images"
|
||||||
|
assert dest_images.exists()
|
||||||
|
assert (dest_images / "figure1.png").exists()
|
||||||
|
assert (dest_images / "figure2.jpg").exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_images_empty_tex(self, monkeypatch, tmp_path):
|
||||||
|
""".tex 文件无图片时返回 0。"""
|
||||||
|
from app.services.image_extractor import extract_images_from_source
|
||||||
|
|
||||||
|
tmp_source = tmp_path / "tmp" / "2401.00002" / "source"
|
||||||
|
tmp_source.mkdir(parents=True)
|
||||||
|
(tmp_source / "main.tex").write_text(r"\documentclass{article}\begin{document}Hello\end{document}")
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
|
||||||
|
monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
|
||||||
|
|
||||||
|
# Mock download_source_zip to avoid real network call
|
||||||
|
async def _noop_download(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
|
||||||
|
|
||||||
|
result = await extract_images_from_source("2401.00002")
|
||||||
|
assert result == 0
|
||||||
@@ -0,0 +1,224 @@
|
|||||||
|
"""页面路由测试 — detail、trends、compare、graceful degradation。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
from unittest.mock import patch as upatch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Detail 页 & 相似论文
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestDetailPage:
|
||||||
|
"""论文详情页测试。"""
|
||||||
|
|
||||||
|
def test_detail_page_renders(self, client, sample_papers_with_summary):
|
||||||
|
"""详情页正常渲染。"""
|
||||||
|
resp = client.get("/paper/2401.20001")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "测试论文" in resp.text or "Test Paper" in resp.text
|
||||||
|
|
||||||
|
def test_detail_page_not_found(self, client):
|
||||||
|
"""不存在的论文返回 404。"""
|
||||||
|
resp = client.get("/paper/nonexistent.99999")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Similar API(详情页内联)
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestDetailSimilarPapers:
|
||||||
|
"""详情页相似论文模块测试(CHROMA 关闭时的降级行为)。"""
|
||||||
|
|
||||||
|
def test_detail_page_renders_with_similar(self, client, sample_papers_with_summary):
|
||||||
|
"""详情页正常渲染(含相似论文模块)。"""
|
||||||
|
resp = client.get("/paper/2401.20001")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "测试论文" in resp.text or "Test Paper" in resp.text
|
||||||
|
|
||||||
|
def test_detail_page_not_found_similar(self, client):
|
||||||
|
"""不存在的论文返回 404。"""
|
||||||
|
resp = client.get("/paper/nonexistent.99999")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Trends Dashboard
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrendsDashboard:
|
||||||
|
"""趋势看板测试。"""
|
||||||
|
|
||||||
|
def test_trends_page_renders(self, client, sample_papers_with_summary):
|
||||||
|
"""趋势看板页面正常渲染。"""
|
||||||
|
resp = client.get("/trends")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "趋势看板" in resp.text
|
||||||
|
assert "chart" in resp.text.lower() or "Chart" in resp.text
|
||||||
|
|
||||||
|
def test_trends_api_returns_data(self, client, sample_papers_with_summary):
|
||||||
|
"""趋势 API 返回正确数据结构。"""
|
||||||
|
resp = client.get("/api/stats/trends")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
assert "daily_counts" in data
|
||||||
|
assert "top_tags" in data
|
||||||
|
assert "upvotes_dist" in data
|
||||||
|
assert "summary_completion" in data
|
||||||
|
|
||||||
|
assert isinstance(data["daily_counts"], list)
|
||||||
|
assert isinstance(data["top_tags"], list)
|
||||||
|
assert isinstance(data["upvotes_dist"], list)
|
||||||
|
assert isinstance(data["summary_completion"], list)
|
||||||
|
|
||||||
|
def test_trends_api_daily_counts(self, client, sample_papers_with_summary):
|
||||||
|
"""每日论文数量数据正确。"""
|
||||||
|
# 使用测试数据的日期范围
|
||||||
|
with upatch("app.services.trends.date") as mock_date:
|
||||||
|
mock_date.today.return_value = date(2024, 1, 20)
|
||||||
|
mock_date.side_effect = lambda *a, **kw: date(*a, **kw)
|
||||||
|
|
||||||
|
resp = client.get("/api/stats/trends")
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data["daily_counts"]) == 5
|
||||||
|
for item in data["daily_counts"]:
|
||||||
|
assert "date" in item
|
||||||
|
assert "count" in item
|
||||||
|
assert item["count"] == 1
|
||||||
|
|
||||||
|
def test_trends_api_top_tags(self, client, sample_papers_with_summary):
|
||||||
|
"""热门标签数据正确。"""
|
||||||
|
resp = client.get("/api/stats/trends")
|
||||||
|
data = resp.json()
|
||||||
|
tags = {t["tag"]: t["count"] for t in data["top_tags"]}
|
||||||
|
assert "NLP" in tags
|
||||||
|
assert tags["NLP"] == 5 # 所有论文都有 NLP
|
||||||
|
|
||||||
|
def test_trends_api_summary_completion(self, client, sample_papers_with_summary):
|
||||||
|
"""总结完成率数据正确。"""
|
||||||
|
resp = client.get("/api/stats/trends")
|
||||||
|
data = resp.json()
|
||||||
|
statuses = {s["status"]: s["count"] for s in data["summary_completion"]}
|
||||||
|
assert "done" in statuses
|
||||||
|
assert statuses["done"] == 4 # 4 篇已完成
|
||||||
|
|
||||||
|
def test_trends_empty_db(self, client):
|
||||||
|
"""无数据时不崩溃。"""
|
||||||
|
resp = client.get("/api/stats/trends")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["daily_counts"] == []
|
||||||
|
assert data["top_tags"] == []
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Compare Page
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestComparePage:
|
||||||
|
"""论文对比页测试。"""
|
||||||
|
|
||||||
|
def test_compare_page_no_ids(self, client):
|
||||||
|
"""无 ID 时显示输入表单。"""
|
||||||
|
resp = client.get("/compare")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "对比" in resp.text
|
||||||
|
|
||||||
|
def test_compare_page_with_ids(self, client, sample_papers_with_summary):
|
||||||
|
"""对比多篇论文正常渲染。"""
|
||||||
|
resp = client.get("/compare?ids=2401.20001,2401.20002")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "2401.20001" in resp.text
|
||||||
|
assert "2401.20002" in resp.text
|
||||||
|
# 应包含对比字段
|
||||||
|
assert "一句话摘要" in resp.text
|
||||||
|
assert "研究问题" in resp.text
|
||||||
|
|
||||||
|
def test_compare_page_max_5(self, client, sample_papers_with_summary):
|
||||||
|
"""最多 5 篇。"""
|
||||||
|
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005"
|
||||||
|
resp = client.get(f"/compare?ids={ids}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_compare_page_over_5_truncates(self, client, sample_papers_with_summary):
|
||||||
|
"""超过 5 篇截断。"""
|
||||||
|
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005,2401.20006"
|
||||||
|
resp = client.get(f"/compare?ids={ids}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_compare_page_invalid_ids(self, client):
|
||||||
|
"""无效 ID 时显示空结果。"""
|
||||||
|
resp = client.get("/compare?ids=nonexistent.99999")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_compare_page_shows_no_summary_placeholder(self, client, sample_papers_with_summary):
|
||||||
|
"""无总结的论文显示占位文本。"""
|
||||||
|
# 2401.20005 没有 summary(status=pending)
|
||||||
|
resp = client.get("/compare?ids=2401.20005")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "暂无总结" in resp.text
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Nav Bar
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestNavBar:
|
||||||
|
"""导航栏测试。"""
|
||||||
|
|
||||||
|
def test_nav_includes_trends_link(self, client):
|
||||||
|
"""导航栏应包含趋势链接。"""
|
||||||
|
resp = client.get("/search")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "/trends" in resp.text
|
||||||
|
|
||||||
|
def test_nav_includes_compare_implicitly(self, client):
|
||||||
|
"""compare 页面可访问。"""
|
||||||
|
resp = client.get("/compare")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Graceful Degradation(CHROMA_ENABLED=false)
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestGracefulDegradation:
|
||||||
|
"""CHROMA_ENABLED=false 时优雅降级测试。"""
|
||||||
|
|
||||||
|
def test_search_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary):
|
||||||
|
"""CHROMA 关闭时 FTS5 搜索正常工作。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
resp = client.get("/search?q=Test")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "Test Paper" in resp.text or "测试论文" in resp.text
|
||||||
|
|
||||||
|
def test_detail_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary):
|
||||||
|
"""CHROMA 关闭时详情页正常工作。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
resp = client.get("/paper/2401.20001")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_trends_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary):
|
||||||
|
"""CHROMA 关闭时趋势看板正常工作。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
resp = client.get("/trends")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_compare_works_without_chroma(self, client, monkeypatch, sample_papers_with_summary):
|
||||||
|
"""CHROMA 关闭时对比页正常工作。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
resp = client.get("/compare?ids=2401.20001,2401.20002")
|
||||||
|
assert resp.status_code == 200
|
||||||
@@ -1,660 +0,0 @@
|
|||||||
"""Phase 5 后续增强测试 — embedder、semantic search、trends、compare、image extraction。"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import shutil
|
|
||||||
import time
|
|
||||||
from datetime import date, datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from app.config import settings
|
|
||||||
from app.database import get_db
|
|
||||||
from app.models import (
|
|
||||||
Paper,
|
|
||||||
PaperAuthor,
|
|
||||||
PaperSummary,
|
|
||||||
PaperTag,
|
|
||||||
SummaryStatus,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Fixtures ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
ADMIN_TOKEN = "test-admin-token-12345"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def admin_headers():
|
|
||||||
return {"Authorization": f"Bearer " + ADMIN_TOKEN}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def auth_client(client, monkeypatch):
|
|
||||||
monkeypatch.setattr(settings, "ADMIN_TOKEN", ADMIN_TOKEN)
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_papers_with_summary(db_session):
|
|
||||||
"""插入多篇带总结的论文。"""
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
papers = []
|
|
||||||
for i, (arxiv_id, paper_date_str) in enumerate([
|
|
||||||
("2401.20001", "2024-01-10"),
|
|
||||||
("2401.20002", "2024-01-11"),
|
|
||||||
("2401.20003", "2024-01-12"),
|
|
||||||
("2401.20004", "2024-01-13"),
|
|
||||||
("2401.20005", "2024-01-14"),
|
|
||||||
]):
|
|
||||||
paper_date = date.fromisoformat(paper_date_str)
|
|
||||||
p = Paper(
|
|
||||||
arxiv_id=arxiv_id,
|
|
||||||
title_en=f"Test Paper {i+1}",
|
|
||||||
title_zh=f"测试论文 {i+1}",
|
|
||||||
abstract=f"Abstract for paper {i+1}.",
|
|
||||||
paper_date=paper_date,
|
|
||||||
crawled_at=now,
|
|
||||||
upvotes=i * 10 + 5,
|
|
||||||
)
|
|
||||||
db_session.add(p)
|
|
||||||
db_session.flush()
|
|
||||||
|
|
||||||
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
|
|
||||||
db_session.add(PaperTag(paper_id=p.id, tag="NLP", source="hf"))
|
|
||||||
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
|
|
||||||
|
|
||||||
db_session.add(SummaryStatus(
|
|
||||||
paper_id=p.id,
|
|
||||||
status="done" if i < 4 else "pending",
|
|
||||||
quality="normal",
|
|
||||||
))
|
|
||||||
|
|
||||||
# 添加总结(前 4 篇)
|
|
||||||
if i < 4:
|
|
||||||
from app.services.schemas import SummarySchema
|
|
||||||
summary = PaperSummary(
|
|
||||||
paper_id=p.id,
|
|
||||||
one_line=f"这是论文{i+1}的一句话摘要",
|
|
||||||
difficulty="中级",
|
|
||||||
motivation_problem=f"论文{i+1}的研究问题",
|
|
||||||
motivation_goal=f"论文{i+1}的研究目标",
|
|
||||||
method_key_idea=f"论文{i+1}的关键思路",
|
|
||||||
method_overview=f"论文{i+1}的方法概述",
|
|
||||||
updated_at=now,
|
|
||||||
full_json=json.dumps({"title_zh": f"测试论文 {i+1}"}),
|
|
||||||
)
|
|
||||||
db_session.add(summary)
|
|
||||||
|
|
||||||
# FTS5
|
|
||||||
import sqlalchemy
|
|
||||||
db_session.execute(
|
|
||||||
sqlalchemy.text(
|
|
||||||
"INSERT INTO papers_fts(rowid, title_en, title_zh, abstract, authors, tags) "
|
|
||||||
"VALUES (:id, :title_en, :title_zh, :abstract, :authors, :tags)"
|
|
||||||
),
|
|
||||||
{
|
|
||||||
"id": p.id,
|
|
||||||
"title_en": p.title_en,
|
|
||||||
"title_zh": p.title_zh or "",
|
|
||||||
"abstract": p.abstract or "",
|
|
||||||
"authors": f"Author {i+1}",
|
|
||||||
"tags": f"NLP, Tag{i+1}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
papers.append(p)
|
|
||||||
db_session.commit()
|
|
||||||
return papers
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
# Embedder 服务测试
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
|
|
||||||
class TestEmbedderInit:
|
|
||||||
"""embedder.py 初始化测试。"""
|
|
||||||
|
|
||||||
def test_chroma_disabled_skip_init(self, monkeypatch):
|
|
||||||
"""CHROMA_ENABLED=false 时不初始化。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
import app.services.embedder as emb
|
|
||||||
emb._chroma.reset()
|
|
||||||
emb.init_chroma()
|
|
||||||
assert emb._chroma._client is None
|
|
||||||
|
|
||||||
def test_chroma_init_success(self, monkeypatch, tmp_path):
|
|
||||||
"""CHROMA_ENABLED=true 时初始化成功。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
|
||||||
|
|
||||||
import app.services.embedder as emb
|
|
||||||
emb._chroma.reset()
|
|
||||||
emb.init_chroma()
|
|
||||||
|
|
||||||
assert emb._chroma._client is not None
|
|
||||||
assert emb._chroma._collection is not None
|
|
||||||
|
|
||||||
# 清理
|
|
||||||
emb._chroma.reset()
|
|
||||||
|
|
||||||
def test_get_collection_returns_none_when_disabled(self, monkeypatch):
|
|
||||||
"""CHROMA_ENABLED=false 时 get_collection 返回 None。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
import app.services.embedder as emb
|
|
||||||
emb._chroma.reset()
|
|
||||||
assert emb.get_collection() is None
|
|
||||||
|
|
||||||
|
|
||||||
class TestEmbedderIndexing:
|
|
||||||
"""embedder.py 索引测试。"""
|
|
||||||
|
|
||||||
def test_index_paper_disabled(self, monkeypatch):
|
|
||||||
"""CHROMA_ENABLED=false 时 index_paper 返回 False。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
import app.services.embedder as emb
|
|
||||||
emb._chroma.reset()
|
|
||||||
assert emb.index_paper("test-id") is False
|
|
||||||
|
|
||||||
def test_index_paper_no_api_config(self, monkeypatch, tmp_path):
|
|
||||||
"""没有 EMBED_API_BASE 时返回 False。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
|
||||||
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
|
|
||||||
monkeypatch.setattr(settings, "EMBED_MODEL", "")
|
|
||||||
|
|
||||||
import app.services.embedder as emb
|
|
||||||
emb._chroma.reset()
|
|
||||||
emb.init_chroma()
|
|
||||||
|
|
||||||
result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"})
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
emb._chroma.reset()
|
|
||||||
|
|
||||||
def test_index_batch_disabled(self, monkeypatch):
|
|
||||||
"""CHROMA_ENABLED=false 时 index_batch 返回全失败。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
import app.services.embedder as emb
|
|
||||||
emb._chroma.reset()
|
|
||||||
result = emb.index_batch(["a", "b"])
|
|
||||||
assert result["success"] == 0
|
|
||||||
assert result["failed"] == 2
|
|
||||||
|
|
||||||
def test_index_batch_empty(self, monkeypatch):
|
|
||||||
"""空列表时返回 0。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
import app.services.embedder as emb
|
|
||||||
result = emb.index_batch([])
|
|
||||||
assert result["total"] == 0
|
|
||||||
|
|
||||||
def test_delete_paper_disabled(self, monkeypatch):
|
|
||||||
"""CHROMA_ENABLED=false 时 delete_paper 返回 False。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
import app.services.embedder as emb
|
|
||||||
emb._chroma.reset()
|
|
||||||
assert emb.delete_paper("test-id") is False
|
|
||||||
|
|
||||||
def test_search_similar_disabled(self, monkeypatch):
|
|
||||||
"""CHROMA_ENABLED=false 时 search_similar 返回空列表。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
import app.services.embedder as emb
|
|
||||||
emb._chroma.reset()
|
|
||||||
assert emb.search_similar("test query") == []
|
|
||||||
|
|
||||||
|
|
||||||
class TestEmbeddingApi:
|
|
||||||
"""_get_embedding 测试。"""
|
|
||||||
|
|
||||||
def test_no_api_base_returns_none(self, monkeypatch):
|
|
||||||
"""EMBED_API_BASE 为空时返回 None。"""
|
|
||||||
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
|
|
||||||
monkeypatch.setattr(settings, "EMBED_MODEL", "")
|
|
||||||
import app.services.embedder as emb
|
|
||||||
assert emb._get_embedding("test") is None
|
|
||||||
|
|
||||||
def test_dimension_mismatch_returns_none(self, monkeypatch):
|
|
||||||
"""维度不匹配时返回 None。"""
|
|
||||||
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
|
|
||||||
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
|
|
||||||
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
|
|
||||||
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 128)
|
|
||||||
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
|
|
||||||
|
|
||||||
import app.services.embedder as emb
|
|
||||||
|
|
||||||
mock_resp = MagicMock()
|
|
||||||
mock_resp.json.return_value = {"data": [{"embedding": [0.1] * 64}]}
|
|
||||||
mock_resp.raise_for_status = MagicMock()
|
|
||||||
|
|
||||||
with patch("httpx.Client") as mock_client:
|
|
||||||
mock_client.return_value.__enter__ = MagicMock(return_value=mock_resp)
|
|
||||||
mock_client.return_value.__exit__ = MagicMock(return_value=False)
|
|
||||||
result = emb._get_embedding("test")
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
def test_api_failure_returns_none(self, monkeypatch):
|
|
||||||
"""API 调用失败时返回 None。"""
|
|
||||||
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
|
|
||||||
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
|
|
||||||
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
|
|
||||||
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 0)
|
|
||||||
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
|
|
||||||
|
|
||||||
import app.services.embedder as emb
|
|
||||||
|
|
||||||
with patch("httpx.Client") as mock_client:
|
|
||||||
mock_client.return_value.__enter__ = MagicMock()
|
|
||||||
mock_client.return_value.__exit__ = MagicMock(return_value=False)
|
|
||||||
mock_client.return_value.__enter__.return_value.post.side_effect = Exception("timeout")
|
|
||||||
result = emb._get_embedding("test")
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
# Searcher 语义模式测试
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
|
|
||||||
class TestSearchSemanticMode:
|
|
||||||
"""searcher.py 语义搜索模式测试。"""
|
|
||||||
|
|
||||||
def test_keyword_mode_default(self, db_session, sample_papers_with_summary):
|
|
||||||
"""默认 keyword 模式走 FTS5。"""
|
|
||||||
from app.services.searcher import search_papers
|
|
||||||
result = search_papers(db_session, query="Test Paper", mode="keyword")
|
|
||||||
assert result["total"] >= 1
|
|
||||||
assert result["distances"] == {}
|
|
||||||
|
|
||||||
def test_semantic_mode_disabled_fallback(self, db_session, monkeypatch, sample_papers_with_summary):
|
|
||||||
"""CHROMA_ENABLED=false + semantic 模式走 FTS5。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
from app.services.searcher import search_papers
|
|
||||||
result = search_papers(db_session, query="Test", mode="semantic")
|
|
||||||
# 应回退到 FTS5
|
|
||||||
assert result["total"] >= 1
|
|
||||||
|
|
||||||
def test_search_returns_distances_dict(self, db_session, sample_papers_with_summary):
|
|
||||||
"""搜索结果应包含 distances 字段。"""
|
|
||||||
from app.services.searcher import search_papers
|
|
||||||
result = search_papers(db_session, query="Test Paper")
|
|
||||||
assert "distances" in result
|
|
||||||
assert isinstance(result["distances"], dict)
|
|
||||||
|
|
||||||
def test_empty_query_returns_empty(self, db_session):
|
|
||||||
"""空查询无标签时返回空。"""
|
|
||||||
from app.services.searcher import search_papers
|
|
||||||
result = search_papers(db_session)
|
|
||||||
assert result["total"] == 0
|
|
||||||
assert result["results"] == []
|
|
||||||
|
|
||||||
def test_tag_only_search(self, db_session, sample_papers_with_summary):
|
|
||||||
"""仅标签搜索。"""
|
|
||||||
from app.services.searcher import search_papers
|
|
||||||
result = search_papers(db_session, tag="NLP")
|
|
||||||
assert result["total"] >= 1
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
# Search Routes 测试
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
|
|
||||||
class TestSearchRoutes:
|
|
||||||
"""搜索路由测试。"""
|
|
||||||
|
|
||||||
def test_search_page_keyword(self, auth_client, sample_papers_with_summary):
|
|
||||||
"""搜索页 keyword 模式。"""
|
|
||||||
resp = auth_client.get("/search?q=Test&mode=keyword")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert "Test" in resp.text or "测试" in resp.text
|
|
||||||
|
|
||||||
def test_search_page_semantic_disabled(self, auth_client, monkeypatch, sample_papers_with_summary):
|
|
||||||
"""语义模式 CHROMA_ENABLED=false 时仍能工作。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
resp = auth_client.get("/search?q=Test&mode=semantic")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
|
|
||||||
def test_search_api_with_mode(self, auth_client, sample_papers_with_summary):
|
|
||||||
"""搜索 API 支持 mode 参数。"""
|
|
||||||
resp = auth_client.get("/api/search?q=Test&mode=keyword")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
data = resp.json()
|
|
||||||
assert "results" in data
|
|
||||||
assert "total" in data
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
# Similar Paper API 测试
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
|
|
||||||
class TestSimilarAPI:
|
|
||||||
"""相似论文 API 测试。"""
|
|
||||||
|
|
||||||
def test_similar_api_disabled(self, auth_client, monkeypatch, sample_papers_with_summary):
|
|
||||||
"""CHROMA_ENABLED=false 时返回空列表。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
resp = auth_client.get("/api/similar/2401.20001")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
data = resp.json()
|
|
||||||
assert data["results"] == []
|
|
||||||
|
|
||||||
def test_similar_api_paper_not_found(self, auth_client, monkeypatch):
|
|
||||||
"""不存在的论文返回空。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
resp = auth_client.get("/api/similar/nonexistent.99999")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.json()["results"] == []
|
|
||||||
|
|
||||||
def test_similar_api_with_top_k(self, auth_client, monkeypatch, sample_papers_with_summary):
|
|
||||||
"""top_k 参数控制返回数量。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
resp = auth_client.get("/api/similar/2401.20001?top_k=3")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
# Detail Page 相似论文测试
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
|
|
||||||
class TestDetailSimilarPapers:
|
|
||||||
"""详情页相似论文模块测试。"""
|
|
||||||
|
|
||||||
def test_detail_page_renders(self, auth_client, sample_papers_with_summary):
|
|
||||||
"""详情页正常渲染。"""
|
|
||||||
resp = auth_client.get("/paper/2401.20001")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert "测试论文" in resp.text or "Test Paper" in resp.text
|
|
||||||
|
|
||||||
def test_detail_page_not_found(self, auth_client):
|
|
||||||
"""不存在的论文返回 404。"""
|
|
||||||
resp = auth_client.get("/paper/nonexistent.99999")
|
|
||||||
assert resp.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
# Trends Dashboard 测试
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
|
|
||||||
class TestTrendsDashboard:
|
|
||||||
"""趋势看板测试。"""
|
|
||||||
|
|
||||||
def test_trends_page_renders(self, auth_client, sample_papers_with_summary):
|
|
||||||
"""趋势看板页面正常渲染。"""
|
|
||||||
resp = auth_client.get("/trends")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert "趋势看板" in resp.text
|
|
||||||
assert "chart" in resp.text.lower() or "Chart" in resp.text
|
|
||||||
|
|
||||||
def test_trends_api_returns_data(self, auth_client, sample_papers_with_summary):
|
|
||||||
"""趋势 API 返回正确数据结构。"""
|
|
||||||
resp = auth_client.get("/api/stats/trends")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
data = resp.json()
|
|
||||||
|
|
||||||
assert "daily_counts" in data
|
|
||||||
assert "top_tags" in data
|
|
||||||
assert "upvotes_dist" in data
|
|
||||||
assert "summary_completion" in data
|
|
||||||
|
|
||||||
assert isinstance(data["daily_counts"], list)
|
|
||||||
assert isinstance(data["top_tags"], list)
|
|
||||||
assert isinstance(data["upvotes_dist"], list)
|
|
||||||
assert isinstance(data["summary_completion"], list)
|
|
||||||
|
|
||||||
def test_trends_api_daily_counts(self, auth_client, sample_papers_with_summary, monkeypatch):
|
|
||||||
"""每日论文数量数据正确。"""
|
|
||||||
# 使用测试数据的日期范围
|
|
||||||
from unittest.mock import patch as upatch
|
|
||||||
import app.routes.trends as trends_mod
|
|
||||||
|
|
||||||
# monkeypatch get_trends_data 中的 date.today
|
|
||||||
with upatch("app.services.trends.date") as mock_date:
|
|
||||||
mock_date.today.return_value = date(2024, 1, 20)
|
|
||||||
mock_date.side_effect = lambda *a, **kw: date(*a, **kw)
|
|
||||||
|
|
||||||
resp = auth_client.get("/api/stats/trends")
|
|
||||||
data = resp.json()
|
|
||||||
assert len(data["daily_counts"]) == 5
|
|
||||||
for item in data["daily_counts"]:
|
|
||||||
assert "date" in item
|
|
||||||
assert "count" in item
|
|
||||||
assert item["count"] == 1
|
|
||||||
|
|
||||||
def test_trends_api_top_tags(self, auth_client, sample_papers_with_summary):
|
|
||||||
"""热门标签数据正确。"""
|
|
||||||
resp = auth_client.get("/api/stats/trends")
|
|
||||||
data = resp.json()
|
|
||||||
tags = {t["tag"]: t["count"] for t in data["top_tags"]}
|
|
||||||
assert "NLP" in tags
|
|
||||||
assert tags["NLP"] == 5 # 所有论文都有 NLP
|
|
||||||
|
|
||||||
def test_trends_api_summary_completion(self, auth_client, sample_papers_with_summary):
|
|
||||||
"""总结完成率数据正确。"""
|
|
||||||
resp = auth_client.get("/api/stats/trends")
|
|
||||||
data = resp.json()
|
|
||||||
statuses = {s["status"]: s["count"] for s in data["summary_completion"]}
|
|
||||||
assert "done" in statuses
|
|
||||||
assert statuses["done"] == 4 # 4 篇已完成
|
|
||||||
|
|
||||||
def test_trends_empty_db(self, auth_client):
|
|
||||||
"""无数据时不崩溃。"""
|
|
||||||
resp = auth_client.get("/api/stats/trends")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
data = resp.json()
|
|
||||||
assert data["daily_counts"] == []
|
|
||||||
assert data["top_tags"] == []
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
# Compare Page 测试
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
|
|
||||||
class TestComparePage:
|
|
||||||
"""论文对比页测试。"""
|
|
||||||
|
|
||||||
def test_compare_page_no_ids(self, auth_client):
|
|
||||||
"""无 ID 时显示输入表单。"""
|
|
||||||
resp = auth_client.get("/compare")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert "对比" in resp.text
|
|
||||||
|
|
||||||
def test_compare_page_with_ids(self, auth_client, sample_papers_with_summary):
|
|
||||||
"""对比多篇论文正常渲染。"""
|
|
||||||
resp = auth_client.get("/compare?ids=2401.20001,2401.20002")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert "2401.20001" in resp.text
|
|
||||||
assert "2401.20002" in resp.text
|
|
||||||
# 应包含对比字段
|
|
||||||
assert "一句话摘要" in resp.text
|
|
||||||
assert "研究问题" in resp.text
|
|
||||||
|
|
||||||
def test_compare_page_max_5(self, auth_client, sample_papers_with_summary):
|
|
||||||
"""最多 5 篇。"""
|
|
||||||
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005"
|
|
||||||
resp = auth_client.get(f"/compare?ids={ids}")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
|
|
||||||
def test_compare_page_over_5_truncates(self, auth_client, sample_papers_with_summary):
|
|
||||||
"""超过 5 篇截断。"""
|
|
||||||
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005,2401.20006"
|
|
||||||
resp = auth_client.get(f"/compare?ids={ids}")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
# 不应包含第 6 篇(不存在)
|
|
||||||
|
|
||||||
def test_compare_page_invalid_ids(self, auth_client):
|
|
||||||
"""无效 ID 时显示空结果。"""
|
|
||||||
resp = auth_client.get("/compare?ids=nonexistent.99999")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
# 不存在的论文
|
|
||||||
assert "未找到" in resp.text or "暂无" in resp.text or resp.status_code == 200
|
|
||||||
|
|
||||||
def test_compare_page_shows_no_summary_placeholder(self, auth_client, sample_papers_with_summary):
|
|
||||||
"""无总结的论文显示占位文本。"""
|
|
||||||
# 2401.20005 没有 summary(status=pending)
|
|
||||||
resp = auth_client.get("/compare?ids=2401.20005")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert "暂无总结" in resp.text
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
# Image Extraction 测试
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
|
|
||||||
class TestImageExtraction:
|
|
||||||
"""LaTeX 图片提取测试。"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path):
|
|
||||||
"""源码目录不存在时返回 0。"""
|
|
||||||
monkeypatch.setattr("app.services.pdf_downloader.tmp_dir", lambda x: tmp_path / "tmp" / x)
|
|
||||||
monkeypatch.setattr("app.services.pdf_downloader.paper_dir", lambda x: tmp_path / "papers" / x)
|
|
||||||
from app.services.image_extractor import extract_images_from_source
|
|
||||||
result = await extract_images_from_source("2401.99999")
|
|
||||||
assert result == 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_images_from_tex(self, monkeypatch, tmp_path):
|
|
||||||
"""从 .tex 文件中提取图片。"""
|
|
||||||
from app.services.image_extractor import extract_images_from_source
|
|
||||||
|
|
||||||
tmp_source = tmp_path / "tmp" / "2401.00001" / "source"
|
|
||||||
tmp_source.mkdir(parents=True)
|
|
||||||
|
|
||||||
images_dir = tmp_source / "figs"
|
|
||||||
images_dir.mkdir()
|
|
||||||
(images_dir / "figure1.png").write_bytes(b"\x89PNG\r\n")
|
|
||||||
(images_dir / "figure2.jpg").write_bytes(b"\xff\xd8\xff\xe0")
|
|
||||||
|
|
||||||
# 创建 .tex 文件
|
|
||||||
tex_content = r"""
|
|
||||||
\documentclass{article}
|
|
||||||
\begin{document}
|
|
||||||
\begin{figure}
|
|
||||||
\includegraphics[width=0.8\textwidth]{figs/figure1.png}
|
|
||||||
\includegraphics{figs/figure2.jpg}
|
|
||||||
\includegraphics[angle=90]{figs/nonexistent.pdf}
|
|
||||||
\end{figure}
|
|
||||||
\end{document}
|
|
||||||
"""
|
|
||||||
(tmp_source / "main.tex").write_text(tex_content)
|
|
||||||
|
|
||||||
papers_dir = tmp_path / "papers" / "2401.00001"
|
|
||||||
monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
|
|
||||||
monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
|
|
||||||
|
|
||||||
# Mock download_source_zip to avoid real network call (source dir already exists)
|
|
||||||
async def _noop_download(*args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
|
|
||||||
|
|
||||||
result = await extract_images_from_source("2401.00001")
|
|
||||||
|
|
||||||
assert result == 2
|
|
||||||
dest_images = papers_dir / "images"
|
|
||||||
assert dest_images.exists()
|
|
||||||
assert (dest_images / "figure1.png").exists()
|
|
||||||
assert (dest_images / "figure2.jpg").exists()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_images_empty_tex(self, monkeypatch, tmp_path):
|
|
||||||
""".tex 文件无图片时返回 0。"""
|
|
||||||
from app.services.image_extractor import extract_images_from_source
|
|
||||||
|
|
||||||
tmp_source = tmp_path / "tmp" / "2401.00002" / "source"
|
|
||||||
tmp_source.mkdir(parents=True)
|
|
||||||
(tmp_source / "main.tex").write_text(r"\documentclass{article}\begin{document}Hello\end{document}")
|
|
||||||
|
|
||||||
monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
|
|
||||||
monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
|
|
||||||
|
|
||||||
# Mock download_source_zip to avoid real network call
|
|
||||||
async def _noop_download(*args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
|
|
||||||
|
|
||||||
result = await extract_images_from_source("2401.00002")
|
|
||||||
assert result == 0
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
# Nav Bar 测试
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
|
|
||||||
class TestNavBar:
|
|
||||||
"""导航栏测试。"""
|
|
||||||
|
|
||||||
def test_nav_includes_trends_link(self, auth_client):
|
|
||||||
"""导航栏应包含趋势链接。"""
|
|
||||||
resp = auth_client.get("/search")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert "/trends" in resp.text
|
|
||||||
|
|
||||||
def test_nav_includes_compare_implicitly(self, auth_client):
|
|
||||||
"""compare 页面可访问。"""
|
|
||||||
resp = auth_client.get("/compare")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
# Graceful Degradation 测试
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
|
|
||||||
class TestGracefulDegradation:
|
|
||||||
"""CHROMA_ENABLED=false 时优雅降级测试。"""
|
|
||||||
|
|
||||||
def test_search_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
|
|
||||||
"""CHROMA 关闭时 FTS5 搜索正常工作。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
resp = auth_client.get("/search?q=Test")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert "Test Paper" in resp.text or "测试论文" in resp.text
|
|
||||||
|
|
||||||
def test_detail_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
|
|
||||||
"""CHROMA 关闭时详情页正常工作。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
resp = auth_client.get("/paper/2401.20001")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
|
|
||||||
def test_trends_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
|
|
||||||
"""CHROMA 关闭时趋势看板正常工作。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
resp = auth_client.get("/trends")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
|
|
||||||
def test_compare_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
|
|
||||||
"""CHROMA 关闭时对比页正常工作。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
resp = auth_client.get("/compare?ids=2401.20001,2401.20002")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_cleaner_works_without_chroma(self, db_session, sample_papers_with_summary, monkeypatch):
|
|
||||||
"""CHROMA 关闭时删除论文正常工作。"""
|
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
|
||||||
import app.services.embedder as emb
|
|
||||||
emb._chroma.reset()
|
|
||||||
|
|
||||||
from app.services.cleaner import delete_papers_by_date_range
|
|
||||||
result = await delete_papers_by_date_range(
|
|
||||||
db_session,
|
|
||||||
date(2024, 1, 10),
|
|
||||||
date(2024, 1, 10),
|
|
||||||
)
|
|
||||||
assert result["status"] == "success"
|
|
||||||
assert result["deleted"] == 1
|
|
||||||
@@ -0,0 +1,189 @@
|
|||||||
|
"""SummarySchema 校验、quality 分级、JSON 提取、错误分类测试。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from app.services.pi_client import (
|
||||||
|
JsonNotFoundError,
|
||||||
|
PiProcessError,
|
||||||
|
PiTimeoutError,
|
||||||
|
extract_json as _extract_json,
|
||||||
|
)
|
||||||
|
from app.services.pdf_downloader import PdfDownloadError
|
||||||
|
from app.services.schemas import (
|
||||||
|
SummarySchema,
|
||||||
|
assess_quality,
|
||||||
|
classify_validation_error,
|
||||||
|
flatten_for_db,
|
||||||
|
)
|
||||||
|
from app.services.summarizer import _classify_error
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# SummarySchema 校验
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestSummarySchema:
|
||||||
|
"""Pydantic schema 校验。"""
|
||||||
|
|
||||||
|
def test_valid_summary(self, sample_summary_dict):
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
assert schema.title_zh == "测试论文中文标题"
|
||||||
|
assert len(schema.tags) == 3
|
||||||
|
assert schema.motivation.problem
|
||||||
|
|
||||||
|
def test_missing_title_zh(self, sample_summary_dict):
|
||||||
|
del sample_summary_dict["title_zh"]
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
assert classify_validation_error(exc_info.value) == "field_missing"
|
||||||
|
|
||||||
|
def test_empty_one_line(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["one_line"] = ""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
|
||||||
|
def test_empty_tags(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["tags"] = []
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
|
||||||
|
def test_empty_motivation_problem(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["motivation"]["problem"] = ""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
|
||||||
|
def test_empty_method_key_idea(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["method"]["key_idea"] = ""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
|
||||||
|
def test_extra_fields_ignored(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["figures"] = ["fig1.png"]
|
||||||
|
sample_summary_dict["takeaway"] = "important paper"
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
assert not hasattr(schema, "figures")
|
||||||
|
assert schema.title_zh # 正常解析
|
||||||
|
|
||||||
|
def test_flatten_for_db(self, sample_summary_dict):
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
flat = flatten_for_db(schema)
|
||||||
|
assert flat["one_line"] == schema.one_line
|
||||||
|
assert flat["motivation_problem"] == schema.motivation.problem
|
||||||
|
assert flat["method_key_idea"] == schema.method.key_idea
|
||||||
|
assert "full_json" in flat
|
||||||
|
assert "updated_at" in flat
|
||||||
|
# JSON 字段可解析
|
||||||
|
assert isinstance(json.loads(flat["prerequisites_json"]), dict)
|
||||||
|
assert isinstance(json.loads(flat["method_steps_json"]), list)
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Quality 分级
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestQualityAssessment:
|
||||||
|
"""质量分级测试。"""
|
||||||
|
|
||||||
|
def test_quality_normal(self, sample_summary_dict):
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
assert assess_quality(schema) == "normal"
|
||||||
|
|
||||||
|
def test_quality_degraded_missing_goal(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["motivation"]["goal"] = ""
|
||||||
|
sample_summary_dict["motivation"]["gap"] = ""
|
||||||
|
sample_summary_dict["method"]["overview"] = ""
|
||||||
|
sample_summary_dict["results"]["main_findings"] = []
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
assert assess_quality(schema) == "degraded"
|
||||||
|
|
||||||
|
def test_quality_low_short_one_line(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["one_line"] = "短"
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
assert assess_quality(schema) == "low"
|
||||||
|
|
||||||
|
def test_quality_low_short_key_idea(self, sample_summary_dict):
|
||||||
|
sample_summary_dict["method"]["key_idea"] = "短"
|
||||||
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
|
assert assess_quality(schema) == "low"
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# JSON 提取
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestJsonExtraction:
|
||||||
|
"""pi 输出的 JSON 提取。"""
|
||||||
|
|
||||||
|
def test_direct_json(self, sample_summary_json):
|
||||||
|
result = _extract_json(sample_summary_json)
|
||||||
|
assert result["title_zh"] == "测试论文中文标题"
|
||||||
|
|
||||||
|
def test_fenced_code_block(self, sample_summary_json):
|
||||||
|
raw = f"一些文字\n```json\n{sample_summary_json}\n```\n更多文字"
|
||||||
|
result = _extract_json(raw)
|
||||||
|
assert result["title_zh"] == "测试论文中文标题"
|
||||||
|
|
||||||
|
def test_fenced_without_lang(self, sample_summary_json):
|
||||||
|
raw = f"文字\n```\n{sample_summary_json}\n```"
|
||||||
|
result = _extract_json(raw)
|
||||||
|
assert result["title_zh"] == "测试论文中文标题"
|
||||||
|
|
||||||
|
def test_embedded_braces(self, sample_summary_dict):
|
||||||
|
json_str = json.dumps(sample_summary_dict, ensure_ascii=False)
|
||||||
|
raw = f"Here is the summary:\n{json_str}\nEnd."
|
||||||
|
result = _extract_json(raw)
|
||||||
|
assert result["title_zh"] == "测试论文中文标题"
|
||||||
|
|
||||||
|
def test_no_json_raises(self):
|
||||||
|
with pytest.raises(JsonNotFoundError):
|
||||||
|
_extract_json("No JSON here at all.")
|
||||||
|
|
||||||
|
def test_json_without_title_zh_falls_through(self):
|
||||||
|
"""不含 title_zh 的 JSON 不是我们要的。"""
|
||||||
|
raw = json.dumps({"other": "data"})
|
||||||
|
# 如果有其他合法 JSON 块也能返回,但没有就直接找最大块
|
||||||
|
# 此场景 raw 本身就是一个 JSON dict,但没有 title_zh
|
||||||
|
# 策略 1 会跳过(无 title_zh),策略 2 无代码块,策略 3 找到最大块
|
||||||
|
result = _extract_json(raw)
|
||||||
|
assert result == {"other": "data"} # 最大块兜底
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 错误分类
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorClassification:
|
||||||
|
"""异常 → error_type 映射。"""
|
||||||
|
|
||||||
|
def test_pdf_download_error(self):
|
||||||
|
assert _classify_error(PdfDownloadError("fail")) == "pdf_download_failed"
|
||||||
|
|
||||||
|
def test_timeout_error(self):
|
||||||
|
assert _classify_error(PiTimeoutError("timeout")) == "timeout"
|
||||||
|
|
||||||
|
def test_process_error(self):
|
||||||
|
assert _classify_error(PiProcessError(1, "stderr")) == "process_error"
|
||||||
|
|
||||||
|
def test_json_not_found(self):
|
||||||
|
assert _classify_error(JsonNotFoundError("not found")) == "json_not_found"
|
||||||
|
|
||||||
|
def test_json_invalid(self):
|
||||||
|
assert _classify_error(json.JSONDecodeError("bad", "", 0)) == "json_invalid"
|
||||||
|
|
||||||
|
def test_field_missing(self):
|
||||||
|
try:
|
||||||
|
SummarySchema.model_validate({"title_zh": ""}) # type: ignore
|
||||||
|
except ValidationError as exc:
|
||||||
|
assert _classify_error(exc) == "field_missing"
|
||||||
|
|
||||||
|
def test_unknown_error(self):
|
||||||
|
assert _classify_error(RuntimeError("boom")) == "unknown"
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
"""搜索、阅读列表、RSS Feed 测试。"""
|
"""搜索服务 + 路由 + 阅读列表 + RSS + 语义模式测试。"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from datetime import date, datetime, timezone
|
from datetime import date, datetime, timezone
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
# 搜索服务单元测试
|
# 搜索服务单元测试
|
||||||
@@ -12,69 +14,59 @@ from datetime import date, datetime, timezone
|
|||||||
|
|
||||||
|
|
||||||
class TestSearchService:
|
class TestSearchService:
|
||||||
"""app/services/searcher.py 单元测试。"""
|
"""app/services/searcher.py — FTS5 关键词搜索单元测试。"""
|
||||||
|
|
||||||
def test_search_by_title(self, db_session, sample_paper):
|
def test_search_by_title(self, db_session, sample_paper):
|
||||||
from app.services.searcher import search_papers
|
from app.services.searcher import search_papers
|
||||||
|
|
||||||
result = search_papers(db_session, query="Test Paper")
|
result = search_papers(db_session, query="Test Paper")
|
||||||
assert result["total"] == 1
|
assert result["total"] == 1
|
||||||
assert result["results"][0].arxiv_id == "2401.12345"
|
assert result["results"][0].arxiv_id == "2401.12345"
|
||||||
|
|
||||||
def test_search_by_abstract(self, db_session, sample_paper):
|
def test_search_by_abstract(self, db_session, sample_paper):
|
||||||
from app.services.searcher import search_papers
|
from app.services.searcher import search_papers
|
||||||
|
|
||||||
result = search_papers(db_session, query="test abstract")
|
result = search_papers(db_session, query="test abstract")
|
||||||
assert result["total"] == 1
|
assert result["total"] == 1
|
||||||
|
|
||||||
def test_search_by_author(self, db_session, sample_paper):
|
def test_search_by_author(self, db_session, sample_paper):
|
||||||
from app.services.searcher import search_papers
|
from app.services.searcher import search_papers
|
||||||
|
|
||||||
result = search_papers(db_session, query="Alice")
|
result = search_papers(db_session, query="Alice")
|
||||||
assert result["total"] == 1
|
assert result["total"] == 1
|
||||||
|
|
||||||
def test_search_by_tag_in_fts(self, db_session, sample_paper):
|
def test_search_by_tag_in_fts(self, db_session, sample_paper):
|
||||||
from app.services.searcher import search_papers
|
from app.services.searcher import search_papers
|
||||||
|
|
||||||
# FTS5 索引中包含 tags 列,可以搜到
|
# FTS5 索引中包含 tags 列,可以搜到
|
||||||
result = search_papers(db_session, query="NLP")
|
result = search_papers(db_session, query="NLP")
|
||||||
assert result["total"] == 1
|
assert result["total"] == 1
|
||||||
|
|
||||||
def test_search_no_results(self, db_session, sample_paper):
|
def test_search_no_results(self, db_session, sample_paper):
|
||||||
from app.services.searcher import search_papers
|
from app.services.searcher import search_papers
|
||||||
|
|
||||||
result = search_papers(db_session, query="quantum entanglement")
|
result = search_papers(db_session, query="quantum entanglement")
|
||||||
assert result["total"] == 0
|
assert result["total"] == 0
|
||||||
assert result["results"] == []
|
assert result["results"] == []
|
||||||
|
|
||||||
def test_search_empty_query_returns_empty(self, db_session):
|
def test_search_empty_query_returns_empty(self, db_session):
|
||||||
from app.services.searcher import search_papers
|
from app.services.searcher import search_papers
|
||||||
|
|
||||||
result = search_papers(db_session, query="")
|
result = search_papers(db_session, query="")
|
||||||
assert result["total"] == 0
|
assert result["total"] == 0
|
||||||
assert result["results"] == []
|
assert result["results"] == []
|
||||||
|
|
||||||
def test_search_special_characters_sanitized(self, db_session, sample_paper):
|
def test_search_special_characters_sanitized(self, db_session, sample_paper):
|
||||||
from app.services.searcher import search_papers
|
from app.services.searcher import search_papers
|
||||||
|
|
||||||
# 特殊字符被清除后,剩下 "Test" 仍然能搜到
|
# 特殊字符被清除后,剩下 "Test" 仍然能搜到
|
||||||
result = search_papers(db_session, query='Test "Paper" {test}')
|
result = search_papers(db_session, query='Test "Paper" {test}')
|
||||||
assert result["total"] >= 1
|
assert result["total"] >= 1
|
||||||
|
|
||||||
def test_search_with_tag_filter(self, db_session, sample_paper):
|
def test_search_with_tag_filter(self, db_session, sample_paper):
|
||||||
from app.services.searcher import search_papers
|
from app.services.searcher import search_papers
|
||||||
|
|
||||||
# 关键词 + 标签筛选
|
# 关键词 + 标签筛选
|
||||||
result = search_papers(db_session, query="Paper", tag="NLP")
|
result = search_papers(db_session, query="Paper", tag="NLP")
|
||||||
assert result["total"] == 1
|
assert result["total"] == 1
|
||||||
|
|
||||||
# 标签不匹配 → 0
|
# 标签不匹配 → 0
|
||||||
result2 = search_papers(db_session, query="Paper", tag="nonexistent")
|
result2 = search_papers(db_session, query="Paper", tag="nonexistent")
|
||||||
assert result2["total"] == 0
|
assert result2["total"] == 0
|
||||||
|
|
||||||
def test_search_tag_only_no_query(self, db_session, sample_paper):
|
def test_search_tag_only_no_query(self, db_session, sample_paper):
|
||||||
from app.services.searcher import search_papers
|
from app.services.searcher import search_papers
|
||||||
|
|
||||||
# 只有标签,无关键词
|
# 只有标签,无关键词
|
||||||
result = search_papers(db_session, tag="NLP")
|
result = search_papers(db_session, tag="NLP")
|
||||||
assert result["total"] == 1
|
assert result["total"] == 1
|
||||||
@@ -82,14 +74,12 @@ class TestSearchService:
|
|||||||
|
|
||||||
def test_search_pagination(self, db_session, sample_paper):
|
def test_search_pagination(self, db_session, sample_paper):
|
||||||
from app.services.searcher import search_papers
|
from app.services.searcher import search_papers
|
||||||
|
|
||||||
result = search_papers(db_session, query="Test", page=2, page_size=10)
|
result = search_papers(db_session, query="Test", page=2, page_size=10)
|
||||||
assert result["page"] == 2
|
assert result["page"] == 2
|
||||||
assert result["total_pages"] == 1 # 只有 1 条结果,1 页
|
assert result["total_pages"] == 1 # 只有 1 条结果,1 页
|
||||||
|
|
||||||
def test_search_returns_snippets(self, db_session, sample_paper):
|
def test_search_returns_snippets(self, db_session, sample_paper):
|
||||||
from app.services.searcher import search_papers
|
from app.services.searcher import search_papers
|
||||||
|
|
||||||
result = search_papers(db_session, query="test abstract")
|
result = search_papers(db_session, query="test abstract")
|
||||||
assert result["total"] == 1
|
assert result["total"] == 1
|
||||||
paper_id = result["results"][0].id
|
paper_id = result["results"][0].id
|
||||||
@@ -99,12 +89,54 @@ class TestSearchService:
|
|||||||
|
|
||||||
def test_get_all_tags(self, db_session, sample_paper):
|
def test_get_all_tags(self, db_session, sample_paper):
|
||||||
from app.services.searcher import get_all_tags
|
from app.services.searcher import get_all_tags
|
||||||
|
|
||||||
tags = get_all_tags(db_session)
|
tags = get_all_tags(db_session)
|
||||||
assert "NLP" in tags
|
assert "NLP" in tags
|
||||||
assert "LLM" in tags
|
assert "LLM" in tags
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 语义 / Embedder 模式测试
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchSemanticMode:
|
||||||
|
"""searcher.py — semantic 模式(含 embedder 回退)测试。"""
|
||||||
|
|
||||||
|
def test_keyword_mode_default(self, db_session, sample_papers_with_summary):
|
||||||
|
"""默认 keyword 模式走 FTS5。"""
|
||||||
|
from app.services.searcher import search_papers
|
||||||
|
result = search_papers(db_session, query="Test Paper", mode="keyword")
|
||||||
|
assert result["total"] >= 1
|
||||||
|
assert result["distances"] == {}
|
||||||
|
|
||||||
|
def test_semantic_mode_disabled_fallback(self, db_session, monkeypatch, sample_papers_with_summary):
|
||||||
|
"""CHROMA_ENABLED=false + semantic 模式走 FTS5。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
from app.services.searcher import search_papers
|
||||||
|
result = search_papers(db_session, query="Test", mode="semantic")
|
||||||
|
assert result["total"] >= 1
|
||||||
|
|
||||||
|
def test_search_returns_distances_dict(self, db_session, sample_papers_with_summary):
|
||||||
|
"""搜索结果应包含 distances 字段。"""
|
||||||
|
from app.services.searcher import search_papers
|
||||||
|
result = search_papers(db_session, query="Test Paper")
|
||||||
|
assert "distances" in result
|
||||||
|
assert isinstance(result["distances"], dict)
|
||||||
|
|
||||||
|
def test_empty_query_returns_empty_no_tags(self, db_session):
|
||||||
|
"""空查询无标签时返回空。"""
|
||||||
|
from app.services.searcher import search_papers
|
||||||
|
result = search_papers(db_session)
|
||||||
|
assert result["total"] == 0
|
||||||
|
assert result["results"] == []
|
||||||
|
|
||||||
|
def test_tag_only_search(self, db_session, sample_papers_with_summary):
|
||||||
|
"""仅标签搜索。"""
|
||||||
|
from app.services.searcher import search_papers
|
||||||
|
result = search_papers(db_session, tag="NLP")
|
||||||
|
assert result["total"] >= 1
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
# 搜索路由 HTTP 测试
|
# 搜索路由 HTTP 测试
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
@@ -131,6 +163,18 @@ class TestSearchRoutes:
|
|||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert "2401.12345" in resp.text
|
assert "2401.12345" in resp.text
|
||||||
|
|
||||||
|
def test_search_page_keyword_mode(self, client, sample_papers_with_summary):
|
||||||
|
"""搜索页 keyword 模式。"""
|
||||||
|
resp = client.get("/search?q=Test&mode=keyword")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "Test" in resp.text or "测试" in resp.text
|
||||||
|
|
||||||
|
def test_search_page_semantic_disabled(self, client, monkeypatch, sample_papers_with_summary):
|
||||||
|
"""语义模式 CHROMA_ENABLED=false 时仍能工作。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
resp = client.get("/search?q=Test&mode=semantic")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
def test_search_api_json(self, client, sample_paper):
|
def test_search_api_json(self, client, sample_paper):
|
||||||
"""GET /api/search?q=Test 返回 JSON。"""
|
"""GET /api/search?q=Test 返回 JSON。"""
|
||||||
resp = client.get("/api/search?q=Test")
|
resp = client.get("/api/search?q=Test")
|
||||||
@@ -146,6 +190,14 @@ class TestSearchRoutes:
|
|||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert data["total"] == 1
|
assert data["total"] == 1
|
||||||
|
|
||||||
|
def test_search_api_with_mode(self, client, sample_papers_with_summary):
|
||||||
|
"""搜索 API 支持 mode 参数。"""
|
||||||
|
resp = client.get("/api/search?q=Test&mode=keyword")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "results" in data
|
||||||
|
assert "total" in data
|
||||||
|
|
||||||
def test_search_api_empty(self, client, sample_paper):
|
def test_search_api_empty(self, client, sample_paper):
|
||||||
"""GET /api/search?q=nonexistent 返回空结果。"""
|
"""GET /api/search?q=nonexistent 返回空结果。"""
|
||||||
resp = client.get("/api/search?q=nonexistent")
|
resp = client.get("/api/search?q=nonexistent")
|
||||||
@@ -161,6 +213,36 @@ class TestSearchRoutes:
|
|||||||
assert data["total"] >= 1
|
assert data["total"] >= 1
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Similar Paper API 测试
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimilarAPI:
|
||||||
|
"""相似论文 API 测试。"""
|
||||||
|
|
||||||
|
def test_similar_api_disabled(self, client, monkeypatch, sample_papers_with_summary):
|
||||||
|
"""CHROMA_ENABLED=false 时返回空列表。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
resp = client.get("/api/similar/2401.20001")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["results"] == []
|
||||||
|
|
||||||
|
def test_similar_api_paper_not_found(self, client, monkeypatch):
|
||||||
|
"""不存在的论文返回空。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
resp = client.get("/api/similar/nonexistent.99999")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["results"] == []
|
||||||
|
|
||||||
|
def test_similar_api_with_top_k(self, client, monkeypatch, sample_papers_with_summary):
|
||||||
|
"""top_k 参数控制返回数量。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
|
resp = client.get("/api/similar/2401.20001?top_k=3")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
# 阅读列表路由测试
|
# 阅读列表路由测试
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
@@ -185,8 +267,6 @@ class TestReadingListRoute:
|
|||||||
|
|
||||||
def test_reading_list_filter_by_status(self, client, sample_paper):
|
def test_reading_list_filter_by_status(self, client, sample_paper):
|
||||||
"""按阅读状态筛选。"""
|
"""按阅读状态筛选。"""
|
||||||
import json
|
|
||||||
|
|
||||||
# 设置阅读状态
|
# 设置阅读状态
|
||||||
client.post(
|
client.post(
|
||||||
"/api/reading-status/2401.12345",
|
"/api/reading-status/2401.12345",
|
||||||
@@ -204,8 +284,6 @@ class TestReadingListRoute:
|
|||||||
|
|
||||||
def test_reading_list_has_note_filter(self, client, sample_paper):
|
def test_reading_list_has_note_filter(self, client, sample_paper):
|
||||||
"""筛选有笔记的论文。"""
|
"""筛选有笔记的论文。"""
|
||||||
import json
|
|
||||||
|
|
||||||
# 写笔记
|
# 写笔记
|
||||||
client.post(
|
client.post(
|
||||||
"/api/note/2401.12345",
|
"/api/note/2401.12345",
|
||||||
+7
-239
@@ -1,15 +1,13 @@
|
|||||||
"""AI 总结服务测试 — Mock 全链路,不调用真实 pi。"""
|
"""AI 总结服务测试 — summarize_one 状态流转、批量处理、DB 更新、文件操作。"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
from datetime import date, datetime, timezone
|
from datetime import date, datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
from app.models import (
|
from app.models import (
|
||||||
@@ -20,193 +18,19 @@ from app.models import (
|
|||||||
SummaryStatus,
|
SummaryStatus,
|
||||||
TaskLock,
|
TaskLock,
|
||||||
)
|
)
|
||||||
from app.services.schemas import (
|
from app.services.pdf_downloader import (
|
||||||
SummarySchema,
|
PdfDownloadError,
|
||||||
assess_quality,
|
cleanup_tmp as _cleanup_tmp,
|
||||||
classify_validation_error,
|
|
||||||
flatten_for_db,
|
|
||||||
)
|
)
|
||||||
|
from app.services.pi_client import PiTimeoutError
|
||||||
|
from app.services.schemas import SummarySchema
|
||||||
from app.services.summarizer import (
|
from app.services.summarizer import (
|
||||||
_classify_error,
|
|
||||||
_save_files,
|
_save_files,
|
||||||
_save_raw_output_only,
|
_save_raw_output_only,
|
||||||
_update_summary_in_db,
|
_update_summary_in_db,
|
||||||
summarize_batch,
|
summarize_batch,
|
||||||
summarize_one,
|
summarize_one,
|
||||||
summarize_single,
|
|
||||||
)
|
)
|
||||||
from app.services.pi_client import (
|
|
||||||
JsonNotFoundError,
|
|
||||||
PiProcessError,
|
|
||||||
PiTimeoutError,
|
|
||||||
call_pi as _call_pi,
|
|
||||||
extract_json as _extract_json,
|
|
||||||
)
|
|
||||||
from app.services.pdf_downloader import (
|
|
||||||
PdfDownloadError,
|
|
||||||
cleanup_tmp as _cleanup_tmp,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
# Schema 校验测试
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
|
|
||||||
class TestSummarySchema:
|
|
||||||
"""Pydantic schema 校验。"""
|
|
||||||
|
|
||||||
def test_valid_summary(self, sample_summary_dict):
|
|
||||||
schema = SummarySchema.model_validate(sample_summary_dict)
|
|
||||||
assert schema.title_zh == "测试论文中文标题"
|
|
||||||
assert len(schema.tags) == 3
|
|
||||||
assert schema.motivation.problem
|
|
||||||
|
|
||||||
def test_missing_title_zh(self, sample_summary_dict):
|
|
||||||
del sample_summary_dict["title_zh"]
|
|
||||||
with pytest.raises(ValidationError) as exc_info:
|
|
||||||
SummarySchema.model_validate(sample_summary_dict)
|
|
||||||
assert classify_validation_error(exc_info.value) == "field_missing"
|
|
||||||
|
|
||||||
def test_empty_one_line(self, sample_summary_dict):
|
|
||||||
sample_summary_dict["one_line"] = ""
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
SummarySchema.model_validate(sample_summary_dict)
|
|
||||||
|
|
||||||
def test_empty_tags(self, sample_summary_dict):
|
|
||||||
sample_summary_dict["tags"] = []
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
SummarySchema.model_validate(sample_summary_dict)
|
|
||||||
|
|
||||||
def test_empty_motivation_problem(self, sample_summary_dict):
|
|
||||||
sample_summary_dict["motivation"]["problem"] = ""
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
SummarySchema.model_validate(sample_summary_dict)
|
|
||||||
|
|
||||||
def test_empty_method_key_idea(self, sample_summary_dict):
|
|
||||||
sample_summary_dict["method"]["key_idea"] = ""
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
SummarySchema.model_validate(sample_summary_dict)
|
|
||||||
|
|
||||||
def test_extra_fields_ignored(self, sample_summary_dict):
|
|
||||||
sample_summary_dict["figures"] = ["fig1.png"]
|
|
||||||
sample_summary_dict["takeaway"] = "important paper"
|
|
||||||
schema = SummarySchema.model_validate(sample_summary_dict)
|
|
||||||
assert not hasattr(schema, "figures")
|
|
||||||
assert schema.title_zh # 正常解析
|
|
||||||
|
|
||||||
def test_flatten_for_db(self, sample_summary_dict):
|
|
||||||
schema = SummarySchema.model_validate(sample_summary_dict)
|
|
||||||
flat = flatten_for_db(schema)
|
|
||||||
assert flat["one_line"] == schema.one_line
|
|
||||||
assert flat["motivation_problem"] == schema.motivation.problem
|
|
||||||
assert flat["method_key_idea"] == schema.method.key_idea
|
|
||||||
assert "full_json" in flat
|
|
||||||
assert "updated_at" in flat
|
|
||||||
# JSON 字段可解析
|
|
||||||
assert isinstance(json.loads(flat["prerequisites_json"]), dict)
|
|
||||||
assert isinstance(json.loads(flat["method_steps_json"]), list)
|
|
||||||
|
|
||||||
|
|
||||||
class TestQualityAssessment:
|
|
||||||
"""质量分级测试。"""
|
|
||||||
|
|
||||||
def test_quality_normal(self, sample_summary_dict):
|
|
||||||
schema = SummarySchema.model_validate(sample_summary_dict)
|
|
||||||
assert assess_quality(schema) == "normal"
|
|
||||||
|
|
||||||
def test_quality_degraded_missing_goal(self, sample_summary_dict):
|
|
||||||
sample_summary_dict["motivation"]["goal"] = ""
|
|
||||||
sample_summary_dict["motivation"]["gap"] = ""
|
|
||||||
sample_summary_dict["method"]["overview"] = ""
|
|
||||||
sample_summary_dict["results"]["main_findings"] = []
|
|
||||||
schema = SummarySchema.model_validate(sample_summary_dict)
|
|
||||||
assert assess_quality(schema) == "degraded"
|
|
||||||
|
|
||||||
def test_quality_low_short_one_line(self, sample_summary_dict):
|
|
||||||
sample_summary_dict["one_line"] = "短"
|
|
||||||
schema = SummarySchema.model_validate(sample_summary_dict)
|
|
||||||
assert assess_quality(schema) == "low"
|
|
||||||
|
|
||||||
def test_quality_low_short_key_idea(self, sample_summary_dict):
|
|
||||||
sample_summary_dict["method"]["key_idea"] = "短"
|
|
||||||
schema = SummarySchema.model_validate(sample_summary_dict)
|
|
||||||
assert assess_quality(schema) == "low"
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
# JSON 提取测试
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
|
|
||||||
class TestJsonExtraction:
|
|
||||||
"""pi 输出的 JSON 提取。"""
|
|
||||||
|
|
||||||
def test_direct_json(self, sample_summary_json):
|
|
||||||
result = _extract_json(sample_summary_json)
|
|
||||||
assert result["title_zh"] == "测试论文中文标题"
|
|
||||||
|
|
||||||
def test_fenced_code_block(self, sample_summary_json):
|
|
||||||
raw = f"一些文字\n```json\n{sample_summary_json}\n```\n更多文字"
|
|
||||||
result = _extract_json(raw)
|
|
||||||
assert result["title_zh"] == "测试论文中文标题"
|
|
||||||
|
|
||||||
def test_fenced_without_lang(self, sample_summary_json):
|
|
||||||
raw = f"文字\n```\n{sample_summary_json}\n```"
|
|
||||||
result = _extract_json(raw)
|
|
||||||
assert result["title_zh"] == "测试论文中文标题"
|
|
||||||
|
|
||||||
def test_embedded_braces(self, sample_summary_dict):
|
|
||||||
json_str = json.dumps(sample_summary_dict, ensure_ascii=False)
|
|
||||||
raw = f"Here is the summary:\n{json_str}\nEnd."
|
|
||||||
result = _extract_json(raw)
|
|
||||||
assert result["title_zh"] == "测试论文中文标题"
|
|
||||||
|
|
||||||
def test_no_json_raises(self):
|
|
||||||
with pytest.raises(JsonNotFoundError):
|
|
||||||
_extract_json("No JSON here at all.")
|
|
||||||
|
|
||||||
def test_json_without_title_zh_falls_through(self):
|
|
||||||
"""不含 title_zh 的 JSON 不是我们要的。"""
|
|
||||||
raw = json.dumps({"other": "data"})
|
|
||||||
# 如果有其他合法 JSON 块也能返回,但没有就直接找最大块
|
|
||||||
# 此场景 raw 本身就是一个 JSON dict,但没有 title_zh
|
|
||||||
# 策略 1 会跳过(无 title_zh),策略 2 无代码块,策略 3 找到最大块
|
|
||||||
result = _extract_json(raw)
|
|
||||||
assert result == {"other": "data"} # 最大块兜底
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
# 错误分类测试
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
|
|
||||||
class TestErrorClassification:
|
|
||||||
"""异常 → error_type 映射。"""
|
|
||||||
|
|
||||||
def test_pdf_download_error(self):
|
|
||||||
assert _classify_error(PdfDownloadError("fail")) == "pdf_download_failed"
|
|
||||||
|
|
||||||
def test_timeout_error(self):
|
|
||||||
assert _classify_error(PiTimeoutError("timeout")) == "timeout"
|
|
||||||
|
|
||||||
def test_process_error(self):
|
|
||||||
assert _classify_error(PiProcessError(1, "stderr")) == "process_error"
|
|
||||||
|
|
||||||
def test_json_not_found(self):
|
|
||||||
assert _classify_error(JsonNotFoundError("not found")) == "json_not_found"
|
|
||||||
|
|
||||||
def test_json_invalid(self):
|
|
||||||
assert _classify_error(json.JSONDecodeError("bad", "", 0)) == "json_invalid"
|
|
||||||
|
|
||||||
def test_field_missing(self):
|
|
||||||
try:
|
|
||||||
SummarySchema.model_validate({"title_zh": ""}) # type: ignore
|
|
||||||
except ValidationError as exc:
|
|
||||||
assert _classify_error(exc) == "field_missing"
|
|
||||||
|
|
||||||
def test_unknown_error(self):
|
|
||||||
assert _classify_error(RuntimeError("boom")) == "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
@@ -675,59 +499,3 @@ class TestBatchSummarize:
|
|||||||
result = await summarize_batch(db_session)
|
result = await summarize_batch(db_session)
|
||||||
assert result["status"] == "success"
|
assert result["status"] == "success"
|
||||||
assert result["total"] == 0
|
assert result["total"] == 0
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
# Admin 路由鉴权测试
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
|
|
||||||
class TestAdminAuth:
|
|
||||||
"""管理接口鉴权 — 只测 HTTP 层,mock 掉实际服务调用。"""
|
|
||||||
|
|
||||||
def test_no_token_returns_401(self, client):
|
|
||||||
"""无 Bearer token 返回 401。"""
|
|
||||||
resp = client.post("/admin/summarize")
|
|
||||||
assert resp.status_code in (401, 403)
|
|
||||||
|
|
||||||
def test_wrong_token_returns_401(self, client):
|
|
||||||
resp = client.post(
|
|
||||||
"/admin/summarize",
|
|
||||||
headers={"Authorization": "Bearer wrong-token"},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 401
|
|
||||||
|
|
||||||
def test_correct_token_batch(self, client, admin_headers):
|
|
||||||
"""正确 token 调用 batch summarize,mock 掉服务层。"""
|
|
||||||
import app.config as config_mod
|
|
||||||
|
|
||||||
original = config_mod.settings.ADMIN_TOKEN
|
|
||||||
config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345"
|
|
||||||
try:
|
|
||||||
with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock:
|
|
||||||
mock.return_value = {"status": "success", "done": 0, "failed": 0, "total": 0}
|
|
||||||
resp = client.post("/admin/summarize", headers=admin_headers)
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.json()["status"] == "success"
|
|
||||||
finally:
|
|
||||||
config_mod.settings.ADMIN_TOKEN = original
|
|
||||||
|
|
||||||
def test_single_paper_not_found(self, client, admin_headers):
|
|
||||||
"""单篇总结不存在的论文返回 404。"""
|
|
||||||
import app.config as config_mod
|
|
||||||
|
|
||||||
original = config_mod.settings.ADMIN_TOKEN
|
|
||||||
config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345"
|
|
||||||
try:
|
|
||||||
with patch(
|
|
||||||
"app.routes.admin.summarize_single",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"},
|
|
||||||
):
|
|
||||||
resp = client.post(
|
|
||||||
"/admin/summarize/nonexistent.99999",
|
|
||||||
headers=admin_headers,
|
|
||||||
)
|
|
||||||
assert resp.status_code == 404
|
|
||||||
finally:
|
|
||||||
config_mod.settings.ADMIN_TOKEN = original
|
|
||||||
|
|||||||
Reference in New Issue
Block a user