feat: add concurrency safety, caption detection, admin enhancements, and performance improvements
This commit is contained in:
@@ -24,6 +24,26 @@ from app.models import (
|
||||
from app.utils import utc_now
|
||||
|
||||
|
||||
# ── ChromaDB 隔离(autouse,所有测试)──────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_chroma(monkeypatch, tmp_path):
|
||||
"""所有测试把 ChromaDB 隔离到临时目录 + 重置单例,绝不污染 data/chroma。
|
||||
|
||||
与内存 DB 隔离同理:summarize 后处理经真实 _maybe_index_chroma → index_paper
|
||||
写入,不隔离会把测试夹具(2401.*)泄漏到生产 data/chroma,污染语义搜索。
|
||||
每个测试前重置 _chroma 单例,确保 CHROMA_DIR 指向本次 tmp。
|
||||
"""
|
||||
import app.services.embedder as emb
|
||||
from app.config import settings
|
||||
|
||||
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||||
emb._chroma.reset()
|
||||
yield
|
||||
emb._chroma.reset()
|
||||
|
||||
|
||||
# ── 内存数据库 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,380 @@
|
||||
"""管理后台新功能测试 — 任务监控、锁释放、重抓、失败分布、配置、导出、重建索引。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.models import Job, JobStatus, SummaryState, SummaryStatus, TaskLock
|
||||
from app.services import admin as admin_svc
|
||||
from app.services.crawler import recrawl_single
|
||||
from app.services.jobs import create_job, run_job
|
||||
from app.utils import utc_now
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_enqueue(monkeypatch):
|
||||
"""禁用路由层的 enqueue_job,阻止 background task 在测试中真实执行。"""
|
||||
from app.routes import admin as admin_route
|
||||
|
||||
monkeypatch.setattr(admin_route, "enqueue_job", lambda *a, **k: None)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 任务监控
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _make_job(db_session, *, type="crawl_daily", status=JobStatus.QUEUED, owner="t"):
|
||||
job = Job(
|
||||
type=type,
|
||||
status=status,
|
||||
owner=owner,
|
||||
payload_json="{}",
|
||||
created_at=utc_now(),
|
||||
)
|
||||
db_session.add(job)
|
||||
db_session.commit()
|
||||
return job
|
||||
|
||||
|
||||
def test_query_jobs_filter_and_pagination(db_session):
|
||||
for i in range(25):
|
||||
_make_job(db_session, status=JobStatus.SUCCESS)
|
||||
for i in range(5):
|
||||
_make_job(db_session, status=JobStatus.FAILED)
|
||||
|
||||
# 无过滤:分页
|
||||
page1, total = admin_svc.query_jobs(db_session, page=1, per_page=20)
|
||||
assert total == 30
|
||||
assert len(page1) == 20
|
||||
page2, _ = admin_svc.query_jobs(db_session, page=2, per_page=20)
|
||||
assert len(page2) == 10
|
||||
|
||||
# status 过滤
|
||||
failed, ftotal = admin_svc.query_jobs(db_session, status="failed", per_page=50)
|
||||
assert ftotal == 5
|
||||
assert len(failed) == 5
|
||||
assert all(j["status"] == "failed" for j in failed)
|
||||
|
||||
# type 过滤
|
||||
_make_job(db_session, type="reindex_fts", status=JobStatus.QUEUED)
|
||||
typed, ttotal = admin_svc.query_jobs(db_session, job_type="reindex_fts")
|
||||
assert ttotal == 1
|
||||
assert typed[0]["type"] == "reindex_fts"
|
||||
|
||||
|
||||
def test_serialize_job_includes_duration(db_session):
|
||||
job = _make_job(db_session, status=JobStatus.SUCCESS)
|
||||
job.started_at = utc_now()
|
||||
job.completed_at = utc_now()
|
||||
db_session.commit()
|
||||
serialized = admin_svc.serialize_job(job)
|
||||
assert serialized["duration_seconds"] is not None
|
||||
assert serialized["duration_seconds"] >= 0
|
||||
|
||||
|
||||
def test_serialize_job_running_without_completed(db_session):
|
||||
# 运行中的 job:completed_at=None,started_at 经 db 读回为 naive UTC,
|
||||
# 不能与 aware 的 utc_now() 直接相减(回归测试)。
|
||||
job = _make_job(db_session, status=JobStatus.RUNNING)
|
||||
job.started_at = utc_now()
|
||||
db_session.commit()
|
||||
serialized = admin_svc.serialize_job(job)
|
||||
assert serialized["duration_seconds"] is not None
|
||||
assert serialized["duration_seconds"] >= 0
|
||||
|
||||
|
||||
def test_get_job_status_counts(db_session):
|
||||
_make_job(db_session, status=JobStatus.QUEUED)
|
||||
_make_job(db_session, status=JobStatus.QUEUED)
|
||||
_make_job(db_session, status=JobStatus.RUNNING)
|
||||
counts = admin_svc.get_job_status_counts(db_session)
|
||||
assert counts.get("queued") == 2
|
||||
assert counts.get("running") == 1
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 锁释放
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def test_force_release_lock(db_session):
|
||||
running = TaskLock(
|
||||
task="crawl", lock_key="k1", status="running", acquired_at=utc_now()
|
||||
)
|
||||
stale = TaskLock(task="crawl", lock_key="k2", status="stale", acquired_at=utc_now())
|
||||
finished = TaskLock(
|
||||
task="crawl", lock_key="k3", status="finished", acquired_at=utc_now()
|
||||
)
|
||||
db_session.add_all([running, stale, finished])
|
||||
db_session.commit()
|
||||
|
||||
assert admin_svc.force_release_lock(db_session, running.id) is True
|
||||
assert admin_svc.force_release_lock(db_session, stale.id) is True
|
||||
# finished 的不应被再次释放
|
||||
assert admin_svc.force_release_lock(db_session, finished.id) is False
|
||||
# 不存在的 id
|
||||
assert admin_svc.force_release_lock(db_session, 999999) is False
|
||||
|
||||
db_session.refresh(running)
|
||||
db_session.refresh(stale)
|
||||
db_session.refresh(finished)
|
||||
assert running.status == "finished"
|
||||
assert running.released_at is not None
|
||||
assert stale.status == "finished"
|
||||
assert finished.status == "finished"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 失败原因分布
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def test_get_failure_breakdown(db_session, sample_papers_range):
|
||||
statuses = (
|
||||
db_session.execute(select(SummaryStatus).order_by(SummaryStatus.id))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
statuses[0].status = SummaryState.FAILED
|
||||
statuses[0].error_type = "pdf_download_failed"
|
||||
statuses[1].status = SummaryState.FAILED
|
||||
statuses[1].error_type = "timeout"
|
||||
statuses[2].status = SummaryState.PERMANENT_FAILURE
|
||||
statuses[2].error_type = None # 归 unknown
|
||||
db_session.commit()
|
||||
|
||||
breakdown = admin_svc.get_failure_breakdown(db_session)
|
||||
by_type = {b["error_type"]: b["count"] for b in breakdown}
|
||||
assert by_type.get("pdf_download_failed") == 1
|
||||
assert by_type.get("timeout") == 1
|
||||
assert by_type.get("unknown") == 1
|
||||
# 降序
|
||||
counts = [b["count"] for b in breakdown]
|
||||
assert counts == sorted(counts, reverse=True)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 配置概览
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def test_get_config_overview_no_secrets():
|
||||
cfg = admin_svc.get_config_overview()
|
||||
assert "summary_backend" in cfg
|
||||
assert "schedule_time" in cfg
|
||||
assert "api_key_configured" in cfg # 只标是否配置,不显值
|
||||
text = str(cfg)
|
||||
# 不应泄露默认密钥值
|
||||
assert "change-me" not in text
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 单篇/批量重抓
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestRecrawl:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self, db_session):
|
||||
res = await recrawl_single(db_session, "9999.99999")
|
||||
assert res["updated"] is False
|
||||
assert res["reason"] == "not_found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_full_metadata(self, db_session, sample_paper):
|
||||
new_item = {
|
||||
"paper": {
|
||||
"id": sample_paper.arxiv_id,
|
||||
"title": "Updated Title",
|
||||
"abstract": "New abstract",
|
||||
"publishedAt": "2024-01-15T00:00:00",
|
||||
"authors": [{"name": "New Author"}],
|
||||
"tags": [{"name": "CV"}, {"name": "Diffusion"}],
|
||||
"upvotes": 100,
|
||||
}
|
||||
}
|
||||
with patch(
|
||||
"app.services.crawler.fetch_daily",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[new_item],
|
||||
):
|
||||
res = await recrawl_single(db_session, sample_paper.arxiv_id)
|
||||
|
||||
assert res["updated"] is True
|
||||
db_session.refresh(sample_paper)
|
||||
assert sample_paper.title_en == "Updated Title"
|
||||
assert sample_paper.abstract == "New abstract"
|
||||
assert sample_paper.upvotes == 100
|
||||
# authors 重建(原 Alice/Bob → New Author)
|
||||
assert sorted(a.name for a in sample_paper.authors) == ["New Author"]
|
||||
# tags 重建(原 NLP/LLM → CV/Diffusion)
|
||||
assert sorted(t.tag for t in sample_paper.tags) == ["CV", "Diffusion"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_in_daily(self, db_session, sample_paper):
|
||||
with patch(
|
||||
"app.services.crawler.fetch_daily",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
res = await recrawl_single(db_session, sample_paper.arxiv_id)
|
||||
assert res["updated"] is False
|
||||
assert res["reason"] == "not_in_daily"
|
||||
assert "date" in res
|
||||
|
||||
|
||||
class TestDispatchRecrawl:
|
||||
@pytest.mark.asyncio
|
||||
async def test_recrawl_one_via_run_job(self, db_session, sample_paper):
|
||||
new_item = {
|
||||
"paper": {
|
||||
"id": sample_paper.arxiv_id,
|
||||
"title": "Via Job",
|
||||
"authors": [],
|
||||
"tags": [],
|
||||
"upvotes": 5,
|
||||
}
|
||||
}
|
||||
with patch(
|
||||
"app.services.crawler.fetch_daily",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[new_item],
|
||||
):
|
||||
job = create_job(
|
||||
db_session,
|
||||
"recrawl_one",
|
||||
owner="test",
|
||||
payload={"arxiv_id": sample_paper.arxiv_id},
|
||||
)
|
||||
result = await run_job(db_session, job.id)
|
||||
|
||||
assert result["updated"] is True
|
||||
db_session.refresh(sample_paper)
|
||||
assert sample_paper.title_en == "Via Job"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recrawl_batch_via_run_job(self, db_session, sample_papers_range):
|
||||
arxiv_ids = [p.arxiv_id for p in sample_papers_range[:2]]
|
||||
items = [
|
||||
{
|
||||
"paper": {
|
||||
"id": aid,
|
||||
"title": "Batch " + aid,
|
||||
"authors": [],
|
||||
"tags": [],
|
||||
"upvotes": 1,
|
||||
}
|
||||
}
|
||||
for aid in arxiv_ids
|
||||
]
|
||||
with patch(
|
||||
"app.services.crawler.fetch_daily",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=lambda d: items,
|
||||
):
|
||||
job = create_job(
|
||||
db_session,
|
||||
"recrawl_batch",
|
||||
owner="test",
|
||||
payload={"arxiv_ids": arxiv_ids},
|
||||
)
|
||||
result = await run_job(db_session, job.id)
|
||||
|
||||
assert result["updated"] == 2
|
||||
assert result["skipped"] == 0
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 路由
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestRoutes:
|
||||
def test_jobs_page_renders(self, auth_client):
|
||||
resp = auth_client.get("/admin/jobs")
|
||||
assert resp.status_code == 200
|
||||
assert "任务监控" in resp.text
|
||||
|
||||
def test_jobs_page_filters_by_status(self, auth_client, db_session):
|
||||
_make_job(db_session, status=JobStatus.FAILED)
|
||||
resp = auth_client.get("/admin/jobs?status=failed")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_export_csv(self, auth_client, sample_papers_range):
|
||||
resp = auth_client.get("/admin/papers/export.csv")
|
||||
assert resp.status_code == 200
|
||||
assert "text/csv" in resp.headers["content-type"]
|
||||
# UTF-8 BOM for Excel
|
||||
assert resp.content.startswith(b"\xef\xbb\xbf")
|
||||
# 表头 + 数据
|
||||
assert "arxiv_id" in resp.text
|
||||
assert "2401.10001" in resp.text
|
||||
|
||||
def test_export_csv_respects_filter(self, auth_client, sample_papers_range):
|
||||
resp = auth_client.get("/admin/papers/export.csv?q=Paper%203")
|
||||
assert resp.status_code == 200
|
||||
assert "2401.10003" in resp.text
|
||||
assert "2401.10001" not in resp.text
|
||||
|
||||
def test_rebuild_indexes_fts(self, auth_client, db_session, no_enqueue):
|
||||
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "fts"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "queued"
|
||||
assert len(data["job_ids"]) == 1
|
||||
jobs = (
|
||||
db_session.execute(select(Job).where(Job.type == "reindex_fts"))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
assert len(jobs) == 1
|
||||
|
||||
def test_rebuild_indexes_both(self, auth_client, db_session, no_enqueue):
|
||||
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "both"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["job_ids"]) == 2
|
||||
|
||||
def test_release_lock_route(self, auth_client, db_session):
|
||||
lock = TaskLock(
|
||||
task="crawl", lock_key="rt", status="running", acquired_at=utc_now()
|
||||
)
|
||||
db_session.add(lock)
|
||||
db_session.commit()
|
||||
resp = auth_client.post(f"/admin/locks/{lock.id}/release")
|
||||
assert resp.status_code == 200
|
||||
db_session.refresh(lock)
|
||||
assert lock.status == "finished"
|
||||
|
||||
def test_paper_recrawl_route(
|
||||
self, auth_client, sample_paper, db_session, no_enqueue
|
||||
):
|
||||
resp = auth_client.post(f"/admin/paper-recrawl/{sample_paper.arxiv_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "queued"
|
||||
jobs = (
|
||||
db_session.execute(select(Job).where(Job.type == "recrawl_one"))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
assert len(jobs) == 1
|
||||
|
||||
def test_batch_recrawl_route(
|
||||
self, auth_client, sample_papers_range, db_session, no_enqueue
|
||||
):
|
||||
ids = [p.arxiv_id for p in sample_papers_range[:3]]
|
||||
resp = auth_client.post(
|
||||
"/admin/papers-batch-action", json={"action": "recrawl", "arxiv_ids": ids}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "queued"
|
||||
jobs = (
|
||||
db_session.execute(select(Job).where(Job.type == "recrawl_batch"))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
assert len(jobs) == 1
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
@@ -154,3 +156,84 @@ class TestEmbeddingApi:
|
||||
)
|
||||
result = emb._get_embedding("test")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 并发安全:init() 双重检查锁 + 集合访问串行化
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestEmbedderConcurrency:
|
||||
"""后处理经 asyncio.to_thread 多 worker 并发调 index_paper 的安全性。"""
|
||||
|
||||
def test_init_serialized_under_concurrency(self, monkeypatch, tmp_path):
|
||||
"""并发 init() 只调一次 PersistentClient(chromadb SharedSystemClient 缓存竞争修复)。
|
||||
|
||||
复现崩坏条件:10 线程同时 init(),fake PersistentClient 故意 sleep 拉长建连窗口。
|
||||
修复前会有多线程同时进入 _create_system_if_not_exists → 并发 mutate 类级缓存;
|
||||
修复后(双重检查锁)只有抢到锁的那个线程建连。
|
||||
"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
|
||||
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||||
import app.services.embedder as emb
|
||||
|
||||
emb._chroma.reset()
|
||||
|
||||
counter = {"n": 0}
|
||||
counter_lock = threading.Lock()
|
||||
|
||||
def fake_persistent_client(path):
|
||||
with counter_lock:
|
||||
counter["n"] += 1
|
||||
time.sleep(0.05) # 拉长建连窗口,放大并发竞争
|
||||
client = MagicMock()
|
||||
client.get_collection.side_effect = Exception(
|
||||
"not exist"
|
||||
) # 触发 create 路径
|
||||
client.create_collection.return_value = MagicMock()
|
||||
return client
|
||||
|
||||
with patch("chromadb.PersistentClient", side_effect=fake_persistent_client):
|
||||
threads = [threading.Thread(target=emb._chroma.init) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert counter["n"] == 1, f"PersistentClient 应只调一次,实际 {counter['n']}"
|
||||
assert emb._chroma._client is not None
|
||||
emb._chroma.reset()
|
||||
|
||||
def test_index_paper_concurrent_no_error(self, monkeypatch, tmp_path):
|
||||
"""并发 index_paper:embedding 锁外并行,集合写入串行化,全部成功。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
|
||||
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||||
import app.services.embedder as emb
|
||||
|
||||
emb._chroma.reset()
|
||||
# 跳过 init,直接注入 mock collection
|
||||
emb._chroma._client = MagicMock()
|
||||
col = MagicMock()
|
||||
col.count.return_value = 0
|
||||
emb._chroma._collection = col
|
||||
|
||||
with patch.object(emb, "_get_embedding", return_value=[0.1, 0.2, 0.3]):
|
||||
errors: list[BaseException] = []
|
||||
|
||||
def worker(i: int) -> None:
|
||||
try:
|
||||
emb.index_paper(
|
||||
f"id-{i}", {"arxiv_id": f"id-{i}", "title_zh": f"标题{i}"}
|
||||
)
|
||||
except BaseException as exc: # noqa: BLE001 — 收集所有错误
|
||||
errors.append(exc)
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert errors == []
|
||||
assert col.upsert.call_count == 10
|
||||
emb._chroma.reset()
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
@@ -166,10 +168,10 @@ class TestClassMapping:
|
||||
def test_table(self):
|
||||
assert _map_class_to_boxclass(5, {5: "table"}) == "table"
|
||||
|
||||
def test_caption_ignored(self):
|
||||
def test_caption_classes(self):
|
||||
names = {4: "figure_caption", 6: "table_caption"}
|
||||
assert _map_class_to_boxclass(4, names) is None
|
||||
assert _map_class_to_boxclass(6, names) is None
|
||||
assert _map_class_to_boxclass(4, names) == "figure_caption"
|
||||
assert _map_class_to_boxclass(6, names) == "table_caption"
|
||||
|
||||
def test_other_classes_ignored(self):
|
||||
names = {0: "title", 1: "plain text", 2: "abandon", 8: "isolate_formula"}
|
||||
@@ -240,9 +242,11 @@ class TestPostprocessOutput:
|
||||
class TestDetectPage:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_detector(self):
|
||||
"""每个测试前重置模块级单例,避免复用上个测试的 mock session。"""
|
||||
"""每个测试前重建单例(带新锁 + 空 session),避免复用上个测试的 mock session。"""
|
||||
mod._LayoutDetector.reset_instance()
|
||||
mod._detector = mod._LayoutDetector()
|
||||
yield
|
||||
mod._LayoutDetector.reset_instance()
|
||||
mod._detector = mod._LayoutDetector()
|
||||
|
||||
@staticmethod
|
||||
@@ -338,6 +342,20 @@ class TestDetectPage:
|
||||
assert len(boxes) == 1
|
||||
assert boxes[0].boxclass == "table"
|
||||
|
||||
def test_returns_caption_box_with_small_height(self, monkeypatch, tmp_path):
|
||||
names = {4: "figure_caption"}
|
||||
sess, (pw, ph) = self._build_mock_session(
|
||||
595, 842, [(4, 100, 405, 300, 417, 0.9)], names
|
||||
)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
boxes = detect_page_layout(page)
|
||||
|
||||
assert len(boxes) == 1
|
||||
assert boxes[0].boxclass == "figure_caption"
|
||||
assert boxes[0].y1 - boxes[0].y0 == pytest.approx(12, abs=1.0)
|
||||
|
||||
def test_filters_low_confidence(self, monkeypatch, tmp_path):
|
||||
names = {3: "figure"}
|
||||
# conf=0.1 < LAYOUT_THRESHOLD(0.2) → 过滤
|
||||
@@ -401,3 +419,141 @@ class TestDetectPage:
|
||||
boxes = detect_page_layout(page)
|
||||
assert len(boxes) == 1
|
||||
assert boxes[0].boxclass == "picture"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 并发安全:锁串行化推理 + 单例 session 只初始化一次
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestDetectPageConcurrency:
|
||||
"""锁包裹整段 detect_page 后,并发调用的安全性。"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_detector(self):
|
||||
"""重建单例(带新锁),避免跨测试锁状态污染。"""
|
||||
mod._LayoutDetector.reset_instance()
|
||||
mod._detector = mod._LayoutDetector()
|
||||
yield
|
||||
mod._LayoutDetector.reset_instance()
|
||||
mod._detector = mod._LayoutDetector()
|
||||
|
||||
@staticmethod
|
||||
def _build_mock_session(page_w, page_h, boxes, names):
|
||||
"""同 TestDetectPage._build_mock_session,额外返回 fake_output 供 side_effect。"""
|
||||
ratio = _compute_render_geometry(page_w, page_h, IMGSZ)
|
||||
pix_w, pix_h = round(page_w * ratio), round(page_h * ratio)
|
||||
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
|
||||
rows = []
|
||||
for cls_id, x0, y0, x1, y1, conf in boxes:
|
||||
rows.append(
|
||||
[
|
||||
x0 * ratio + dw,
|
||||
y0 * ratio + dh,
|
||||
x1 * ratio + dw,
|
||||
y1 * ratio + dh,
|
||||
conf,
|
||||
cls_id,
|
||||
]
|
||||
)
|
||||
fake_output = (
|
||||
np.array([rows], dtype=np.float32)
|
||||
if rows
|
||||
else np.zeros((1, 0, 6), dtype=np.float32)
|
||||
)
|
||||
sess = MagicMock()
|
||||
inp = MagicMock()
|
||||
inp.name = "images"
|
||||
sess.get_inputs.return_value = [inp]
|
||||
sess.run.return_value = [fake_output]
|
||||
sess.get_providers.return_value = ["CPUExecutionProvider"]
|
||||
meta = MagicMock()
|
||||
meta.custom_metadata_map = {
|
||||
"names": json.dumps({str(k): v for k, v in names.items()})
|
||||
}
|
||||
sess.get_modelmeta.return_value = meta
|
||||
return sess, (pix_w, pix_h), fake_output
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_page(page_w, page_h, pix_w, pix_h):
|
||||
pix = MagicMock()
|
||||
pix.width = pix_w
|
||||
pix.height = pix_h
|
||||
pix.n = 3
|
||||
pix.samples = bytes([128] * (pix_w * pix_h * 3))
|
||||
page = MagicMock()
|
||||
page.rect.width = page_w
|
||||
page.rect.height = page_h
|
||||
page.get_pixmap.return_value = pix
|
||||
return page
|
||||
|
||||
def _setup(self, monkeypatch, tmp_path, sess):
|
||||
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
|
||||
(tmp_path / "m.onnx").write_bytes(b"x")
|
||||
monkeypatch.setattr(ort, "InferenceSession", lambda *a, **kw: sess)
|
||||
|
||||
def test_detect_page_serializes_concurrent_calls(self, monkeypatch, tmp_path):
|
||||
"""多线程并发调 detect_page_layout,session.run 临界区同时只有一个。"""
|
||||
sess, (pw, ph), fake_output = self._build_mock_session(
|
||||
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
|
||||
)
|
||||
in_critical = 0
|
||||
max_concurrent = 0
|
||||
counter_lock = threading.Lock()
|
||||
|
||||
def counting_run(*args, **kwargs):
|
||||
nonlocal in_critical, max_concurrent
|
||||
with counter_lock:
|
||||
in_critical += 1
|
||||
max_concurrent = max(max_concurrent, in_critical)
|
||||
time.sleep(0.02) # 放大竞争窗口,让并发线程有机会重叠
|
||||
try:
|
||||
return [fake_output]
|
||||
finally:
|
||||
with counter_lock:
|
||||
in_critical -= 1
|
||||
|
||||
sess.run.side_effect = counting_run
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
|
||||
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(8)]
|
||||
threads = [
|
||||
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# 锁生效 → 临界区同时只有一个;不加锁时此值会 > 1(回归保护)
|
||||
assert max_concurrent == 1
|
||||
|
||||
def test_session_created_once_under_concurrency(self, monkeypatch, tmp_path):
|
||||
"""多线程并发首次调用,InferenceSession 只创建一次(锁间接保护 _init_session)。"""
|
||||
sess, (pw, ph), _fake_output = self._build_mock_session(
|
||||
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
|
||||
)
|
||||
create_count = 0
|
||||
create_lock = threading.Lock()
|
||||
|
||||
def counting_init(*args, **kwargs):
|
||||
nonlocal create_count
|
||||
with create_lock:
|
||||
create_count += 1
|
||||
time.sleep(0.02) # 放大窗口,让并发首调都来抢
|
||||
return sess
|
||||
|
||||
monkeypatch.setattr(ort, "InferenceSession", counting_init)
|
||||
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
|
||||
(tmp_path / "m.onnx").write_bytes(b"x")
|
||||
|
||||
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(6)]
|
||||
threads = [
|
||||
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert create_count == 1
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pymupdf
|
||||
|
||||
from app.services import pdf_image_extractor as mod
|
||||
from app.services.layout_detector import LayoutBox
|
||||
|
||||
|
||||
def test_process_page_extracts_doclayout_caption(tmp_path):
|
||||
images_dest = tmp_path / "images"
|
||||
images_dest.mkdir()
|
||||
manifest: dict[str, dict] = {}
|
||||
|
||||
pix = MagicMock()
|
||||
pix.tobytes.return_value = b"jpeg"
|
||||
|
||||
page = MagicMock()
|
||||
page.rect.width = 600
|
||||
page.get_pixmap.return_value = pix
|
||||
page.get_text.return_value = "Figure 1: Overall architecture.\n"
|
||||
|
||||
doc = MagicMock()
|
||||
doc.__getitem__.return_value = page
|
||||
|
||||
boxes = [
|
||||
LayoutBox(100, 100, 300, 300, "picture"),
|
||||
LayoutBox(95, 310, 320, 325, "figure_caption"),
|
||||
]
|
||||
|
||||
extracted = mod._process_page(
|
||||
doc,
|
||||
0,
|
||||
boxes,
|
||||
images_dest=images_dest,
|
||||
manifest=manifest,
|
||||
seen_labels=set(),
|
||||
arxiv_id="2401.00001",
|
||||
)
|
||||
|
||||
assert extracted == 1
|
||||
info = manifest["figure_(p1-1).jpg"]
|
||||
assert info["caption_text"] == "Figure 1: Overall architecture."
|
||||
assert info["caption_source"] == "doclayout"
|
||||
assert info["caption_box"] == [95.0, 310.0, 320.0, 325.0]
|
||||
|
||||
|
||||
def test_process_page_includes_caption_in_render(tmp_path):
|
||||
"""渲染时把 caption 区域合并进同一张截图。"""
|
||||
images_dest = tmp_path / "images"
|
||||
images_dest.mkdir()
|
||||
manifest: dict[str, dict] = {}
|
||||
|
||||
pix = MagicMock()
|
||||
pix.tobytes.return_value = b"jpeg"
|
||||
|
||||
page = MagicMock()
|
||||
page.rect.width = 600
|
||||
page.get_pixmap.return_value = pix
|
||||
page.get_text.return_value = "Figure 1: Caption text.\n"
|
||||
|
||||
doc = MagicMock()
|
||||
doc.__getitem__.return_value = page
|
||||
|
||||
boxes = [
|
||||
LayoutBox(100, 100, 300, 300, "picture"),
|
||||
LayoutBox(95, 310, 320, 325, "figure_caption"),
|
||||
]
|
||||
|
||||
mod._process_page(
|
||||
doc,
|
||||
0,
|
||||
boxes,
|
||||
images_dest=images_dest,
|
||||
manifest=manifest,
|
||||
seen_labels=set(),
|
||||
arxiv_id="2401.00001",
|
||||
)
|
||||
|
||||
# 内容 [100,100,300,300] ∪ caption [95,310,320,325],各方向加 _REGION_PADDING=5
|
||||
# → Rect(90, 95, 325, 330)
|
||||
clip = page.get_pixmap.call_args.kwargs["clip"]
|
||||
assert clip == pymupdf.Rect(90, 95, 325, 330)
|
||||
|
||||
|
||||
def test_label_images_preserves_doclayout_caption(tmp_path, monkeypatch):
|
||||
arxiv_id = "2401.00001"
|
||||
paper_root = tmp_path / arxiv_id
|
||||
images_dest = paper_root / "images"
|
||||
images_dest.mkdir(parents=True)
|
||||
(images_dest / "figure_(p1-1).jpg").write_bytes(b"jpeg")
|
||||
(images_dest / "manifest.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"figure_(p1-1).jpg": {
|
||||
"page": 1,
|
||||
"type": "figure",
|
||||
"label": "Figure (p1-1)",
|
||||
"box": [100, 100, 300, 300],
|
||||
"caption_text": "Figure 1: PDF original caption.",
|
||||
"caption_source": "doclayout",
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
pdf_path = tmp_path / "paper.pdf"
|
||||
pdf_path.write_bytes(b"%PDF")
|
||||
monkeypatch.setattr(mod, "paper_dir", lambda _arxiv_id: paper_root)
|
||||
|
||||
page = MagicMock()
|
||||
page.search_for.return_value = [pymupdf.Rect(120, 305, 180, 320)]
|
||||
|
||||
fake_doc = MagicMock()
|
||||
fake_doc.page_count = 1
|
||||
fake_doc.__getitem__.return_value = page
|
||||
fake_doc.__enter__.return_value = fake_doc
|
||||
fake_doc.__exit__.return_value = False
|
||||
monkeypatch.setattr(mod.pymupdf, "open", lambda _path: fake_doc)
|
||||
|
||||
labeled = mod.label_images_by_summary(
|
||||
arxiv_id,
|
||||
[{"id": "Figure 1", "caption": "Summary caption."}],
|
||||
pdf_path=pdf_path,
|
||||
)
|
||||
|
||||
assert labeled == 1
|
||||
manifest = json.loads((images_dest / "manifest.json").read_text())
|
||||
info = manifest["figure_1.jpg"]
|
||||
assert info["caption_text"] == "Figure 1: PDF original caption."
|
||||
assert info["caption_source"] == "doclayout"
|
||||
assert info["summary_caption_text"] == "Summary caption."
|
||||
@@ -366,6 +366,38 @@ class TestSummarizeOneFlow:
|
||||
result = await summarize_one(db_session, sample_paper)
|
||||
assert result["status"] == "skipped"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_processing_runs_in_thread(
|
||||
self, db_session, sample_paper, mock_pi_output, _summarize_tmp_paths
|
||||
):
|
||||
"""后处理(图片提取/ChromaDB)在工作线程而非事件循环线程执行。"""
|
||||
import threading
|
||||
|
||||
seen_threads: list[int] = []
|
||||
main_thread = threading.current_thread().ident
|
||||
|
||||
def spy_extract(arxiv_id, schema):
|
||||
seen_threads.append(threading.current_thread().ident)
|
||||
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch(
|
||||
"app.services.summary_generator.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_pi_output, "test-session-id"),
|
||||
),
|
||||
patch(
|
||||
"app.services.summary_persister._maybe_extract_images",
|
||||
side_effect=spy_extract,
|
||||
),
|
||||
patch("app.services.summary_persister._maybe_index_chroma"),
|
||||
):
|
||||
result = await summarize_one(db_session, sample_paper)
|
||||
|
||||
assert result["status"] == "done"
|
||||
assert seen_threads, "后处理未被调用"
|
||||
assert seen_threads[0] != main_thread, "后处理应在工作线程执行,不阻塞事件循环"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 批量操作测试
|
||||
|
||||
Reference in New Issue
Block a user