feat: add admin dashboard, pipeline service, lightbox, and update dependencies

This commit is contained in:
2026-06-09 09:32:10 +08:00
parent 0d293422ac
commit 32978b3fc5
50 changed files with 4054 additions and 1618 deletions
+6 -17
View File
@@ -3,14 +3,12 @@
from __future__ import annotations
import json
from datetime import date, datetime, timezone
from pathlib import Path
from unittest.mock import AsyncMock
from datetime import date
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine, event
from sqlalchemy.orm import DeclarativeBase, sessionmaker
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.database import get_db
@@ -23,21 +21,12 @@ from app.models import (
PaperTag,
SummaryStatus,
)
from app.utils import utc_now
# ── 内存数据库 ──────────────────────────────────────────────────────────
class _TestBase(DeclarativeBase):
pass
# 复用 app.models 的 Base metadata
from app.database import Base as _AppBase # noqa: E402
_TestBase.metadata = _AppBase.metadata
@pytest.fixture
def db_engine():
"""创建内存 SQLite 引擎 + FTS5。"""
@@ -94,7 +83,7 @@ _TEST_ADMIN_PASSWORD = "test-password-12345"
@pytest.fixture
def sample_paper(db_session):
"""插入一篇测试论文 + 作者 + 标签 + summary_status(pending)。"""
now = datetime.now(timezone.utc)
now = utc_now()
paper = Paper(
arxiv_id=SAMPLE_ARXIV_ID,
title_en="Test Paper Title",
@@ -234,7 +223,7 @@ def auth_client(client, monkeypatch):
@pytest.fixture
def sample_papers_range(db_session):
"""插入 5 篇不同日期的论文(用于 admin / cleaner 测试)。"""
now = datetime.now(timezone.utc)
now = utc_now()
papers = []
for i, (arxiv_id, paper_date_str) in enumerate(
[
@@ -281,7 +270,7 @@ def sample_papers_range(db_session):
@pytest.fixture
def sample_papers_with_summary(db_session):
"""插入 5 篇带总结的论文(用于 search / pages / trends 测试)。"""
now = datetime.now(timezone.utc)
now = utc_now()
papers = []
for i, (arxiv_id, paper_date_str) in enumerate(
[
+6 -11
View File
@@ -3,7 +3,6 @@
from __future__ import annotations
import logging
from datetime import date, datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
@@ -14,6 +13,7 @@ from app.models import (
CrawlLog,
TaskLock,
)
from app.utils import utc_now
# ═══════════════════════════════════════════════════════════════════════
@@ -24,11 +24,6 @@ from app.models import (
class TestAdminAuth:
"""管理接口鉴权测试。"""
def test_unauthenticated_redirects_to_login(self, auth_client):
"""未登录时请求管理接口应重定向到登录页。"""
# 用未登录的 clientauth_client 已登录,这里直接用 client)
pass # 见下方 test_no_session_returns_303
def test_no_session_returns_303(self, client, monkeypatch):
"""无 session 时请求管理接口应返回 303 重定向。"""
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
@@ -58,7 +53,7 @@ class TestAdminAuth:
follow_redirects=False,
)
assert resp.status_code == 303
assert "/admin/logs" in resp.headers.get("location", "")
assert "/admin/" in resp.headers.get("location", "")
def test_logout_clears_session(self, auth_client, monkeypatch):
"""退出登录后应清除 session。"""
@@ -265,7 +260,7 @@ class TestAdminLogs:
):
"""日志页面应包含日志数据。"""
# 先创建一条日志
now = datetime.now(timezone.utc)
now = utc_now()
db_session.add(
CrawlLog(
task="crawl",
@@ -345,7 +340,7 @@ class TestScheduler:
@pytest.mark.asyncio
async def test_daily_pipeline_lock_prevents_reentry(self, db_session):
"""pipeline 使用 task_locks 防重入。"""
now = datetime.now(timezone.utc)
now = utc_now()
lock = TaskLock(
task="scheduler",
lock_key="pipeline-2024-01-15",
@@ -380,7 +375,7 @@ class TestTaskLocks:
def test_unique_running_lock(self, db_session):
"""同一 task + lock_key 只能有一个 running 锁。"""
now = datetime.now(timezone.utc)
now = utc_now()
lock1 = TaskLock(
task="crawl",
lock_key="2024-01-15",
@@ -405,7 +400,7 @@ class TestTaskLocks:
def test_released_lock_allows_new(self, db_session):
"""已释放的锁允许新的 running 锁。"""
now = datetime.now(timezone.utc)
now = utc_now()
lock1 = TaskLock(
task="crawl",
lock_key="2024-01-16",
+4 -25
View File
@@ -4,7 +4,7 @@ from __future__ import annotations
import os
import time
from datetime import date, datetime, timezone
from datetime import date
import pytest
from sqlalchemy import select
@@ -18,6 +18,8 @@ from app.models import (
UserNote,
UserReadingStatus,
)
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
from app.utils import utc_now
# ── Fixtures ────────────────────────────────────────────────────────────
@@ -27,7 +29,7 @@ from app.models import (
def sample_paper_with_user_data(db_session, sample_papers_range):
"""给第一篇论文添加用户数据(收藏、阅读状态、笔记)。"""
paper = sample_papers_range[0]
now = datetime.now(timezone.utc)
now = utc_now()
db_session.add(UserBookmark(paper_id=paper.id, created_at=now))
db_session.add(
UserReadingStatus(paper_id=paper.id, status="read_summary", updated_at=now)
@@ -67,8 +69,6 @@ class TestCleanupTmp:
os.utime(old_dir, (old_mtime, old_mtime))
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 1
@@ -85,8 +85,6 @@ class TestCleanupTmp:
(recent_dir / "paper.pdf").write_text("fake pdf")
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 1
@@ -96,8 +94,6 @@ class TestCleanupTmp:
def test_cleanup_empty_dir(self, tmp_path, monkeypatch):
"""data/tmp/ 不存在时安全返回。"""
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_path / "nonexistent")
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 0
assert result["removed"] == 0
@@ -116,8 +112,6 @@ class TestCleanupTmp:
recent_dir.mkdir()
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 2
@@ -137,8 +131,6 @@ class TestDeletePapersByDateRange:
@pytest.mark.asyncio
async def test_delete_by_date_range(self, db_session, sample_papers_range):
"""删除指定日期范围的论文。"""
from app.services.cleaner import delete_papers_by_date_range
# 删除 1月11日 ~ 1月13日(3篇)
result = await delete_papers_by_date_range(
db_session,
@@ -159,8 +151,6 @@ class TestDeletePapersByDateRange:
@pytest.mark.asyncio
async def test_delete_creates_job_record(self, db_session, sample_papers_range):
"""删除操作应创建 data_delete_jobs 记录。"""
from app.services.cleaner import delete_papers_by_date_range
await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
@@ -178,8 +168,6 @@ class TestDeletePapersByDateRange:
@pytest.mark.asyncio
async def test_delete_creates_crawl_log(self, db_session, sample_papers_range):
"""删除操作应写入 crawl_logs。"""
from app.services.cleaner import delete_papers_by_date_range
await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
@@ -199,8 +187,6 @@ class TestDeletePapersByDateRange:
self, db_session, sample_paper_with_user_data
):
"""删除论文时应 cascade 删除关联的用户数据。"""
from app.services.cleaner import delete_papers_by_date_range
paper = sample_paper_with_user_data
# 删除
@@ -235,7 +221,6 @@ class TestDeletePapersByDateRange:
async def test_delete_removes_fts(self, db_session, sample_papers_range):
"""删除论文时应同步删除 FTS5 索引。"""
import sqlalchemy
from app.services.cleaner import delete_papers_by_date_range
await delete_papers_by_date_range(
db_session,
@@ -254,8 +239,6 @@ class TestDeletePapersByDateRange:
self, db_session, sample_papers_range, tmp_path, monkeypatch
):
"""删除论文时应删除本地文件目录。"""
from app.services.cleaner import delete_papers_by_date_range
papers_dir = tmp_path / "papers"
papers_dir.mkdir()
(papers_dir / "2401.10001").mkdir()
@@ -274,8 +257,6 @@ class TestDeletePapersByDateRange:
@pytest.mark.asyncio
async def test_delete_empty_range(self, db_session, sample_papers_range):
"""日期范围内无论文时返回 0。"""
from app.services.cleaner import delete_papers_by_date_range
result = await delete_papers_by_date_range(
db_session,
date(2025, 1, 1),
@@ -295,8 +276,6 @@ class TestDeletePapersByDateRange:
emb._chroma.reset()
from app.services.cleaner import delete_papers_by_date_range
result = await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
-19
View File
@@ -4,7 +4,6 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from app.config import settings
@@ -84,24 +83,6 @@ class TestEmbedderIndexing:
emb._chroma.reset()
def test_index_batch_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 index_batch 返回全失败。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
result = emb.index_batch(["a", "b"])
assert result["success"] == 0
assert result["failed"] == 2
def test_index_batch_empty(self, monkeypatch):
"""空列表时返回 0。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
result = emb.index_batch([])
assert result["total"] == 0
def test_delete_paper_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 delete_paper 返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
-21
View File
@@ -5,7 +5,6 @@ from __future__ import annotations
from datetime import date
from unittest.mock import patch as upatch
import pytest
from app.config import settings
@@ -30,26 +29,6 @@ class TestDetailPage:
assert resp.status_code == 404
# ═══════════════════════════════════════════════════════════════════════
# Similar API(详情页内联)
# ═══════════════════════════════════════════════════════════════════════
class TestDetailSimilarPapers:
"""详情页相似论文模块测试(CHROMA 关闭时的降级行为)。"""
def test_detail_page_renders_with_similar(self, client, sample_papers_with_summary):
"""详情页正常渲染(含相似论文模块)。"""
resp = client.get("/paper/2401.20001")
assert resp.status_code == 200
assert "测试论文" in resp.text or "Test Paper" in resp.text
def test_detail_page_not_found_similar(self, client):
"""不存在的论文返回 404。"""
resp = client.get("/paper/nonexistent.99999")
assert resp.status_code == 404
# ═══════════════════════════════════════════════════════════════════════
# Trends Dashboard
# ═══════════════════════════════════════════════════════════════════════
+5 -48
View File
@@ -2,10 +2,12 @@
from __future__ import annotations
from datetime import date
import pytest
from datetime import date, datetime, timezone
from app.config import settings
from app.services.searcher import get_all_tags, search_papers
# ═══════════════════════════════════════════════════════════════════════
@@ -17,90 +19,60 @@ class TestSearchService:
"""app/services/searcher.py — FTS5 关键词搜索单元测试。"""
def test_search_by_title(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test Paper")
assert result["total"] == 1
assert result["results"][0].arxiv_id == "2401.12345"
def test_search_by_abstract(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="test abstract")
assert result["total"] == 1
def test_search_by_author(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="Alice")
assert result["total"] == 1
def test_search_by_tag_in_fts(self, db_session, sample_paper):
from app.services.searcher import search_papers
# FTS5 索引中包含 tags 列,可以搜到
result = search_papers(db_session, query="NLP")
assert result["total"] == 1
def test_search_no_results(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="quantum entanglement")
assert result["total"] == 0
assert result["results"] == []
def test_search_empty_query_returns_empty(self, db_session):
from app.services.searcher import search_papers
result = search_papers(db_session, query="")
assert result["total"] == 0
assert result["results"] == []
def test_search_special_characters_sanitized(self, db_session, sample_paper):
from app.services.searcher import search_papers
# 特殊字符被清除后,剩下 "Test" 仍然能搜到
result = search_papers(db_session, query='Test "Paper" {test}')
assert result["total"] >= 1
def test_search_with_tag_filter(self, db_session, sample_paper):
from app.services.searcher import search_papers
# 关键词 + 标签筛选
result = search_papers(db_session, query="Paper", tag="NLP")
assert result["total"] == 1
# 标签不匹配 → 0
result2 = search_papers(db_session, query="Paper", tag="nonexistent")
assert result2["total"] == 0
def test_search_tag_only_no_query(self, db_session, sample_paper):
from app.services.searcher import search_papers
# 只有标签,无关键词
result = search_papers(db_session, tag="NLP")
assert result["total"] == 1
assert result["results"][0].arxiv_id == "2401.12345"
def test_search_pagination(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test", page=2, page_size=10)
assert result["page"] == 2
assert result["total_pages"] == 1 # 只有 1 条结果,1 页
assert result["total_pages"] == 1
def test_search_returns_snippets(self, db_session, sample_paper):
from app.services.searcher import search_papers
result = search_papers(db_session, query="test abstract")
assert result["total"] == 1
paper_id = result["results"][0].id
assert paper_id in result["snippets"]
snippet = result["snippets"][paper_id]
assert "abstract" in snippet
assert "abstract" in result["snippets"][paper_id]
def test_get_all_tags(self, db_session, sample_paper):
from app.services.searcher import get_all_tags
tags = get_all_tags(db_session)
assert "NLP" in tags
assert "LLM" in tags
@@ -115,9 +87,6 @@ class TestSearchSemanticMode:
"""searcher.py — semantic 模式(含 embedder 回退)测试。"""
def test_keyword_mode_default(self, db_session, sample_papers_with_summary):
"""默认 keyword 模式走 FTS5。"""
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test Paper", mode="keyword")
assert result["total"] >= 1
assert result["distances"] == {}
@@ -125,35 +94,23 @@ class TestSearchSemanticMode:
def test_semantic_mode_disabled_fallback(
self, db_session, monkeypatch, sample_papers_with_summary
):
"""CHROMA_ENABLED=false + semantic 模式走 FTS5。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test", mode="semantic")
assert result["total"] >= 1
def test_search_returns_distances_dict(
self, db_session, sample_papers_with_summary
):
"""搜索结果应包含 distances 字段。"""
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test Paper")
assert "distances" in result
assert isinstance(result["distances"], dict)
def test_empty_query_returns_empty_no_tags(self, db_session):
"""空查询无标签时返回空。"""
from app.services.searcher import search_papers
result = search_papers(db_session)
assert result["total"] == 0
assert result["results"] == []
def test_tag_only_search(self, db_session, sample_papers_with_summary):
"""仅标签搜索。"""
from app.services.searcher import search_papers
result = search_papers(db_session, tag="NLP")
assert result["total"] >= 1
+37 -51
View File
@@ -3,8 +3,7 @@
from __future__ import annotations
import json
from datetime import date, datetime, timezone
from pathlib import Path
from datetime import date
from unittest.mock import AsyncMock, patch
import pytest
@@ -26,11 +25,27 @@ from app.services.pi_client import PiTimeoutError
from app.services.schemas import SummarySchema
from app.services.summarizer import (
_save_files,
_save_raw_output_only,
_update_summary_in_db,
summarize_batch,
summarize_one,
)
from app.utils import utc_now
# ── 共享 fixture ──────────────────────────────────────────────────────────
@pytest.fixture
def _summarize_tmp_paths(tmp_path):
"""将 data 目录重定向到 tmp_path(供 summarizer 测试使用)。"""
with (
patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
patch("app.utils.TMP_DIR", tmp_path / "tmp"),
):
yield
# ═══════════════════════════════════════════════════════════════════════
@@ -130,7 +145,7 @@ class TestFileOperations:
def test_save_raw_output_only(self, tmp_path):
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
_save_raw_output_only("2401.12345", "raw output")
_save_files("2401.12345", None, "raw output")
paper_dir = tmp_path / "2401.12345"
assert (paper_dir / "raw_output.txt").exists()
assert not (paper_dir / "summary.json").exists()
@@ -157,24 +172,9 @@ class TestFileOperations:
class TestSummarizeOneFlow:
"""summarize_one 的状态流转(mock pi 和 PDF)。"""
@pytest.fixture
def _patch_paths(self, tmp_path):
"""将 data 目录重定向到 tmp_path。"""
with (
patch(
"app.services.summarizer.paper_dir",
lambda aid: tmp_path / "papers" / aid,
),
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
patch("app.utils.TMP_DIR", tmp_path / "tmp"),
):
yield
@pytest.mark.asyncio
async def test_full_success_path(
self, db_session, sample_paper, mock_pi_output, _patch_paths
self, db_session, sample_paper, mock_pi_output, _summarize_tmp_paths
):
"""pending → processing → done 全流程。"""
with (
@@ -209,7 +209,7 @@ class TestSummarizeOneFlow:
assert fts_row[0] == "测试论文中文标题"
@pytest.mark.asyncio
async def test_pdf_download_failure(self, db_session, sample_paper, _patch_paths):
async def test_pdf_download_failure(self, db_session, sample_paper, _summarize_tmp_paths):
"""PDF 下载失败 → error_type=pdf_download_failedtmp 被清理。"""
with (
patch(
@@ -228,7 +228,7 @@ class TestSummarizeOneFlow:
assert status.error_type == "pdf_download_failed"
@pytest.mark.asyncio
async def test_pi_timeout(self, db_session, sample_paper, _patch_paths):
async def test_pi_timeout(self, db_session, sample_paper, _summarize_tmp_paths):
"""pi 超时 → timeout 错误,retry_count 递增。"""
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
@@ -245,7 +245,7 @@ class TestSummarizeOneFlow:
assert result["retry_count"] == 1
@pytest.mark.asyncio
async def test_json_not_found(self, db_session, sample_paper, _patch_paths):
async def test_json_not_found(self, db_session, sample_paper, _summarize_tmp_paths):
"""pi 输出无 JSON → 验证循环重试 4 次后 ValueError (unknown)。"""
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
@@ -262,7 +262,7 @@ class TestSummarizeOneFlow:
@pytest.mark.asyncio
async def test_validation_fails_and_retries(
self, db_session, sample_paper, _patch_paths
self, db_session, sample_paper, _summarize_tmp_paths
):
"""验证失败(字段不符合要求)→ 重试多次后失败。"""
bad_json = json.dumps(
@@ -294,7 +294,7 @@ class TestSummarizeOneFlow:
@pytest.mark.asyncio
async def test_raw_output_saved_on_failure(
self, db_session, sample_paper, tmp_path, _patch_paths
self, db_session, sample_paper, tmp_path, _summarize_tmp_paths
):
"""失败时仍保存 raw_output.txt。"""
with (
@@ -313,7 +313,7 @@ class TestSummarizeOneFlow:
@pytest.mark.asyncio
async def test_tmp_cleaned_on_success(
self, db_session, sample_paper, mock_pi_output, tmp_path, _patch_paths
self, db_session, sample_paper, mock_pi_output, tmp_path, _summarize_tmp_paths
):
"""成功后清理 tmp 目录。"""
with (
@@ -331,7 +331,7 @@ class TestSummarizeOneFlow:
@pytest.mark.asyncio
async def test_tmp_cleaned_on_failure(
self, db_session, sample_paper, tmp_path, _patch_paths
self, db_session, sample_paper, tmp_path, _summarize_tmp_paths
):
"""失败后也清理 tmp 目录。"""
with (
@@ -347,7 +347,7 @@ class TestSummarizeOneFlow:
assert not tmp_paper.exists()
@pytest.mark.asyncio
async def test_skips_done_paper(self, db_session, sample_paper, _patch_paths):
async def test_skips_done_paper(self, db_session, sample_paper, _summarize_tmp_paths):
"""已完成的论文跳过。"""
sample_paper.summary_status.status = "done"
db_session.commit()
@@ -364,26 +364,12 @@ class TestSummarizeOneFlow:
class TestBatchSummarize:
"""批量总结测试。"""
@pytest.fixture
def _patch_paths(self, tmp_path):
with (
patch(
"app.services.summarizer.paper_dir",
lambda aid: tmp_path / "papers" / aid,
),
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
patch("app.utils.TMP_DIR", tmp_path / "tmp"),
):
yield
@pytest.mark.asyncio
async def test_batch_multiple_papers(
self, db_session, db_engine, mock_pi_output, _patch_paths
self, db_session, db_engine, mock_pi_output, _summarize_tmp_paths
):
"""批量处理多篇论文。"""
now = datetime.now(timezone.utc)
now = utc_now()
for i in range(3):
p = Paper(
arxiv_id=f"2401.1234{i}",
@@ -426,10 +412,10 @@ class TestBatchSummarize:
@pytest.mark.asyncio
async def test_single_failure_no_block(
self, db_session, db_engine, mock_pi_output, _patch_paths
self, db_session, db_engine, mock_pi_output, _summarize_tmp_paths
):
"""一篇失败不阻塞其他。"""
now = datetime.now(timezone.utc)
now = utc_now()
for i in range(2):
p = Paper(
arxiv_id=f"2401.5678{i}",
@@ -451,7 +437,7 @@ class TestBatchSummarize:
call_count = 0
async def _mock_call_pi(meta_path, pdf_path):
async def _mock_call_pi(meta_path, pdf_path, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
@@ -468,7 +454,7 @@ class TestBatchSummarize:
assert result["failed"] == 1
@pytest.mark.asyncio
async def test_task_lock_conflict(self, db_session, _patch_paths):
async def test_task_lock_conflict(self, db_session, _summarize_tmp_paths):
"""TaskLock 防止并发 batch。"""
# 先插入一个 running 锁
db_session.add(
@@ -476,7 +462,7 @@ class TestBatchSummarize:
task="summarize",
lock_key="batch",
status="running",
acquired_at=datetime.now(timezone.utc),
acquired_at=utc_now(),
)
)
db_session.commit()
@@ -486,7 +472,7 @@ class TestBatchSummarize:
@pytest.mark.asyncio
async def test_task_lock_released(
self, db_session, db_engine, mock_pi_output, _patch_paths
self, db_session, db_engine, mock_pi_output, _summarize_tmp_paths
):
"""完成后释放 TaskLock。"""
from sqlalchemy.orm import sessionmaker as _sm
@@ -516,7 +502,7 @@ class TestBatchSummarize:
assert lock.released_at is not None
@pytest.mark.asyncio
async def test_batch_empty(self, db_session, _patch_paths):
async def test_batch_empty(self, db_session, _summarize_tmp_paths):
"""无 pending 论文时返回空结果。"""
result = await summarize_batch(db_session)
assert result["status"] == "success"
+8 -30
View File
@@ -2,8 +2,12 @@
from __future__ import annotations
import json
from datetime import datetime, timezone
from app.services.user_data import (
get_note,
save_note,
set_reading_status,
toggle_bookmark,
)
# ═══════════════════════════════════════════════════════════════════════
@@ -13,22 +17,16 @@ from datetime import datetime, timezone
class TestBookmarkService:
def test_toggle_bookmark_add(self, db_session, sample_paper):
from app.services.user_data import toggle_bookmark
result = toggle_bookmark(db_session, "2401.12345")
assert result["bookmarked"] is True
assert result["arxiv_id"] == "2401.12345"
def test_toggle_bookmark_remove(self, db_session, sample_paper):
from app.services.user_data import toggle_bookmark
toggle_bookmark(db_session, "2401.12345") # 添加
result = toggle_bookmark(db_session, "2401.12345") # 移除
toggle_bookmark(db_session, "2401.12345")
result = toggle_bookmark(db_session, "2401.12345")
assert result["bookmarked"] is False
def test_toggle_bookmark_not_found(self, db_session):
from app.services.user_data import toggle_bookmark
result = toggle_bookmark(db_session, "nonexistent")
assert "error" in result
assert result["error"] == "not_found"
@@ -41,36 +39,26 @@ class TestBookmarkService:
class TestReadingStatusService:
def test_set_reading_status(self, db_session, sample_paper):
from app.services.user_data import set_reading_status
result = set_reading_status(db_session, "2401.12345", "read_summary")
assert result["status"] == "read_summary"
assert result["arxiv_id"] == "2401.12345"
def test_set_reading_status_invalid(self, db_session, sample_paper):
from app.services.user_data import set_reading_status
result = set_reading_status(db_session, "2401.12345", "invalid_status")
assert "error" in result
assert result["error"] == "invalid_status"
def test_update_existing_status(self, db_session, sample_paper):
from app.services.user_data import set_reading_status
set_reading_status(db_session, "2401.12345", "skimmed")
result = set_reading_status(db_session, "2401.12345", "read_full")
assert result["status"] == "read_full"
def test_set_reading_status_not_found(self, db_session):
from app.services.user_data import set_reading_status
result = set_reading_status(db_session, "nonexistent", "unread")
assert "error" in result
assert result["error"] == "not_found"
def test_all_valid_statuses(self, db_session, sample_paper):
from app.services.user_data import set_reading_status
for status in ("unread", "skimmed", "read_summary", "read_full"):
result = set_reading_status(db_session, "2401.12345", status)
assert result["status"] == status
@@ -83,8 +71,6 @@ class TestReadingStatusService:
class TestNoteService:
def test_save_and_get_note(self, db_session, sample_paper):
from app.services.user_data import get_note, save_note
save_note(db_session, "2401.12345", "这是一条测试笔记")
result = get_note(db_session, "2401.12345")
assert result["content"] == "这是一条测试笔记"
@@ -92,29 +78,21 @@ class TestNoteService:
assert result["updated_at"] is not None
def test_update_note(self, db_session, sample_paper):
from app.services.user_data import get_note, save_note
save_note(db_session, "2401.12345", "旧笔记")
save_note(db_session, "2401.12345", "新笔记")
result = get_note(db_session, "2401.12345")
assert result["content"] == "新笔记"
def test_get_note_empty(self, db_session, sample_paper):
from app.services.user_data import get_note
result = get_note(db_session, "2401.12345")
assert result["content"] == ""
assert result["updated_at"] is None
def test_get_note_paper_not_found(self, db_session):
from app.services.user_data import get_note
result = get_note(db_session, "nonexistent")
assert result is None
def test_save_note_paper_not_found(self, db_session):
from app.services.user_data import save_note
result = save_note(db_session, "nonexistent", "内容")
assert "error" in result
assert result["error"] == "not_found"