532 lines
20 KiB
Python
532 lines
20 KiB
Python
"""AI 总结服务测试 — summarize_one 状态流转、批量处理、DB 更新、文件操作。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from datetime import date, datetime, timezone
|
||
from pathlib import Path
|
||
from unittest.mock import AsyncMock, patch
|
||
|
||
import pytest
|
||
from sqlalchemy import text
|
||
|
||
from app.models import (
|
||
CrawlLog,
|
||
Paper,
|
||
PaperSummary,
|
||
PaperTag,
|
||
SummaryStatus,
|
||
TaskLock,
|
||
)
|
||
from app.services.pdf_downloader import (
|
||
PdfDownloadError,
|
||
cleanup_tmp as _cleanup_tmp,
|
||
)
|
||
from app.services.pi_client import PiTimeoutError
|
||
from app.services.schemas import SummarySchema
|
||
from app.services.summarizer import (
|
||
_save_files,
|
||
_save_raw_output_only,
|
||
_update_summary_in_db,
|
||
summarize_batch,
|
||
summarize_one,
|
||
)
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# DB 更新测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestDbUpdate:
|
||
"""_update_summary_in_db 验证。"""
|
||
|
||
def test_summary_written(self, db_session, sample_paper, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||
|
||
summary = db_session.get(PaperSummary, sample_paper.id)
|
||
assert summary is not None
|
||
assert summary.one_line == schema.one_line
|
||
assert summary.motivation_problem == schema.motivation.problem
|
||
assert json.loads(summary.full_json)["title_zh"] == schema.title_zh
|
||
|
||
def test_paper_title_zh_updated(
|
||
self, db_session, sample_paper, sample_summary_dict
|
||
):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||
|
||
db_session.refresh(sample_paper)
|
||
assert sample_paper.title_zh == "测试论文中文标题"
|
||
assert sample_paper.summary_quality == "normal"
|
||
|
||
def test_fts_updated(self, db_session, sample_paper, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||
|
||
row = db_session.execute(
|
||
text("SELECT title_zh, summary_text FROM papers_fts WHERE rowid = :id"),
|
||
{"id": sample_paper.id},
|
||
).fetchone()
|
||
assert row is not None
|
||
assert row[0] == "测试论文中文标题"
|
||
assert schema.one_line in row[1]
|
||
|
||
def test_ai_tags_added(self, db_session, sample_paper, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||
|
||
tags = (
|
||
db_session.query(PaperTag)
|
||
.filter(PaperTag.paper_id == sample_paper.id, PaperTag.source == "ai")
|
||
.all()
|
||
)
|
||
tag_names = {t.tag for t in tags}
|
||
# AI tags 来自 schema.tags
|
||
assert "自然语言处理" in tag_names
|
||
assert "大语言模型" in tag_names
|
||
|
||
def test_existing_tags_not_duplicated(
|
||
self, db_session, sample_paper, sample_summary_dict
|
||
):
|
||
"""已存在的标签名(同 name)不会被 AI source 重复插入。"""
|
||
# sample_paper 已有 NLP (hf)、LLM (hf)
|
||
# 让 AI 输出包含 NLP(与 HF 重复)和 "新标签"(新的)
|
||
sample_summary_dict["tags"] = ["NLP", "新标签"]
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||
|
||
all_tags = (
|
||
db_session.query(PaperTag)
|
||
.filter(PaperTag.paper_id == sample_paper.id)
|
||
.all()
|
||
)
|
||
tag_names = [t.tag for t in all_tags]
|
||
# NLP 只出现一次(HF 原有的),AI 不会重复加
|
||
assert tag_names.count("NLP") == 1
|
||
# "新标签" 是 AI 新加的
|
||
assert "新标签" in tag_names
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 文件操作测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestFileOperations:
|
||
"""文件保存和清理。"""
|
||
|
||
def test_save_files(self, tmp_path, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
|
||
_save_files("2401.12345", schema, "raw output text")
|
||
|
||
paper_dir = tmp_path / "2401.12345"
|
||
assert (paper_dir / "summary.json").exists()
|
||
assert (paper_dir / "raw_output.txt").exists()
|
||
saved = json.loads((paper_dir / "summary.json").read_text())
|
||
assert saved["title_zh"] == "测试论文中文标题"
|
||
|
||
def test_save_raw_output_only(self, tmp_path):
|
||
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
|
||
_save_raw_output_only("2401.12345", "raw output")
|
||
paper_dir = tmp_path / "2401.12345"
|
||
assert (paper_dir / "raw_output.txt").exists()
|
||
assert not (paper_dir / "summary.json").exists()
|
||
|
||
def test_cleanup_tmp(self, tmp_path):
|
||
tmp_paper = tmp_path / "2401.12345"
|
||
tmp_paper.mkdir()
|
||
(tmp_paper / "paper.pdf").write_bytes(b"%PDF-fake")
|
||
with patch("app.services.pdf_downloader.TMP_DIR", tmp_path):
|
||
_cleanup_tmp("2401.12345")
|
||
assert not tmp_paper.exists()
|
||
|
||
def test_cleanup_tmp_nonexistent(self, tmp_path):
|
||
"""清理不存在的目录不报错。"""
|
||
with patch("app.services.pdf_downloader.TMP_DIR", tmp_path):
|
||
_cleanup_tmp("nonexistent") # 不抛异常
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 全流程状态流转测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestSummarizeOneFlow:
|
||
"""summarize_one 的状态流转(mock pi 和 PDF)。"""
|
||
|
||
@pytest.fixture
|
||
def _patch_paths(self, tmp_path):
|
||
"""将 data 目录重定向到 tmp_path。"""
|
||
with (
|
||
patch(
|
||
"app.services.summarizer.paper_dir",
|
||
lambda aid: tmp_path / "papers" / aid,
|
||
),
|
||
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
|
||
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
|
||
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
|
||
patch("app.utils.TMP_DIR", tmp_path / "tmp"),
|
||
):
|
||
yield
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_full_success_path(
|
||
self, db_session, sample_paper, mock_pi_output, _patch_paths
|
||
):
|
||
"""pending → processing → done 全流程。"""
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summarizer.call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value=mock_pi_output,
|
||
),
|
||
):
|
||
result = await summarize_one(db_session, sample_paper)
|
||
|
||
assert result["status"] == "done"
|
||
assert result["quality"] == "normal"
|
||
|
||
# 验证 DB 状态
|
||
db_session.refresh(sample_paper)
|
||
assert sample_paper.summary_status.status == "done"
|
||
assert sample_paper.summary_status.quality == "normal"
|
||
assert sample_paper.title_zh == "测试论文中文标题"
|
||
|
||
# 验证 summary 已写入
|
||
summary = db_session.get(PaperSummary, sample_paper.id)
|
||
assert summary is not None
|
||
assert summary.one_line
|
||
|
||
# 验证 FTS 已更新
|
||
fts_row = db_session.execute(
|
||
text("SELECT title_zh FROM papers_fts WHERE rowid = :id"),
|
||
{"id": sample_paper.id},
|
||
).fetchone()
|
||
assert fts_row[0] == "测试论文中文标题"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pdf_download_failure(self, db_session, sample_paper, _patch_paths):
|
||
"""PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。"""
|
||
with (
|
||
patch(
|
||
"app.services.summarizer.download_pdf",
|
||
new_callable=AsyncMock,
|
||
side_effect=PdfDownloadError("network error"),
|
||
),
|
||
):
|
||
result = await summarize_one(db_session, sample_paper)
|
||
|
||
assert result["status"] == "failed"
|
||
assert result["error_type"] == "pdf_download_failed"
|
||
|
||
db_session.refresh(sample_paper)
|
||
status = sample_paper.summary_status
|
||
assert status.error_type == "pdf_download_failed"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pi_timeout(self, db_session, sample_paper, _patch_paths):
|
||
"""pi 超时 → timeout 错误,retry_count 递增。"""
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summarizer.call_pi",
|
||
new_callable=AsyncMock,
|
||
side_effect=PiTimeoutError("timeout after 300s"),
|
||
),
|
||
):
|
||
result = await summarize_one(db_session, sample_paper)
|
||
|
||
assert result["status"] == "failed"
|
||
assert result["error_type"] == "timeout"
|
||
assert result["retry_count"] == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_json_not_found(self, db_session, sample_paper, _patch_paths):
|
||
"""pi 输出无 JSON → json_not_found。"""
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summarizer.call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value="No JSON in this output at all.",
|
||
),
|
||
):
|
||
result = await summarize_one(db_session, sample_paper)
|
||
|
||
assert result["status"] == "failed"
|
||
assert result["error_type"] == "json_not_found"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_field_missing_and_retry(
|
||
self, db_session, sample_paper, _patch_paths
|
||
):
|
||
"""必填字段缺失 → field_missing → retry → permanent_failure。"""
|
||
bad_json = json.dumps(
|
||
{
|
||
"title_zh": "", # 空的必填字段
|
||
"one_line": "valid line",
|
||
"tags": ["tag1"],
|
||
"motivation": {"problem": "valid problem"},
|
||
"method": {"key_idea": "valid idea"},
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
bad_output = f"```json\n{bad_json}\n```"
|
||
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summarizer.call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value=bad_output,
|
||
),
|
||
):
|
||
# 第一次失败 → pending (retry)
|
||
result1 = await summarize_one(db_session, sample_paper)
|
||
assert result1["status"] == "failed"
|
||
assert result1["error_type"] == "field_missing"
|
||
assert result1["retry_count"] == 1
|
||
|
||
# 第二次失败 → permanent_failure (SUMMARY_MAX_RETRIES=1, 所以 2 次 > 1+1)
|
||
db_session.refresh(sample_paper)
|
||
result2 = await summarize_one(db_session, sample_paper)
|
||
assert result2["status"] == "failed"
|
||
assert result2["retry_count"] == 2
|
||
|
||
db_session.refresh(sample_paper)
|
||
assert sample_paper.summary_status.status == "permanent_failure"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_raw_output_saved_on_failure(
|
||
self, db_session, sample_paper, tmp_path, _patch_paths
|
||
):
|
||
"""失败时仍保存 raw_output.txt。"""
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summarizer.call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value="Some output without JSON",
|
||
),
|
||
):
|
||
await summarize_one(db_session, sample_paper)
|
||
|
||
raw_file = tmp_path / "papers" / sample_paper.arxiv_id / "raw_output.txt"
|
||
assert raw_file.exists()
|
||
assert "Some output without JSON" in raw_file.read_text()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_tmp_cleaned_on_success(
|
||
self, db_session, sample_paper, mock_pi_output, tmp_path, _patch_paths
|
||
):
|
||
"""成功后清理 tmp 目录。"""
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summarizer.call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value=mock_pi_output,
|
||
),
|
||
):
|
||
await summarize_one(db_session, sample_paper)
|
||
|
||
tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id
|
||
assert not tmp_paper.exists()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_tmp_cleaned_on_failure(
|
||
self, db_session, sample_paper, tmp_path, _patch_paths
|
||
):
|
||
"""失败后也清理 tmp 目录。"""
|
||
with (
|
||
patch(
|
||
"app.services.summarizer.download_pdf",
|
||
new_callable=AsyncMock,
|
||
side_effect=PdfDownloadError("fail"),
|
||
),
|
||
):
|
||
await summarize_one(db_session, sample_paper)
|
||
|
||
tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id
|
||
assert not tmp_paper.exists()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_skips_done_paper(self, db_session, sample_paper, _patch_paths):
|
||
"""已完成的论文跳过。"""
|
||
sample_paper.summary_status.status = "done"
|
||
db_session.commit()
|
||
|
||
result = await summarize_one(db_session, sample_paper)
|
||
assert result["status"] == "skipped"
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 批量操作测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestBatchSummarize:
|
||
"""批量总结测试。"""
|
||
|
||
@pytest.fixture
|
||
def _patch_paths(self, tmp_path):
|
||
with (
|
||
patch(
|
||
"app.services.summarizer.paper_dir",
|
||
lambda aid: tmp_path / "papers" / aid,
|
||
),
|
||
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
|
||
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
|
||
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
|
||
patch("app.utils.TMP_DIR", tmp_path / "tmp"),
|
||
):
|
||
yield
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_batch_multiple_papers(
|
||
self, db_session, db_engine, mock_pi_output, _patch_paths
|
||
):
|
||
"""批量处理多篇论文。"""
|
||
now = datetime.now(timezone.utc)
|
||
for i in range(3):
|
||
p = Paper(
|
||
arxiv_id=f"2401.1234{i}",
|
||
title_en=f"Test Paper {i}",
|
||
abstract=f"Abstract {i}",
|
||
paper_date=date(2024, 1, 15),
|
||
crawled_at=now,
|
||
pdf_url=f"https://arxiv.org/pdf/2401.1234{i}.pdf",
|
||
)
|
||
db_session.add(p)
|
||
db_session.flush()
|
||
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
|
||
|
||
db_session.commit()
|
||
|
||
# 每个 worker 用独立 session(同一个内存引擎)
|
||
from sqlalchemy.orm import sessionmaker as _sm
|
||
|
||
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summarizer.call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value=mock_pi_output,
|
||
),
|
||
):
|
||
result = await summarize_batch(db_session, _session_factory=_TestSession)
|
||
|
||
assert result["status"] == "success"
|
||
assert result["done"] == 3
|
||
assert result["failed"] == 0
|
||
|
||
# 验证 CrawlLog
|
||
log = db_session.query(CrawlLog).filter(CrawlLog.task == "summarize").first()
|
||
assert log is not None
|
||
assert log.status == "success"
|
||
assert log.papers_found == 3
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_single_failure_no_block(
|
||
self, db_session, db_engine, mock_pi_output, _patch_paths
|
||
):
|
||
"""一篇失败不阻塞其他。"""
|
||
now = datetime.now(timezone.utc)
|
||
for i in range(2):
|
||
p = Paper(
|
||
arxiv_id=f"2401.5678{i}",
|
||
title_en=f"Paper {i}",
|
||
abstract=f"Abstract {i}",
|
||
paper_date=date(2024, 1, 15),
|
||
crawled_at=now,
|
||
pdf_url=f"https://arxiv.org/pdf/2401.5678{i}.pdf",
|
||
)
|
||
db_session.add(p)
|
||
db_session.flush()
|
||
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
|
||
|
||
db_session.commit()
|
||
|
||
from sqlalchemy.orm import sessionmaker as _sm
|
||
|
||
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||
|
||
call_count = 0
|
||
|
||
async def _mock_call_pi(meta_path, pdf_path):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
if call_count == 1:
|
||
raise PiTimeoutError("timeout")
|
||
return mock_pi_output
|
||
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch("app.services.summarizer.call_pi", side_effect=_mock_call_pi),
|
||
):
|
||
result = await summarize_batch(db_session, _session_factory=_TestSession)
|
||
|
||
assert result["done"] == 1
|
||
assert result["failed"] == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_task_lock_conflict(self, db_session, _patch_paths):
|
||
"""TaskLock 防止并发 batch。"""
|
||
# 先插入一个 running 锁
|
||
db_session.add(
|
||
TaskLock(
|
||
task="summarize",
|
||
lock_key="batch",
|
||
status="running",
|
||
acquired_at=datetime.now(timezone.utc),
|
||
)
|
||
)
|
||
db_session.commit()
|
||
|
||
result = await summarize_batch(db_session)
|
||
assert result["status"] == "conflict"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_task_lock_released(
|
||
self, db_session, db_engine, mock_pi_output, _patch_paths
|
||
):
|
||
"""完成后释放 TaskLock。"""
|
||
from sqlalchemy.orm import sessionmaker as _sm
|
||
|
||
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||
|
||
with (
|
||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summarizer.call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value=mock_pi_output,
|
||
),
|
||
):
|
||
await summarize_batch(db_session, _session_factory=_TestSession)
|
||
|
||
locks = (
|
||
db_session.query(TaskLock)
|
||
.filter(
|
||
TaskLock.task == "summarize",
|
||
TaskLock.lock_key == "batch",
|
||
)
|
||
.all()
|
||
)
|
||
for lock in locks:
|
||
assert lock.status == "finished"
|
||
assert lock.released_at is not None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_batch_empty(self, db_session, _patch_paths):
|
||
"""无 pending 论文时返回空结果。"""
|
||
result = await summarize_batch(db_session)
|
||
assert result["status"] == "success"
|
||
assert result["total"] == 0
|