21f16e6756
- Split summarizer into summary_generator and summary_persister modules - Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection - Add layout_detector service for PicoDet-S_layout_3cls integration - Add exceptions module with ConflictError and NotFoundError - Improve admin dashboard with better statistics and task management - Add design review document with system optimization suggestions - Add new tests for crawler, pdf_downloader, pipeline, and summary_utils - Update dependencies and configuration - Clean up dead code and improve error handling
523 lines
20 KiB
Python
523 lines
20 KiB
Python
"""AI 总结服务测试 — summarize_one 状态流转、批量处理、DB 更新、文件操作。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from datetime import date
|
||
from unittest.mock import AsyncMock, patch
|
||
|
||
import pytest
|
||
from sqlalchemy import text
|
||
|
||
from app.models import (
|
||
CrawlLog,
|
||
Paper,
|
||
PaperSummary,
|
||
PaperTag,
|
||
SummaryStatus,
|
||
TaskLock,
|
||
)
|
||
from app.services.pdf_downloader import (
|
||
PdfDownloadError,
|
||
cleanup_tmp as _cleanup_tmp,
|
||
)
|
||
from app.services.pi_client import PiTimeoutError
|
||
from app.services.schemas import SummarySchema
|
||
from app.services.summarizer import summarize_batch, summarize_one
|
||
from app.services.summary_persister import _save_files, _update_summary_in_db
|
||
from app.utils import utc_now
|
||
|
||
|
||
# ── 共享 fixture ──────────────────────────────────────────────────────────
|
||
|
||
|
||
@pytest.fixture
|
||
def _summarize_tmp_paths(tmp_path):
|
||
"""将 data 目录重定向到 tmp_path(供 summarizer 测试使用)。"""
|
||
with (
|
||
patch(
|
||
"app.services.summary_persister.paper_dir",
|
||
lambda aid: tmp_path / "papers" / aid,
|
||
),
|
||
patch(
|
||
"app.services.summary_generator.paper_dir",
|
||
lambda aid: tmp_path / "papers" / aid,
|
||
),
|
||
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
|
||
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
|
||
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
|
||
patch("app.utils.TMP_DIR", tmp_path / "tmp"),
|
||
):
|
||
yield
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# DB 更新测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestDbUpdate:
|
||
"""_update_summary_in_db 验证。"""
|
||
|
||
def test_summary_written(self, db_session, sample_paper, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||
|
||
summary = db_session.get(PaperSummary, sample_paper.id)
|
||
assert summary is not None
|
||
assert summary.one_line == schema.one_line
|
||
assert summary.motivation_problem == schema.motivation.problem
|
||
assert json.loads(summary.full_json)["title_zh"] == schema.title_zh
|
||
|
||
def test_paper_title_zh_updated(
|
||
self, db_session, sample_paper, sample_summary_dict
|
||
):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||
|
||
db_session.refresh(sample_paper)
|
||
assert sample_paper.title_zh == "测试论文中文标题"
|
||
assert sample_paper.summary_quality == "normal"
|
||
|
||
def test_fts_updated(self, db_session, sample_paper, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||
|
||
row = db_session.execute(
|
||
text("SELECT title_zh, summary_text FROM papers_fts WHERE rowid = :id"),
|
||
{"id": sample_paper.id},
|
||
).fetchone()
|
||
assert row is not None
|
||
assert row[0] == "测试论文中文标题"
|
||
assert schema.one_line in row[1]
|
||
|
||
def test_ai_tags_added(self, db_session, sample_paper, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||
|
||
tags = (
|
||
db_session.query(PaperTag)
|
||
.filter(PaperTag.paper_id == sample_paper.id, PaperTag.source == "ai")
|
||
.all()
|
||
)
|
||
tag_names = {t.tag for t in tags}
|
||
# AI tags 来自 schema.tags
|
||
assert "自然语言处理" in tag_names
|
||
assert "大语言模型" in tag_names
|
||
|
||
def test_existing_tags_not_duplicated(
|
||
self, db_session, sample_paper, sample_summary_dict
|
||
):
|
||
"""已存在的标签名(同 name)不会被 AI source 重复插入。"""
|
||
# sample_paper 已有 NLP (hf)、LLM (hf)
|
||
# 让 AI 输出包含 NLP(与 HF 重复)和 "新标签"(新的)
|
||
sample_summary_dict["tags"] = ["NLP", "新标签"]
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||
|
||
all_tags = (
|
||
db_session.query(PaperTag)
|
||
.filter(PaperTag.paper_id == sample_paper.id)
|
||
.all()
|
||
)
|
||
tag_names = [t.tag for t in all_tags]
|
||
# NLP 只出现一次(HF 原有的),AI 不会重复加
|
||
assert tag_names.count("NLP") == 1
|
||
# "新标签" 是 AI 新加的
|
||
assert "新标签" in tag_names
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 文件操作测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestFileOperations:
|
||
"""文件保存和清理。"""
|
||
|
||
def test_save_files(self, tmp_path, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
with patch(
|
||
"app.services.summary_persister.paper_dir", lambda aid: tmp_path / aid
|
||
):
|
||
_save_files("2401.12345", schema, "raw output text")
|
||
|
||
paper_dir = tmp_path / "2401.12345"
|
||
assert (paper_dir / "summary.json").exists()
|
||
assert (paper_dir / "raw_output.txt").exists()
|
||
saved = json.loads((paper_dir / "summary.json").read_text())
|
||
assert saved["title_zh"] == "测试论文中文标题"
|
||
|
||
def test_save_raw_output_only(self, tmp_path):
|
||
with patch(
|
||
"app.services.summary_persister.paper_dir", lambda aid: tmp_path / aid
|
||
):
|
||
_save_files("2401.12345", None, "raw output")
|
||
paper_dir = tmp_path / "2401.12345"
|
||
assert (paper_dir / "raw_output.txt").exists()
|
||
assert not (paper_dir / "summary.json").exists()
|
||
|
||
def test_cleanup_tmp(self, tmp_path):
|
||
tmp_paper = tmp_path / "2401.12345"
|
||
tmp_paper.mkdir()
|
||
(tmp_paper / "paper.pdf").write_bytes(b"%PDF-fake")
|
||
with patch("app.services.pdf_downloader.TMP_DIR", tmp_path):
|
||
_cleanup_tmp("2401.12345")
|
||
assert not tmp_paper.exists()
|
||
|
||
def test_cleanup_tmp_nonexistent(self, tmp_path):
|
||
"""清理不存在的目录不报错。"""
|
||
with patch("app.services.pdf_downloader.TMP_DIR", tmp_path):
|
||
_cleanup_tmp("nonexistent") # 不抛异常
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 全流程状态流转测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestSummarizeOneFlow:
|
||
"""summarize_one 的状态流转(mock pi 和 PDF)。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_full_success_path(
|
||
self, db_session, sample_paper, mock_pi_output, _summarize_tmp_paths
|
||
):
|
||
"""pending → processing → done 全流程。"""
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summary_generator.call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value=(mock_pi_output, "test-session-id"),
|
||
),
|
||
):
|
||
result = await summarize_one(db_session, sample_paper)
|
||
|
||
assert result["status"] == "done"
|
||
assert result["quality"] == "normal"
|
||
|
||
# 验证 DB 状态
|
||
db_session.refresh(sample_paper)
|
||
assert sample_paper.summary_status.status == "done"
|
||
assert sample_paper.summary_status.quality == "normal"
|
||
assert sample_paper.title_zh == "测试论文中文标题"
|
||
|
||
# 验证 summary 已写入
|
||
summary = db_session.get(PaperSummary, sample_paper.id)
|
||
assert summary is not None
|
||
assert summary.one_line
|
||
|
||
# 验证 FTS 已更新
|
||
fts_row = db_session.execute(
|
||
text("SELECT title_zh FROM papers_fts WHERE rowid = :id"),
|
||
{"id": sample_paper.id},
|
||
).fetchone()
|
||
assert fts_row[0] == "测试论文中文标题"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pdf_download_failure(
|
||
self, db_session, sample_paper, _summarize_tmp_paths
|
||
):
|
||
"""PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。"""
|
||
with (
|
||
patch(
|
||
"app.services.summarizer.download_pdf",
|
||
new_callable=AsyncMock,
|
||
side_effect=PdfDownloadError("network error"),
|
||
),
|
||
):
|
||
result = await summarize_one(db_session, sample_paper)
|
||
|
||
assert result["status"] == "failed"
|
||
assert result["error_type"] == "pdf_download_failed"
|
||
|
||
db_session.refresh(sample_paper)
|
||
status = sample_paper.summary_status
|
||
assert status.error_type == "pdf_download_failed"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pi_timeout(self, db_session, sample_paper, _summarize_tmp_paths):
|
||
"""pi 超时 → timeout 错误,retry_count 递增。"""
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summary_generator.call_pi",
|
||
new_callable=AsyncMock,
|
||
side_effect=PiTimeoutError("timeout after 300s"),
|
||
),
|
||
):
|
||
result = await summarize_one(db_session, sample_paper)
|
||
|
||
assert result["status"] == "failed"
|
||
assert result["error_type"] == "timeout"
|
||
assert result["retry_count"] == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_json_not_found(self, db_session, sample_paper, _summarize_tmp_paths):
|
||
"""pi 输出无 JSON → 验证循环重试 4 次后 ValueError (unknown)。"""
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summary_generator.call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value=("No JSON in this output at all.", "test-session-id"),
|
||
),
|
||
):
|
||
result = await summarize_one(db_session, sample_paper)
|
||
|
||
assert result["status"] == "failed"
|
||
assert result["error_type"] == "unknown"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_validation_fails_and_retries(
|
||
self, db_session, sample_paper, _summarize_tmp_paths
|
||
):
|
||
"""验证失败(字段不符合要求)→ 重试多次后失败。"""
|
||
bad_json = json.dumps(
|
||
{
|
||
"arxiv_id": sample_paper.arxiv_id,
|
||
"title_zh": "", # 空的必填字段
|
||
"one_line": "valid line",
|
||
"tags": ["tag1"],
|
||
"motivation": {"problem": "valid problem"},
|
||
"method": {"key_idea": "valid idea"},
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
bad_output = f"```json\n{bad_json}\n```"
|
||
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summary_generator.call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value=(bad_output, "test-session-id"),
|
||
),
|
||
):
|
||
# _validate_summary 先拦截,4 轮都失败后 ValueError → unknown
|
||
result = await summarize_one(db_session, sample_paper)
|
||
assert result["status"] == "failed"
|
||
assert result["error_type"] == "unknown"
|
||
assert result["retry_count"] == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_raw_output_saved_on_failure(
|
||
self, db_session, sample_paper, tmp_path, _summarize_tmp_paths
|
||
):
|
||
"""失败时仍保存 raw_output.txt。"""
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summary_generator.call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value=("Some output without JSON", "test-session-id"),
|
||
),
|
||
):
|
||
await summarize_one(db_session, sample_paper)
|
||
|
||
raw_file = tmp_path / "papers" / sample_paper.arxiv_id / "raw_output.txt"
|
||
assert raw_file.exists()
|
||
assert "Some output without JSON" in raw_file.read_text()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_tmp_cleaned_on_success(
|
||
self, db_session, sample_paper, mock_pi_output, tmp_path, _summarize_tmp_paths
|
||
):
|
||
"""成功后清理 tmp 目录。"""
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summary_generator.call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value=(mock_pi_output, "test-session-id"),
|
||
),
|
||
):
|
||
await summarize_one(db_session, sample_paper)
|
||
|
||
tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id
|
||
assert not tmp_paper.exists()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_tmp_cleaned_on_failure(
|
||
self, db_session, sample_paper, tmp_path, _summarize_tmp_paths
|
||
):
|
||
"""失败后也清理 tmp 目录。"""
|
||
with (
|
||
patch(
|
||
"app.services.summarizer.download_pdf",
|
||
new_callable=AsyncMock,
|
||
side_effect=PdfDownloadError("fail"),
|
||
),
|
||
):
|
||
await summarize_one(db_session, sample_paper)
|
||
|
||
tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id
|
||
assert not tmp_paper.exists()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_skips_done_paper(
|
||
self, db_session, sample_paper, _summarize_tmp_paths
|
||
):
|
||
"""已完成的论文跳过。"""
|
||
sample_paper.summary_status.status = "done"
|
||
db_session.commit()
|
||
|
||
result = await summarize_one(db_session, sample_paper)
|
||
assert result["status"] == "skipped"
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 批量操作测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestBatchSummarize:
|
||
"""批量总结测试。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_batch_multiple_papers(
|
||
self, db_session, db_engine, mock_pi_output, _summarize_tmp_paths
|
||
):
|
||
"""批量处理多篇论文。"""
|
||
now = utc_now()
|
||
for i in range(3):
|
||
p = Paper(
|
||
arxiv_id=f"2401.1234{i}",
|
||
title_en=f"Test Paper {i}",
|
||
abstract=f"Abstract {i}",
|
||
paper_date=date(2024, 1, 15),
|
||
crawled_at=now,
|
||
pdf_url=f"https://arxiv.org/pdf/2401.1234{i}.pdf",
|
||
)
|
||
db_session.add(p)
|
||
db_session.flush()
|
||
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
|
||
|
||
db_session.commit()
|
||
|
||
# 每个 worker 用独立 session(同一个内存引擎)
|
||
from sqlalchemy.orm import sessionmaker as _sm
|
||
|
||
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summary_generator.call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value=(mock_pi_output, "test-session-id"),
|
||
),
|
||
):
|
||
result = await summarize_batch(db_session, _session_factory=_TestSession)
|
||
|
||
assert result["status"] == "success"
|
||
assert result["done"] == 3
|
||
assert result["failed"] == 0
|
||
|
||
# 验证 CrawlLog
|
||
log = db_session.query(CrawlLog).filter(CrawlLog.task == "summarize").first()
|
||
assert log is not None
|
||
assert log.status == "success"
|
||
assert log.papers_found == 3
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_single_failure_no_block(
|
||
self, db_session, db_engine, mock_pi_output, _summarize_tmp_paths
|
||
):
|
||
"""一篇失败不阻塞其他。"""
|
||
now = utc_now()
|
||
for i in range(2):
|
||
p = Paper(
|
||
arxiv_id=f"2401.5678{i}",
|
||
title_en=f"Paper {i}",
|
||
abstract=f"Abstract {i}",
|
||
paper_date=date(2024, 1, 15),
|
||
crawled_at=now,
|
||
pdf_url=f"https://arxiv.org/pdf/2401.5678{i}.pdf",
|
||
)
|
||
db_session.add(p)
|
||
db_session.flush()
|
||
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
|
||
|
||
db_session.commit()
|
||
|
||
from sqlalchemy.orm import sessionmaker as _sm
|
||
|
||
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||
|
||
call_count = 0
|
||
|
||
async def _mock_call_pi(meta_path, pdf_path, **kwargs):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
if call_count == 1:
|
||
raise PiTimeoutError("timeout")
|
||
return mock_pi_output, "test-session-id"
|
||
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch("app.services.summary_generator.call_pi", side_effect=_mock_call_pi),
|
||
):
|
||
result = await summarize_batch(db_session, _session_factory=_TestSession)
|
||
|
||
assert result["done"] == 1
|
||
assert result["failed"] == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_task_lock_conflict(self, db_session, _summarize_tmp_paths):
|
||
"""TaskLock 防止并发 batch。"""
|
||
from app.exceptions import ConflictError
|
||
|
||
# 先插入一个 running 锁
|
||
db_session.add(
|
||
TaskLock(
|
||
task="summarize",
|
||
lock_key="batch",
|
||
status="running",
|
||
acquired_at=utc_now(),
|
||
)
|
||
)
|
||
db_session.commit()
|
||
|
||
with pytest.raises(ConflictError):
|
||
await summarize_batch(db_session)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_task_lock_released(
|
||
self, db_session, db_engine, mock_pi_output, _summarize_tmp_paths
|
||
):
|
||
"""完成后释放 TaskLock。"""
|
||
from sqlalchemy.orm import sessionmaker as _sm
|
||
|
||
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summary_generator.call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value=(mock_pi_output, "test-session-id"),
|
||
),
|
||
):
|
||
await summarize_batch(db_session, _session_factory=_TestSession)
|
||
|
||
locks = (
|
||
db_session.query(TaskLock)
|
||
.filter(
|
||
TaskLock.task == "summarize",
|
||
TaskLock.lock_key == "batch",
|
||
)
|
||
.all()
|
||
)
|
||
for lock in locks:
|
||
assert lock.status == "finished"
|
||
assert lock.released_at is not None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_batch_empty(self, db_session, _summarize_tmp_paths):
|
||
"""无 pending 论文时返回空结果。"""
|
||
result = await summarize_batch(db_session)
|
||
assert result["status"] == "success"
|
||
assert result["total"] == 0
|