feat: add concurrency safety, caption detection, admin enhancements, and performance improvements

This commit is contained in:
2026-06-14 22:20:02 +08:00
parent 8f13c31991
commit 29fb20828e
23 changed files with 1782 additions and 114 deletions
+20
View File
@@ -24,6 +24,26 @@ from app.models import (
from app.utils import utc_now
# ── ChromaDB 隔离(autouse,所有测试)──────────────────────────────────
@pytest.fixture(autouse=True)
def _isolate_chroma(monkeypatch, tmp_path):
"""所有测试把 ChromaDB 隔离到临时目录 + 重置单例,绝不污染 data/chroma。
与内存 DB 隔离同理:summarize 后处理经真实 _maybe_index_chroma → index_paper
写入,不隔离会把测试夹具(2401.*)泄漏到生产 data/chroma,污染语义搜索。
每个测试前重置 _chroma 单例,确保 CHROMA_DIR 指向本次 tmp。
"""
import app.services.embedder as emb
from app.config import settings
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
emb._chroma.reset()
yield
emb._chroma.reset()
# ── 内存数据库 ──────────────────────────────────────────────────────────
+380
View File
@@ -0,0 +1,380 @@
"""管理后台新功能测试 — 任务监控、锁释放、重抓、失败分布、配置、导出、重建索引。"""
from __future__ import annotations
import pytest
from sqlalchemy import select
from unittest.mock import AsyncMock, patch
from app.models import Job, JobStatus, SummaryState, SummaryStatus, TaskLock
from app.services import admin as admin_svc
from app.services.crawler import recrawl_single
from app.services.jobs import create_job, run_job
from app.utils import utc_now
@pytest.fixture
def no_enqueue(monkeypatch):
"""禁用路由层的 enqueue_job,阻止 background task 在测试中真实执行。"""
from app.routes import admin as admin_route
monkeypatch.setattr(admin_route, "enqueue_job", lambda *a, **k: None)
# ═══════════════════════════════════════════════════════════════════════
# 任务监控
# ═══════════════════════════════════════════════════════════════════════
def _make_job(db_session, *, type="crawl_daily", status=JobStatus.QUEUED, owner="t"):
job = Job(
type=type,
status=status,
owner=owner,
payload_json="{}",
created_at=utc_now(),
)
db_session.add(job)
db_session.commit()
return job
def test_query_jobs_filter_and_pagination(db_session):
for i in range(25):
_make_job(db_session, status=JobStatus.SUCCESS)
for i in range(5):
_make_job(db_session, status=JobStatus.FAILED)
# 无过滤:分页
page1, total = admin_svc.query_jobs(db_session, page=1, per_page=20)
assert total == 30
assert len(page1) == 20
page2, _ = admin_svc.query_jobs(db_session, page=2, per_page=20)
assert len(page2) == 10
# status 过滤
failed, ftotal = admin_svc.query_jobs(db_session, status="failed", per_page=50)
assert ftotal == 5
assert len(failed) == 5
assert all(j["status"] == "failed" for j in failed)
# type 过滤
_make_job(db_session, type="reindex_fts", status=JobStatus.QUEUED)
typed, ttotal = admin_svc.query_jobs(db_session, job_type="reindex_fts")
assert ttotal == 1
assert typed[0]["type"] == "reindex_fts"
def test_serialize_job_includes_duration(db_session):
job = _make_job(db_session, status=JobStatus.SUCCESS)
job.started_at = utc_now()
job.completed_at = utc_now()
db_session.commit()
serialized = admin_svc.serialize_job(job)
assert serialized["duration_seconds"] is not None
assert serialized["duration_seconds"] >= 0
def test_serialize_job_running_without_completed(db_session):
# 运行中的 jobcompleted_at=Nonestarted_at 经 db 读回为 naive UTC
# 不能与 aware 的 utc_now() 直接相减(回归测试)。
job = _make_job(db_session, status=JobStatus.RUNNING)
job.started_at = utc_now()
db_session.commit()
serialized = admin_svc.serialize_job(job)
assert serialized["duration_seconds"] is not None
assert serialized["duration_seconds"] >= 0
def test_get_job_status_counts(db_session):
_make_job(db_session, status=JobStatus.QUEUED)
_make_job(db_session, status=JobStatus.QUEUED)
_make_job(db_session, status=JobStatus.RUNNING)
counts = admin_svc.get_job_status_counts(db_session)
assert counts.get("queued") == 2
assert counts.get("running") == 1
# ═══════════════════════════════════════════════════════════════════════
# 锁释放
# ═══════════════════════════════════════════════════════════════════════
def test_force_release_lock(db_session):
running = TaskLock(
task="crawl", lock_key="k1", status="running", acquired_at=utc_now()
)
stale = TaskLock(task="crawl", lock_key="k2", status="stale", acquired_at=utc_now())
finished = TaskLock(
task="crawl", lock_key="k3", status="finished", acquired_at=utc_now()
)
db_session.add_all([running, stale, finished])
db_session.commit()
assert admin_svc.force_release_lock(db_session, running.id) is True
assert admin_svc.force_release_lock(db_session, stale.id) is True
# finished 的不应被再次释放
assert admin_svc.force_release_lock(db_session, finished.id) is False
# 不存在的 id
assert admin_svc.force_release_lock(db_session, 999999) is False
db_session.refresh(running)
db_session.refresh(stale)
db_session.refresh(finished)
assert running.status == "finished"
assert running.released_at is not None
assert stale.status == "finished"
assert finished.status == "finished"
# ═══════════════════════════════════════════════════════════════════════
# 失败原因分布
# ═══════════════════════════════════════════════════════════════════════
def test_get_failure_breakdown(db_session, sample_papers_range):
statuses = (
db_session.execute(select(SummaryStatus).order_by(SummaryStatus.id))
.scalars()
.all()
)
statuses[0].status = SummaryState.FAILED
statuses[0].error_type = "pdf_download_failed"
statuses[1].status = SummaryState.FAILED
statuses[1].error_type = "timeout"
statuses[2].status = SummaryState.PERMANENT_FAILURE
statuses[2].error_type = None # 归 unknown
db_session.commit()
breakdown = admin_svc.get_failure_breakdown(db_session)
by_type = {b["error_type"]: b["count"] for b in breakdown}
assert by_type.get("pdf_download_failed") == 1
assert by_type.get("timeout") == 1
assert by_type.get("unknown") == 1
# 降序
counts = [b["count"] for b in breakdown]
assert counts == sorted(counts, reverse=True)
# ═══════════════════════════════════════════════════════════════════════
# 配置概览
# ═══════════════════════════════════════════════════════════════════════
def test_get_config_overview_no_secrets():
cfg = admin_svc.get_config_overview()
assert "summary_backend" in cfg
assert "schedule_time" in cfg
assert "api_key_configured" in cfg # 只标是否配置,不显值
text = str(cfg)
# 不应泄露默认密钥值
assert "change-me" not in text
# ═══════════════════════════════════════════════════════════════════════
# 单篇/批量重抓
# ═══════════════════════════════════════════════════════════════════════
class TestRecrawl:
@pytest.mark.asyncio
async def test_not_found(self, db_session):
res = await recrawl_single(db_session, "9999.99999")
assert res["updated"] is False
assert res["reason"] == "not_found"
@pytest.mark.asyncio
async def test_updates_full_metadata(self, db_session, sample_paper):
new_item = {
"paper": {
"id": sample_paper.arxiv_id,
"title": "Updated Title",
"abstract": "New abstract",
"publishedAt": "2024-01-15T00:00:00",
"authors": [{"name": "New Author"}],
"tags": [{"name": "CV"}, {"name": "Diffusion"}],
"upvotes": 100,
}
}
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
return_value=[new_item],
):
res = await recrawl_single(db_session, sample_paper.arxiv_id)
assert res["updated"] is True
db_session.refresh(sample_paper)
assert sample_paper.title_en == "Updated Title"
assert sample_paper.abstract == "New abstract"
assert sample_paper.upvotes == 100
# authors 重建(原 Alice/Bob → New Author
assert sorted(a.name for a in sample_paper.authors) == ["New Author"]
# tags 重建(原 NLP/LLM → CV/Diffusion
assert sorted(t.tag for t in sample_paper.tags) == ["CV", "Diffusion"]
@pytest.mark.asyncio
async def test_not_in_daily(self, db_session, sample_paper):
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
return_value=[],
):
res = await recrawl_single(db_session, sample_paper.arxiv_id)
assert res["updated"] is False
assert res["reason"] == "not_in_daily"
assert "date" in res
class TestDispatchRecrawl:
@pytest.mark.asyncio
async def test_recrawl_one_via_run_job(self, db_session, sample_paper):
new_item = {
"paper": {
"id": sample_paper.arxiv_id,
"title": "Via Job",
"authors": [],
"tags": [],
"upvotes": 5,
}
}
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
return_value=[new_item],
):
job = create_job(
db_session,
"recrawl_one",
owner="test",
payload={"arxiv_id": sample_paper.arxiv_id},
)
result = await run_job(db_session, job.id)
assert result["updated"] is True
db_session.refresh(sample_paper)
assert sample_paper.title_en == "Via Job"
@pytest.mark.asyncio
async def test_recrawl_batch_via_run_job(self, db_session, sample_papers_range):
arxiv_ids = [p.arxiv_id for p in sample_papers_range[:2]]
items = [
{
"paper": {
"id": aid,
"title": "Batch " + aid,
"authors": [],
"tags": [],
"upvotes": 1,
}
}
for aid in arxiv_ids
]
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
side_effect=lambda d: items,
):
job = create_job(
db_session,
"recrawl_batch",
owner="test",
payload={"arxiv_ids": arxiv_ids},
)
result = await run_job(db_session, job.id)
assert result["updated"] == 2
assert result["skipped"] == 0
# ═══════════════════════════════════════════════════════════════════════
# 路由
# ═══════════════════════════════════════════════════════════════════════
class TestRoutes:
def test_jobs_page_renders(self, auth_client):
resp = auth_client.get("/admin/jobs")
assert resp.status_code == 200
assert "任务监控" in resp.text
def test_jobs_page_filters_by_status(self, auth_client, db_session):
_make_job(db_session, status=JobStatus.FAILED)
resp = auth_client.get("/admin/jobs?status=failed")
assert resp.status_code == 200
def test_export_csv(self, auth_client, sample_papers_range):
resp = auth_client.get("/admin/papers/export.csv")
assert resp.status_code == 200
assert "text/csv" in resp.headers["content-type"]
# UTF-8 BOM for Excel
assert resp.content.startswith(b"\xef\xbb\xbf")
# 表头 + 数据
assert "arxiv_id" in resp.text
assert "2401.10001" in resp.text
def test_export_csv_respects_filter(self, auth_client, sample_papers_range):
resp = auth_client.get("/admin/papers/export.csv?q=Paper%203")
assert resp.status_code == 200
assert "2401.10003" in resp.text
assert "2401.10001" not in resp.text
def test_rebuild_indexes_fts(self, auth_client, db_session, no_enqueue):
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "fts"})
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "queued"
assert len(data["job_ids"]) == 1
jobs = (
db_session.execute(select(Job).where(Job.type == "reindex_fts"))
.scalars()
.all()
)
assert len(jobs) == 1
def test_rebuild_indexes_both(self, auth_client, db_session, no_enqueue):
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "both"})
assert resp.status_code == 200
data = resp.json()
assert len(data["job_ids"]) == 2
def test_release_lock_route(self, auth_client, db_session):
lock = TaskLock(
task="crawl", lock_key="rt", status="running", acquired_at=utc_now()
)
db_session.add(lock)
db_session.commit()
resp = auth_client.post(f"/admin/locks/{lock.id}/release")
assert resp.status_code == 200
db_session.refresh(lock)
assert lock.status == "finished"
def test_paper_recrawl_route(
self, auth_client, sample_paper, db_session, no_enqueue
):
resp = auth_client.post(f"/admin/paper-recrawl/{sample_paper.arxiv_id}")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "queued"
jobs = (
db_session.execute(select(Job).where(Job.type == "recrawl_one"))
.scalars()
.all()
)
assert len(jobs) == 1
def test_batch_recrawl_route(
self, auth_client, sample_papers_range, db_session, no_enqueue
):
ids = [p.arxiv_id for p in sample_papers_range[:3]]
resp = auth_client.post(
"/admin/papers-batch-action", json={"action": "recrawl", "arxiv_ids": ids}
)
assert resp.status_code == 200
assert resp.json()["status"] == "queued"
jobs = (
db_session.execute(select(Job).where(Job.type == "recrawl_batch"))
.scalars()
.all()
)
assert len(jobs) == 1
+83
View File
@@ -2,6 +2,8 @@
from __future__ import annotations
import threading
import time
from unittest.mock import MagicMock, patch
@@ -154,3 +156,84 @@ class TestEmbeddingApi:
)
result = emb._get_embedding("test")
assert result is None
# ═══════════════════════════════════════════════════════════════════════
# 并发安全:init() 双重检查锁 + 集合访问串行化
# ═══════════════════════════════════════════════════════════════════════
class TestEmbedderConcurrency:
"""后处理经 asyncio.to_thread 多 worker 并发调 index_paper 的安全性。"""
def test_init_serialized_under_concurrency(self, monkeypatch, tmp_path):
"""并发 init() 只调一次 PersistentClientchromadb SharedSystemClient 缓存竞争修复)。
复现崩坏条件:10 线程同时 init()fake PersistentClient 故意 sleep 拉长建连窗口。
修复前会有多线程同时进入 _create_system_if_not_exists → 并发 mutate 类级缓存;
修复后(双重检查锁)只有抢到锁的那个线程建连。
"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
import app.services.embedder as emb
emb._chroma.reset()
counter = {"n": 0}
counter_lock = threading.Lock()
def fake_persistent_client(path):
with counter_lock:
counter["n"] += 1
time.sleep(0.05) # 拉长建连窗口,放大并发竞争
client = MagicMock()
client.get_collection.side_effect = Exception(
"not exist"
) # 触发 create 路径
client.create_collection.return_value = MagicMock()
return client
with patch("chromadb.PersistentClient", side_effect=fake_persistent_client):
threads = [threading.Thread(target=emb._chroma.init) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert counter["n"] == 1, f"PersistentClient 应只调一次,实际 {counter['n']}"
assert emb._chroma._client is not None
emb._chroma.reset()
def test_index_paper_concurrent_no_error(self, monkeypatch, tmp_path):
"""并发 index_paperembedding 锁外并行,集合写入串行化,全部成功。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
import app.services.embedder as emb
emb._chroma.reset()
# 跳过 init,直接注入 mock collection
emb._chroma._client = MagicMock()
col = MagicMock()
col.count.return_value = 0
emb._chroma._collection = col
with patch.object(emb, "_get_embedding", return_value=[0.1, 0.2, 0.3]):
errors: list[BaseException] = []
def worker(i: int) -> None:
try:
emb.index_paper(
f"id-{i}", {"arxiv_id": f"id-{i}", "title_zh": f"标题{i}"}
)
except BaseException as exc: # noqa: BLE001 — 收集所有错误
errors.append(exc)
threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert errors == []
assert col.upsert.call_count == 10
emb._chroma.reset()
+160 -4
View File
@@ -7,6 +7,8 @@
from __future__ import annotations
import json
import threading
import time
from unittest.mock import MagicMock
import numpy as np
@@ -166,10 +168,10 @@ class TestClassMapping:
def test_table(self):
assert _map_class_to_boxclass(5, {5: "table"}) == "table"
def test_caption_ignored(self):
def test_caption_classes(self):
names = {4: "figure_caption", 6: "table_caption"}
assert _map_class_to_boxclass(4, names) is None
assert _map_class_to_boxclass(6, names) is None
assert _map_class_to_boxclass(4, names) == "figure_caption"
assert _map_class_to_boxclass(6, names) == "table_caption"
def test_other_classes_ignored(self):
names = {0: "title", 1: "plain text", 2: "abandon", 8: "isolate_formula"}
@@ -240,9 +242,11 @@ class TestPostprocessOutput:
class TestDetectPage:
@pytest.fixture(autouse=True)
def _reset_detector(self):
"""每个测试前重置模块级单例,避免复用上个测试的 mock session。"""
"""每个测试前重建单例(带新锁 + 空 session),避免复用上个测试的 mock session。"""
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector()
yield
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector()
@staticmethod
@@ -338,6 +342,20 @@ class TestDetectPage:
assert len(boxes) == 1
assert boxes[0].boxclass == "table"
def test_returns_caption_box_with_small_height(self, monkeypatch, tmp_path):
names = {4: "figure_caption"}
sess, (pw, ph) = self._build_mock_session(
595, 842, [(4, 100, 405, 300, 417, 0.9)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
assert boxes[0].boxclass == "figure_caption"
assert boxes[0].y1 - boxes[0].y0 == pytest.approx(12, abs=1.0)
def test_filters_low_confidence(self, monkeypatch, tmp_path):
names = {3: "figure"}
# conf=0.1 < LAYOUT_THRESHOLD(0.2) → 过滤
@@ -401,3 +419,141 @@ class TestDetectPage:
boxes = detect_page_layout(page)
assert len(boxes) == 1
assert boxes[0].boxclass == "picture"
# ═══════════════════════════════════════════════════════════════════════
# 并发安全:锁串行化推理 + 单例 session 只初始化一次
# ═══════════════════════════════════════════════════════════════════════
class TestDetectPageConcurrency:
"""锁包裹整段 detect_page 后,并发调用的安全性。"""
@pytest.fixture(autouse=True)
def _reset_detector(self):
"""重建单例(带新锁),避免跨测试锁状态污染。"""
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector()
yield
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector()
@staticmethod
def _build_mock_session(page_w, page_h, boxes, names):
"""同 TestDetectPage._build_mock_session,额外返回 fake_output 供 side_effect。"""
ratio = _compute_render_geometry(page_w, page_h, IMGSZ)
pix_w, pix_h = round(page_w * ratio), round(page_h * ratio)
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
rows = []
for cls_id, x0, y0, x1, y1, conf in boxes:
rows.append(
[
x0 * ratio + dw,
y0 * ratio + dh,
x1 * ratio + dw,
y1 * ratio + dh,
conf,
cls_id,
]
)
fake_output = (
np.array([rows], dtype=np.float32)
if rows
else np.zeros((1, 0, 6), dtype=np.float32)
)
sess = MagicMock()
inp = MagicMock()
inp.name = "images"
sess.get_inputs.return_value = [inp]
sess.run.return_value = [fake_output]
sess.get_providers.return_value = ["CPUExecutionProvider"]
meta = MagicMock()
meta.custom_metadata_map = {
"names": json.dumps({str(k): v for k, v in names.items()})
}
sess.get_modelmeta.return_value = meta
return sess, (pix_w, pix_h), fake_output
@staticmethod
def _make_mock_page(page_w, page_h, pix_w, pix_h):
pix = MagicMock()
pix.width = pix_w
pix.height = pix_h
pix.n = 3
pix.samples = bytes([128] * (pix_w * pix_h * 3))
page = MagicMock()
page.rect.width = page_w
page.rect.height = page_h
page.get_pixmap.return_value = pix
return page
def _setup(self, monkeypatch, tmp_path, sess):
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
(tmp_path / "m.onnx").write_bytes(b"x")
monkeypatch.setattr(ort, "InferenceSession", lambda *a, **kw: sess)
def test_detect_page_serializes_concurrent_calls(self, monkeypatch, tmp_path):
"""多线程并发调 detect_page_layoutsession.run 临界区同时只有一个。"""
sess, (pw, ph), fake_output = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
)
in_critical = 0
max_concurrent = 0
counter_lock = threading.Lock()
def counting_run(*args, **kwargs):
nonlocal in_critical, max_concurrent
with counter_lock:
in_critical += 1
max_concurrent = max(max_concurrent, in_critical)
time.sleep(0.02) # 放大竞争窗口,让并发线程有机会重叠
try:
return [fake_output]
finally:
with counter_lock:
in_critical -= 1
sess.run.side_effect = counting_run
self._setup(monkeypatch, tmp_path, sess)
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(8)]
threads = [
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
]
for t in threads:
t.start()
for t in threads:
t.join()
# 锁生效 → 临界区同时只有一个;不加锁时此值会 > 1(回归保护)
assert max_concurrent == 1
def test_session_created_once_under_concurrency(self, monkeypatch, tmp_path):
"""多线程并发首次调用,InferenceSession 只创建一次(锁间接保护 _init_session)。"""
sess, (pw, ph), _fake_output = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
)
create_count = 0
create_lock = threading.Lock()
def counting_init(*args, **kwargs):
nonlocal create_count
with create_lock:
create_count += 1
time.sleep(0.02) # 放大窗口,让并发首调都来抢
return sess
monkeypatch.setattr(ort, "InferenceSession", counting_init)
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
(tmp_path / "m.onnx").write_bytes(b"x")
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(6)]
threads = [
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
]
for t in threads:
t.start()
for t in threads:
t.join()
assert create_count == 1
+134
View File
@@ -0,0 +1,134 @@
from __future__ import annotations
import json
from unittest.mock import MagicMock
import pymupdf
from app.services import pdf_image_extractor as mod
from app.services.layout_detector import LayoutBox
def test_process_page_extracts_doclayout_caption(tmp_path):
images_dest = tmp_path / "images"
images_dest.mkdir()
manifest: dict[str, dict] = {}
pix = MagicMock()
pix.tobytes.return_value = b"jpeg"
page = MagicMock()
page.rect.width = 600
page.get_pixmap.return_value = pix
page.get_text.return_value = "Figure 1: Overall architecture.\n"
doc = MagicMock()
doc.__getitem__.return_value = page
boxes = [
LayoutBox(100, 100, 300, 300, "picture"),
LayoutBox(95, 310, 320, 325, "figure_caption"),
]
extracted = mod._process_page(
doc,
0,
boxes,
images_dest=images_dest,
manifest=manifest,
seen_labels=set(),
arxiv_id="2401.00001",
)
assert extracted == 1
info = manifest["figure_(p1-1).jpg"]
assert info["caption_text"] == "Figure 1: Overall architecture."
assert info["caption_source"] == "doclayout"
assert info["caption_box"] == [95.0, 310.0, 320.0, 325.0]
def test_process_page_includes_caption_in_render(tmp_path):
"""渲染时把 caption 区域合并进同一张截图。"""
images_dest = tmp_path / "images"
images_dest.mkdir()
manifest: dict[str, dict] = {}
pix = MagicMock()
pix.tobytes.return_value = b"jpeg"
page = MagicMock()
page.rect.width = 600
page.get_pixmap.return_value = pix
page.get_text.return_value = "Figure 1: Caption text.\n"
doc = MagicMock()
doc.__getitem__.return_value = page
boxes = [
LayoutBox(100, 100, 300, 300, "picture"),
LayoutBox(95, 310, 320, 325, "figure_caption"),
]
mod._process_page(
doc,
0,
boxes,
images_dest=images_dest,
manifest=manifest,
seen_labels=set(),
arxiv_id="2401.00001",
)
# 内容 [100,100,300,300] caption [95,310,320,325],各方向加 _REGION_PADDING=5
# → Rect(90, 95, 325, 330)
clip = page.get_pixmap.call_args.kwargs["clip"]
assert clip == pymupdf.Rect(90, 95, 325, 330)
def test_label_images_preserves_doclayout_caption(tmp_path, monkeypatch):
arxiv_id = "2401.00001"
paper_root = tmp_path / arxiv_id
images_dest = paper_root / "images"
images_dest.mkdir(parents=True)
(images_dest / "figure_(p1-1).jpg").write_bytes(b"jpeg")
(images_dest / "manifest.json").write_text(
json.dumps(
{
"figure_(p1-1).jpg": {
"page": 1,
"type": "figure",
"label": "Figure (p1-1)",
"box": [100, 100, 300, 300],
"caption_text": "Figure 1: PDF original caption.",
"caption_source": "doclayout",
}
}
)
)
pdf_path = tmp_path / "paper.pdf"
pdf_path.write_bytes(b"%PDF")
monkeypatch.setattr(mod, "paper_dir", lambda _arxiv_id: paper_root)
page = MagicMock()
page.search_for.return_value = [pymupdf.Rect(120, 305, 180, 320)]
fake_doc = MagicMock()
fake_doc.page_count = 1
fake_doc.__getitem__.return_value = page
fake_doc.__enter__.return_value = fake_doc
fake_doc.__exit__.return_value = False
monkeypatch.setattr(mod.pymupdf, "open", lambda _path: fake_doc)
labeled = mod.label_images_by_summary(
arxiv_id,
[{"id": "Figure 1", "caption": "Summary caption."}],
pdf_path=pdf_path,
)
assert labeled == 1
manifest = json.loads((images_dest / "manifest.json").read_text())
info = manifest["figure_1.jpg"]
assert info["caption_text"] == "Figure 1: PDF original caption."
assert info["caption_source"] == "doclayout"
assert info["summary_caption_text"] == "Summary caption."
+32
View File
@@ -366,6 +366,38 @@ class TestSummarizeOneFlow:
result = await summarize_one(db_session, sample_paper)
assert result["status"] == "skipped"
@pytest.mark.asyncio
async def test_post_processing_runs_in_thread(
self, db_session, sample_paper, mock_pi_output, _summarize_tmp_paths
):
"""后处理(图片提取/ChromaDB)在工作线程而非事件循环线程执行。"""
import threading
seen_threads: list[int] = []
main_thread = threading.current_thread().ident
def spy_extract(arxiv_id, schema):
seen_threads.append(threading.current_thread().ident)
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=(mock_pi_output, "test-session-id"),
),
patch(
"app.services.summary_persister._maybe_extract_images",
side_effect=spy_extract,
),
patch("app.services.summary_persister._maybe_index_chroma"),
):
result = await summarize_one(db_session, sample_paper)
assert result["status"] == "done"
assert seen_threads, "后处理未被调用"
assert seen_threads[0] != main_thread, "后处理应在工作线程执行,不阻塞事件循环"
# ═══════════════════════════════════════════════════════════════════════
# 批量操作测试