feat: add admin dashboard, pipeline service, lightbox, and update dependencies
This commit is contained in:
+6
-17
@@ -3,14 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import date, datetime, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database import get_db
|
||||
@@ -23,21 +21,12 @@ from app.models import (
|
||||
PaperTag,
|
||||
SummaryStatus,
|
||||
)
|
||||
from app.utils import utc_now
|
||||
|
||||
|
||||
# ── 内存数据库 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _TestBase(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
# 复用 app.models 的 Base metadata
|
||||
from app.database import Base as _AppBase # noqa: E402
|
||||
|
||||
_TestBase.metadata = _AppBase.metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_engine():
|
||||
"""创建内存 SQLite 引擎 + FTS5。"""
|
||||
@@ -94,7 +83,7 @@ _TEST_ADMIN_PASSWORD = "test-password-12345"
|
||||
@pytest.fixture
|
||||
def sample_paper(db_session):
|
||||
"""插入一篇测试论文 + 作者 + 标签 + summary_status(pending)。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
paper = Paper(
|
||||
arxiv_id=SAMPLE_ARXIV_ID,
|
||||
title_en="Test Paper Title",
|
||||
@@ -234,7 +223,7 @@ def auth_client(client, monkeypatch):
|
||||
@pytest.fixture
|
||||
def sample_papers_range(db_session):
|
||||
"""插入 5 篇不同日期的论文(用于 admin / cleaner 测试)。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
papers = []
|
||||
for i, (arxiv_id, paper_date_str) in enumerate(
|
||||
[
|
||||
@@ -281,7 +270,7 @@ def sample_papers_range(db_session):
|
||||
@pytest.fixture
|
||||
def sample_papers_with_summary(db_session):
|
||||
"""插入 5 篇带总结的论文(用于 search / pages / trends 测试)。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
papers = []
|
||||
for i, (arxiv_id, paper_date_str) in enumerate(
|
||||
[
|
||||
|
||||
+6
-11
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import date, datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -14,6 +13,7 @@ from app.models import (
|
||||
CrawlLog,
|
||||
TaskLock,
|
||||
)
|
||||
from app.utils import utc_now
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
@@ -24,11 +24,6 @@ from app.models import (
|
||||
class TestAdminAuth:
|
||||
"""管理接口鉴权测试。"""
|
||||
|
||||
def test_unauthenticated_redirects_to_login(self, auth_client):
|
||||
"""未登录时请求管理接口应重定向到登录页。"""
|
||||
# 用未登录的 client(auth_client 已登录,这里直接用 client)
|
||||
pass # 见下方 test_no_session_returns_303
|
||||
|
||||
def test_no_session_returns_303(self, client, monkeypatch):
|
||||
"""无 session 时请求管理接口应返回 303 重定向。"""
|
||||
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
|
||||
@@ -58,7 +53,7 @@ class TestAdminAuth:
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert resp.status_code == 303
|
||||
assert "/admin/logs" in resp.headers.get("location", "")
|
||||
assert "/admin/" in resp.headers.get("location", "")
|
||||
|
||||
def test_logout_clears_session(self, auth_client, monkeypatch):
|
||||
"""退出登录后应清除 session。"""
|
||||
@@ -265,7 +260,7 @@ class TestAdminLogs:
|
||||
):
|
||||
"""日志页面应包含日志数据。"""
|
||||
# 先创建一条日志
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
db_session.add(
|
||||
CrawlLog(
|
||||
task="crawl",
|
||||
@@ -345,7 +340,7 @@ class TestScheduler:
|
||||
@pytest.mark.asyncio
|
||||
async def test_daily_pipeline_lock_prevents_reentry(self, db_session):
|
||||
"""pipeline 使用 task_locks 防重入。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
lock = TaskLock(
|
||||
task="scheduler",
|
||||
lock_key="pipeline-2024-01-15",
|
||||
@@ -380,7 +375,7 @@ class TestTaskLocks:
|
||||
|
||||
def test_unique_running_lock(self, db_session):
|
||||
"""同一 task + lock_key 只能有一个 running 锁。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
lock1 = TaskLock(
|
||||
task="crawl",
|
||||
lock_key="2024-01-15",
|
||||
@@ -405,7 +400,7 @@ class TestTaskLocks:
|
||||
|
||||
def test_released_lock_allows_new(self, db_session):
|
||||
"""已释放的锁允许新的 running 锁。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
lock1 = TaskLock(
|
||||
task="crawl",
|
||||
lock_key="2024-01-16",
|
||||
|
||||
+4
-25
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import date, datetime, timezone
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
@@ -18,6 +18,8 @@ from app.models import (
|
||||
UserNote,
|
||||
UserReadingStatus,
|
||||
)
|
||||
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
|
||||
from app.utils import utc_now
|
||||
|
||||
|
||||
# ── Fixtures ────────────────────────────────────────────────────────────
|
||||
@@ -27,7 +29,7 @@ from app.models import (
|
||||
def sample_paper_with_user_data(db_session, sample_papers_range):
|
||||
"""给第一篇论文添加用户数据(收藏、阅读状态、笔记)。"""
|
||||
paper = sample_papers_range[0]
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
db_session.add(UserBookmark(paper_id=paper.id, created_at=now))
|
||||
db_session.add(
|
||||
UserReadingStatus(paper_id=paper.id, status="read_summary", updated_at=now)
|
||||
@@ -67,8 +69,6 @@ class TestCleanupTmp:
|
||||
os.utime(old_dir, (old_mtime, old_mtime))
|
||||
|
||||
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
||||
from app.services.cleaner import cleanup_tmp
|
||||
|
||||
result = cleanup_tmp()
|
||||
|
||||
assert result["scanned"] == 1
|
||||
@@ -85,8 +85,6 @@ class TestCleanupTmp:
|
||||
(recent_dir / "paper.pdf").write_text("fake pdf")
|
||||
|
||||
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
||||
from app.services.cleaner import cleanup_tmp
|
||||
|
||||
result = cleanup_tmp()
|
||||
|
||||
assert result["scanned"] == 1
|
||||
@@ -96,8 +94,6 @@ class TestCleanupTmp:
|
||||
def test_cleanup_empty_dir(self, tmp_path, monkeypatch):
|
||||
"""data/tmp/ 不存在时安全返回。"""
|
||||
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_path / "nonexistent")
|
||||
from app.services.cleaner import cleanup_tmp
|
||||
|
||||
result = cleanup_tmp()
|
||||
assert result["scanned"] == 0
|
||||
assert result["removed"] == 0
|
||||
@@ -116,8 +112,6 @@ class TestCleanupTmp:
|
||||
recent_dir.mkdir()
|
||||
|
||||
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
||||
from app.services.cleaner import cleanup_tmp
|
||||
|
||||
result = cleanup_tmp()
|
||||
|
||||
assert result["scanned"] == 2
|
||||
@@ -137,8 +131,6 @@ class TestDeletePapersByDateRange:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_by_date_range(self, db_session, sample_papers_range):
|
||||
"""删除指定日期范围的论文。"""
|
||||
from app.services.cleaner import delete_papers_by_date_range
|
||||
|
||||
# 删除 1月11日 ~ 1月13日(3篇)
|
||||
result = await delete_papers_by_date_range(
|
||||
db_session,
|
||||
@@ -159,8 +151,6 @@ class TestDeletePapersByDateRange:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_creates_job_record(self, db_session, sample_papers_range):
|
||||
"""删除操作应创建 data_delete_jobs 记录。"""
|
||||
from app.services.cleaner import delete_papers_by_date_range
|
||||
|
||||
await delete_papers_by_date_range(
|
||||
db_session,
|
||||
date(2024, 1, 10),
|
||||
@@ -178,8 +168,6 @@ class TestDeletePapersByDateRange:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_creates_crawl_log(self, db_session, sample_papers_range):
|
||||
"""删除操作应写入 crawl_logs。"""
|
||||
from app.services.cleaner import delete_papers_by_date_range
|
||||
|
||||
await delete_papers_by_date_range(
|
||||
db_session,
|
||||
date(2024, 1, 10),
|
||||
@@ -199,8 +187,6 @@ class TestDeletePapersByDateRange:
|
||||
self, db_session, sample_paper_with_user_data
|
||||
):
|
||||
"""删除论文时应 cascade 删除关联的用户数据。"""
|
||||
from app.services.cleaner import delete_papers_by_date_range
|
||||
|
||||
paper = sample_paper_with_user_data
|
||||
|
||||
# 删除
|
||||
@@ -235,7 +221,6 @@ class TestDeletePapersByDateRange:
|
||||
async def test_delete_removes_fts(self, db_session, sample_papers_range):
|
||||
"""删除论文时应同步删除 FTS5 索引。"""
|
||||
import sqlalchemy
|
||||
from app.services.cleaner import delete_papers_by_date_range
|
||||
|
||||
await delete_papers_by_date_range(
|
||||
db_session,
|
||||
@@ -254,8 +239,6 @@ class TestDeletePapersByDateRange:
|
||||
self, db_session, sample_papers_range, tmp_path, monkeypatch
|
||||
):
|
||||
"""删除论文时应删除本地文件目录。"""
|
||||
from app.services.cleaner import delete_papers_by_date_range
|
||||
|
||||
papers_dir = tmp_path / "papers"
|
||||
papers_dir.mkdir()
|
||||
(papers_dir / "2401.10001").mkdir()
|
||||
@@ -274,8 +257,6 @@ class TestDeletePapersByDateRange:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_empty_range(self, db_session, sample_papers_range):
|
||||
"""日期范围内无论文时返回 0。"""
|
||||
from app.services.cleaner import delete_papers_by_date_range
|
||||
|
||||
result = await delete_papers_by_date_range(
|
||||
db_session,
|
||||
date(2025, 1, 1),
|
||||
@@ -295,8 +276,6 @@ class TestDeletePapersByDateRange:
|
||||
|
||||
emb._chroma.reset()
|
||||
|
||||
from app.services.cleaner import delete_papers_by_date_range
|
||||
|
||||
result = await delete_papers_by_date_range(
|
||||
db_session,
|
||||
date(2024, 1, 10),
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.config import settings
|
||||
|
||||
@@ -84,24 +83,6 @@ class TestEmbedderIndexing:
|
||||
|
||||
emb._chroma.reset()
|
||||
|
||||
def test_index_batch_disabled(self, monkeypatch):
|
||||
"""CHROMA_ENABLED=false 时 index_batch 返回全失败。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
import app.services.embedder as emb
|
||||
|
||||
emb._chroma.reset()
|
||||
result = emb.index_batch(["a", "b"])
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 2
|
||||
|
||||
def test_index_batch_empty(self, monkeypatch):
|
||||
"""空列表时返回 0。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
import app.services.embedder as emb
|
||||
|
||||
result = emb.index_batch([])
|
||||
assert result["total"] == 0
|
||||
|
||||
def test_delete_paper_disabled(self, monkeypatch):
|
||||
"""CHROMA_ENABLED=false 时 delete_paper 返回 False。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
from datetime import date
|
||||
from unittest.mock import patch as upatch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.config import settings
|
||||
|
||||
@@ -30,26 +29,6 @@ class TestDetailPage:
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Similar API(详情页内联)
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestDetailSimilarPapers:
|
||||
"""详情页相似论文模块测试(CHROMA 关闭时的降级行为)。"""
|
||||
|
||||
def test_detail_page_renders_with_similar(self, client, sample_papers_with_summary):
|
||||
"""详情页正常渲染(含相似论文模块)。"""
|
||||
resp = client.get("/paper/2401.20001")
|
||||
assert resp.status_code == 200
|
||||
assert "测试论文" in resp.text or "Test Paper" in resp.text
|
||||
|
||||
def test_detail_page_not_found_similar(self, client):
|
||||
"""不存在的论文返回 404。"""
|
||||
resp = client.get("/paper/nonexistent.99999")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Trends Dashboard
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
+5
-48
@@ -2,10 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
from datetime import date, datetime, timezone
|
||||
|
||||
from app.config import settings
|
||||
from app.services.searcher import get_all_tags, search_papers
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
@@ -17,90 +19,60 @@ class TestSearchService:
|
||||
"""app/services/searcher.py — FTS5 关键词搜索单元测试。"""
|
||||
|
||||
def test_search_by_title(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="Test Paper")
|
||||
assert result["total"] == 1
|
||||
assert result["results"][0].arxiv_id == "2401.12345"
|
||||
|
||||
def test_search_by_abstract(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="test abstract")
|
||||
assert result["total"] == 1
|
||||
|
||||
def test_search_by_author(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="Alice")
|
||||
assert result["total"] == 1
|
||||
|
||||
def test_search_by_tag_in_fts(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
# FTS5 索引中包含 tags 列,可以搜到
|
||||
result = search_papers(db_session, query="NLP")
|
||||
assert result["total"] == 1
|
||||
|
||||
def test_search_no_results(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="quantum entanglement")
|
||||
assert result["total"] == 0
|
||||
assert result["results"] == []
|
||||
|
||||
def test_search_empty_query_returns_empty(self, db_session):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="")
|
||||
assert result["total"] == 0
|
||||
assert result["results"] == []
|
||||
|
||||
def test_search_special_characters_sanitized(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
# 特殊字符被清除后,剩下 "Test" 仍然能搜到
|
||||
result = search_papers(db_session, query='Test "Paper" {test}')
|
||||
assert result["total"] >= 1
|
||||
|
||||
def test_search_with_tag_filter(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
# 关键词 + 标签筛选
|
||||
result = search_papers(db_session, query="Paper", tag="NLP")
|
||||
assert result["total"] == 1
|
||||
# 标签不匹配 → 0
|
||||
result2 = search_papers(db_session, query="Paper", tag="nonexistent")
|
||||
assert result2["total"] == 0
|
||||
|
||||
def test_search_tag_only_no_query(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
# 只有标签,无关键词
|
||||
result = search_papers(db_session, tag="NLP")
|
||||
assert result["total"] == 1
|
||||
assert result["results"][0].arxiv_id == "2401.12345"
|
||||
|
||||
def test_search_pagination(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="Test", page=2, page_size=10)
|
||||
assert result["page"] == 2
|
||||
assert result["total_pages"] == 1 # 只有 1 条结果,1 页
|
||||
assert result["total_pages"] == 1
|
||||
|
||||
def test_search_returns_snippets(self, db_session, sample_paper):
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="test abstract")
|
||||
assert result["total"] == 1
|
||||
paper_id = result["results"][0].id
|
||||
assert paper_id in result["snippets"]
|
||||
snippet = result["snippets"][paper_id]
|
||||
assert "abstract" in snippet
|
||||
assert "abstract" in result["snippets"][paper_id]
|
||||
|
||||
def test_get_all_tags(self, db_session, sample_paper):
|
||||
from app.services.searcher import get_all_tags
|
||||
|
||||
tags = get_all_tags(db_session)
|
||||
assert "NLP" in tags
|
||||
assert "LLM" in tags
|
||||
@@ -115,9 +87,6 @@ class TestSearchSemanticMode:
|
||||
"""searcher.py — semantic 模式(含 embedder 回退)测试。"""
|
||||
|
||||
def test_keyword_mode_default(self, db_session, sample_papers_with_summary):
|
||||
"""默认 keyword 模式走 FTS5。"""
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="Test Paper", mode="keyword")
|
||||
assert result["total"] >= 1
|
||||
assert result["distances"] == {}
|
||||
@@ -125,35 +94,23 @@ class TestSearchSemanticMode:
|
||||
def test_semantic_mode_disabled_fallback(
|
||||
self, db_session, monkeypatch, sample_papers_with_summary
|
||||
):
|
||||
"""CHROMA_ENABLED=false + semantic 模式走 FTS5。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="Test", mode="semantic")
|
||||
assert result["total"] >= 1
|
||||
|
||||
def test_search_returns_distances_dict(
|
||||
self, db_session, sample_papers_with_summary
|
||||
):
|
||||
"""搜索结果应包含 distances 字段。"""
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, query="Test Paper")
|
||||
assert "distances" in result
|
||||
assert isinstance(result["distances"], dict)
|
||||
|
||||
def test_empty_query_returns_empty_no_tags(self, db_session):
|
||||
"""空查询无标签时返回空。"""
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session)
|
||||
assert result["total"] == 0
|
||||
assert result["results"] == []
|
||||
|
||||
def test_tag_only_search(self, db_session, sample_papers_with_summary):
|
||||
"""仅标签搜索。"""
|
||||
from app.services.searcher import search_papers
|
||||
|
||||
result = search_papers(db_session, tag="NLP")
|
||||
assert result["total"] >= 1
|
||||
|
||||
|
||||
+37
-51
@@ -3,8 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import date, datetime, timezone
|
||||
from pathlib import Path
|
||||
from datetime import date
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -26,11 +25,27 @@ from app.services.pi_client import PiTimeoutError
|
||||
from app.services.schemas import SummarySchema
|
||||
from app.services.summarizer import (
|
||||
_save_files,
|
||||
_save_raw_output_only,
|
||||
_update_summary_in_db,
|
||||
summarize_batch,
|
||||
summarize_one,
|
||||
)
|
||||
from app.utils import utc_now
|
||||
|
||||
|
||||
# ── 共享 fixture ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _summarize_tmp_paths(tmp_path):
|
||||
"""将 data 目录重定向到 tmp_path(供 summarizer 测试使用)。"""
|
||||
with (
|
||||
patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
|
||||
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
|
||||
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
|
||||
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
|
||||
patch("app.utils.TMP_DIR", tmp_path / "tmp"),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
@@ -130,7 +145,7 @@ class TestFileOperations:
|
||||
|
||||
def test_save_raw_output_only(self, tmp_path):
|
||||
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
|
||||
_save_raw_output_only("2401.12345", "raw output")
|
||||
_save_files("2401.12345", None, "raw output")
|
||||
paper_dir = tmp_path / "2401.12345"
|
||||
assert (paper_dir / "raw_output.txt").exists()
|
||||
assert not (paper_dir / "summary.json").exists()
|
||||
@@ -157,24 +172,9 @@ class TestFileOperations:
|
||||
class TestSummarizeOneFlow:
|
||||
"""summarize_one 的状态流转(mock pi 和 PDF)。"""
|
||||
|
||||
@pytest.fixture
|
||||
def _patch_paths(self, tmp_path):
|
||||
"""将 data 目录重定向到 tmp_path。"""
|
||||
with (
|
||||
patch(
|
||||
"app.services.summarizer.paper_dir",
|
||||
lambda aid: tmp_path / "papers" / aid,
|
||||
),
|
||||
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
|
||||
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
|
||||
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
|
||||
patch("app.utils.TMP_DIR", tmp_path / "tmp"),
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_success_path(
|
||||
self, db_session, sample_paper, mock_pi_output, _patch_paths
|
||||
self, db_session, sample_paper, mock_pi_output, _summarize_tmp_paths
|
||||
):
|
||||
"""pending → processing → done 全流程。"""
|
||||
with (
|
||||
@@ -209,7 +209,7 @@ class TestSummarizeOneFlow:
|
||||
assert fts_row[0] == "测试论文中文标题"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_download_failure(self, db_session, sample_paper, _patch_paths):
|
||||
async def test_pdf_download_failure(self, db_session, sample_paper, _summarize_tmp_paths):
|
||||
"""PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。"""
|
||||
with (
|
||||
patch(
|
||||
@@ -228,7 +228,7 @@ class TestSummarizeOneFlow:
|
||||
assert status.error_type == "pdf_download_failed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pi_timeout(self, db_session, sample_paper, _patch_paths):
|
||||
async def test_pi_timeout(self, db_session, sample_paper, _summarize_tmp_paths):
|
||||
"""pi 超时 → timeout 错误,retry_count 递增。"""
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
@@ -245,7 +245,7 @@ class TestSummarizeOneFlow:
|
||||
assert result["retry_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_not_found(self, db_session, sample_paper, _patch_paths):
|
||||
async def test_json_not_found(self, db_session, sample_paper, _summarize_tmp_paths):
|
||||
"""pi 输出无 JSON → 验证循环重试 4 次后 ValueError (unknown)。"""
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
@@ -262,7 +262,7 @@ class TestSummarizeOneFlow:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_fails_and_retries(
|
||||
self, db_session, sample_paper, _patch_paths
|
||||
self, db_session, sample_paper, _summarize_tmp_paths
|
||||
):
|
||||
"""验证失败(字段不符合要求)→ 重试多次后失败。"""
|
||||
bad_json = json.dumps(
|
||||
@@ -294,7 +294,7 @@ class TestSummarizeOneFlow:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_output_saved_on_failure(
|
||||
self, db_session, sample_paper, tmp_path, _patch_paths
|
||||
self, db_session, sample_paper, tmp_path, _summarize_tmp_paths
|
||||
):
|
||||
"""失败时仍保存 raw_output.txt。"""
|
||||
with (
|
||||
@@ -313,7 +313,7 @@ class TestSummarizeOneFlow:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_cleaned_on_success(
|
||||
self, db_session, sample_paper, mock_pi_output, tmp_path, _patch_paths
|
||||
self, db_session, sample_paper, mock_pi_output, tmp_path, _summarize_tmp_paths
|
||||
):
|
||||
"""成功后清理 tmp 目录。"""
|
||||
with (
|
||||
@@ -331,7 +331,7 @@ class TestSummarizeOneFlow:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tmp_cleaned_on_failure(
|
||||
self, db_session, sample_paper, tmp_path, _patch_paths
|
||||
self, db_session, sample_paper, tmp_path, _summarize_tmp_paths
|
||||
):
|
||||
"""失败后也清理 tmp 目录。"""
|
||||
with (
|
||||
@@ -347,7 +347,7 @@ class TestSummarizeOneFlow:
|
||||
assert not tmp_paper.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_done_paper(self, db_session, sample_paper, _patch_paths):
|
||||
async def test_skips_done_paper(self, db_session, sample_paper, _summarize_tmp_paths):
|
||||
"""已完成的论文跳过。"""
|
||||
sample_paper.summary_status.status = "done"
|
||||
db_session.commit()
|
||||
@@ -364,26 +364,12 @@ class TestSummarizeOneFlow:
|
||||
class TestBatchSummarize:
|
||||
"""批量总结测试。"""
|
||||
|
||||
@pytest.fixture
|
||||
def _patch_paths(self, tmp_path):
|
||||
with (
|
||||
patch(
|
||||
"app.services.summarizer.paper_dir",
|
||||
lambda aid: tmp_path / "papers" / aid,
|
||||
),
|
||||
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
|
||||
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
|
||||
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
|
||||
patch("app.utils.TMP_DIR", tmp_path / "tmp"),
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_multiple_papers(
|
||||
self, db_session, db_engine, mock_pi_output, _patch_paths
|
||||
self, db_session, db_engine, mock_pi_output, _summarize_tmp_paths
|
||||
):
|
||||
"""批量处理多篇论文。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
for i in range(3):
|
||||
p = Paper(
|
||||
arxiv_id=f"2401.1234{i}",
|
||||
@@ -426,10 +412,10 @@ class TestBatchSummarize:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_failure_no_block(
|
||||
self, db_session, db_engine, mock_pi_output, _patch_paths
|
||||
self, db_session, db_engine, mock_pi_output, _summarize_tmp_paths
|
||||
):
|
||||
"""一篇失败不阻塞其他。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
for i in range(2):
|
||||
p = Paper(
|
||||
arxiv_id=f"2401.5678{i}",
|
||||
@@ -451,7 +437,7 @@ class TestBatchSummarize:
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _mock_call_pi(meta_path, pdf_path):
|
||||
async def _mock_call_pi(meta_path, pdf_path, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
@@ -468,7 +454,7 @@ class TestBatchSummarize:
|
||||
assert result["failed"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_lock_conflict(self, db_session, _patch_paths):
|
||||
async def test_task_lock_conflict(self, db_session, _summarize_tmp_paths):
|
||||
"""TaskLock 防止并发 batch。"""
|
||||
# 先插入一个 running 锁
|
||||
db_session.add(
|
||||
@@ -476,7 +462,7 @@ class TestBatchSummarize:
|
||||
task="summarize",
|
||||
lock_key="batch",
|
||||
status="running",
|
||||
acquired_at=datetime.now(timezone.utc),
|
||||
acquired_at=utc_now(),
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
@@ -486,7 +472,7 @@ class TestBatchSummarize:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_lock_released(
|
||||
self, db_session, db_engine, mock_pi_output, _patch_paths
|
||||
self, db_session, db_engine, mock_pi_output, _summarize_tmp_paths
|
||||
):
|
||||
"""完成后释放 TaskLock。"""
|
||||
from sqlalchemy.orm import sessionmaker as _sm
|
||||
@@ -516,7 +502,7 @@ class TestBatchSummarize:
|
||||
assert lock.released_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_empty(self, db_session, _patch_paths):
|
||||
async def test_batch_empty(self, db_session, _summarize_tmp_paths):
|
||||
"""无 pending 论文时返回空结果。"""
|
||||
result = await summarize_batch(db_session)
|
||||
assert result["status"] == "success"
|
||||
|
||||
+8
-30
@@ -2,8 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from app.services.user_data import (
|
||||
get_note,
|
||||
save_note,
|
||||
set_reading_status,
|
||||
toggle_bookmark,
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
@@ -13,22 +17,16 @@ from datetime import datetime, timezone
|
||||
|
||||
class TestBookmarkService:
|
||||
def test_toggle_bookmark_add(self, db_session, sample_paper):
|
||||
from app.services.user_data import toggle_bookmark
|
||||
|
||||
result = toggle_bookmark(db_session, "2401.12345")
|
||||
assert result["bookmarked"] is True
|
||||
assert result["arxiv_id"] == "2401.12345"
|
||||
|
||||
def test_toggle_bookmark_remove(self, db_session, sample_paper):
|
||||
from app.services.user_data import toggle_bookmark
|
||||
|
||||
toggle_bookmark(db_session, "2401.12345") # 添加
|
||||
result = toggle_bookmark(db_session, "2401.12345") # 移除
|
||||
toggle_bookmark(db_session, "2401.12345")
|
||||
result = toggle_bookmark(db_session, "2401.12345")
|
||||
assert result["bookmarked"] is False
|
||||
|
||||
def test_toggle_bookmark_not_found(self, db_session):
|
||||
from app.services.user_data import toggle_bookmark
|
||||
|
||||
result = toggle_bookmark(db_session, "nonexistent")
|
||||
assert "error" in result
|
||||
assert result["error"] == "not_found"
|
||||
@@ -41,36 +39,26 @@ class TestBookmarkService:
|
||||
|
||||
class TestReadingStatusService:
|
||||
def test_set_reading_status(self, db_session, sample_paper):
|
||||
from app.services.user_data import set_reading_status
|
||||
|
||||
result = set_reading_status(db_session, "2401.12345", "read_summary")
|
||||
assert result["status"] == "read_summary"
|
||||
assert result["arxiv_id"] == "2401.12345"
|
||||
|
||||
def test_set_reading_status_invalid(self, db_session, sample_paper):
|
||||
from app.services.user_data import set_reading_status
|
||||
|
||||
result = set_reading_status(db_session, "2401.12345", "invalid_status")
|
||||
assert "error" in result
|
||||
assert result["error"] == "invalid_status"
|
||||
|
||||
def test_update_existing_status(self, db_session, sample_paper):
|
||||
from app.services.user_data import set_reading_status
|
||||
|
||||
set_reading_status(db_session, "2401.12345", "skimmed")
|
||||
result = set_reading_status(db_session, "2401.12345", "read_full")
|
||||
assert result["status"] == "read_full"
|
||||
|
||||
def test_set_reading_status_not_found(self, db_session):
|
||||
from app.services.user_data import set_reading_status
|
||||
|
||||
result = set_reading_status(db_session, "nonexistent", "unread")
|
||||
assert "error" in result
|
||||
assert result["error"] == "not_found"
|
||||
|
||||
def test_all_valid_statuses(self, db_session, sample_paper):
|
||||
from app.services.user_data import set_reading_status
|
||||
|
||||
for status in ("unread", "skimmed", "read_summary", "read_full"):
|
||||
result = set_reading_status(db_session, "2401.12345", status)
|
||||
assert result["status"] == status
|
||||
@@ -83,8 +71,6 @@ class TestReadingStatusService:
|
||||
|
||||
class TestNoteService:
|
||||
def test_save_and_get_note(self, db_session, sample_paper):
|
||||
from app.services.user_data import get_note, save_note
|
||||
|
||||
save_note(db_session, "2401.12345", "这是一条测试笔记")
|
||||
result = get_note(db_session, "2401.12345")
|
||||
assert result["content"] == "这是一条测试笔记"
|
||||
@@ -92,29 +78,21 @@ class TestNoteService:
|
||||
assert result["updated_at"] is not None
|
||||
|
||||
def test_update_note(self, db_session, sample_paper):
|
||||
from app.services.user_data import get_note, save_note
|
||||
|
||||
save_note(db_session, "2401.12345", "旧笔记")
|
||||
save_note(db_session, "2401.12345", "新笔记")
|
||||
result = get_note(db_session, "2401.12345")
|
||||
assert result["content"] == "新笔记"
|
||||
|
||||
def test_get_note_empty(self, db_session, sample_paper):
|
||||
from app.services.user_data import get_note
|
||||
|
||||
result = get_note(db_session, "2401.12345")
|
||||
assert result["content"] == ""
|
||||
assert result["updated_at"] is None
|
||||
|
||||
def test_get_note_paper_not_found(self, db_session):
|
||||
from app.services.user_data import get_note
|
||||
|
||||
result = get_note(db_session, "nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_save_note_paper_not_found(self, db_session):
|
||||
from app.services.user_data import save_note
|
||||
|
||||
result = save_note(db_session, "nonexistent", "内容")
|
||||
assert "error" in result
|
||||
assert result["error"] == "not_found"
|
||||
|
||||
Reference in New Issue
Block a user