381 lines
15 KiB
Python
381 lines
15 KiB
Python
"""管理后台新功能测试 — 任务监控、锁释放、重抓、失败分布、配置、导出、重建索引。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import pytest
|
||
from sqlalchemy import select
|
||
from unittest.mock import AsyncMock, patch
|
||
|
||
from app.models import Job, JobStatus, SummaryState, SummaryStatus, TaskLock
|
||
from app.services import admin as admin_svc
|
||
from app.services.crawler import recrawl_single
|
||
from app.services.jobs import create_job, run_job
|
||
from app.utils import utc_now
|
||
|
||
|
||
@pytest.fixture
|
||
def no_enqueue(monkeypatch):
|
||
"""禁用路由层的 enqueue_job,阻止 background task 在测试中真实执行。"""
|
||
from app.routes import admin as admin_route
|
||
|
||
monkeypatch.setattr(admin_route, "enqueue_job", lambda *a, **k: None)
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 任务监控
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
def _make_job(db_session, *, type="crawl_daily", status=JobStatus.QUEUED, owner="t"):
|
||
job = Job(
|
||
type=type,
|
||
status=status,
|
||
owner=owner,
|
||
payload_json="{}",
|
||
created_at=utc_now(),
|
||
)
|
||
db_session.add(job)
|
||
db_session.commit()
|
||
return job
|
||
|
||
|
||
def test_query_jobs_filter_and_pagination(db_session):
|
||
for i in range(25):
|
||
_make_job(db_session, status=JobStatus.SUCCESS)
|
||
for i in range(5):
|
||
_make_job(db_session, status=JobStatus.FAILED)
|
||
|
||
# 无过滤:分页
|
||
page1, total = admin_svc.query_jobs(db_session, page=1, per_page=20)
|
||
assert total == 30
|
||
assert len(page1) == 20
|
||
page2, _ = admin_svc.query_jobs(db_session, page=2, per_page=20)
|
||
assert len(page2) == 10
|
||
|
||
# status 过滤
|
||
failed, ftotal = admin_svc.query_jobs(db_session, status="failed", per_page=50)
|
||
assert ftotal == 5
|
||
assert len(failed) == 5
|
||
assert all(j["status"] == "failed" for j in failed)
|
||
|
||
# type 过滤
|
||
_make_job(db_session, type="reindex_fts", status=JobStatus.QUEUED)
|
||
typed, ttotal = admin_svc.query_jobs(db_session, job_type="reindex_fts")
|
||
assert ttotal == 1
|
||
assert typed[0]["type"] == "reindex_fts"
|
||
|
||
|
||
def test_serialize_job_includes_duration(db_session):
|
||
job = _make_job(db_session, status=JobStatus.SUCCESS)
|
||
job.started_at = utc_now()
|
||
job.completed_at = utc_now()
|
||
db_session.commit()
|
||
serialized = admin_svc.serialize_job(job)
|
||
assert serialized["duration_seconds"] is not None
|
||
assert serialized["duration_seconds"] >= 0
|
||
|
||
|
||
def test_serialize_job_running_without_completed(db_session):
|
||
# 运行中的 job:completed_at=None,started_at 经 db 读回为 naive UTC,
|
||
# 不能与 aware 的 utc_now() 直接相减(回归测试)。
|
||
job = _make_job(db_session, status=JobStatus.RUNNING)
|
||
job.started_at = utc_now()
|
||
db_session.commit()
|
||
serialized = admin_svc.serialize_job(job)
|
||
assert serialized["duration_seconds"] is not None
|
||
assert serialized["duration_seconds"] >= 0
|
||
|
||
|
||
def test_get_job_status_counts(db_session):
|
||
_make_job(db_session, status=JobStatus.QUEUED)
|
||
_make_job(db_session, status=JobStatus.QUEUED)
|
||
_make_job(db_session, status=JobStatus.RUNNING)
|
||
counts = admin_svc.get_job_status_counts(db_session)
|
||
assert counts.get("queued") == 2
|
||
assert counts.get("running") == 1
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 锁释放
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
def test_force_release_lock(db_session):
|
||
running = TaskLock(
|
||
task="crawl", lock_key="k1", status="running", acquired_at=utc_now()
|
||
)
|
||
stale = TaskLock(task="crawl", lock_key="k2", status="stale", acquired_at=utc_now())
|
||
finished = TaskLock(
|
||
task="crawl", lock_key="k3", status="finished", acquired_at=utc_now()
|
||
)
|
||
db_session.add_all([running, stale, finished])
|
||
db_session.commit()
|
||
|
||
assert admin_svc.force_release_lock(db_session, running.id) is True
|
||
assert admin_svc.force_release_lock(db_session, stale.id) is True
|
||
# finished 的不应被再次释放
|
||
assert admin_svc.force_release_lock(db_session, finished.id) is False
|
||
# 不存在的 id
|
||
assert admin_svc.force_release_lock(db_session, 999999) is False
|
||
|
||
db_session.refresh(running)
|
||
db_session.refresh(stale)
|
||
db_session.refresh(finished)
|
||
assert running.status == "finished"
|
||
assert running.released_at is not None
|
||
assert stale.status == "finished"
|
||
assert finished.status == "finished"
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 失败原因分布
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
def test_get_failure_breakdown(db_session, sample_papers_range):
|
||
statuses = (
|
||
db_session.execute(select(SummaryStatus).order_by(SummaryStatus.id))
|
||
.scalars()
|
||
.all()
|
||
)
|
||
statuses[0].status = SummaryState.FAILED
|
||
statuses[0].error_type = "pdf_download_failed"
|
||
statuses[1].status = SummaryState.FAILED
|
||
statuses[1].error_type = "timeout"
|
||
statuses[2].status = SummaryState.PERMANENT_FAILURE
|
||
statuses[2].error_type = None # 归 unknown
|
||
db_session.commit()
|
||
|
||
breakdown = admin_svc.get_failure_breakdown(db_session)
|
||
by_type = {b["error_type"]: b["count"] for b in breakdown}
|
||
assert by_type.get("pdf_download_failed") == 1
|
||
assert by_type.get("timeout") == 1
|
||
assert by_type.get("unknown") == 1
|
||
# 降序
|
||
counts = [b["count"] for b in breakdown]
|
||
assert counts == sorted(counts, reverse=True)
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 配置概览
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
def test_get_config_overview_no_secrets():
|
||
cfg = admin_svc.get_config_overview()
|
||
assert "summary_backend" in cfg
|
||
assert "schedule_time" in cfg
|
||
assert "api_key_configured" in cfg # 只标是否配置,不显值
|
||
text = str(cfg)
|
||
# 不应泄露默认密钥值
|
||
assert "change-me" not in text
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 单篇/批量重抓
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestRecrawl:
|
||
@pytest.mark.asyncio
|
||
async def test_not_found(self, db_session):
|
||
res = await recrawl_single(db_session, "9999.99999")
|
||
assert res["updated"] is False
|
||
assert res["reason"] == "not_found"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_updates_full_metadata(self, db_session, sample_paper):
|
||
new_item = {
|
||
"paper": {
|
||
"id": sample_paper.arxiv_id,
|
||
"title": "Updated Title",
|
||
"abstract": "New abstract",
|
||
"publishedAt": "2024-01-15T00:00:00",
|
||
"authors": [{"name": "New Author"}],
|
||
"tags": [{"name": "CV"}, {"name": "Diffusion"}],
|
||
"upvotes": 100,
|
||
}
|
||
}
|
||
with patch(
|
||
"app.services.crawler.fetch_daily",
|
||
new_callable=AsyncMock,
|
||
return_value=[new_item],
|
||
):
|
||
res = await recrawl_single(db_session, sample_paper.arxiv_id)
|
||
|
||
assert res["updated"] is True
|
||
db_session.refresh(sample_paper)
|
||
assert sample_paper.title_en == "Updated Title"
|
||
assert sample_paper.abstract == "New abstract"
|
||
assert sample_paper.upvotes == 100
|
||
# authors 重建(原 Alice/Bob → New Author)
|
||
assert sorted(a.name for a in sample_paper.authors) == ["New Author"]
|
||
# tags 重建(原 NLP/LLM → CV/Diffusion)
|
||
assert sorted(t.tag for t in sample_paper.tags) == ["CV", "Diffusion"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_not_in_daily(self, db_session, sample_paper):
|
||
with patch(
|
||
"app.services.crawler.fetch_daily",
|
||
new_callable=AsyncMock,
|
||
return_value=[],
|
||
):
|
||
res = await recrawl_single(db_session, sample_paper.arxiv_id)
|
||
assert res["updated"] is False
|
||
assert res["reason"] == "not_in_daily"
|
||
assert "date" in res
|
||
|
||
|
||
class TestDispatchRecrawl:
|
||
@pytest.mark.asyncio
|
||
async def test_recrawl_one_via_run_job(self, db_session, sample_paper):
|
||
new_item = {
|
||
"paper": {
|
||
"id": sample_paper.arxiv_id,
|
||
"title": "Via Job",
|
||
"authors": [],
|
||
"tags": [],
|
||
"upvotes": 5,
|
||
}
|
||
}
|
||
with patch(
|
||
"app.services.crawler.fetch_daily",
|
||
new_callable=AsyncMock,
|
||
return_value=[new_item],
|
||
):
|
||
job = create_job(
|
||
db_session,
|
||
"recrawl_one",
|
||
owner="test",
|
||
payload={"arxiv_id": sample_paper.arxiv_id},
|
||
)
|
||
result = await run_job(db_session, job.id)
|
||
|
||
assert result["updated"] is True
|
||
db_session.refresh(sample_paper)
|
||
assert sample_paper.title_en == "Via Job"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_recrawl_batch_via_run_job(self, db_session, sample_papers_range):
|
||
arxiv_ids = [p.arxiv_id for p in sample_papers_range[:2]]
|
||
items = [
|
||
{
|
||
"paper": {
|
||
"id": aid,
|
||
"title": "Batch " + aid,
|
||
"authors": [],
|
||
"tags": [],
|
||
"upvotes": 1,
|
||
}
|
||
}
|
||
for aid in arxiv_ids
|
||
]
|
||
with patch(
|
||
"app.services.crawler.fetch_daily",
|
||
new_callable=AsyncMock,
|
||
side_effect=lambda d: items,
|
||
):
|
||
job = create_job(
|
||
db_session,
|
||
"recrawl_batch",
|
||
owner="test",
|
||
payload={"arxiv_ids": arxiv_ids},
|
||
)
|
||
result = await run_job(db_session, job.id)
|
||
|
||
assert result["updated"] == 2
|
||
assert result["skipped"] == 0
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 路由
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestRoutes:
|
||
def test_jobs_page_renders(self, auth_client):
|
||
resp = auth_client.get("/admin/jobs")
|
||
assert resp.status_code == 200
|
||
assert "任务监控" in resp.text
|
||
|
||
def test_jobs_page_filters_by_status(self, auth_client, db_session):
|
||
_make_job(db_session, status=JobStatus.FAILED)
|
||
resp = auth_client.get("/admin/jobs?status=failed")
|
||
assert resp.status_code == 200
|
||
|
||
def test_export_csv(self, auth_client, sample_papers_range):
|
||
resp = auth_client.get("/admin/papers/export.csv")
|
||
assert resp.status_code == 200
|
||
assert "text/csv" in resp.headers["content-type"]
|
||
# UTF-8 BOM for Excel
|
||
assert resp.content.startswith(b"\xef\xbb\xbf")
|
||
# 表头 + 数据
|
||
assert "arxiv_id" in resp.text
|
||
assert "2401.10001" in resp.text
|
||
|
||
def test_export_csv_respects_filter(self, auth_client, sample_papers_range):
|
||
resp = auth_client.get("/admin/papers/export.csv?q=Paper%203")
|
||
assert resp.status_code == 200
|
||
assert "2401.10003" in resp.text
|
||
assert "2401.10001" not in resp.text
|
||
|
||
def test_rebuild_indexes_fts(self, auth_client, db_session, no_enqueue):
|
||
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "fts"})
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert data["status"] == "queued"
|
||
assert len(data["job_ids"]) == 1
|
||
jobs = (
|
||
db_session.execute(select(Job).where(Job.type == "reindex_fts"))
|
||
.scalars()
|
||
.all()
|
||
)
|
||
assert len(jobs) == 1
|
||
|
||
def test_rebuild_indexes_both(self, auth_client, db_session, no_enqueue):
|
||
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "both"})
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert len(data["job_ids"]) == 2
|
||
|
||
def test_release_lock_route(self, auth_client, db_session):
|
||
lock = TaskLock(
|
||
task="crawl", lock_key="rt", status="running", acquired_at=utc_now()
|
||
)
|
||
db_session.add(lock)
|
||
db_session.commit()
|
||
resp = auth_client.post(f"/admin/locks/{lock.id}/release")
|
||
assert resp.status_code == 200
|
||
db_session.refresh(lock)
|
||
assert lock.status == "finished"
|
||
|
||
def test_paper_recrawl_route(
|
||
self, auth_client, sample_paper, db_session, no_enqueue
|
||
):
|
||
resp = auth_client.post(f"/admin/paper-recrawl/{sample_paper.arxiv_id}")
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert data["status"] == "queued"
|
||
jobs = (
|
||
db_session.execute(select(Job).where(Job.type == "recrawl_one"))
|
||
.scalars()
|
||
.all()
|
||
)
|
||
assert len(jobs) == 1
|
||
|
||
def test_batch_recrawl_route(
|
||
self, auth_client, sample_papers_range, db_session, no_enqueue
|
||
):
|
||
ids = [p.arxiv_id for p in sample_papers_range[:3]]
|
||
resp = auth_client.post(
|
||
"/admin/papers-batch-action", json={"action": "recrawl", "arxiv_ids": ids}
|
||
)
|
||
assert resp.status_code == 200
|
||
assert resp.json()["status"] == "queued"
|
||
jobs = (
|
||
db_session.execute(select(Job).where(Job.type == "recrawl_batch"))
|
||
.scalars()
|
||
.all()
|
||
)
|
||
assert len(jobs) == 1
|