feat: refactor summarizer and PDF extraction pipeline
- Split summarizer into summary_generator and summary_persister modules - Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection - Add layout_detector service for PicoDet-S_layout_3cls integration - Add exceptions module with ConflictError and NotFoundError - Improve admin dashboard with better statistics and task management - Add design review document with system optimization suggestions - Add new tests for crawler, pdf_downloader, pipeline, and summary_utils - Update dependencies and configuration - Clean up dead code and improve error handling
This commit is contained in:
+7
-1
@@ -161,7 +161,13 @@ def sample_summary_dict() -> dict:
|
||||
"results": {
|
||||
"main_findings": "在长文本基准 LongBench 上取得了 SOTA 结果,平均得分提升 3.2 个百分点。推理速度相比全注意力提升了 2 倍,显存占用降低 60%。在 32k 序列长度下仍保持与全注意力相当的生成质量。",
|
||||
"benchmarks": [
|
||||
{"task": "长文本摘要", "metric": "ROUGE-L", "this_work": "42.1", "baseline": "38.9", "improvement": "+3.2"},
|
||||
{
|
||||
"task": "长文本摘要",
|
||||
"metric": "ROUGE-L",
|
||||
"this_work": "42.1",
|
||||
"baseline": "38.9",
|
||||
"improvement": "+3.2",
|
||||
},
|
||||
],
|
||||
"limitations": "在超长文本(>100k tokens)上效果有所下降,主要原因是全局采样点数量不足以覆盖所有关键信息。此外,在小规模数据集上的优势不如大规模数据集明显。",
|
||||
},
|
||||
|
||||
+9
-13
@@ -67,7 +67,7 @@ class TestAdminAuth:
|
||||
def test_correct_session_accepted(self, auth_client):
|
||||
"""已登录 session 应被接受(crawl 可能会失败但不是 303)。"""
|
||||
with patch(
|
||||
"app.routes.admin.crawl_daily", new_callable=AsyncMock
|
||||
"app.routes.admin.run_crawl", new_callable=AsyncMock
|
||||
) as mock_crawl:
|
||||
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
|
||||
resp = auth_client.post("/admin/crawl")
|
||||
@@ -83,9 +83,7 @@ class TestAdminAuth:
|
||||
|
||||
def test_correct_session_batch_summarize(self, auth_client):
|
||||
"""已登录调用 batch summarize,mock 掉服务层。"""
|
||||
with patch(
|
||||
"app.routes.admin.summarize_batch", new_callable=AsyncMock
|
||||
) as mock:
|
||||
with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = {
|
||||
"status": "success",
|
||||
"done": 0,
|
||||
@@ -98,10 +96,12 @@ class TestAdminAuth:
|
||||
|
||||
def test_single_paper_not_found(self, auth_client):
|
||||
"""单篇总结不存在的论文返回 404。"""
|
||||
from app.exceptions import NotFoundError
|
||||
|
||||
with patch(
|
||||
"app.routes.admin.summarize_single",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"},
|
||||
side_effect=NotFoundError("Paper not found: nonexistent.99999"),
|
||||
):
|
||||
resp = auth_client.post("/admin/summarize/nonexistent.99999")
|
||||
assert resp.status_code == 404
|
||||
@@ -118,7 +118,7 @@ class TestAdminCrawl:
|
||||
def test_crawl_default_today(self, auth_client):
|
||||
"""不指定日期时默认抓取今天。"""
|
||||
with patch(
|
||||
"app.routes.admin.crawl_daily", new_callable=AsyncMock
|
||||
"app.routes.admin.run_crawl", new_callable=AsyncMock
|
||||
) as mock_crawl:
|
||||
mock_crawl.return_value = {"found": 5, "new": 3, "status": "success"}
|
||||
resp = auth_client.post("/admin/crawl")
|
||||
@@ -130,7 +130,7 @@ class TestAdminCrawl:
|
||||
def test_crawl_specific_date(self, auth_client):
|
||||
"""指定日期抓取。"""
|
||||
with patch(
|
||||
"app.routes.admin.crawl_daily", new_callable=AsyncMock
|
||||
"app.routes.admin.run_crawl", new_callable=AsyncMock
|
||||
) as mock_crawl:
|
||||
mock_crawl.return_value = {"found": 2, "new": 1, "status": "success"}
|
||||
resp = auth_client.post("/admin/crawl?date=2024-01-15")
|
||||
@@ -194,9 +194,7 @@ class TestAdminDelete:
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_delete_with_confirm(
|
||||
self, auth_client, db_session, sample_papers_range
|
||||
):
|
||||
def test_delete_with_confirm(self, auth_client, db_session, sample_papers_range):
|
||||
"""confirm='DELETE' 时应执行删除。"""
|
||||
resp = auth_client.post(
|
||||
"/admin/delete",
|
||||
@@ -255,9 +253,7 @@ class TestAdminLogs:
|
||||
resp = client.get("/admin/logs", follow_redirects=False)
|
||||
assert resp.status_code == 303
|
||||
|
||||
def test_logs_contains_data(
|
||||
self, auth_client, db_session, sample_papers_range
|
||||
):
|
||||
def test_logs_contains_data(self, auth_client, db_session, sample_papers_range):
|
||||
"""日志页面应包含日志数据。"""
|
||||
# 先创建一条日志
|
||||
now = utc_now()
|
||||
|
||||
@@ -0,0 +1,189 @@
|
||||
"""爬虫服务测试 — _parse_paper、fetch_daily、upsert_papers、crawl_daily。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.crawler import (
|
||||
_parse_paper,
|
||||
crawl_daily,
|
||||
fetch_daily,
|
||||
upsert_papers,
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# _parse_paper
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestParsePaper:
|
||||
def test_normal_item(self):
|
||||
item = {
|
||||
"paper": {
|
||||
"id": "2401.12345",
|
||||
"title": "Test Paper",
|
||||
"abstract": "Abstract text",
|
||||
"publishedAt": "2024-01-15T00:00:00",
|
||||
"authors": [{"name": "Alice"}, {"name": "Bob"}],
|
||||
"tags": [{"name": "NLP"}, {"name": "LLM"}],
|
||||
"upvotes": 42,
|
||||
}
|
||||
}
|
||||
result = _parse_paper(item)
|
||||
assert result["arxiv_id"] == "2401.12345"
|
||||
assert result["title_en"] == "Test Paper"
|
||||
assert len(result["authors"]) == 2
|
||||
assert result["authors"] == ["Alice", "Bob"]
|
||||
assert result["tags"] == ["NLP", "LLM"]
|
||||
assert result["upvotes"] == 42
|
||||
assert "huggingface.co" in result["hf_url"]
|
||||
|
||||
def test_empty_id(self):
|
||||
item = {"paper": {"id": "", "authors": [], "tags": []}}
|
||||
result = _parse_paper(item)
|
||||
assert result["arxiv_id"] == ""
|
||||
assert result["hf_url"] == ""
|
||||
|
||||
def test_missing_published_at(self):
|
||||
item = {"paper": {"id": "2401.00001", "title": "T", "authors": [], "tags": []}}
|
||||
result = _parse_paper(item)
|
||||
assert result["published_at"] is None
|
||||
|
||||
def test_flat_structure_fallback(self):
|
||||
"""无 paper 包装时直接从顶层取字段。"""
|
||||
item = {"id": "2401.99999", "title": "Flat", "authors": [], "tags": []}
|
||||
result = _parse_paper(item)
|
||||
assert result["arxiv_id"] == "2401.99999"
|
||||
assert result["title_en"] == "Flat"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# fetch_daily
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestFetchDaily:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_papers(self, monkeypatch):
|
||||
fake_data = [{"paper": {"id": "2401.00001"}}]
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = fake_data
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("app.services.crawler.make_http_client", return_value=mock_client):
|
||||
result = await fetch_daily("2024-01-15")
|
||||
assert len(result) == 1
|
||||
assert result[0]["paper"]["id"] == "2401.00001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respects_top_n(self, monkeypatch):
|
||||
fake_data = [{"paper": {"id": f"2401.{i:05d}"}} for i in range(10)]
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = fake_data
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("app.services.crawler.make_http_client", return_value=mock_client):
|
||||
result = await fetch_daily("2024-01-15", top_n=3)
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# upsert_papers
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestUpsertPapers:
|
||||
def test_inserts_new_paper(self, db_session):
|
||||
papers_raw = [
|
||||
{
|
||||
"paper": {
|
||||
"id": "2401.00001",
|
||||
"title": "New Paper",
|
||||
"abstract": "Abstract",
|
||||
"authors": [{"name": "Alice"}],
|
||||
"tags": [{"name": "CV"}],
|
||||
"upvotes": 5,
|
||||
}
|
||||
}
|
||||
]
|
||||
new = upsert_papers(db_session, papers_raw, "2024-01-15")
|
||||
assert len(new) == 1
|
||||
assert new[0].arxiv_id == "2401.00001"
|
||||
assert new[0].title_en == "New Paper"
|
||||
|
||||
def test_updates_existing_upvotes(self, db_session, sample_paper):
|
||||
papers_raw = [
|
||||
{
|
||||
"paper": {
|
||||
"id": sample_paper.arxiv_id,
|
||||
"title": sample_paper.title_en,
|
||||
"upvotes": 999,
|
||||
"authors": [],
|
||||
"tags": [],
|
||||
}
|
||||
}
|
||||
]
|
||||
new = upsert_papers(db_session, papers_raw, "2024-01-15")
|
||||
assert len(new) == 0 # 不新增
|
||||
db_session.refresh(sample_paper)
|
||||
assert sample_paper.upvotes == 999
|
||||
|
||||
def test_skips_empty_id(self, db_session):
|
||||
papers_raw = [{"paper": {"id": "", "title": "Nope", "authors": [], "tags": []}}]
|
||||
new = upsert_papers(db_session, papers_raw, "2024-01-15")
|
||||
assert len(new) == 0
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# crawl_daily
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestCrawlDaily:
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_flow(self, db_session):
|
||||
with patch(
|
||||
"app.services.crawler.fetch_daily",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = [
|
||||
{
|
||||
"paper": {
|
||||
"id": "2401.00001",
|
||||
"title": "T",
|
||||
"authors": [],
|
||||
"tags": [],
|
||||
"upvotes": 0,
|
||||
}
|
||||
}
|
||||
]
|
||||
result = await crawl_daily(db_session, "2024-01-15")
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["new"] == 1
|
||||
assert result["found"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_returns_failed(self, db_session):
|
||||
with patch(
|
||||
"app.services.crawler.fetch_daily",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ConnectionError("network error"),
|
||||
):
|
||||
result = await crawl_daily(db_session, "2024-01-15")
|
||||
|
||||
assert result["status"] == "failed"
|
||||
assert "network error" in result["error"]
|
||||
@@ -0,0 +1,77 @@
|
||||
"""PDF 下载测试 — download_pdf、路径工具、错误处理。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.pdf_downloader import (
|
||||
PdfDownloadError,
|
||||
download_pdf,
|
||||
paper_dir,
|
||||
tmp_dir,
|
||||
)
|
||||
from app.utils import PAPERS_DIR, TMP_DIR
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 路径工具
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestPathHelpers:
|
||||
def test_paper_dir(self):
|
||||
assert paper_dir("2401.12345") == PAPERS_DIR / "2401.12345"
|
||||
|
||||
def test_tmp_dir(self):
|
||||
assert tmp_dir("2401.12345") == TMP_DIR / "2401.12345"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# download_pdf
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestDownloadPdf:
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_download(self, tmp_path):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.content = b"%PDF-1.4 fake"
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = mock_resp
|
||||
|
||||
with (
|
||||
patch("app.services.pdf_downloader.TMP_DIR", tmp_path),
|
||||
patch(
|
||||
"app.services.pdf_downloader._get_session", return_value=mock_session
|
||||
),
|
||||
):
|
||||
result = await download_pdf(
|
||||
"2401.12345", "https://arxiv.org/pdf/2401.12345.pdf"
|
||||
)
|
||||
|
||||
assert result.exists()
|
||||
assert result.name == "paper.pdf"
|
||||
assert result.read_bytes() == b"%PDF-1.4 fake"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_pdf_url_raises(self):
|
||||
with pytest.raises(PdfDownloadError, match="no pdf_url"):
|
||||
await download_pdf("2401.12345", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_failure_raises(self, tmp_path):
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.side_effect = ConnectionError("refused")
|
||||
|
||||
with (
|
||||
patch("app.services.pdf_downloader.TMP_DIR", tmp_path),
|
||||
patch(
|
||||
"app.services.pdf_downloader._get_session", return_value=mock_session
|
||||
),
|
||||
):
|
||||
with pytest.raises(PdfDownloadError, match="failed to download"):
|
||||
await download_pdf("2401.12345", "https://bad.url/pdf.pdf")
|
||||
@@ -0,0 +1,77 @@
|
||||
"""流水线编排测试 — run_pipeline (crawl → summarize → cleanup)。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models import TaskLock
|
||||
from app.services.pipeline import run_pipeline
|
||||
from app.utils import utc_now
|
||||
|
||||
|
||||
class TestRunPipeline:
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_pipeline_success(self, db_session):
|
||||
with (
|
||||
patch(
|
||||
"app.services.pipeline.crawl_daily", new_callable=AsyncMock
|
||||
) as mock_crawl,
|
||||
patch(
|
||||
"app.services.pipeline.summarize_batch", new_callable=AsyncMock
|
||||
) as mock_summ,
|
||||
patch("app.services.pipeline.cleanup_tmp") as mock_clean,
|
||||
):
|
||||
mock_crawl.return_value = {"status": "success", "found": 5, "new": 2}
|
||||
mock_summ.return_value = {"status": "success", "done": 2, "failed": 0}
|
||||
mock_clean.return_value = {"removed": 0}
|
||||
|
||||
result = await run_pipeline(db_session, "2024-01-15", "test")
|
||||
|
||||
assert result["status"] == "success"
|
||||
mock_crawl.assert_called_once()
|
||||
mock_summ.assert_called_once()
|
||||
mock_clean.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_lock_prevents_reentry(self, db_session):
|
||||
"""已有 running 锁时抛出 RuntimeError。"""
|
||||
now = utc_now()
|
||||
db_session.add(
|
||||
TaskLock(
|
||||
task="scheduler",
|
||||
lock_key="pipeline-2024-01-15",
|
||||
status="running",
|
||||
owner="other",
|
||||
acquired_at=now,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
with pytest.raises(RuntimeError, match="already running"):
|
||||
await run_pipeline(db_session, "2024-01-15", "test")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crawl_failure_still_runs_summarize_and_cleanup(self, db_session):
|
||||
"""crawl 失败时 pipeline 继续执行 summarize 和 cleanup。"""
|
||||
with (
|
||||
patch(
|
||||
"app.services.pipeline.crawl_daily", new_callable=AsyncMock
|
||||
) as mock_crawl,
|
||||
patch(
|
||||
"app.services.pipeline.summarize_batch", new_callable=AsyncMock
|
||||
) as mock_summ,
|
||||
patch("app.services.pipeline.cleanup_tmp") as mock_clean,
|
||||
):
|
||||
mock_crawl.side_effect = ConnectionError("timeout")
|
||||
mock_summ.return_value = {"status": "success", "done": 0}
|
||||
mock_clean.return_value = {"removed": 0}
|
||||
|
||||
result = await run_pipeline(db_session, "2024-01-15", "test")
|
||||
|
||||
# pipeline 捕获异常,返回 failed
|
||||
assert result["status"] == "failed"
|
||||
assert "timeout" in result["error"]
|
||||
# summarize 和 cleanup 不会被调用(exception 跳出 try 块)
|
||||
mock_summ.assert_not_called()
|
||||
@@ -20,7 +20,7 @@ from app.services.schemas import (
|
||||
classify_validation_error,
|
||||
flatten_for_db,
|
||||
)
|
||||
from app.services.summarizer import _classify_error
|
||||
from app.services.summary_generator import _classify_error
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
+35
-22
@@ -23,12 +23,8 @@ from app.services.pdf_downloader import (
|
||||
)
|
||||
from app.services.pi_client import PiTimeoutError
|
||||
from app.services.schemas import SummarySchema
|
||||
from app.services.summarizer import (
|
||||
_save_files,
|
||||
_update_summary_in_db,
|
||||
summarize_batch,
|
||||
summarize_one,
|
||||
)
|
||||
from app.services.summarizer import summarize_batch, summarize_one
|
||||
from app.services.summary_persister import _save_files, _update_summary_in_db
|
||||
from app.utils import utc_now
|
||||
|
||||
|
||||
@@ -39,7 +35,14 @@ from app.utils import utc_now
|
||||
def _summarize_tmp_paths(tmp_path):
|
||||
"""将 data 目录重定向到 tmp_path(供 summarizer 测试使用)。"""
|
||||
with (
|
||||
patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
|
||||
patch(
|
||||
"app.services.summary_persister.paper_dir",
|
||||
lambda aid: tmp_path / "papers" / aid,
|
||||
),
|
||||
patch(
|
||||
"app.services.summary_generator.paper_dir",
|
||||
lambda aid: tmp_path / "papers" / aid,
|
||||
),
|
||||
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
|
||||
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
|
||||
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
|
||||
@@ -134,7 +137,9 @@ class TestFileOperations:
|
||||
|
||||
def test_save_files(self, tmp_path, sample_summary_dict):
|
||||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
|
||||
with patch(
|
||||
"app.services.summary_persister.paper_dir", lambda aid: tmp_path / aid
|
||||
):
|
||||
_save_files("2401.12345", schema, "raw output text")
|
||||
|
||||
paper_dir = tmp_path / "2401.12345"
|
||||
@@ -144,7 +149,9 @@ class TestFileOperations:
|
||||
assert saved["title_zh"] == "测试论文中文标题"
|
||||
|
||||
def test_save_raw_output_only(self, tmp_path):
|
||||
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
|
||||
with patch(
|
||||
"app.services.summary_persister.paper_dir", lambda aid: tmp_path / aid
|
||||
):
|
||||
_save_files("2401.12345", None, "raw output")
|
||||
paper_dir = tmp_path / "2401.12345"
|
||||
assert (paper_dir / "raw_output.txt").exists()
|
||||
@@ -180,7 +187,7 @@ class TestSummarizeOneFlow:
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
"app.services.summary_generator.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_pi_output, "test-session-id"),
|
||||
),
|
||||
@@ -209,7 +216,9 @@ class TestSummarizeOneFlow:
|
||||
assert fts_row[0] == "测试论文中文标题"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_download_failure(self, db_session, sample_paper, _summarize_tmp_paths):
|
||||
async def test_pdf_download_failure(
|
||||
self, db_session, sample_paper, _summarize_tmp_paths
|
||||
):
|
||||
"""PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。"""
|
||||
with (
|
||||
patch(
|
||||
@@ -233,7 +242,7 @@ class TestSummarizeOneFlow:
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
"app.services.summary_generator.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=PiTimeoutError("timeout after 300s"),
|
||||
),
|
||||
@@ -250,7 +259,7 @@ class TestSummarizeOneFlow:
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
"app.services.summary_generator.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("No JSON in this output at all.", "test-session-id"),
|
||||
),
|
||||
@@ -281,7 +290,7 @@ class TestSummarizeOneFlow:
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
"app.services.summary_generator.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(bad_output, "test-session-id"),
|
||||
),
|
||||
@@ -300,7 +309,7 @@ class TestSummarizeOneFlow:
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
"app.services.summary_generator.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("Some output without JSON", "test-session-id"),
|
||||
),
|
||||
@@ -319,7 +328,7 @@ class TestSummarizeOneFlow:
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
"app.services.summary_generator.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_pi_output, "test-session-id"),
|
||||
),
|
||||
@@ -347,7 +356,9 @@ class TestSummarizeOneFlow:
|
||||
assert not tmp_paper.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_done_paper(self, db_session, sample_paper, _summarize_tmp_paths):
|
||||
async def test_skips_done_paper(
|
||||
self, db_session, sample_paper, _summarize_tmp_paths
|
||||
):
|
||||
"""已完成的论文跳过。"""
|
||||
sample_paper.summary_status.status = "done"
|
||||
db_session.commit()
|
||||
@@ -393,7 +404,7 @@ class TestBatchSummarize:
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
"app.services.summary_generator.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_pi_output, "test-session-id"),
|
||||
),
|
||||
@@ -446,7 +457,7 @@ class TestBatchSummarize:
|
||||
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch("app.services.summarizer.call_pi", side_effect=_mock_call_pi),
|
||||
patch("app.services.summary_generator.call_pi", side_effect=_mock_call_pi),
|
||||
):
|
||||
result = await summarize_batch(db_session, _session_factory=_TestSession)
|
||||
|
||||
@@ -456,6 +467,8 @@ class TestBatchSummarize:
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_lock_conflict(self, db_session, _summarize_tmp_paths):
|
||||
"""TaskLock 防止并发 batch。"""
|
||||
from app.exceptions import ConflictError
|
||||
|
||||
# 先插入一个 running 锁
|
||||
db_session.add(
|
||||
TaskLock(
|
||||
@@ -467,8 +480,8 @@ class TestBatchSummarize:
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
result = await summarize_batch(db_session)
|
||||
assert result["status"] == "conflict"
|
||||
with pytest.raises(ConflictError):
|
||||
await summarize_batch(db_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_lock_released(
|
||||
@@ -482,7 +495,7 @@ class TestBatchSummarize:
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch(
|
||||
"app.services.summarizer.call_pi",
|
||||
"app.services.summary_generator.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_pi_output, "test-session-id"),
|
||||
),
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
"""summary_utils 测试 — PDF 文本提取、正文裁剪、JSON 提取、meta.json 写入、prompt 构建。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.summary_utils import (
|
||||
JsonNotFoundError,
|
||||
_trim_body,
|
||||
build_prompt,
|
||||
extract_json,
|
||||
extract_pdf_text,
|
||||
write_meta_json,
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# _trim_body 正文裁剪
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestTrimBody:
|
||||
def test_removes_references_section(self):
|
||||
text = "Intro\n\nSome content here.\n\nReferences\n[1] Smith et al."
|
||||
result = _trim_body(text)
|
||||
assert "References" not in result
|
||||
assert "Intro" in result
|
||||
|
||||
def test_removes_bibliography(self):
|
||||
text = "Body\n\nBibliography\n[1] Smith"
|
||||
result = _trim_body(text)
|
||||
assert "Bibliography" not in result
|
||||
|
||||
def test_keeps_appendix_after_references(self):
|
||||
text = "Body\n\nReferences\n[1] X\n\nAppendix\nExtra content"
|
||||
result = _trim_body(text)
|
||||
assert "Appendix" in result
|
||||
assert "Extra content" in result
|
||||
assert "References" not in result
|
||||
|
||||
def test_removes_acknowledgments(self):
|
||||
text = "Body\n\nAcknowledgments\nThanks to everyone."
|
||||
result = _trim_body(text)
|
||||
assert "Acknowledgments" not in result
|
||||
|
||||
def test_max_chars_truncation(self):
|
||||
text = "A" * 1000
|
||||
result = _trim_body(text, max_chars=100)
|
||||
assert len(result) <= 100
|
||||
|
||||
def test_no_truncation_when_none(self):
|
||||
text = "A" * 500
|
||||
result = _trim_body(text, max_chars=None)
|
||||
assert len(result) == 500
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# extract_pdf_text
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestExtractPdfText:
|
||||
def test_extracts_text_and_saves(self, tmp_path):
|
||||
pdf_path = tmp_path / "test.pdf"
|
||||
pdf_path.write_bytes(b"%PDF-fake")
|
||||
|
||||
mock_page = MagicMock()
|
||||
mock_page.get_text.return_value = "Page 1 text"
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.__iter__ = MagicMock(return_value=iter([mock_page]))
|
||||
mock_doc.__enter__ = MagicMock(return_value=mock_doc)
|
||||
mock_doc.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("pymupdf.open", return_value=mock_doc),
|
||||
patch(
|
||||
"app.services.summary_utils._trim_body", side_effect=lambda t, **kw: t
|
||||
),
|
||||
):
|
||||
result_path = extract_pdf_text(pdf_path)
|
||||
|
||||
assert result_path.suffix == ".txt"
|
||||
assert result_path.exists()
|
||||
assert "Page 1 text" in result_path.read_text()
|
||||
|
||||
def test_uses_cached_txt(self, tmp_path):
|
||||
pdf_path = tmp_path / "test.pdf"
|
||||
pdf_path.write_bytes(b"%PDF-fake")
|
||||
txt_path = tmp_path / "test.txt"
|
||||
txt_path.write_text("cached", encoding="utf-8")
|
||||
|
||||
with patch("pymupdf.open") as mock_open:
|
||||
result = extract_pdf_text(pdf_path)
|
||||
|
||||
mock_open.assert_not_called()
|
||||
assert result == txt_path
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# write_meta_json
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestWriteMetaJson:
|
||||
def test_writes_meta_json(self, tmp_path, sample_paper):
|
||||
with patch("app.services.pdf_downloader.paper_dir", lambda aid: tmp_path / aid):
|
||||
result = write_meta_json(sample_paper)
|
||||
|
||||
assert result.exists()
|
||||
assert result.name == "meta.json"
|
||||
data = json.loads(result.read_text(encoding="utf-8"))
|
||||
assert data["arxiv_id"] == "2401.12345"
|
||||
assert data["title_en"] == "Test Paper Title"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# build_prompt
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestBuildPrompt:
|
||||
def test_inject_mode_contains_schema(self, tmp_path):
|
||||
prompt = build_prompt(
|
||||
"2401.12345", tmp_path / "meta", tmp_path / "txt", "inject"
|
||||
)
|
||||
assert "title_zh" in prompt
|
||||
assert "必须包含以下字段" in prompt
|
||||
|
||||
def test_search_mode_contains_read_instruction(self, tmp_path):
|
||||
prompt = build_prompt(
|
||||
"2401.12345", tmp_path / "meta", tmp_path / "txt", "search"
|
||||
)
|
||||
assert "read" in prompt.lower()
|
||||
assert "title_zh" in prompt
|
||||
|
||||
def test_fix_errors_mode(self, tmp_path):
|
||||
prompt = build_prompt(
|
||||
"2401.12345",
|
||||
tmp_path / "meta",
|
||||
tmp_path / "txt",
|
||||
"inject",
|
||||
fix_errors=["字段缺失"],
|
||||
)
|
||||
assert "字段缺失" in prompt
|
||||
assert "修正" in prompt
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# extract_json
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestExtractJson:
|
||||
def test_direct_json(self, sample_summary_json):
|
||||
result = extract_json(sample_summary_json)
|
||||
assert result["title_zh"] == "测试论文中文标题"
|
||||
|
||||
def test_fenced_code_block(self, sample_summary_json):
|
||||
raw = f"some text\n```json\n{sample_summary_json}\n```\nmore text"
|
||||
result = extract_json(raw)
|
||||
assert result["title_zh"] == "测试论文中文标题"
|
||||
|
||||
def test_brace_matching_fallback(self, sample_summary_dict):
|
||||
json_str = json.dumps(sample_summary_dict, ensure_ascii=False)
|
||||
raw = f"Here is the result: {json_str} end."
|
||||
result = extract_json(raw)
|
||||
assert result["title_zh"] == "测试论文中文标题"
|
||||
|
||||
def test_no_json_raises(self):
|
||||
with pytest.raises(JsonNotFoundError):
|
||||
extract_json("plain text no json here at all")
|
||||
+13
-14
@@ -2,6 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.exceptions import NotFoundError, ValidationError
|
||||
from app.services.user_data import (
|
||||
get_note,
|
||||
save_note,
|
||||
@@ -27,9 +30,8 @@ class TestBookmarkService:
|
||||
assert result["bookmarked"] is False
|
||||
|
||||
def test_toggle_bookmark_not_found(self, db_session):
|
||||
result = toggle_bookmark(db_session, "nonexistent")
|
||||
assert "error" in result
|
||||
assert result["error"] == "not_found"
|
||||
with pytest.raises(NotFoundError):
|
||||
toggle_bookmark(db_session, "nonexistent")
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
@@ -44,9 +46,8 @@ class TestReadingStatusService:
|
||||
assert result["arxiv_id"] == "2401.12345"
|
||||
|
||||
def test_set_reading_status_invalid(self, db_session, sample_paper):
|
||||
result = set_reading_status(db_session, "2401.12345", "invalid_status")
|
||||
assert "error" in result
|
||||
assert result["error"] == "invalid_status"
|
||||
with pytest.raises(ValidationError):
|
||||
set_reading_status(db_session, "2401.12345", "invalid_status")
|
||||
|
||||
def test_update_existing_status(self, db_session, sample_paper):
|
||||
set_reading_status(db_session, "2401.12345", "skimmed")
|
||||
@@ -54,9 +55,8 @@ class TestReadingStatusService:
|
||||
assert result["status"] == "read_full"
|
||||
|
||||
def test_set_reading_status_not_found(self, db_session):
|
||||
result = set_reading_status(db_session, "nonexistent", "unread")
|
||||
assert "error" in result
|
||||
assert result["error"] == "not_found"
|
||||
with pytest.raises(NotFoundError):
|
||||
set_reading_status(db_session, "nonexistent", "unread")
|
||||
|
||||
def test_all_valid_statuses(self, db_session, sample_paper):
|
||||
for status in ("unread", "skimmed", "read_summary", "read_full"):
|
||||
@@ -93,9 +93,8 @@ class TestNoteService:
|
||||
assert result is None
|
||||
|
||||
def test_save_note_paper_not_found(self, db_session):
|
||||
result = save_note(db_session, "nonexistent", "内容")
|
||||
assert "error" in result
|
||||
assert result["error"] == "not_found"
|
||||
with pytest.raises(NotFoundError):
|
||||
save_note(db_session, "nonexistent", "内容")
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
@@ -143,12 +142,12 @@ class TestUserDataRoutes:
|
||||
assert data["status"] == "read_summary"
|
||||
|
||||
def test_reading_status_invalid(self, client, sample_paper):
|
||||
"""无效状态返回 422。"""
|
||||
"""无效状态返回 400 (ValidationError)。"""
|
||||
resp = client.post(
|
||||
"/api/reading-status/2401.12345",
|
||||
json={"status": "invalid"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_reading_status_not_found(self, client):
|
||||
"""不存在的论文返回 404。"""
|
||||
|
||||
Reference in New Issue
Block a user