feat: refactor summarizer and PDF extraction pipeline

- Split summarizer into summary_generator and summary_persister modules
- Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection
- Add layout_detector service for PicoDet-S_layout_3cls integration
- Add exceptions module with ConflictError and NotFoundError
- Improve admin dashboard with better statistics and task management
- Add design review document with system optimization suggestions
- Add new tests for crawler, pdf_downloader, pipeline, and summary_utils
- Update dependencies and configuration
- Clean up dead code and improve error handling
This commit is contained in:
2026-06-13 13:16:47 +08:00
parent e2f0e1a8be
commit 21f16e6756
43 changed files with 3304 additions and 1494 deletions
+7 -1
View File
@@ -161,7 +161,13 @@ def sample_summary_dict() -> dict:
"results": {
"main_findings": "在长文本基准 LongBench 上取得了 SOTA 结果,平均得分提升 3.2 个百分点。推理速度相比全注意力提升了 2 倍,显存占用降低 60%。在 32k 序列长度下仍保持与全注意力相当的生成质量。",
"benchmarks": [
{"task": "长文本摘要", "metric": "ROUGE-L", "this_work": "42.1", "baseline": "38.9", "improvement": "+3.2"},
{
"task": "长文本摘要",
"metric": "ROUGE-L",
"this_work": "42.1",
"baseline": "38.9",
"improvement": "+3.2",
},
],
"limitations": "在超长文本(>100k tokens)上效果有所下降,主要原因是全局采样点数量不足以覆盖所有关键信息。此外,在小规模数据集上的优势不如大规模数据集明显。",
},
+9 -13
View File
@@ -67,7 +67,7 @@ class TestAdminAuth:
def test_correct_session_accepted(self, auth_client):
"""已登录 session 应被接受(crawl 可能会失败但不是 303)。"""
with patch(
"app.routes.admin.crawl_daily", new_callable=AsyncMock
"app.routes.admin.run_crawl", new_callable=AsyncMock
) as mock_crawl:
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
resp = auth_client.post("/admin/crawl")
@@ -83,9 +83,7 @@ class TestAdminAuth:
def test_correct_session_batch_summarize(self, auth_client):
"""已登录调用 batch summarizemock 掉服务层。"""
with patch(
"app.routes.admin.summarize_batch", new_callable=AsyncMock
) as mock:
with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock:
mock.return_value = {
"status": "success",
"done": 0,
@@ -98,10 +96,12 @@ class TestAdminAuth:
def test_single_paper_not_found(self, auth_client):
"""单篇总结不存在的论文返回 404。"""
from app.exceptions import NotFoundError
with patch(
"app.routes.admin.summarize_single",
new_callable=AsyncMock,
return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"},
side_effect=NotFoundError("Paper not found: nonexistent.99999"),
):
resp = auth_client.post("/admin/summarize/nonexistent.99999")
assert resp.status_code == 404
@@ -118,7 +118,7 @@ class TestAdminCrawl:
def test_crawl_default_today(self, auth_client):
"""不指定日期时默认抓取今天。"""
with patch(
"app.routes.admin.crawl_daily", new_callable=AsyncMock
"app.routes.admin.run_crawl", new_callable=AsyncMock
) as mock_crawl:
mock_crawl.return_value = {"found": 5, "new": 3, "status": "success"}
resp = auth_client.post("/admin/crawl")
@@ -130,7 +130,7 @@ class TestAdminCrawl:
def test_crawl_specific_date(self, auth_client):
"""指定日期抓取。"""
with patch(
"app.routes.admin.crawl_daily", new_callable=AsyncMock
"app.routes.admin.run_crawl", new_callable=AsyncMock
) as mock_crawl:
mock_crawl.return_value = {"found": 2, "new": 1, "status": "success"}
resp = auth_client.post("/admin/crawl?date=2024-01-15")
@@ -194,9 +194,7 @@ class TestAdminDelete:
)
assert resp.status_code == 422
def test_delete_with_confirm(
self, auth_client, db_session, sample_papers_range
):
def test_delete_with_confirm(self, auth_client, db_session, sample_papers_range):
"""confirm='DELETE' 时应执行删除。"""
resp = auth_client.post(
"/admin/delete",
@@ -255,9 +253,7 @@ class TestAdminLogs:
resp = client.get("/admin/logs", follow_redirects=False)
assert resp.status_code == 303
def test_logs_contains_data(
self, auth_client, db_session, sample_papers_range
):
def test_logs_contains_data(self, auth_client, db_session, sample_papers_range):
"""日志页面应包含日志数据。"""
# 先创建一条日志
now = utc_now()
+189
View File
@@ -0,0 +1,189 @@
"""爬虫服务测试 — _parse_paper、fetch_daily、upsert_papers、crawl_daily。"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.crawler import (
_parse_paper,
crawl_daily,
fetch_daily,
upsert_papers,
)
# ═══════════════════════════════════════════════════════════════════════
# _parse_paper
# ═══════════════════════════════════════════════════════════════════════
class TestParsePaper:
def test_normal_item(self):
item = {
"paper": {
"id": "2401.12345",
"title": "Test Paper",
"abstract": "Abstract text",
"publishedAt": "2024-01-15T00:00:00",
"authors": [{"name": "Alice"}, {"name": "Bob"}],
"tags": [{"name": "NLP"}, {"name": "LLM"}],
"upvotes": 42,
}
}
result = _parse_paper(item)
assert result["arxiv_id"] == "2401.12345"
assert result["title_en"] == "Test Paper"
assert len(result["authors"]) == 2
assert result["authors"] == ["Alice", "Bob"]
assert result["tags"] == ["NLP", "LLM"]
assert result["upvotes"] == 42
assert "huggingface.co" in result["hf_url"]
def test_empty_id(self):
item = {"paper": {"id": "", "authors": [], "tags": []}}
result = _parse_paper(item)
assert result["arxiv_id"] == ""
assert result["hf_url"] == ""
def test_missing_published_at(self):
item = {"paper": {"id": "2401.00001", "title": "T", "authors": [], "tags": []}}
result = _parse_paper(item)
assert result["published_at"] is None
def test_flat_structure_fallback(self):
"""无 paper 包装时直接从顶层取字段。"""
item = {"id": "2401.99999", "title": "Flat", "authors": [], "tags": []}
result = _parse_paper(item)
assert result["arxiv_id"] == "2401.99999"
assert result["title_en"] == "Flat"
# ═══════════════════════════════════════════════════════════════════════
# fetch_daily
# ═══════════════════════════════════════════════════════════════════════
class TestFetchDaily:
@pytest.mark.asyncio
async def test_returns_papers(self, monkeypatch):
fake_data = [{"paper": {"id": "2401.00001"}}]
mock_resp = MagicMock()
mock_resp.json.return_value = fake_data
mock_resp.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("app.services.crawler.make_http_client", return_value=mock_client):
result = await fetch_daily("2024-01-15")
assert len(result) == 1
assert result[0]["paper"]["id"] == "2401.00001"
@pytest.mark.asyncio
async def test_respects_top_n(self, monkeypatch):
fake_data = [{"paper": {"id": f"2401.{i:05d}"}} for i in range(10)]
mock_resp = MagicMock()
mock_resp.json.return_value = fake_data
mock_resp.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("app.services.crawler.make_http_client", return_value=mock_client):
result = await fetch_daily("2024-01-15", top_n=3)
assert len(result) == 3
# ═══════════════════════════════════════════════════════════════════════
# upsert_papers
# ═══════════════════════════════════════════════════════════════════════
class TestUpsertPapers:
def test_inserts_new_paper(self, db_session):
papers_raw = [
{
"paper": {
"id": "2401.00001",
"title": "New Paper",
"abstract": "Abstract",
"authors": [{"name": "Alice"}],
"tags": [{"name": "CV"}],
"upvotes": 5,
}
}
]
new = upsert_papers(db_session, papers_raw, "2024-01-15")
assert len(new) == 1
assert new[0].arxiv_id == "2401.00001"
assert new[0].title_en == "New Paper"
def test_updates_existing_upvotes(self, db_session, sample_paper):
papers_raw = [
{
"paper": {
"id": sample_paper.arxiv_id,
"title": sample_paper.title_en,
"upvotes": 999,
"authors": [],
"tags": [],
}
}
]
new = upsert_papers(db_session, papers_raw, "2024-01-15")
assert len(new) == 0 # 不新增
db_session.refresh(sample_paper)
assert sample_paper.upvotes == 999
def test_skips_empty_id(self, db_session):
papers_raw = [{"paper": {"id": "", "title": "Nope", "authors": [], "tags": []}}]
new = upsert_papers(db_session, papers_raw, "2024-01-15")
assert len(new) == 0
# ═══════════════════════════════════════════════════════════════════════
# crawl_daily
# ═══════════════════════════════════════════════════════════════════════
class TestCrawlDaily:
@pytest.mark.asyncio
async def test_success_flow(self, db_session):
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
) as mock_fetch:
mock_fetch.return_value = [
{
"paper": {
"id": "2401.00001",
"title": "T",
"authors": [],
"tags": [],
"upvotes": 0,
}
}
]
result = await crawl_daily(db_session, "2024-01-15")
assert result["status"] == "success"
assert result["new"] == 1
assert result["found"] == 1
@pytest.mark.asyncio
async def test_failure_returns_failed(self, db_session):
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
side_effect=ConnectionError("network error"),
):
result = await crawl_daily(db_session, "2024-01-15")
assert result["status"] == "failed"
assert "network error" in result["error"]
+77
View File
@@ -0,0 +1,77 @@
"""PDF 下载测试 — download_pdf、路径工具、错误处理。"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from app.services.pdf_downloader import (
PdfDownloadError,
download_pdf,
paper_dir,
tmp_dir,
)
from app.utils import PAPERS_DIR, TMP_DIR
# ═══════════════════════════════════════════════════════════════════════
# 路径工具
# ═══════════════════════════════════════════════════════════════════════
class TestPathHelpers:
def test_paper_dir(self):
assert paper_dir("2401.12345") == PAPERS_DIR / "2401.12345"
def test_tmp_dir(self):
assert tmp_dir("2401.12345") == TMP_DIR / "2401.12345"
# ═══════════════════════════════════════════════════════════════════════
# download_pdf
# ═══════════════════════════════════════════════════════════════════════
class TestDownloadPdf:
@pytest.mark.asyncio
async def test_success_download(self, tmp_path):
mock_resp = MagicMock()
mock_resp.content = b"%PDF-1.4 fake"
mock_resp.raise_for_status = MagicMock()
mock_session = MagicMock()
mock_session.get.return_value = mock_resp
with (
patch("app.services.pdf_downloader.TMP_DIR", tmp_path),
patch(
"app.services.pdf_downloader._get_session", return_value=mock_session
),
):
result = await download_pdf(
"2401.12345", "https://arxiv.org/pdf/2401.12345.pdf"
)
assert result.exists()
assert result.name == "paper.pdf"
assert result.read_bytes() == b"%PDF-1.4 fake"
@pytest.mark.asyncio
async def test_empty_pdf_url_raises(self):
with pytest.raises(PdfDownloadError, match="no pdf_url"):
await download_pdf("2401.12345", "")
@pytest.mark.asyncio
async def test_http_failure_raises(self, tmp_path):
mock_session = MagicMock()
mock_session.get.side_effect = ConnectionError("refused")
with (
patch("app.services.pdf_downloader.TMP_DIR", tmp_path),
patch(
"app.services.pdf_downloader._get_session", return_value=mock_session
),
):
with pytest.raises(PdfDownloadError, match="failed to download"):
await download_pdf("2401.12345", "https://bad.url/pdf.pdf")
+77
View File
@@ -0,0 +1,77 @@
"""流水线编排测试 — run_pipeline (crawl → summarize → cleanup)。"""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.models import TaskLock
from app.services.pipeline import run_pipeline
from app.utils import utc_now
class TestRunPipeline:
@pytest.mark.asyncio
async def test_full_pipeline_success(self, db_session):
with (
patch(
"app.services.pipeline.crawl_daily", new_callable=AsyncMock
) as mock_crawl,
patch(
"app.services.pipeline.summarize_batch", new_callable=AsyncMock
) as mock_summ,
patch("app.services.pipeline.cleanup_tmp") as mock_clean,
):
mock_crawl.return_value = {"status": "success", "found": 5, "new": 2}
mock_summ.return_value = {"status": "success", "done": 2, "failed": 0}
mock_clean.return_value = {"removed": 0}
result = await run_pipeline(db_session, "2024-01-15", "test")
assert result["status"] == "success"
mock_crawl.assert_called_once()
mock_summ.assert_called_once()
mock_clean.assert_called_once()
@pytest.mark.asyncio
async def test_pipeline_lock_prevents_reentry(self, db_session):
"""已有 running 锁时抛出 RuntimeError。"""
now = utc_now()
db_session.add(
TaskLock(
task="scheduler",
lock_key="pipeline-2024-01-15",
status="running",
owner="other",
acquired_at=now,
)
)
db_session.commit()
with pytest.raises(RuntimeError, match="already running"):
await run_pipeline(db_session, "2024-01-15", "test")
@pytest.mark.asyncio
async def test_crawl_failure_still_runs_summarize_and_cleanup(self, db_session):
"""crawl 失败时 pipeline 继续执行 summarize 和 cleanup。"""
with (
patch(
"app.services.pipeline.crawl_daily", new_callable=AsyncMock
) as mock_crawl,
patch(
"app.services.pipeline.summarize_batch", new_callable=AsyncMock
) as mock_summ,
patch("app.services.pipeline.cleanup_tmp") as mock_clean,
):
mock_crawl.side_effect = ConnectionError("timeout")
mock_summ.return_value = {"status": "success", "done": 0}
mock_clean.return_value = {"removed": 0}
result = await run_pipeline(db_session, "2024-01-15", "test")
# pipeline 捕获异常,返回 failed
assert result["status"] == "failed"
assert "timeout" in result["error"]
# summarize 和 cleanup 不会被调用(exception 跳出 try 块)
mock_summ.assert_not_called()
+1 -1
View File
@@ -20,7 +20,7 @@ from app.services.schemas import (
classify_validation_error,
flatten_for_db,
)
from app.services.summarizer import _classify_error
from app.services.summary_generator import _classify_error
# ═══════════════════════════════════════════════════════════════════════
+35 -22
View File
@@ -23,12 +23,8 @@ from app.services.pdf_downloader import (
)
from app.services.pi_client import PiTimeoutError
from app.services.schemas import SummarySchema
from app.services.summarizer import (
_save_files,
_update_summary_in_db,
summarize_batch,
summarize_one,
)
from app.services.summarizer import summarize_batch, summarize_one
from app.services.summary_persister import _save_files, _update_summary_in_db
from app.utils import utc_now
@@ -39,7 +35,14 @@ from app.utils import utc_now
def _summarize_tmp_paths(tmp_path):
"""将 data 目录重定向到 tmp_path(供 summarizer 测试使用)。"""
with (
patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
patch(
"app.services.summary_persister.paper_dir",
lambda aid: tmp_path / "papers" / aid,
),
patch(
"app.services.summary_generator.paper_dir",
lambda aid: tmp_path / "papers" / aid,
),
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
@@ -134,7 +137,9 @@ class TestFileOperations:
def test_save_files(self, tmp_path, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
with patch(
"app.services.summary_persister.paper_dir", lambda aid: tmp_path / aid
):
_save_files("2401.12345", schema, "raw output text")
paper_dir = tmp_path / "2401.12345"
@@ -144,7 +149,9 @@ class TestFileOperations:
assert saved["title_zh"] == "测试论文中文标题"
def test_save_raw_output_only(self, tmp_path):
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
with patch(
"app.services.summary_persister.paper_dir", lambda aid: tmp_path / aid
):
_save_files("2401.12345", None, "raw output")
paper_dir = tmp_path / "2401.12345"
assert (paper_dir / "raw_output.txt").exists()
@@ -180,7 +187,7 @@ class TestSummarizeOneFlow:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=(mock_pi_output, "test-session-id"),
),
@@ -209,7 +216,9 @@ class TestSummarizeOneFlow:
assert fts_row[0] == "测试论文中文标题"
@pytest.mark.asyncio
async def test_pdf_download_failure(self, db_session, sample_paper, _summarize_tmp_paths):
async def test_pdf_download_failure(
self, db_session, sample_paper, _summarize_tmp_paths
):
"""PDF 下载失败 → error_type=pdf_download_failedtmp 被清理。"""
with (
patch(
@@ -233,7 +242,7 @@ class TestSummarizeOneFlow:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
side_effect=PiTimeoutError("timeout after 300s"),
),
@@ -250,7 +259,7 @@ class TestSummarizeOneFlow:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=("No JSON in this output at all.", "test-session-id"),
),
@@ -281,7 +290,7 @@ class TestSummarizeOneFlow:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=(bad_output, "test-session-id"),
),
@@ -300,7 +309,7 @@ class TestSummarizeOneFlow:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=("Some output without JSON", "test-session-id"),
),
@@ -319,7 +328,7 @@ class TestSummarizeOneFlow:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=(mock_pi_output, "test-session-id"),
),
@@ -347,7 +356,9 @@ class TestSummarizeOneFlow:
assert not tmp_paper.exists()
@pytest.mark.asyncio
async def test_skips_done_paper(self, db_session, sample_paper, _summarize_tmp_paths):
async def test_skips_done_paper(
self, db_session, sample_paper, _summarize_tmp_paths
):
"""已完成的论文跳过。"""
sample_paper.summary_status.status = "done"
db_session.commit()
@@ -393,7 +404,7 @@ class TestBatchSummarize:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=(mock_pi_output, "test-session-id"),
),
@@ -446,7 +457,7 @@ class TestBatchSummarize:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.call_pi", side_effect=_mock_call_pi),
patch("app.services.summary_generator.call_pi", side_effect=_mock_call_pi),
):
result = await summarize_batch(db_session, _session_factory=_TestSession)
@@ -456,6 +467,8 @@ class TestBatchSummarize:
@pytest.mark.asyncio
async def test_task_lock_conflict(self, db_session, _summarize_tmp_paths):
"""TaskLock 防止并发 batch。"""
from app.exceptions import ConflictError
# 先插入一个 running 锁
db_session.add(
TaskLock(
@@ -467,8 +480,8 @@ class TestBatchSummarize:
)
db_session.commit()
result = await summarize_batch(db_session)
assert result["status"] == "conflict"
with pytest.raises(ConflictError):
await summarize_batch(db_session)
@pytest.mark.asyncio
async def test_task_lock_released(
@@ -482,7 +495,7 @@ class TestBatchSummarize:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=(mock_pi_output, "test-session-id"),
),
+174
View File
@@ -0,0 +1,174 @@
"""summary_utils 测试 — PDF 文本提取、正文裁剪、JSON 提取、meta.json 写入、prompt 构建。"""
from __future__ import annotations
import json
from unittest.mock import MagicMock, patch
import pytest
from app.services.summary_utils import (
JsonNotFoundError,
_trim_body,
build_prompt,
extract_json,
extract_pdf_text,
write_meta_json,
)
# ═══════════════════════════════════════════════════════════════════════
# _trim_body 正文裁剪
# ═══════════════════════════════════════════════════════════════════════
class TestTrimBody:
def test_removes_references_section(self):
text = "Intro\n\nSome content here.\n\nReferences\n[1] Smith et al."
result = _trim_body(text)
assert "References" not in result
assert "Intro" in result
def test_removes_bibliography(self):
text = "Body\n\nBibliography\n[1] Smith"
result = _trim_body(text)
assert "Bibliography" not in result
def test_keeps_appendix_after_references(self):
text = "Body\n\nReferences\n[1] X\n\nAppendix\nExtra content"
result = _trim_body(text)
assert "Appendix" in result
assert "Extra content" in result
assert "References" not in result
def test_removes_acknowledgments(self):
text = "Body\n\nAcknowledgments\nThanks to everyone."
result = _trim_body(text)
assert "Acknowledgments" not in result
def test_max_chars_truncation(self):
text = "A" * 1000
result = _trim_body(text, max_chars=100)
assert len(result) <= 100
def test_no_truncation_when_none(self):
text = "A" * 500
result = _trim_body(text, max_chars=None)
assert len(result) == 500
# ═══════════════════════════════════════════════════════════════════════
# extract_pdf_text
# ═══════════════════════════════════════════════════════════════════════
class TestExtractPdfText:
def test_extracts_text_and_saves(self, tmp_path):
pdf_path = tmp_path / "test.pdf"
pdf_path.write_bytes(b"%PDF-fake")
mock_page = MagicMock()
mock_page.get_text.return_value = "Page 1 text"
mock_doc = MagicMock()
mock_doc.__iter__ = MagicMock(return_value=iter([mock_page]))
mock_doc.__enter__ = MagicMock(return_value=mock_doc)
mock_doc.__exit__ = MagicMock(return_value=False)
with (
patch("pymupdf.open", return_value=mock_doc),
patch(
"app.services.summary_utils._trim_body", side_effect=lambda t, **kw: t
),
):
result_path = extract_pdf_text(pdf_path)
assert result_path.suffix == ".txt"
assert result_path.exists()
assert "Page 1 text" in result_path.read_text()
def test_uses_cached_txt(self, tmp_path):
pdf_path = tmp_path / "test.pdf"
pdf_path.write_bytes(b"%PDF-fake")
txt_path = tmp_path / "test.txt"
txt_path.write_text("cached", encoding="utf-8")
with patch("pymupdf.open") as mock_open:
result = extract_pdf_text(pdf_path)
mock_open.assert_not_called()
assert result == txt_path
# ═══════════════════════════════════════════════════════════════════════
# write_meta_json
# ═══════════════════════════════════════════════════════════════════════
class TestWriteMetaJson:
def test_writes_meta_json(self, tmp_path, sample_paper):
with patch("app.services.pdf_downloader.paper_dir", lambda aid: tmp_path / aid):
result = write_meta_json(sample_paper)
assert result.exists()
assert result.name == "meta.json"
data = json.loads(result.read_text(encoding="utf-8"))
assert data["arxiv_id"] == "2401.12345"
assert data["title_en"] == "Test Paper Title"
# ═══════════════════════════════════════════════════════════════════════
# build_prompt
# ═══════════════════════════════════════════════════════════════════════
class TestBuildPrompt:
def test_inject_mode_contains_schema(self, tmp_path):
prompt = build_prompt(
"2401.12345", tmp_path / "meta", tmp_path / "txt", "inject"
)
assert "title_zh" in prompt
assert "必须包含以下字段" in prompt
def test_search_mode_contains_read_instruction(self, tmp_path):
prompt = build_prompt(
"2401.12345", tmp_path / "meta", tmp_path / "txt", "search"
)
assert "read" in prompt.lower()
assert "title_zh" in prompt
def test_fix_errors_mode(self, tmp_path):
prompt = build_prompt(
"2401.12345",
tmp_path / "meta",
tmp_path / "txt",
"inject",
fix_errors=["字段缺失"],
)
assert "字段缺失" in prompt
assert "修正" in prompt
# ═══════════════════════════════════════════════════════════════════════
# extract_json
# ═══════════════════════════════════════════════════════════════════════
class TestExtractJson:
def test_direct_json(self, sample_summary_json):
result = extract_json(sample_summary_json)
assert result["title_zh"] == "测试论文中文标题"
def test_fenced_code_block(self, sample_summary_json):
raw = f"some text\n```json\n{sample_summary_json}\n```\nmore text"
result = extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_brace_matching_fallback(self, sample_summary_dict):
json_str = json.dumps(sample_summary_dict, ensure_ascii=False)
raw = f"Here is the result: {json_str} end."
result = extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_no_json_raises(self):
with pytest.raises(JsonNotFoundError):
extract_json("plain text no json here at all")
+13 -14
View File
@@ -2,6 +2,9 @@
from __future__ import annotations
import pytest
from app.exceptions import NotFoundError, ValidationError
from app.services.user_data import (
get_note,
save_note,
@@ -27,9 +30,8 @@ class TestBookmarkService:
assert result["bookmarked"] is False
def test_toggle_bookmark_not_found(self, db_session):
result = toggle_bookmark(db_session, "nonexistent")
assert "error" in result
assert result["error"] == "not_found"
with pytest.raises(NotFoundError):
toggle_bookmark(db_session, "nonexistent")
# ═══════════════════════════════════════════════════════════════════════
@@ -44,9 +46,8 @@ class TestReadingStatusService:
assert result["arxiv_id"] == "2401.12345"
def test_set_reading_status_invalid(self, db_session, sample_paper):
result = set_reading_status(db_session, "2401.12345", "invalid_status")
assert "error" in result
assert result["error"] == "invalid_status"
with pytest.raises(ValidationError):
set_reading_status(db_session, "2401.12345", "invalid_status")
def test_update_existing_status(self, db_session, sample_paper):
set_reading_status(db_session, "2401.12345", "skimmed")
@@ -54,9 +55,8 @@ class TestReadingStatusService:
assert result["status"] == "read_full"
def test_set_reading_status_not_found(self, db_session):
result = set_reading_status(db_session, "nonexistent", "unread")
assert "error" in result
assert result["error"] == "not_found"
with pytest.raises(NotFoundError):
set_reading_status(db_session, "nonexistent", "unread")
def test_all_valid_statuses(self, db_session, sample_paper):
for status in ("unread", "skimmed", "read_summary", "read_full"):
@@ -93,9 +93,8 @@ class TestNoteService:
assert result is None
def test_save_note_paper_not_found(self, db_session):
result = save_note(db_session, "nonexistent", "内容")
assert "error" in result
assert result["error"] == "not_found"
with pytest.raises(NotFoundError):
save_note(db_session, "nonexistent", "内容")
# ═══════════════════════════════════════════════════════════════════════
@@ -143,12 +142,12 @@ class TestUserDataRoutes:
assert data["status"] == "read_summary"
def test_reading_status_invalid(self, client, sample_paper):
"""无效状态返回 422"""
"""无效状态返回 400 (ValidationError)"""
resp = client.post(
"/api/reading-status/2401.12345",
json={"status": "invalid"},
)
assert resp.status_code == 422
assert resp.status_code == 400
def test_reading_status_not_found(self, client):
"""不存在的论文返回 404。"""