85c4cfb9e8
- Add image_extractor, pdf_downloader, pi_client, trends services - Add shared utils module - Refactor summarizer, embedder, routes for cleaner separation - Update tests to match new service structure
640 lines
23 KiB
Python
640 lines
23 KiB
Python
"""Phase 4 管理和自动化测试 — cleaner、admin routes、scheduler。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import shutil
|
|
import time
|
|
from datetime import date, datetime, timezone
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy import select
|
|
|
|
from app.database import get_db
|
|
from app.config import settings
|
|
from app.models import (
|
|
CrawlLog,
|
|
DataDeleteJob,
|
|
Paper,
|
|
PaperAuthor,
|
|
PaperSummary,
|
|
PaperTag,
|
|
SummaryStatus,
|
|
TaskLock,
|
|
UserBookmark,
|
|
UserNote,
|
|
UserReadingStatus,
|
|
)
|
|
|
|
# ── Fixtures ────────────────────────────────────────────────────────────
|
|
|
|
ADMIN_TOKEN = "test-admin-token-12345"
|
|
|
|
|
|
@pytest.fixture
|
|
def admin_headers():
|
|
return {"Authorization": f"Bearer {ADMIN_TOKEN}"}
|
|
|
|
|
|
@pytest.fixture
|
|
def wrong_admin_headers():
|
|
return {"Authorization": "Bearer wrong-token"}
|
|
|
|
|
|
@pytest.fixture
|
|
def auth_client(client, monkeypatch):
|
|
"""带 admin token monkeypatch 的 TestClient。"""
|
|
monkeypatch.setattr(settings, "ADMIN_TOKEN", ADMIN_TOKEN)
|
|
return client
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_papers(db_session):
|
|
"""插入多篇不同日期的论文。"""
|
|
now = datetime.now(timezone.utc)
|
|
papers = []
|
|
for i, (arxiv_id, paper_date_str) in enumerate([
|
|
("2401.10001", "2024-01-10"),
|
|
("2401.10002", "2024-01-11"),
|
|
("2401.10003", "2024-01-12"),
|
|
("2401.10004", "2024-01-13"),
|
|
("2401.10005", "2024-01-14"),
|
|
]):
|
|
paper_date = date.fromisoformat(paper_date_str)
|
|
p = Paper(
|
|
arxiv_id=arxiv_id,
|
|
title_en=f"Test Paper {i+1}",
|
|
abstract=f"Abstract for paper {i+1}.",
|
|
paper_date=paper_date,
|
|
crawled_at=now,
|
|
upvotes=i * 10,
|
|
)
|
|
db_session.add(p)
|
|
db_session.flush()
|
|
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
|
|
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
|
|
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
|
|
# FTS5
|
|
import sqlalchemy
|
|
db_session.execute(
|
|
sqlalchemy.text(
|
|
"INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) "
|
|
"VALUES (:id, :title, :abstract, :authors, :tags)"
|
|
),
|
|
{"id": p.id, "title": p.title_en, "abstract": p.abstract,
|
|
"authors": f"Author {i+1}", "tags": f"Tag{i+1}"},
|
|
)
|
|
papers.append(p)
|
|
db_session.commit()
|
|
return papers
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_paper_with_user_data(db_session, sample_papers):
|
|
"""给第一篇论文添加用户数据(收藏、阅读状态、笔记)。"""
|
|
paper = sample_papers[0]
|
|
now = datetime.now(timezone.utc)
|
|
db_session.add(UserBookmark(paper_id=paper.id, created_at=now))
|
|
db_session.add(UserReadingStatus(paper_id=paper.id, status="read_summary", updated_at=now))
|
|
db_session.add(UserNote(
|
|
paper_id=paper.id,
|
|
content="My notes on this paper",
|
|
created_at=now,
|
|
updated_at=now,
|
|
))
|
|
db_session.commit()
|
|
return paper
|
|
|
|
|
|
@pytest.fixture
|
|
def tmp_data_dir(tmp_path):
|
|
"""创建临时 data 目录结构。"""
|
|
tmp_dir = tmp_path / "data" / "tmp"
|
|
papers_dir = tmp_path / "data" / "papers"
|
|
tmp_dir.mkdir(parents=True)
|
|
papers_dir.mkdir(parents=True)
|
|
return tmp_path / "data"
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
# Cleaner 服务测试
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestCleanupTmp:
|
|
"""app/services/cleaner.py — cleanup_tmp 测试。"""
|
|
|
|
def test_cleanup_removes_old_dirs(self, tmp_path, monkeypatch):
|
|
"""超过 24 小时的临时目录应被删除。"""
|
|
tmp_dir = tmp_path / "tmp"
|
|
tmp_dir.mkdir()
|
|
|
|
# 创建一个旧目录
|
|
old_dir = tmp_dir / "2401.00001"
|
|
old_dir.mkdir()
|
|
(old_dir / "paper.pdf").write_text("fake pdf")
|
|
|
|
# 修改目录时间为 25 小时前
|
|
old_mtime = time.time() - 25 * 3600
|
|
os.utime(old_dir, (old_mtime, old_mtime))
|
|
|
|
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
|
from app.services.cleaner import cleanup_tmp
|
|
result = cleanup_tmp()
|
|
|
|
assert result["scanned"] == 1
|
|
assert result["removed"] == 1
|
|
assert not old_dir.exists()
|
|
|
|
def test_cleanup_keeps_recent_dirs(self, tmp_path, monkeypatch):
|
|
"""24 小时内的临时目录应保留。"""
|
|
tmp_dir = tmp_path / "tmp"
|
|
tmp_dir.mkdir()
|
|
|
|
recent_dir = tmp_dir / "2401.00002"
|
|
recent_dir.mkdir()
|
|
(recent_dir / "paper.pdf").write_text("fake pdf")
|
|
|
|
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
|
from app.services.cleaner import cleanup_tmp
|
|
result = cleanup_tmp()
|
|
|
|
assert result["scanned"] == 1
|
|
assert result["removed"] == 0
|
|
assert recent_dir.exists()
|
|
|
|
def test_cleanup_empty_dir(self, tmp_path, monkeypatch):
|
|
"""data/tmp/ 不存在时安全返回。"""
|
|
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_path / "nonexistent")
|
|
from app.services.cleaner import cleanup_tmp
|
|
result = cleanup_tmp()
|
|
assert result["scanned"] == 0
|
|
assert result["removed"] == 0
|
|
|
|
def test_cleanup_mixed_ages(self, tmp_path, monkeypatch):
|
|
"""混合新旧目录时只删除旧的。"""
|
|
tmp_dir = tmp_path / "tmp"
|
|
tmp_dir.mkdir()
|
|
|
|
old_dir = tmp_dir / "2401.old"
|
|
old_dir.mkdir()
|
|
old_mtime = time.time() - 30 * 3600
|
|
os.utime(old_dir, (old_mtime, old_mtime))
|
|
|
|
recent_dir = tmp_dir / "2401.new"
|
|
recent_dir.mkdir()
|
|
|
|
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
|
from app.services.cleaner import cleanup_tmp
|
|
result = cleanup_tmp()
|
|
|
|
assert result["scanned"] == 2
|
|
assert result["removed"] == 1
|
|
assert not old_dir.exists()
|
|
assert recent_dir.exists()
|
|
|
|
|
|
class TestDeletePapersByDateRange:
|
|
"""app/services/cleaner.py — delete_papers_by_date_range 测试。"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_by_date_range(self, db_session, sample_papers):
|
|
"""删除指定日期范围的论文。"""
|
|
from app.services.cleaner import delete_papers_by_date_range
|
|
|
|
# 删除 1月11日 ~ 1月13日(3篇)
|
|
result = await delete_papers_by_date_range(
|
|
db_session,
|
|
date(2024, 1, 11),
|
|
date(2024, 1, 13),
|
|
)
|
|
|
|
assert result["deleted"] == 3
|
|
assert result["total"] == 3
|
|
assert result["status"] == "success"
|
|
|
|
# 确认数据库中只剩 2 篇
|
|
remaining = db_session.execute(select(Paper)).scalars().all()
|
|
assert len(remaining) == 2
|
|
dates = {p.paper_date for p in remaining}
|
|
assert dates == {date(2024, 1, 10), date(2024, 1, 14)}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_creates_job_record(self, db_session, sample_papers):
|
|
"""删除操作应创建 data_delete_jobs 记录。"""
|
|
from app.services.cleaner import delete_papers_by_date_range
|
|
|
|
await delete_papers_by_date_range(
|
|
db_session,
|
|
date(2024, 1, 10),
|
|
date(2024, 1, 14),
|
|
)
|
|
|
|
jobs = db_session.execute(select(DataDeleteJob)).scalars().all()
|
|
assert len(jobs) == 1
|
|
assert jobs[0].status == "success"
|
|
assert jobs[0].date_start == date(2024, 1, 10)
|
|
assert jobs[0].date_end == date(2024, 1, 14)
|
|
assert jobs[0].paper_count == 5
|
|
assert jobs[0].completed_at is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_creates_crawl_log(self, db_session, sample_papers):
|
|
"""删除操作应写入 crawl_logs。"""
|
|
from app.services.cleaner import delete_papers_by_date_range
|
|
|
|
await delete_papers_by_date_range(
|
|
db_session,
|
|
date(2024, 1, 10),
|
|
date(2024, 1, 14),
|
|
)
|
|
|
|
logs = db_session.execute(
|
|
select(CrawlLog).where(CrawlLog.task == "delete")
|
|
).scalars().all()
|
|
assert len(logs) == 1
|
|
assert logs[0].status == "success"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_cascade_user_data(self, db_session, sample_paper_with_user_data):
|
|
"""删除论文时应 cascade 删除关联的用户数据。"""
|
|
from app.services.cleaner import delete_papers_by_date_range
|
|
|
|
paper = sample_paper_with_user_data
|
|
# 确认用户数据存在
|
|
assert db_session.get(UserBookmark, db_session.execute(
|
|
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
|
|
).scalar_one_or_none().id if db_session.execute(
|
|
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
|
|
).scalar_one_or_none() else None) is not None or True
|
|
|
|
# 删除
|
|
result = await delete_papers_by_date_range(
|
|
db_session,
|
|
date(2024, 1, 10),
|
|
date(2024, 1, 10),
|
|
)
|
|
assert result["deleted"] == 1
|
|
|
|
# 确认用户数据被 cascade 删除
|
|
assert db_session.execute(
|
|
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
|
|
).scalar_one_or_none() is None
|
|
assert db_session.execute(
|
|
select(UserReadingStatus).where(UserReadingStatus.paper_id == paper.id)
|
|
).scalar_one_or_none() is None
|
|
assert db_session.execute(
|
|
select(UserNote).where(UserNote.paper_id == paper.id)
|
|
).scalar_one_or_none() is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_removes_fts(self, db_session, sample_papers):
|
|
"""删除论文时应同步删除 FTS5 索引。"""
|
|
import sqlalchemy
|
|
from app.services.cleaner import delete_papers_by_date_range
|
|
|
|
await delete_papers_by_date_range(
|
|
db_session,
|
|
date(2024, 1, 10),
|
|
date(2024, 1, 14),
|
|
)
|
|
|
|
# FTS5 应为空
|
|
rows = db_session.execute(
|
|
sqlalchemy.text("SELECT count(*) FROM papers_fts")
|
|
).scalar()
|
|
assert rows == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_removes_local_files(self, db_session, sample_papers, tmp_path, monkeypatch):
|
|
"""删除论文时应删除本地文件目录。"""
|
|
from app.services.cleaner import delete_papers_by_date_range
|
|
|
|
papers_dir = tmp_path / "papers"
|
|
papers_dir.mkdir()
|
|
(papers_dir / "2401.10001").mkdir()
|
|
(papers_dir / "2401.10001" / "meta.json").write_text("{}")
|
|
|
|
monkeypatch.setattr("app.services.cleaner.PAPERS_DIR", papers_dir)
|
|
|
|
result = await delete_papers_by_date_range(
|
|
db_session,
|
|
date(2024, 1, 10),
|
|
date(2024, 1, 10),
|
|
)
|
|
assert result["deleted"] == 1
|
|
assert not (papers_dir / "2401.10001").exists()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_empty_range(self, db_session, sample_papers):
|
|
"""日期范围内无论文时返回 0。"""
|
|
from app.services.cleaner import delete_papers_by_date_range
|
|
|
|
result = await delete_papers_by_date_range(
|
|
db_session,
|
|
date(2025, 1, 1),
|
|
date(2025, 1, 31),
|
|
)
|
|
assert result["total"] == 0
|
|
assert result["deleted"] == 0
|
|
assert result["status"] == "success"
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
# Admin Routes 测试
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestAdminAuth:
|
|
"""管理接口鉴权测试。"""
|
|
|
|
def test_no_token_returns_403(self, auth_client):
|
|
"""无 token 时请求管理接口应返回 403。"""
|
|
resp = auth_client.post("/admin/crawl")
|
|
assert resp.status_code in (403, 401)
|
|
|
|
def test_wrong_token_returns_401(self, auth_client, wrong_admin_headers):
|
|
"""错误 token 应返回 401。"""
|
|
resp = auth_client.post("/admin/crawl", headers=wrong_admin_headers)
|
|
assert resp.status_code == 401
|
|
|
|
def test_correct_token_accepted(self, auth_client, admin_headers):
|
|
"""正确 token 应被接受(crawl 可能会失败但不是 401)。"""
|
|
# mock crawl_daily 避免 API 调用
|
|
with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl:
|
|
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
|
|
resp = auth_client.post("/admin/crawl", headers=admin_headers)
|
|
assert resp.status_code != 401
|
|
|
|
|
|
class TestAdminCrawl:
|
|
"""POST /admin/crawl 测试。"""
|
|
|
|
def test_crawl_default_today(self, auth_client, admin_headers):
|
|
"""不指定日期时默认抓取今天。"""
|
|
with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl:
|
|
mock_crawl.return_value = {"found": 5, "new": 3, "status": "success"}
|
|
resp = auth_client.post("/admin/crawl", headers=admin_headers)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["status"] == "success"
|
|
# 验证调用了 crawl_daily
|
|
mock_crawl.assert_called_once()
|
|
|
|
def test_crawl_specific_date(self, auth_client, admin_headers):
|
|
"""指定日期抓取。"""
|
|
with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl:
|
|
mock_crawl.return_value = {"found": 2, "new": 1, "status": "success"}
|
|
resp = auth_client.post("/admin/crawl?date=2024-01-15", headers=admin_headers)
|
|
assert resp.status_code == 200
|
|
mock_crawl.assert_called_once()
|
|
call_args = mock_crawl.call_args
|
|
assert call_args[0][1] == "2024-01-15"
|
|
|
|
|
|
class TestAdminCleanup:
|
|
"""POST /admin/cleanup 测试。"""
|
|
|
|
def test_cleanup_returns_stats(self, auth_client, admin_headers):
|
|
"""清理应返回统计信息。"""
|
|
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
|
|
mock_cleanup.return_value = {"scanned": 3, "removed": 1, "errors": []}
|
|
resp = auth_client.post("/admin/cleanup", headers=admin_headers)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["scanned"] == 3
|
|
assert data["removed"] == 1
|
|
|
|
def test_cleanup_writes_log(self, auth_client, admin_headers, db_session):
|
|
"""清理应写入 crawl_logs。"""
|
|
with patch("app.routes.admin.cleanup_tmp") as mock_cleanup:
|
|
mock_cleanup.return_value = {"scanned": 0, "removed": 0, "errors": []}
|
|
auth_client.post("/admin/cleanup", headers=admin_headers)
|
|
|
|
logs = db_session.execute(
|
|
select(CrawlLog).where(CrawlLog.task == "cleanup")
|
|
).scalars().all()
|
|
assert len(logs) >= 1
|
|
assert logs[-1].status == "success"
|
|
|
|
|
|
class TestAdminDelete:
|
|
"""POST /admin/delete 测试。"""
|
|
|
|
def test_delete_requires_confirm(self, auth_client, admin_headers):
|
|
"""confirm 不是 'DELETE' 时应返回 422。"""
|
|
resp = auth_client.post(
|
|
"/admin/delete",
|
|
json={
|
|
"date_start": "2024-01-10",
|
|
"date_end": "2024-01-12",
|
|
"include_notes": True,
|
|
"confirm": "WRONG",
|
|
},
|
|
headers=admin_headers,
|
|
)
|
|
assert resp.status_code == 422
|
|
|
|
def test_delete_with_confirm(self, auth_client, admin_headers, db_session, sample_papers):
|
|
"""confirm='DELETE' 时应执行删除。"""
|
|
resp = auth_client.post(
|
|
"/admin/delete",
|
|
json={
|
|
"date_start": "2024-01-10",
|
|
"date_end": "2024-01-12",
|
|
"include_notes": True,
|
|
"confirm": "DELETE",
|
|
},
|
|
headers=admin_headers,
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["deleted"] == 3
|
|
|
|
def test_delete_invalid_date_range(self, auth_client, admin_headers):
|
|
"""date_start > date_end 应返回 400。"""
|
|
resp = auth_client.post(
|
|
"/admin/delete",
|
|
json={
|
|
"date_start": "2024-01-15",
|
|
"date_end": "2024-01-10",
|
|
"confirm": "DELETE",
|
|
},
|
|
headers=admin_headers,
|
|
)
|
|
assert resp.status_code == 400
|
|
|
|
def test_delete_without_confirm_field(self, auth_client, admin_headers):
|
|
"""缺少 confirm 字段应返回 422。"""
|
|
resp = auth_client.post(
|
|
"/admin/delete",
|
|
json={
|
|
"date_start": "2024-01-10",
|
|
"date_end": "2024-01-12",
|
|
},
|
|
headers=admin_headers,
|
|
)
|
|
assert resp.status_code == 422
|
|
|
|
|
|
class TestAdminLogs:
|
|
"""GET /admin/logs 测试。"""
|
|
|
|
def test_logs_returns_page(self, auth_client, admin_headers):
|
|
"""应返回管理日志页面。"""
|
|
resp = auth_client.get("/admin/logs", headers=admin_headers)
|
|
assert resp.status_code == 200
|
|
assert "text/html" in resp.headers.get("content-type", "")
|
|
|
|
def test_logs_requires_auth(self, auth_client):
|
|
"""日志页面需要鉴权。"""
|
|
resp = auth_client.get("/admin/logs")
|
|
assert resp.status_code in (403, 401)
|
|
|
|
def test_logs_contains_data(self, auth_client, admin_headers, db_session, sample_papers):
|
|
"""日志页面应包含日志数据。"""
|
|
# 先创建一条日志
|
|
now = datetime.now(timezone.utc)
|
|
db_session.add(CrawlLog(
|
|
task="crawl", status="success", started_at=now, completed_at=now,
|
|
))
|
|
db_session.commit()
|
|
|
|
resp = auth_client.get("/admin/logs", headers=admin_headers)
|
|
assert resp.status_code == 200
|
|
assert "crawl" in resp.text.lower() or "日志" in resp.text
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
# Scheduler 测试
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestScheduler:
|
|
"""app/services/scheduler.py 测试。"""
|
|
|
|
def test_scheduler_disabled_by_default(self, monkeypatch):
|
|
"""SCHEDULER_ENABLED=false 时不应启动调度器。"""
|
|
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", False)
|
|
from app.services.scheduler import start_scheduler
|
|
# 重置模块级变量
|
|
import app.services.scheduler as sched_mod
|
|
sched_mod._scheduler = None
|
|
|
|
result = start_scheduler()
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scheduler_start_stop(self, monkeypatch):
|
|
"""调度器应能正常启动和停止。"""
|
|
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", True)
|
|
monkeypatch.setattr(settings, "APP_WORKERS", 1)
|
|
import app.services.scheduler as sched_mod
|
|
sched_mod._scheduler = None
|
|
|
|
from app.services.scheduler import start_scheduler, stop_scheduler
|
|
scheduler = start_scheduler()
|
|
assert scheduler is not None
|
|
|
|
# 验证 job 已添加
|
|
jobs = scheduler.get_jobs()
|
|
assert len(jobs) >= 1
|
|
assert jobs[0].id == "daily_pipeline"
|
|
|
|
stop_scheduler()
|
|
assert sched_mod._scheduler is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scheduler_warns_multi_worker(self, monkeypatch, caplog):
|
|
"""APP_WORKERS > 1 时应打印警告。"""
|
|
import logging
|
|
monkeypatch.setattr(settings, "SCHEDULER_ENABLED", True)
|
|
monkeypatch.setattr(settings, "APP_WORKERS", 4)
|
|
import app.services.scheduler as sched_mod
|
|
sched_mod._scheduler = None
|
|
|
|
from app.services.scheduler import start_scheduler, stop_scheduler
|
|
with caplog.at_level(logging.WARNING):
|
|
scheduler = start_scheduler()
|
|
|
|
assert scheduler is not None
|
|
assert any("APP_WORKERS" in r.message for r in caplog.records)
|
|
|
|
stop_scheduler()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_daily_pipeline_lock_prevents_reentry(self, db_session):
|
|
"""pipeline 使用 task_locks 防重入。"""
|
|
now = datetime.now(timezone.utc)
|
|
lock = TaskLock(
|
|
task="scheduler",
|
|
lock_key="pipeline-2024-01-15",
|
|
status="running",
|
|
owner="test",
|
|
acquired_at=now,
|
|
)
|
|
db_session.add(lock)
|
|
db_session.commit()
|
|
|
|
# 第二次获取锁应失败
|
|
lock2 = TaskLock(
|
|
task="scheduler",
|
|
lock_key="pipeline-2024-01-15",
|
|
status="running",
|
|
owner="test2",
|
|
acquired_at=now,
|
|
)
|
|
db_session.add(lock2)
|
|
with pytest.raises(Exception):
|
|
db_session.commit()
|
|
db_session.rollback()
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
# TaskLock 集成测试
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestTaskLocks:
|
|
"""task_locks 防重入机制测试。"""
|
|
|
|
def test_unique_running_lock(self, db_session):
|
|
"""同一 task + lock_key 只能有一个 running 锁。"""
|
|
now = datetime.now(timezone.utc)
|
|
lock1 = TaskLock(
|
|
task="crawl", lock_key="2024-01-15",
|
|
status="running", owner="test1", acquired_at=now,
|
|
)
|
|
db_session.add(lock1)
|
|
db_session.commit()
|
|
|
|
lock2 = TaskLock(
|
|
task="crawl", lock_key="2024-01-15",
|
|
status="running", owner="test2", acquired_at=now,
|
|
)
|
|
db_session.add(lock2)
|
|
with pytest.raises(Exception):
|
|
db_session.commit()
|
|
db_session.rollback()
|
|
|
|
def test_released_lock_allows_new(self, db_session):
|
|
"""已释放的锁允许新的 running 锁。"""
|
|
now = datetime.now(timezone.utc)
|
|
lock1 = TaskLock(
|
|
task="crawl", lock_key="2024-01-16",
|
|
status="finished", owner="test1",
|
|
acquired_at=now, released_at=now,
|
|
)
|
|
db_session.add(lock1)
|
|
db_session.commit()
|
|
|
|
lock2 = TaskLock(
|
|
task="crawl", lock_key="2024-01-16",
|
|
status="running", owner="test2", acquired_at=now,
|
|
)
|
|
db_session.add(lock2)
|
|
db_session.commit() # 应成功
|