Files
daily-paper/tests/test_summarizer.py
T
Rain-Bus f7f1a4c0cb refactor: split monolithic phase tests into per-module test files
- rename test_admin_phase4.py -> test_admin.py, test_search.py -> test_searcher.py
- split test_phase5.py into test_cleaner, test_embedder, test_image_extractor, test_pages
- move schema tests from test_summarizer.py into dedicated test_schemas.py
- add sample_papers_range and sample_papers_with_summary fixtures in conftest
- update .gitignore to exclude all of data/
2026-06-06 00:34:30 +08:00

502 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI 总结服务测试 — summarize_one 状态流转、批量处理、DB 更新、文件操作。"""
from __future__ import annotations
import json
from datetime import date, datetime, timezone
from pathlib import Path
from unittest.mock import AsyncMock, patch
import pytest
from sqlalchemy import text
from app.models import (
CrawlLog,
Paper,
PaperSummary,
PaperTag,
SummaryStatus,
TaskLock,
)
from app.services.pdf_downloader import (
PdfDownloadError,
cleanup_tmp as _cleanup_tmp,
)
from app.services.pi_client import PiTimeoutError
from app.services.schemas import SummarySchema
from app.services.summarizer import (
_save_files,
_save_raw_output_only,
_update_summary_in_db,
summarize_batch,
summarize_one,
)
# ═══════════════════════════════════════════════════════════════════════
# DB 更新测试
# ═══════════════════════════════════════════════════════════════════════
class TestDbUpdate:
"""_update_summary_in_db 验证。"""
def test_summary_written(self, db_session, sample_paper, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
summary = db_session.get(PaperSummary, sample_paper.id)
assert summary is not None
assert summary.one_line == schema.one_line
assert summary.motivation_problem == schema.motivation.problem
assert json.loads(summary.full_json)["title_zh"] == schema.title_zh
def test_paper_title_zh_updated(self, db_session, sample_paper, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
db_session.refresh(sample_paper)
assert sample_paper.title_zh == "测试论文中文标题"
assert sample_paper.summary_quality == "normal"
def test_fts_updated(self, db_session, sample_paper, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
row = db_session.execute(
text("SELECT title_zh, summary_text FROM papers_fts WHERE rowid = :id"),
{"id": sample_paper.id},
).fetchone()
assert row is not None
assert row[0] == "测试论文中文标题"
assert schema.one_line in row[1]
def test_ai_tags_added(self, db_session, sample_paper, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
tags = (
db_session.query(PaperTag)
.filter(PaperTag.paper_id == sample_paper.id, PaperTag.source == "ai")
.all()
)
tag_names = {t.tag for t in tags}
# AI tags 来自 schema.tags
assert "自然语言处理" in tag_names
assert "大语言模型" in tag_names
def test_existing_tags_not_duplicated(self, db_session, sample_paper, sample_summary_dict):
"""已存在的标签名(同 name)不会被 AI source 重复插入。"""
# sample_paper 已有 NLP (hf)、LLM (hf)
# 让 AI 输出包含 NLP(与 HF 重复)和 "新标签"(新的)
sample_summary_dict["tags"] = ["NLP", "新标签"]
schema = SummarySchema.model_validate(sample_summary_dict)
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
all_tags = (
db_session.query(PaperTag)
.filter(PaperTag.paper_id == sample_paper.id)
.all()
)
tag_names = [t.tag for t in all_tags]
# NLP 只出现一次(HF 原有的),AI 不会重复加
assert tag_names.count("NLP") == 1
# "新标签" 是 AI 新加的
assert "新标签" in tag_names
# ═══════════════════════════════════════════════════════════════════════
# 文件操作测试
# ═══════════════════════════════════════════════════════════════════════
class TestFileOperations:
"""文件保存和清理。"""
def test_save_files(self, tmp_path, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
_save_files("2401.12345", schema, "raw output text")
paper_dir = tmp_path / "2401.12345"
assert (paper_dir / "summary.json").exists()
assert (paper_dir / "raw_output.txt").exists()
saved = json.loads((paper_dir / "summary.json").read_text())
assert saved["title_zh"] == "测试论文中文标题"
def test_save_raw_output_only(self, tmp_path):
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
_save_raw_output_only("2401.12345", "raw output")
paper_dir = tmp_path / "2401.12345"
assert (paper_dir / "raw_output.txt").exists()
assert not (paper_dir / "summary.json").exists()
def test_cleanup_tmp(self, tmp_path):
tmp_paper = tmp_path / "2401.12345"
tmp_paper.mkdir()
(tmp_paper / "paper.pdf").write_bytes(b"%PDF-fake")
with patch("app.services.pdf_downloader.TMP_DIR", tmp_path):
_cleanup_tmp("2401.12345")
assert not tmp_paper.exists()
def test_cleanup_tmp_nonexistent(self, tmp_path):
"""清理不存在的目录不报错。"""
with patch("app.services.pdf_downloader.TMP_DIR", tmp_path):
_cleanup_tmp("nonexistent") # 不抛异常
# ═══════════════════════════════════════════════════════════════════════
# 全流程状态流转测试
# ═══════════════════════════════════════════════════════════════════════
class TestSummarizeOneFlow:
"""summarize_one 的状态流转(mock pi 和 PDF)。"""
@pytest.fixture
def _patch_paths(self, tmp_path):
"""将 data 目录重定向到 tmp_path。"""
with (
patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
patch("app.utils.TMP_DIR", tmp_path / "tmp"),
):
yield
@pytest.mark.asyncio
async def test_full_success_path(
self, db_session, sample_paper, mock_pi_output, _patch_paths
):
"""pending → processing → done 全流程。"""
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
result = await summarize_one(db_session, sample_paper)
assert result["status"] == "done"
assert result["quality"] == "normal"
# 验证 DB 状态
db_session.refresh(sample_paper)
assert sample_paper.summary_status.status == "done"
assert sample_paper.summary_status.quality == "normal"
assert sample_paper.title_zh == "测试论文中文标题"
# 验证 summary 已写入
summary = db_session.get(PaperSummary, sample_paper.id)
assert summary is not None
assert summary.one_line
# 验证 FTS 已更新
fts_row = db_session.execute(
text("SELECT title_zh FROM papers_fts WHERE rowid = :id"),
{"id": sample_paper.id},
).fetchone()
assert fts_row[0] == "测试论文中文标题"
@pytest.mark.asyncio
async def test_pdf_download_failure(
self, db_session, sample_paper, _patch_paths
):
"""PDF 下载失败 → error_type=pdf_download_failedtmp 被清理。"""
with (
patch(
"app.services.summarizer.download_pdf",
new_callable=AsyncMock,
side_effect=PdfDownloadError("network error"),
),
):
result = await summarize_one(db_session, sample_paper)
assert result["status"] == "failed"
assert result["error_type"] == "pdf_download_failed"
db_session.refresh(sample_paper)
status = sample_paper.summary_status
assert status.error_type == "pdf_download_failed"
@pytest.mark.asyncio
async def test_pi_timeout(self, db_session, sample_paper, _patch_paths):
"""pi 超时 → timeout 错误,retry_count 递增。"""
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
new_callable=AsyncMock,
side_effect=PiTimeoutError("timeout after 300s"),
),
):
result = await summarize_one(db_session, sample_paper)
assert result["status"] == "failed"
assert result["error_type"] == "timeout"
assert result["retry_count"] == 1
@pytest.mark.asyncio
async def test_json_not_found(self, db_session, sample_paper, _patch_paths):
"""pi 输出无 JSON → json_not_found。"""
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value="No JSON in this output at all.",
),
):
result = await summarize_one(db_session, sample_paper)
assert result["status"] == "failed"
assert result["error_type"] == "json_not_found"
@pytest.mark.asyncio
async def test_field_missing_and_retry(
self, db_session, sample_paper, _patch_paths
):
"""必填字段缺失 → field_missing → retry → permanent_failure。"""
bad_json = json.dumps({
"title_zh": "", # 空的必填字段
"one_line": "valid line",
"tags": ["tag1"],
"motivation": {"problem": "valid problem"},
"method": {"key_idea": "valid idea"},
}, ensure_ascii=False)
bad_output = f"```json\n{bad_json}\n```"
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value=bad_output,
),
):
# 第一次失败 → pending (retry)
result1 = await summarize_one(db_session, sample_paper)
assert result1["status"] == "failed"
assert result1["error_type"] == "field_missing"
assert result1["retry_count"] == 1
# 第二次失败 → permanent_failure (SUMMARY_MAX_RETRIES=1, 所以 2 次 > 1+1)
db_session.refresh(sample_paper)
result2 = await summarize_one(db_session, sample_paper)
assert result2["status"] == "failed"
assert result2["retry_count"] == 2
db_session.refresh(sample_paper)
assert sample_paper.summary_status.status == "permanent_failure"
@pytest.mark.asyncio
async def test_raw_output_saved_on_failure(
self, db_session, sample_paper, tmp_path, _patch_paths
):
"""失败时仍保存 raw_output.txt。"""
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value="Some output without JSON",
),
):
await summarize_one(db_session, sample_paper)
raw_file = tmp_path / "papers" / sample_paper.arxiv_id / "raw_output.txt"
assert raw_file.exists()
assert "Some output without JSON" in raw_file.read_text()
@pytest.mark.asyncio
async def test_tmp_cleaned_on_success(
self, db_session, sample_paper, mock_pi_output, tmp_path, _patch_paths
):
"""成功后清理 tmp 目录。"""
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
await summarize_one(db_session, sample_paper)
tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id
assert not tmp_paper.exists()
@pytest.mark.asyncio
async def test_tmp_cleaned_on_failure(
self, db_session, sample_paper, tmp_path, _patch_paths
):
"""失败后也清理 tmp 目录。"""
with (
patch(
"app.services.summarizer.download_pdf",
new_callable=AsyncMock,
side_effect=PdfDownloadError("fail"),
),
):
await summarize_one(db_session, sample_paper)
tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id
assert not tmp_paper.exists()
@pytest.mark.asyncio
async def test_skips_done_paper(self, db_session, sample_paper, _patch_paths):
"""已完成的论文跳过。"""
sample_paper.summary_status.status = "done"
db_session.commit()
result = await summarize_one(db_session, sample_paper)
assert result["status"] == "skipped"
# ═══════════════════════════════════════════════════════════════════════
# 批量操作测试
# ═══════════════════════════════════════════════════════════════════════
class TestBatchSummarize:
"""批量总结测试。"""
@pytest.fixture
def _patch_paths(self, tmp_path):
with (
patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
patch("app.utils.TMP_DIR", tmp_path / "tmp"),
):
yield
@pytest.mark.asyncio
async def test_batch_multiple_papers(
self, db_session, db_engine, mock_pi_output, _patch_paths
):
"""批量处理多篇论文。"""
now = datetime.now(timezone.utc)
for i in range(3):
p = Paper(
arxiv_id=f"2401.1234{i}",
title_en=f"Test Paper {i}",
abstract=f"Abstract {i}",
paper_date=date(2024, 1, 15),
crawled_at=now,
pdf_url=f"https://arxiv.org/pdf/2401.1234{i}.pdf",
)
db_session.add(p)
db_session.flush()
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
db_session.commit()
# 每个 worker 用独立 session(同一个内存引擎)
from sqlalchemy.orm import sessionmaker as _sm
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
result = await summarize_batch(
db_session, _session_factory=_TestSession
)
assert result["status"] == "success"
assert result["done"] == 3
assert result["failed"] == 0
# 验证 CrawlLog
log = db_session.query(CrawlLog).filter(CrawlLog.task == "summarize").first()
assert log is not None
assert log.status == "success"
assert log.papers_found == 3
@pytest.mark.asyncio
async def test_single_failure_no_block(
self, db_session, db_engine, mock_pi_output, _patch_paths
):
"""一篇失败不阻塞其他。"""
now = datetime.now(timezone.utc)
for i in range(2):
p = Paper(
arxiv_id=f"2401.5678{i}",
title_en=f"Paper {i}",
abstract=f"Abstract {i}",
paper_date=date(2024, 1, 15),
crawled_at=now,
pdf_url=f"https://arxiv.org/pdf/2401.5678{i}.pdf",
)
db_session.add(p)
db_session.flush()
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
db_session.commit()
from sqlalchemy.orm import sessionmaker as _sm
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
call_count = 0
async def _mock_call_pi(meta_path, pdf_path):
nonlocal call_count
call_count += 1
if call_count == 1:
raise PiTimeoutError("timeout")
return mock_pi_output
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.call_pi", side_effect=_mock_call_pi),
):
result = await summarize_batch(
db_session, _session_factory=_TestSession
)
assert result["done"] == 1
assert result["failed"] == 1
@pytest.mark.asyncio
async def test_task_lock_conflict(self, db_session, _patch_paths):
"""TaskLock 防止并发 batch。"""
# 先插入一个 running 锁
db_session.add(
TaskLock(
task="summarize",
lock_key="batch",
status="running",
acquired_at=datetime.now(timezone.utc),
)
)
db_session.commit()
result = await summarize_batch(db_session)
assert result["status"] == "conflict"
@pytest.mark.asyncio
async def test_task_lock_released(self, db_session, db_engine, mock_pi_output, _patch_paths):
"""完成后释放 TaskLock。"""
from sqlalchemy.orm import sessionmaker as _sm
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
await summarize_batch(
db_session, _session_factory=_TestSession
)
locks = db_session.query(TaskLock).filter(
TaskLock.task == "summarize",
TaskLock.lock_key == "batch",
).all()
for lock in locks:
assert lock.status == "finished"
assert lock.released_at is not None
@pytest.mark.asyncio
async def test_batch_empty(self, db_session, _patch_paths):
"""无 pending 论文时返回空结果。"""
result = await summarize_batch(db_session)
assert result["status"] == "success"
assert result["total"] == 0