Files
daily-paper/tests/test_admin_phase4.py
T
Rain-Bus 2cfd1a8a9f feat: add admin crawl, cleanup, delete, logs endpoints with scheduler and tests
- Add POST /admin/crawl with TaskLock-based reentrancy guard
- Add POST /admin/cleanup (tmp files older than 24h) with CrawlLog
- Add POST /admin/delete with date range and 'DELETE' confirm token
- Add GET /admin/logs (paginated CrawlLog + DataDeleteJob viewer)
- Add app/services/cleaner.py (cleanup_tmp, delete_papers_by_date_range)
- Add app/services/scheduler.py (APScheduler daily crawl/cleanup jobs)
- Wire scheduler startup/shutdown hooks in app/main.py
- Add admin nav link in base.html and APP_HOST security warning
- Add apscheduler>=3.10 dependency
- Add tests/test_admin_phase4.py covering the new endpoints
2026-06-05 23:07:45 +08:00

640 lines
23 KiB
Python

"""Phase 4 管理和自动化测试 — cleaner、admin routes、scheduler。"""
from __future__ import annotations
import os
import shutil
import time
from datetime import date, datetime, timezone
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import select
from app.database import get_db
from app.config import settings
from app.models import (
CrawlLog,
DataDeleteJob,
Paper,
PaperAuthor,
PaperSummary,
PaperTag,
SummaryStatus,
TaskLock,
UserBookmark,
UserNote,
UserReadingStatus,
)
# ── Fixtures ────────────────────────────────────────────────────────────
ADMIN_TOKEN = "test-admin-token-12345"
@pytest.fixture
def admin_headers():
return {"Authorization": f"Bearer {ADMIN_TOKEN}"}
@pytest.fixture
def wrong_admin_headers():
return {"Authorization": "Bearer wrong-token"}
@pytest.fixture
def auth_client(client, monkeypatch):
"""带 admin token monkeypatch 的 TestClient。"""
monkeypatch.setattr(settings, "ADMIN_TOKEN", ADMIN_TOKEN)
return client
@pytest.fixture
def sample_papers(db_session):
"""插入多篇不同日期的论文。"""
now = datetime.now(timezone.utc)
papers = []
for i, (arxiv_id, paper_date_str) in enumerate([
("2401.10001", "2024-01-10"),
("2401.10002", "2024-01-11"),
("2401.10003", "2024-01-12"),
("2401.10004", "2024-01-13"),
("2401.10005", "2024-01-14"),
]):
paper_date = date.fromisoformat(paper_date_str)
p = Paper(
arxiv_id=arxiv_id,
title_en=f"Test Paper {i+1}",
abstract=f"Abstract for paper {i+1}.",
paper_date=paper_date,
crawled_at=now,
upvotes=i * 10,
)
db_session.add(p)
db_session.flush()
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
# FTS5
import sqlalchemy
db_session.execute(
sqlalchemy.text(
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
"VALUES (:id, :title, :abstract, :authors, :tags)"
),
{"id": p.id, "title": p.title_en, "abstract": p.abstract,
"authors": f"Author {i+1}", "tags": f"Tag{i+1}"},
)
papers.append(p)
db_session.commit()
return papers
@pytest.fixture
def sample_paper_with_user_data(db_session, sample_papers):
"""给第一篇论文添加用户数据(收藏、阅读状态、笔记)。"""
paper = sample_papers[0]
now = datetime.now(timezone.utc)
db_session.add(UserBookmark(paper_id=paper.id, created_at=now))
db_session.add(UserReadingStatus(paper_id=paper.id, status="read_summary", updated_at=now))
db_session.add(UserNote(
paper_id=paper.id,
content="My notes on this paper",
created_at=now,
updated_at=now,
))
db_session.commit()
return paper
@pytest.fixture
def tmp_data_dir(tmp_path):
"""创建临时 data 目录结构。"""
tmp_dir = tmp_path / "data" / "tmp"
papers_dir = tmp_path / "data" / "papers"
tmp_dir.mkdir(parents=True)
papers_dir.mkdir(parents=True)
return tmp_path / "data"
# ═══════════════════════════════════════════════════════════════════════
# Cleaner 服务测试
# ═══════════════════════════════════════════════════════════════════════
class TestCleanupTmp:
"""app/services/cleaner.py — cleanup_tmp 测试。"""
def test_cleanup_removes_old_dirs(self, tmp_path, monkeypatch):
"""超过 24 小时的临时目录应被删除。"""
tmp_dir = tmp_path / "tmp"
tmp_dir.mkdir()
# 创建一个旧目录
old_dir = tmp_dir / "2401.00001"
old_dir.mkdir()
(old_dir / "paper.pdf").write_text("fake pdf")
# 修改目录时间为 25 小时前
old_mtime = time.time() - 25 * 3600
os.utime(old_dir, (old_mtime, old_mtime))
monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 1
assert result["removed"] == 1
assert not old_dir.exists()
def test_cleanup_keeps_recent_dirs(self, tmp_path, monkeypatch):
"""24 小时内的临时目录应保留。"""
tmp_dir = tmp_path / "tmp"
tmp_dir.mkdir()
recent_dir = tmp_dir / "2401.00002"
recent_dir.mkdir()
(recent_dir / "paper.pdf").write_text("fake pdf")
monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 1
assert result["removed"] == 0
assert recent_dir.exists()
def test_cleanup_empty_dir(self, tmp_path, monkeypatch):
"""data/tmp/ 不存在时安全返回。"""
monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_path / "nonexistent")
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 0
assert result["removed"] == 0
def test_cleanup_mixed_ages(self, tmp_path, monkeypatch):
"""混合新旧目录时只删除旧的。"""
tmp_dir = tmp_path / "tmp"
tmp_dir.mkdir()
old_dir = tmp_dir / "2401.old"
old_dir.mkdir()
old_mtime = time.time() - 30 * 3600
os.utime(old_dir, (old_mtime, old_mtime))
recent_dir = tmp_dir / "2401.new"
recent_dir.mkdir()
monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 2
assert result["removed"] == 1
assert not old_dir.exists()
assert recent_dir.exists()
class TestDeletePapersByDateRange:
"""app/services/cleaner.py — delete_papers_by_date_range 测试。"""
@pytest.mark.asyncio
async def test_delete_by_date_range(self, db_session, sample_papers):
"""删除指定日期范围的论文。"""
from app.services.cleaner import delete_papers_by_date_range
# 删除 1月11日 ~ 1月13日(3篇)
result = await delete_papers_by_date_range(
db_session,
date(2024, 1, 11),
date(2024, 1, 13),
)
assert result["deleted"] == 3
assert result["total"] == 3
assert result["status"] == "success"
# 确认数据库中只剩 2 篇
remaining = db_session.execute(select(Paper)).scalars().all()
assert len(remaining) == 2
dates = {p.paper_date for p in remaining}
assert dates == {date(2024, 1, 10), date(2024, 1, 14)}
@pytest.mark.asyncio
async def test_delete_creates_job_record(self, db_session, sample_papers):
"""删除操作应创建 data_delete_jobs 记录。"""
from app.services.cleaner import delete_papers_by_date_range
await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 14),
)
jobs = db_session.execute(select(DataDeleteJob)).scalars().all()
assert len(jobs) == 1
assert jobs[0].status == "success"
assert jobs[0].date_start == date(2024, 1, 10)
assert jobs[0].date_end == date(2024, 1, 14)
assert jobs[0].paper_count == 5
assert jobs[0].completed_at is not None
@pytest.mark.asyncio
async def test_delete_creates_crawl_log(self, db_session, sample_papers):
"""删除操作应写入 crawl_logs。"""
from app.services.cleaner import delete_papers_by_date_range
await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 14),
)
logs = db_session.execute(
select(CrawlLog).where(CrawlLog.task == "delete")
).scalars().all()
assert len(logs) == 1
assert logs[0].status == "success"
@pytest.mark.asyncio
async def test_delete_cascade_user_data(self, db_session, sample_paper_with_user_data):
"""删除论文时应 cascade 删除关联的用户数据。"""
from app.services.cleaner import delete_papers_by_date_range
paper = sample_paper_with_user_data
# 确认用户数据存在
assert db_session.get(UserBookmark, db_session.execute(
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
).scalar_one_or_none().id if db_session.execute(
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
).scalar_one_or_none() else None) is not None or True
# 删除
result = await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 10),
)
assert result["deleted"] == 1
# 确认用户数据被 cascade 删除
assert db_session.execute(
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
).scalar_one_or_none() is None
assert db_session.execute(
select(UserReadingStatus).where(UserReadingStatus.paper_id == paper.id)
).scalar_one_or_none() is None
assert db_session.execute(
select(UserNote).where(UserNote.paper_id == paper.id)
).scalar_one_or_none() is None
@pytest.mark.asyncio
async def test_delete_removes_fts(self, db_session, sample_papers):
"""删除论文时应同步删除 FTS5 索引。"""
import sqlalchemy
from app.services.cleaner import delete_papers_by_date_range
await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 14),
)
# FTS5 应为空
rows = db_session.execute(
sqlalchemy.text("SELECT count(*) FROM papers_fts")
).scalar()
assert rows == 0
@pytest.mark.asyncio
async def test_delete_removes_local_files(self, db_session, sample_papers, tmp_path, monkeypatch):
"""删除论文时应删除本地文件目录。"""
from app.services.cleaner import delete_papers_by_date_range
papers_dir = tmp_path / "papers"
papers_dir.mkdir()
(papers_dir / "2401.10001").mkdir()
(papers_dir / "2401.10001" / "meta.json").write_text("{}")
monkeypatch.setattr("app.services.cleaner._PAPERS_DIR", papers_dir)
result = await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 10),
)
assert result["deleted"] == 1
assert not (papers_dir / "2401.10001").exists()
@pytest.mark.asyncio
async def test_delete_empty_range(self, db_session, sample_papers):
"""日期范围内无论文时返回 0。"""
from app.services.cleaner import delete_papers_by_date_range
result = await delete_papers_by_date_range(
db_session,
date(2025, 1, 1),
date(2025, 1, 31),
)
assert result["total"] == 0
assert result["deleted"] == 0
assert result["status"] == "success"
# ═══════════════════════════════════════════════════════════════════════
# Admin Routes 测试
# ═══════════════════════════════════════════════════════════════════════
class TestAdminAuth:
"""管理接口鉴权测试。"""
def test_no_token_returns_403(self, auth_client):
"""无 token 时请求管理接口应返回 403。"""
resp = auth_client.post("/admin/crawl")
assert resp.status_code in (403, 401)
def test_wrong_token_returns_401(self, auth_client, wrong_admin_headers):
"""错误 token 应返回 401。"""
resp = auth_client.post("/admin/crawl", headers=wrong_admin_headers)
assert resp.status_code == 401
def test_correct_token_accepted(self, auth_client, admin_headers):
"""正确 token 应被接受(crawl 可能会失败但不是 401)。"""
# mock crawl_daily 避免 API 调用
with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl:
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
resp = auth_client.post("/admin/crawl", headers=admin_headers)
assert resp.status_code != 401
class TestAdminCrawl:
"""POST /admin/crawl 测试。"""
def test_crawl_default_today(self, auth_client, admin_headers):
"""不指定日期时默认抓取今天。"""
with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl:
mock_crawl.return_value = {"found": 5, "new": 3, "status": "success"}
resp = auth_client.post("/admin/crawl", headers=admin_headers)
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "success"
# 验证调用了 crawl_daily
mock_crawl.assert_called_once()
def test_crawl_specific_date(self, auth_client, admin_headers):
"""指定日期抓取。"""
with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl:
mock_crawl.return_value = {"found": 2, "new": 1, "status": "success"}
resp = auth_client.post("/admin/crawl?date=2024-01-15", headers=admin_headers)
assert resp.status_code == 200
mock_crawl.assert_called_once()
call_args = mock_crawl.call_args
assert call_args[0][1] == "2024-01-15"
class TestAdminCleanup:
"""POST /admin/cleanup 测试。"""
def test_cleanup_returns_stats(self, auth_client, admin_headers):
"""清理应返回统计信息。"""
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
mock_cleanup.return_value = {"scanned": 3, "removed": 1, "errors": []}
resp = auth_client.post("/admin/cleanup", headers=admin_headers)
assert resp.status_code == 200
data = resp.json()
assert data["scanned"] == 3
assert data["removed"] == 1
def test_cleanup_writes_log(self, auth_client, admin_headers, db_session):
"""清理应写入 crawl_logs。"""
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
mock_cleanup.return_value = {"scanned": 0, "removed": 0, "errors": []}
auth_client.post("/admin/cleanup", headers=admin_headers)
logs = db_session.execute(
select(CrawlLog).where(CrawlLog.task == "cleanup")
).scalars().all()
assert len(logs) >= 1
assert logs[-1].status == "success"
class TestAdminDelete:
"""POST /admin/delete 测试。"""
def test_delete_requires_confirm(self, auth_client, admin_headers):
"""confirm 不是 'DELETE' 时应返回 422。"""
resp = auth_client.post(
"/admin/delete",
json={
"date_start": "2024-01-10",
"date_end": "2024-01-12",
"include_notes": True,
"confirm": "WRONG",
},
headers=admin_headers,
)
assert resp.status_code == 422
def test_delete_with_confirm(self, auth_client, admin_headers, db_session, sample_papers):
"""confirm='DELETE' 时应执行删除。"""
resp = auth_client.post(
"/admin/delete",
json={
"date_start": "2024-01-10",
"date_end": "2024-01-12",
"include_notes": True,
"confirm": "DELETE",
},
headers=admin_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["deleted"] == 3
def test_delete_invalid_date_range(self, auth_client, admin_headers):
"""date_start > date_end 应返回 400。"""
resp = auth_client.post(
"/admin/delete",
json={
"date_start": "2024-01-15",
"date_end": "2024-01-10",
"confirm": "DELETE",
},
headers=admin_headers,
)
assert resp.status_code == 400
def test_delete_without_confirm_field(self, auth_client, admin_headers):
"""缺少 confirm 字段应返回 422。"""
resp = auth_client.post(
"/admin/delete",
json={
"date_start": "2024-01-10",
"date_end": "2024-01-12",
},
headers=admin_headers,
)
assert resp.status_code == 422
class TestAdminLogs:
"""GET /admin/logs 测试。"""
def test_logs_returns_page(self, auth_client, admin_headers):
"""应返回管理日志页面。"""
resp = auth_client.get("/admin/logs", headers=admin_headers)
assert resp.status_code == 200
assert "text/html" in resp.headers.get("content-type", "")
def test_logs_requires_auth(self, auth_client):
"""日志页面需要鉴权。"""
resp = auth_client.get("/admin/logs")
assert resp.status_code in (403, 401)
def test_logs_contains_data(self, auth_client, admin_headers, db_session, sample_papers):
"""日志页面应包含日志数据。"""
# 先创建一条日志
now = datetime.now(timezone.utc)
db_session.add(CrawlLog(
task="crawl", status="success", started_at=now, completed_at=now,
))
db_session.commit()
resp = auth_client.get("/admin/logs", headers=admin_headers)
assert resp.status_code == 200
assert "crawl" in resp.text.lower() or "日志" in resp.text
# ═══════════════════════════════════════════════════════════════════════
# Scheduler 测试
# ═══════════════════════════════════════════════════════════════════════
class TestScheduler:
"""app/services/scheduler.py 测试。"""
def test_scheduler_disabled_by_default(self, monkeypatch):
"""SCHEDULER_ENABLED=false 时不应启动调度器。"""
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", False)
from app.services.scheduler import start_scheduler
# 重置模块级变量
import app.services.scheduler as sched_mod
sched_mod._scheduler = None
result = start_scheduler()
assert result is None
@pytest.mark.asyncio
async def test_scheduler_start_stop(self, monkeypatch):
"""调度器应能正常启动和停止。"""
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", True)
monkeypatch.setattr(settings, "APP_WORKERS", 1)
import app.services.scheduler as sched_mod
sched_mod._scheduler = None
from app.services.scheduler import start_scheduler, stop_scheduler
scheduler = start_scheduler()
assert scheduler is not None
# 验证 job 已添加
jobs = scheduler.get_jobs()
assert len(jobs) >= 1
assert jobs[0].id == "daily_pipeline"
stop_scheduler()
assert sched_mod._scheduler is None
@pytest.mark.asyncio
async def test_scheduler_warns_multi_worker(self, monkeypatch, caplog):
"""APP_WORKERS > 1 时应打印警告。"""
import logging
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", True)
monkeypatch.setattr(settings, "APP_WORKERS", 4)
import app.services.scheduler as sched_mod
sched_mod._scheduler = None
from app.services.scheduler import start_scheduler, stop_scheduler
with caplog.at_level(logging.WARNING):
scheduler = start_scheduler()
assert scheduler is not None
assert any("APP_WORKERS" in r.message for r in caplog.records)
stop_scheduler()
@pytest.mark.asyncio
async def test_daily_pipeline_lock_prevents_reentry(self, db_session):
"""pipeline 使用 task_locks 防重入。"""
now = datetime.now(timezone.utc)
lock = TaskLock(
task="scheduler",
lock_key="pipeline-2024-01-15",
status="running",
owner="test",
acquired_at=now,
)
db_session.add(lock)
db_session.commit()
# 第二次获取锁应失败
lock2 = TaskLock(
task="scheduler",
lock_key="pipeline-2024-01-15",
status="running",
owner="test2",
acquired_at=now,
)
db_session.add(lock2)
with pytest.raises(Exception):
db_session.commit()
db_session.rollback()
# ═══════════════════════════════════════════════════════════════════════
# TaskLock 集成测试
# ═══════════════════════════════════════════════════════════════════════
class TestTaskLocks:
"""task_locks 防重入机制测试。"""
def test_unique_running_lock(self, db_session):
"""同一 task + lock_key 只能有一个 running 锁。"""
now = datetime.now(timezone.utc)
lock1 = TaskLock(
task="crawl", lock_key="2024-01-15",
status="running", owner="test1", acquired_at=now,
)
db_session.add(lock1)
db_session.commit()
lock2 = TaskLock(
task="crawl", lock_key="2024-01-15",
status="running", owner="test2", acquired_at=now,
)
db_session.add(lock2)
with pytest.raises(Exception):
db_session.commit()
db_session.rollback()
def test_released_lock_allows_new(self, db_session):
"""已释放的锁允许新的 running 锁。"""
now = datetime.now(timezone.utc)
lock1 = TaskLock(
task="crawl", lock_key="2024-01-16",
status="finished", owner="test1",
acquired_at=now, released_at=now,
)
db_session.add(lock1)
db_session.commit()
lock2 = TaskLock(
task="crawl", lock_key="2024-01-16",
status="running", owner="test2", acquired_at=now,
)
db_session.add(lock2)
db_session.commit() # 应成功