Files
daily-paper/tests/test_admin.py
T
Rain-Bus 90fe705e8f refactor: 迁移布局检测模型从 PicoDet 到 DocLayout-YOLO
- 核心变更:
  - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024)
  - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备
  - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式
  - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls]
  - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table)

- 新增文件:
  - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行)
  - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例)

- 配置更新:
  - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID)
  - app/config.py: Settings 类对应字段
  - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等)

- 删除旧文件:
  - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本

- 文档更新:
  - README.md: 更新环境变量说明
  - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py)

此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
2026-06-14 10:41:44 +08:00

474 lines
18 KiB
Python

"""管理接口测试 — admin routes、auth、scheduler、task locks。"""
from __future__ import annotations
import logging
from unittest.mock import patch
import pytest
from sqlalchemy import select, text
from app.config import settings
from app.models import (
CrawlLog,
Job,
SummaryState,
SummaryStatus,
TaskLock,
)
from app.utils import utc_now
# ═══════════════════════════════════════════════════════════════════════
# Admin Routes — 鉴权测试
# ═══════════════════════════════════════════════════════════════════════
class TestAdminAuth:
"""管理接口鉴权测试。"""
def test_no_session_returns_303(self, client, monkeypatch):
"""无 session 时请求管理接口应返回 303 重定向。"""
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
resp = client.post("/admin/crawl", follow_redirects=False)
assert resp.status_code == 303
assert "/admin/login" in resp.headers.get("location", "")
def test_wrong_password_shows_error(self, client, monkeypatch):
"""错误密码应返回登录页并显示错误。"""
monkeypatch.setattr(settings, "ADMIN_USERNAME", "admin")
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "correct-pass")
resp = client.post(
"/admin/login",
data={"username": "admin", "password": "wrong-pass"},
follow_redirects=False,
)
assert resp.status_code == 200
assert "错误" in resp.text or "error" in resp.text.lower()
def test_correct_login_redirects_to_logs(self, client, monkeypatch):
"""正确登录应重定向到 /admin/logs。"""
monkeypatch.setattr(settings, "ADMIN_USERNAME", "admin")
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "test-pass")
resp = client.post(
"/admin/login",
data={"username": "admin", "password": "test-pass"},
follow_redirects=False,
)
assert resp.status_code == 303
assert "/admin/" in resp.headers.get("location", "")
def test_logout_clears_session(self, auth_client, monkeypatch):
"""退出登录后应清除 session。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.post("/admin/logout", follow_redirects=False)
assert resp.status_code == 303
# 退出后访问管理页应被重定向
resp = auth_client.get("/admin/logs", follow_redirects=False)
assert resp.status_code == 303
def test_correct_session_batch_summarize(self, auth_client):
"""已登录调用 batch summarize,应创建后台任务。"""
with patch("app.routes.admin.enqueue_job"):
resp = auth_client.post("/admin/summarize")
assert resp.status_code == 200
assert resp.json()["status"] == "queued"
assert "job_id" in resp.json()
# ═══════════════════════════════════════════════════════════════════════
# Admin Routes — Crawl
# ═══════════════════════════════════════════════════════════════════════
class TestAdminCrawl:
"""POST /admin/crawl 测试。"""
def test_crawl_specific_date(self, auth_client):
"""指定日期抓取。"""
with patch("app.routes.admin.enqueue_job"):
resp = auth_client.post("/admin/crawl?date=2024-01-15")
assert resp.status_code == 200
assert resp.json()["target_date"] == "2024-01-15"
# ═══════════════════════════════════════════════════════════════════════
# Admin Routes — Cleanup
# ═══════════════════════════════════════════════════════════════════════
class TestAdminCleanup:
"""POST /admin/cleanup 测试。"""
def test_cleanup_returns_stats(self, auth_client):
"""同步清理排障接口应返回统计信息。"""
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
mock_cleanup.return_value = {"scanned": 3, "removed": 1, "errors": []}
resp = auth_client.post("/admin/cleanup-now")
assert resp.status_code == 200
data = resp.json()
assert data["scanned"] == 3
assert data["removed"] == 1
def test_cleanup_writes_log(self, auth_client, db_session):
"""同步清理排障接口应写入 crawl_logs。"""
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
mock_cleanup.return_value = {"scanned": 0, "removed": 0, "errors": []}
auth_client.post("/admin/cleanup-now")
logs = (
db_session.execute(select(CrawlLog).where(CrawlLog.task == "cleanup"))
.scalars()
.all()
)
assert len(logs) >= 1
assert logs[-1].status == "success"
# ═══════════════════════════════════════════════════════════════════════
# Admin Routes — Delete
# ═══════════════════════════════════════════════════════════════════════
class TestAdminDelete:
"""POST /admin/delete 测试。"""
def test_delete_requires_confirm(self, auth_client):
"""confirm 不是 'DELETE' 时应返回 422。"""
resp = auth_client.post(
"/admin/delete",
json={
"date_start": "2024-01-10",
"date_end": "2024-01-12",
"include_notes": True,
"confirm": "WRONG",
},
)
assert resp.status_code == 422
def test_delete_with_confirm(self, auth_client, db_session, sample_papers_range):
"""confirm='DELETE' 时应创建后台删除 job。"""
with patch("app.routes.admin.enqueue_job"):
resp = auth_client.post(
"/admin/delete",
json={
"date_start": "2024-01-10",
"date_end": "2024-01-12",
"include_notes": True,
"confirm": "DELETE",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "queued"
assert db_session.get(Job, data["job_id"]) is not None
def test_delete_invalid_date_range(self, auth_client):
"""date_start > date_end 应返回 400。"""
resp = auth_client.post(
"/admin/delete",
json={
"date_start": "2024-01-15",
"date_end": "2024-01-10",
"confirm": "DELETE",
},
)
assert resp.status_code == 400
# ═══════════════════════════════════════════════════════════════════════
# Admin Routes — Logs
# ═══════════════════════════════════════════════════════════════════════
class TestAdminLogs:
"""GET /admin/logs 测试。"""
def test_logs_requires_auth(self, client, monkeypatch):
"""日志页面需要鉴权。"""
monkeypatch.setattr(settings, "ADMIN_PASSWORD", "some-password")
resp = client.get("/admin/logs", follow_redirects=False)
assert resp.status_code == 303
def test_logs_contains_data(self, auth_client, db_session, sample_papers_range):
"""日志页面应包含日志数据。"""
# 先创建一条日志
now = utc_now()
db_session.add(
CrawlLog(
task="crawl",
status="success",
started_at=now,
completed_at=now,
)
)
db_session.commit()
resp = auth_client.get("/admin/logs")
assert resp.status_code == 200
assert "crawl" in resp.text.lower() or "日志" in resp.text
class TestAdminJobs:
"""后台 job 查询接口测试。"""
def test_job_detail_returns_payload_and_events(self, auth_client, db_session):
"""GET /admin/jobs/{id} 返回 job 主记录和事件。"""
with patch("app.routes.admin.enqueue_job"):
resp = auth_client.post("/admin/crawl?date=2024-01-15")
job_id = resp.json()["job_id"]
resp = auth_client.get(f"/admin/jobs/{job_id}")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == job_id
assert data["type"] == "crawl_daily"
assert data["payload"] == {"target_date": "2024-01-15"}
assert data["events"][0]["stage"] == "created"
def test_job_detail_not_found(self, auth_client):
resp = auth_client.get("/admin/jobs/999999")
assert resp.status_code == 404
class TestAdminSummaryStatus:
"""总结状态管理接口测试。"""
def test_summary_status_json_filters_failed(
self, auth_client, db_session, sample_paper
):
sample_paper.summary_status.status = SummaryState.FAILED
sample_paper.summary_status.retry_count = 2
sample_paper.summary_status.error_type = "timeout"
db_session.commit()
resp = auth_client.get("/admin/summary-status?status=failed")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
assert data["items"][0]["arxiv_id"] == sample_paper.arxiv_id
assert data["items"][0]["retry_count"] == 2
def test_retry_failed_resets_failed_statuses(
self, auth_client, db_session, sample_paper
):
sample_paper.summary_status.status = SummaryState.PERMANENT_FAILURE
sample_paper.summary_status.error = "bad json"
sample_paper.summary_status.error_type = "json_invalid"
db_session.commit()
resp = auth_client.post("/admin/summary-retry-failed")
assert resp.status_code == 200
assert resp.json()["count"] == 1
db_session.refresh(sample_paper.summary_status)
assert sample_paper.summary_status.status == SummaryState.PENDING
assert sample_paper.summary_status.error is None
assert sample_paper.summary_status.error_type is None
class TestAdminPapers:
"""论文管理批量操作测试。"""
def test_single_delete_removes_paper_and_fts(
self, auth_client, db_session, sample_paper
):
paper_id = sample_paper.id
resp = auth_client.post(f"/admin/paper-delete/{sample_paper.arxiv_id}")
assert resp.status_code == 200
assert db_session.get(type(sample_paper), paper_id) is None
fts_row = db_session.execute(
text("SELECT rowid FROM papers_fts WHERE rowid = :id"),
{"id": paper_id},
).fetchone()
assert fts_row is None
def test_batch_delete_removes_papers_and_fts(
self, auth_client, db_session, sample_papers_range
):
target_ids = [p.id for p in sample_papers_range[:2]]
target_arxiv_ids = [p.arxiv_id for p in sample_papers_range[:2]]
resp = auth_client.post(
"/admin/papers-batch-action",
json={"action": "delete", "arxiv_ids": target_arxiv_ids},
)
assert resp.status_code == 200
assert resp.json()["count"] == 2
remaining = db_session.execute(
text("SELECT rowid FROM papers_fts WHERE rowid IN (:id1, :id2)"),
{"id1": target_ids[0], "id2": target_ids[1]},
).fetchall()
assert remaining == []
def test_batch_summarize_sets_pending_status(
self, auth_client, db_session, sample_papers_range
):
paper = sample_papers_range[0]
paper.summary_status.status = SummaryState.DONE
db_session.commit()
resp = auth_client.post(
"/admin/papers-batch-action",
json={"action": "summarize", "arxiv_ids": [paper.arxiv_id]},
)
assert resp.status_code == 200
status = db_session.scalar(
select(SummaryStatus).where(SummaryStatus.paper_id == paper.id)
)
assert status is not None
assert status.status == SummaryState.PENDING
# ═══════════════════════════════════════════════════════════════════════
# Scheduler 测试
# ═══════════════════════════════════════════════════════════════════════
class TestScheduler:
"""app/services/scheduler.py 测试。"""
def test_scheduler_disabled_by_default(self, monkeypatch):
"""SCHEDULER_ENABLED=false 时不应启动调度器。"""
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", False)
import app.services.scheduler as sched_mod
sched_mod._scheduler = None
from app.services.scheduler import start_scheduler
result = start_scheduler()
assert result is None
@pytest.mark.asyncio
async def test_scheduler_start_stop(self, monkeypatch):
"""调度器应能正常启动和停止。"""
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", True)
monkeypatch.setattr(settings, "APP_WORKERS", 1)
import app.services.scheduler as sched_mod
sched_mod._scheduler = None
from app.services.scheduler import start_scheduler, stop_scheduler
scheduler = start_scheduler()
assert scheduler is not None
# 验证 job 已添加
jobs = scheduler.get_jobs()
assert len(jobs) >= 1
assert jobs[0].id == "daily_pipeline"
stop_scheduler()
assert sched_mod._scheduler is None
@pytest.mark.asyncio
async def test_scheduler_warns_multi_worker(self, monkeypatch, caplog):
"""APP_WORKERS > 1 时应打印警告。"""
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", True)
monkeypatch.setattr(settings, "APP_WORKERS", 4)
import app.services.scheduler as sched_mod
sched_mod._scheduler = None
from app.services.scheduler import start_scheduler, stop_scheduler
with caplog.at_level(logging.WARNING):
scheduler = start_scheduler()
assert scheduler is not None
assert any("APP_WORKERS" in r.message for r in caplog.records)
stop_scheduler()
@pytest.mark.asyncio
async def test_daily_pipeline_lock_prevents_reentry(self, db_session):
"""pipeline 使用 task_locks 防重入。"""
now = utc_now()
lock = TaskLock(
task="scheduler",
lock_key="pipeline-2024-01-15",
status="running",
owner="test",
acquired_at=now,
)
db_session.add(lock)
db_session.commit()
# 第二次获取锁应失败
lock2 = TaskLock(
task="scheduler",
lock_key="pipeline-2024-01-15",
status="running",
owner="test2",
acquired_at=now,
)
db_session.add(lock2)
with pytest.raises(Exception):
db_session.commit()
db_session.rollback()
# ═══════════════════════════════════════════════════════════════════════
# TaskLock 集成测试
# ═══════════════════════════════════════════════════════════════════════
class TestTaskLocks:
"""task_locks 防重入机制测试。"""
def test_unique_running_lock(self, db_session):
"""同一 task + lock_key 只能有一个 running 锁。"""
now = utc_now()
lock1 = TaskLock(
task="crawl",
lock_key="2024-01-15",
status="running",
owner="test1",
acquired_at=now,
)
db_session.add(lock1)
db_session.commit()
lock2 = TaskLock(
task="crawl",
lock_key="2024-01-15",
status="running",
owner="test2",
acquired_at=now,
)
db_session.add(lock2)
with pytest.raises(Exception):
db_session.commit()
db_session.rollback()
def test_released_lock_allows_new(self, db_session):
"""已释放的锁允许新的 running 锁。"""
now = utc_now()
lock1 = TaskLock(
task="crawl",
lock_key="2024-01-16",
status="finished",
owner="test1",
acquired_at=now,
released_at=now,
)
db_session.add(lock1)
db_session.commit()
lock2 = TaskLock(
task="crawl",
lock_key="2024-01-16",
status="running",
owner="test2",
acquired_at=now,
)
db_session.add(lock2)
db_session.commit() # 应成功