feat: add concurrency safety, caption detection, admin enhancements, and performance improvements

This commit is contained in:
2026-06-14 22:20:02 +08:00
parent 8f13c31991
commit 29fb20828e
23 changed files with 1782 additions and 114 deletions
+380
View File
@@ -0,0 +1,380 @@
"""管理后台新功能测试 — 任务监控、锁释放、重抓、失败分布、配置、导出、重建索引。"""
from __future__ import annotations
import pytest
from sqlalchemy import select
from unittest.mock import AsyncMock, patch
from app.models import Job, JobStatus, SummaryState, SummaryStatus, TaskLock
from app.services import admin as admin_svc
from app.services.crawler import recrawl_single
from app.services.jobs import create_job, run_job
from app.utils import utc_now
@pytest.fixture
def no_enqueue(monkeypatch):
"""禁用路由层的 enqueue_job,阻止 background task 在测试中真实执行。"""
from app.routes import admin as admin_route
monkeypatch.setattr(admin_route, "enqueue_job", lambda *a, **k: None)
# ═══════════════════════════════════════════════════════════════════════
# 任务监控
# ═══════════════════════════════════════════════════════════════════════
def _make_job(db_session, *, type="crawl_daily", status=JobStatus.QUEUED, owner="t"):
job = Job(
type=type,
status=status,
owner=owner,
payload_json="{}",
created_at=utc_now(),
)
db_session.add(job)
db_session.commit()
return job
def test_query_jobs_filter_and_pagination(db_session):
for i in range(25):
_make_job(db_session, status=JobStatus.SUCCESS)
for i in range(5):
_make_job(db_session, status=JobStatus.FAILED)
# 无过滤:分页
page1, total = admin_svc.query_jobs(db_session, page=1, per_page=20)
assert total == 30
assert len(page1) == 20
page2, _ = admin_svc.query_jobs(db_session, page=2, per_page=20)
assert len(page2) == 10
# status 过滤
failed, ftotal = admin_svc.query_jobs(db_session, status="failed", per_page=50)
assert ftotal == 5
assert len(failed) == 5
assert all(j["status"] == "failed" for j in failed)
# type 过滤
_make_job(db_session, type="reindex_fts", status=JobStatus.QUEUED)
typed, ttotal = admin_svc.query_jobs(db_session, job_type="reindex_fts")
assert ttotal == 1
assert typed[0]["type"] == "reindex_fts"
def test_serialize_job_includes_duration(db_session):
job = _make_job(db_session, status=JobStatus.SUCCESS)
job.started_at = utc_now()
job.completed_at = utc_now()
db_session.commit()
serialized = admin_svc.serialize_job(job)
assert serialized["duration_seconds"] is not None
assert serialized["duration_seconds"] >= 0
def test_serialize_job_running_without_completed(db_session):
# 运行中的 jobcompleted_at=Nonestarted_at 经 db 读回为 naive UTC
# 不能与 aware 的 utc_now() 直接相减(回归测试)。
job = _make_job(db_session, status=JobStatus.RUNNING)
job.started_at = utc_now()
db_session.commit()
serialized = admin_svc.serialize_job(job)
assert serialized["duration_seconds"] is not None
assert serialized["duration_seconds"] >= 0
def test_get_job_status_counts(db_session):
_make_job(db_session, status=JobStatus.QUEUED)
_make_job(db_session, status=JobStatus.QUEUED)
_make_job(db_session, status=JobStatus.RUNNING)
counts = admin_svc.get_job_status_counts(db_session)
assert counts.get("queued") == 2
assert counts.get("running") == 1
# ═══════════════════════════════════════════════════════════════════════
# 锁释放
# ═══════════════════════════════════════════════════════════════════════
def test_force_release_lock(db_session):
running = TaskLock(
task="crawl", lock_key="k1", status="running", acquired_at=utc_now()
)
stale = TaskLock(task="crawl", lock_key="k2", status="stale", acquired_at=utc_now())
finished = TaskLock(
task="crawl", lock_key="k3", status="finished", acquired_at=utc_now()
)
db_session.add_all([running, stale, finished])
db_session.commit()
assert admin_svc.force_release_lock(db_session, running.id) is True
assert admin_svc.force_release_lock(db_session, stale.id) is True
# finished 的不应被再次释放
assert admin_svc.force_release_lock(db_session, finished.id) is False
# 不存在的 id
assert admin_svc.force_release_lock(db_session, 999999) is False
db_session.refresh(running)
db_session.refresh(stale)
db_session.refresh(finished)
assert running.status == "finished"
assert running.released_at is not None
assert stale.status == "finished"
assert finished.status == "finished"
# ═══════════════════════════════════════════════════════════════════════
# 失败原因分布
# ═══════════════════════════════════════════════════════════════════════
def test_get_failure_breakdown(db_session, sample_papers_range):
statuses = (
db_session.execute(select(SummaryStatus).order_by(SummaryStatus.id))
.scalars()
.all()
)
statuses[0].status = SummaryState.FAILED
statuses[0].error_type = "pdf_download_failed"
statuses[1].status = SummaryState.FAILED
statuses[1].error_type = "timeout"
statuses[2].status = SummaryState.PERMANENT_FAILURE
statuses[2].error_type = None # 归 unknown
db_session.commit()
breakdown = admin_svc.get_failure_breakdown(db_session)
by_type = {b["error_type"]: b["count"] for b in breakdown}
assert by_type.get("pdf_download_failed") == 1
assert by_type.get("timeout") == 1
assert by_type.get("unknown") == 1
# 降序
counts = [b["count"] for b in breakdown]
assert counts == sorted(counts, reverse=True)
# ═══════════════════════════════════════════════════════════════════════
# 配置概览
# ═══════════════════════════════════════════════════════════════════════
def test_get_config_overview_no_secrets():
cfg = admin_svc.get_config_overview()
assert "summary_backend" in cfg
assert "schedule_time" in cfg
assert "api_key_configured" in cfg # 只标是否配置,不显值
text = str(cfg)
# 不应泄露默认密钥值
assert "change-me" not in text
# ═══════════════════════════════════════════════════════════════════════
# 单篇/批量重抓
# ═══════════════════════════════════════════════════════════════════════
class TestRecrawl:
@pytest.mark.asyncio
async def test_not_found(self, db_session):
res = await recrawl_single(db_session, "9999.99999")
assert res["updated"] is False
assert res["reason"] == "not_found"
@pytest.mark.asyncio
async def test_updates_full_metadata(self, db_session, sample_paper):
new_item = {
"paper": {
"id": sample_paper.arxiv_id,
"title": "Updated Title",
"abstract": "New abstract",
"publishedAt": "2024-01-15T00:00:00",
"authors": [{"name": "New Author"}],
"tags": [{"name": "CV"}, {"name": "Diffusion"}],
"upvotes": 100,
}
}
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
return_value=[new_item],
):
res = await recrawl_single(db_session, sample_paper.arxiv_id)
assert res["updated"] is True
db_session.refresh(sample_paper)
assert sample_paper.title_en == "Updated Title"
assert sample_paper.abstract == "New abstract"
assert sample_paper.upvotes == 100
# authors 重建(原 Alice/Bob → New Author
assert sorted(a.name for a in sample_paper.authors) == ["New Author"]
# tags 重建(原 NLP/LLM → CV/Diffusion
assert sorted(t.tag for t in sample_paper.tags) == ["CV", "Diffusion"]
@pytest.mark.asyncio
async def test_not_in_daily(self, db_session, sample_paper):
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
return_value=[],
):
res = await recrawl_single(db_session, sample_paper.arxiv_id)
assert res["updated"] is False
assert res["reason"] == "not_in_daily"
assert "date" in res
class TestDispatchRecrawl:
@pytest.mark.asyncio
async def test_recrawl_one_via_run_job(self, db_session, sample_paper):
new_item = {
"paper": {
"id": sample_paper.arxiv_id,
"title": "Via Job",
"authors": [],
"tags": [],
"upvotes": 5,
}
}
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
return_value=[new_item],
):
job = create_job(
db_session,
"recrawl_one",
owner="test",
payload={"arxiv_id": sample_paper.arxiv_id},
)
result = await run_job(db_session, job.id)
assert result["updated"] is True
db_session.refresh(sample_paper)
assert sample_paper.title_en == "Via Job"
@pytest.mark.asyncio
async def test_recrawl_batch_via_run_job(self, db_session, sample_papers_range):
arxiv_ids = [p.arxiv_id for p in sample_papers_range[:2]]
items = [
{
"paper": {
"id": aid,
"title": "Batch " + aid,
"authors": [],
"tags": [],
"upvotes": 1,
}
}
for aid in arxiv_ids
]
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
side_effect=lambda d: items,
):
job = create_job(
db_session,
"recrawl_batch",
owner="test",
payload={"arxiv_ids": arxiv_ids},
)
result = await run_job(db_session, job.id)
assert result["updated"] == 2
assert result["skipped"] == 0
# ═══════════════════════════════════════════════════════════════════════
# 路由
# ═══════════════════════════════════════════════════════════════════════
class TestRoutes:
def test_jobs_page_renders(self, auth_client):
resp = auth_client.get("/admin/jobs")
assert resp.status_code == 200
assert "任务监控" in resp.text
def test_jobs_page_filters_by_status(self, auth_client, db_session):
_make_job(db_session, status=JobStatus.FAILED)
resp = auth_client.get("/admin/jobs?status=failed")
assert resp.status_code == 200
def test_export_csv(self, auth_client, sample_papers_range):
resp = auth_client.get("/admin/papers/export.csv")
assert resp.status_code == 200
assert "text/csv" in resp.headers["content-type"]
# UTF-8 BOM for Excel
assert resp.content.startswith(b"\xef\xbb\xbf")
# 表头 + 数据
assert "arxiv_id" in resp.text
assert "2401.10001" in resp.text
def test_export_csv_respects_filter(self, auth_client, sample_papers_range):
resp = auth_client.get("/admin/papers/export.csv?q=Paper%203")
assert resp.status_code == 200
assert "2401.10003" in resp.text
assert "2401.10001" not in resp.text
def test_rebuild_indexes_fts(self, auth_client, db_session, no_enqueue):
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "fts"})
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "queued"
assert len(data["job_ids"]) == 1
jobs = (
db_session.execute(select(Job).where(Job.type == "reindex_fts"))
.scalars()
.all()
)
assert len(jobs) == 1
def test_rebuild_indexes_both(self, auth_client, db_session, no_enqueue):
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "both"})
assert resp.status_code == 200
data = resp.json()
assert len(data["job_ids"]) == 2
def test_release_lock_route(self, auth_client, db_session):
lock = TaskLock(
task="crawl", lock_key="rt", status="running", acquired_at=utc_now()
)
db_session.add(lock)
db_session.commit()
resp = auth_client.post(f"/admin/locks/{lock.id}/release")
assert resp.status_code == 200
db_session.refresh(lock)
assert lock.status == "finished"
def test_paper_recrawl_route(
self, auth_client, sample_paper, db_session, no_enqueue
):
resp = auth_client.post(f"/admin/paper-recrawl/{sample_paper.arxiv_id}")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "queued"
jobs = (
db_session.execute(select(Job).where(Job.type == "recrawl_one"))
.scalars()
.all()
)
assert len(jobs) == 1
def test_batch_recrawl_route(
self, auth_client, sample_papers_range, db_session, no_enqueue
):
ids = [p.arxiv_id for p in sample_papers_range[:3]]
resp = auth_client.post(
"/admin/papers-batch-action", json={"action": "recrawl", "arxiv_ids": ids}
)
assert resp.status_code == 200
assert resp.json()["status"] == "queued"
jobs = (
db_session.execute(select(Job).where(Job.type == "recrawl_batch"))
.scalars()
.all()
)
assert len(jobs) == 1