Files
daily-paper/tests/test_admin_features.py
T

381 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""管理后台新功能测试 — 任务监控、锁释放、重抓、失败分布、配置、导出、重建索引。"""
from __future__ import annotations
import pytest
from sqlalchemy import select
from unittest.mock import AsyncMock, patch
from app.models import Job, JobStatus, SummaryState, SummaryStatus, TaskLock
from app.services import admin as admin_svc
from app.services.crawler import recrawl_single
from app.services.jobs import create_job, run_job
from app.utils import utc_now
@pytest.fixture
def no_enqueue(monkeypatch):
"""禁用路由层的 enqueue_job,阻止 background task 在测试中真实执行。"""
from app.routes import admin as admin_route
monkeypatch.setattr(admin_route, "enqueue_job", lambda *a, **k: None)
# ═══════════════════════════════════════════════════════════════════════
# 任务监控
# ═══════════════════════════════════════════════════════════════════════
def _make_job(db_session, *, type="crawl_daily", status=JobStatus.QUEUED, owner="t"):
job = Job(
type=type,
status=status,
owner=owner,
payload_json="{}",
created_at=utc_now(),
)
db_session.add(job)
db_session.commit()
return job
def test_query_jobs_filter_and_pagination(db_session):
for i in range(25):
_make_job(db_session, status=JobStatus.SUCCESS)
for i in range(5):
_make_job(db_session, status=JobStatus.FAILED)
# 无过滤:分页
page1, total = admin_svc.query_jobs(db_session, page=1, per_page=20)
assert total == 30
assert len(page1) == 20
page2, _ = admin_svc.query_jobs(db_session, page=2, per_page=20)
assert len(page2) == 10
# status 过滤
failed, ftotal = admin_svc.query_jobs(db_session, status="failed", per_page=50)
assert ftotal == 5
assert len(failed) == 5
assert all(j["status"] == "failed" for j in failed)
# type 过滤
_make_job(db_session, type="reindex_fts", status=JobStatus.QUEUED)
typed, ttotal = admin_svc.query_jobs(db_session, job_type="reindex_fts")
assert ttotal == 1
assert typed[0]["type"] == "reindex_fts"
def test_serialize_job_includes_duration(db_session):
job = _make_job(db_session, status=JobStatus.SUCCESS)
job.started_at = utc_now()
job.completed_at = utc_now()
db_session.commit()
serialized = admin_svc.serialize_job(job)
assert serialized["duration_seconds"] is not None
assert serialized["duration_seconds"] >= 0
def test_serialize_job_running_without_completed(db_session):
# 运行中的 jobcompleted_at=Nonestarted_at 经 db 读回为 naive UTC
# 不能与 aware 的 utc_now() 直接相减(回归测试)。
job = _make_job(db_session, status=JobStatus.RUNNING)
job.started_at = utc_now()
db_session.commit()
serialized = admin_svc.serialize_job(job)
assert serialized["duration_seconds"] is not None
assert serialized["duration_seconds"] >= 0
def test_get_job_status_counts(db_session):
_make_job(db_session, status=JobStatus.QUEUED)
_make_job(db_session, status=JobStatus.QUEUED)
_make_job(db_session, status=JobStatus.RUNNING)
counts = admin_svc.get_job_status_counts(db_session)
assert counts.get("queued") == 2
assert counts.get("running") == 1
# ═══════════════════════════════════════════════════════════════════════
# 锁释放
# ═══════════════════════════════════════════════════════════════════════
def test_force_release_lock(db_session):
running = TaskLock(
task="crawl", lock_key="k1", status="running", acquired_at=utc_now()
)
stale = TaskLock(task="crawl", lock_key="k2", status="stale", acquired_at=utc_now())
finished = TaskLock(
task="crawl", lock_key="k3", status="finished", acquired_at=utc_now()
)
db_session.add_all([running, stale, finished])
db_session.commit()
assert admin_svc.force_release_lock(db_session, running.id) is True
assert admin_svc.force_release_lock(db_session, stale.id) is True
# finished 的不应被再次释放
assert admin_svc.force_release_lock(db_session, finished.id) is False
# 不存在的 id
assert admin_svc.force_release_lock(db_session, 999999) is False
db_session.refresh(running)
db_session.refresh(stale)
db_session.refresh(finished)
assert running.status == "finished"
assert running.released_at is not None
assert stale.status == "finished"
assert finished.status == "finished"
# ═══════════════════════════════════════════════════════════════════════
# 失败原因分布
# ═══════════════════════════════════════════════════════════════════════
def test_get_failure_breakdown(db_session, sample_papers_range):
statuses = (
db_session.execute(select(SummaryStatus).order_by(SummaryStatus.id))
.scalars()
.all()
)
statuses[0].status = SummaryState.FAILED
statuses[0].error_type = "pdf_download_failed"
statuses[1].status = SummaryState.FAILED
statuses[1].error_type = "timeout"
statuses[2].status = SummaryState.PERMANENT_FAILURE
statuses[2].error_type = None # 归 unknown
db_session.commit()
breakdown = admin_svc.get_failure_breakdown(db_session)
by_type = {b["error_type"]: b["count"] for b in breakdown}
assert by_type.get("pdf_download_failed") == 1
assert by_type.get("timeout") == 1
assert by_type.get("unknown") == 1
# 降序
counts = [b["count"] for b in breakdown]
assert counts == sorted(counts, reverse=True)
# ═══════════════════════════════════════════════════════════════════════
# 配置概览
# ═══════════════════════════════════════════════════════════════════════
def test_get_config_overview_no_secrets():
cfg = admin_svc.get_config_overview()
assert "summary_backend" in cfg
assert "schedule_time" in cfg
assert "api_key_configured" in cfg # 只标是否配置,不显值
text = str(cfg)
# 不应泄露默认密钥值
assert "change-me" not in text
# ═══════════════════════════════════════════════════════════════════════
# 单篇/批量重抓
# ═══════════════════════════════════════════════════════════════════════
class TestRecrawl:
@pytest.mark.asyncio
async def test_not_found(self, db_session):
res = await recrawl_single(db_session, "9999.99999")
assert res["updated"] is False
assert res["reason"] == "not_found"
@pytest.mark.asyncio
async def test_updates_full_metadata(self, db_session, sample_paper):
new_item = {
"paper": {
"id": sample_paper.arxiv_id,
"title": "Updated Title",
"abstract": "New abstract",
"publishedAt": "2024-01-15T00:00:00",
"authors": [{"name": "New Author"}],
"tags": [{"name": "CV"}, {"name": "Diffusion"}],
"upvotes": 100,
}
}
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
return_value=[new_item],
):
res = await recrawl_single(db_session, sample_paper.arxiv_id)
assert res["updated"] is True
db_session.refresh(sample_paper)
assert sample_paper.title_en == "Updated Title"
assert sample_paper.abstract == "New abstract"
assert sample_paper.upvotes == 100
# authors 重建(原 Alice/Bob → New Author
assert sorted(a.name for a in sample_paper.authors) == ["New Author"]
# tags 重建(原 NLP/LLM → CV/Diffusion
assert sorted(t.tag for t in sample_paper.tags) == ["CV", "Diffusion"]
@pytest.mark.asyncio
async def test_not_in_daily(self, db_session, sample_paper):
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
return_value=[],
):
res = await recrawl_single(db_session, sample_paper.arxiv_id)
assert res["updated"] is False
assert res["reason"] == "not_in_daily"
assert "date" in res
class TestDispatchRecrawl:
@pytest.mark.asyncio
async def test_recrawl_one_via_run_job(self, db_session, sample_paper):
new_item = {
"paper": {
"id": sample_paper.arxiv_id,
"title": "Via Job",
"authors": [],
"tags": [],
"upvotes": 5,
}
}
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
return_value=[new_item],
):
job = create_job(
db_session,
"recrawl_one",
owner="test",
payload={"arxiv_id": sample_paper.arxiv_id},
)
result = await run_job(db_session, job.id)
assert result["updated"] is True
db_session.refresh(sample_paper)
assert sample_paper.title_en == "Via Job"
@pytest.mark.asyncio
async def test_recrawl_batch_via_run_job(self, db_session, sample_papers_range):
arxiv_ids = [p.arxiv_id for p in sample_papers_range[:2]]
items = [
{
"paper": {
"id": aid,
"title": "Batch " + aid,
"authors": [],
"tags": [],
"upvotes": 1,
}
}
for aid in arxiv_ids
]
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
side_effect=lambda d: items,
):
job = create_job(
db_session,
"recrawl_batch",
owner="test",
payload={"arxiv_ids": arxiv_ids},
)
result = await run_job(db_session, job.id)
assert result["updated"] == 2
assert result["skipped"] == 0
# ═══════════════════════════════════════════════════════════════════════
# 路由
# ═══════════════════════════════════════════════════════════════════════
class TestRoutes:
def test_jobs_page_renders(self, auth_client):
resp = auth_client.get("/admin/jobs")
assert resp.status_code == 200
assert "任务监控" in resp.text
def test_jobs_page_filters_by_status(self, auth_client, db_session):
_make_job(db_session, status=JobStatus.FAILED)
resp = auth_client.get("/admin/jobs?status=failed")
assert resp.status_code == 200
def test_export_csv(self, auth_client, sample_papers_range):
resp = auth_client.get("/admin/papers/export.csv")
assert resp.status_code == 200
assert "text/csv" in resp.headers["content-type"]
# UTF-8 BOM for Excel
assert resp.content.startswith(b"\xef\xbb\xbf")
# 表头 + 数据
assert "arxiv_id" in resp.text
assert "2401.10001" in resp.text
def test_export_csv_respects_filter(self, auth_client, sample_papers_range):
resp = auth_client.get("/admin/papers/export.csv?q=Paper%203")
assert resp.status_code == 200
assert "2401.10003" in resp.text
assert "2401.10001" not in resp.text
def test_rebuild_indexes_fts(self, auth_client, db_session, no_enqueue):
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "fts"})
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "queued"
assert len(data["job_ids"]) == 1
jobs = (
db_session.execute(select(Job).where(Job.type == "reindex_fts"))
.scalars()
.all()
)
assert len(jobs) == 1
def test_rebuild_indexes_both(self, auth_client, db_session, no_enqueue):
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "both"})
assert resp.status_code == 200
data = resp.json()
assert len(data["job_ids"]) == 2
def test_release_lock_route(self, auth_client, db_session):
lock = TaskLock(
task="crawl", lock_key="rt", status="running", acquired_at=utc_now()
)
db_session.add(lock)
db_session.commit()
resp = auth_client.post(f"/admin/locks/{lock.id}/release")
assert resp.status_code == 200
db_session.refresh(lock)
assert lock.status == "finished"
def test_paper_recrawl_route(
self, auth_client, sample_paper, db_session, no_enqueue
):
resp = auth_client.post(f"/admin/paper-recrawl/{sample_paper.arxiv_id}")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "queued"
jobs = (
db_session.execute(select(Job).where(Job.type == "recrawl_one"))
.scalars()
.all()
)
assert len(jobs) == 1
def test_batch_recrawl_route(
self, auth_client, sample_papers_range, db_session, no_enqueue
):
ids = [p.arxiv_id for p in sample_papers_range[:3]]
resp = auth_client.post(
"/admin/papers-batch-action", json={"action": "recrawl", "arxiv_ids": ids}
)
assert resp.status_code == 200
assert resp.json()["status"] == "queued"
jobs = (
db_session.execute(select(Job).where(Job.type == "recrawl_batch"))
.scalars()
.all()
)
assert len(jobs) == 1