"""AI 总结服务测试 — summarize_one 状态流转、批量处理、DB 更新、文件操作。""" from __future__ import annotations import json from datetime import date from unittest.mock import AsyncMock, patch import pytest from sqlalchemy import text from app.models import ( CrawlLog, Paper, PaperSummary, PaperTag, SummaryStatus, TaskLock, ) from app.services.pdf_downloader import ( PdfDownloadError, cleanup_tmp as _cleanup_tmp, ) from app.services.pi_client import PiTimeoutError from app.services.schemas import SummarySchema from app.services.summarizer import summarize_batch, summarize_one from app.services.summary_persister import _save_files, _update_summary_in_db from app.utils import utc_now # ── 共享 fixture ────────────────────────────────────────────────────────── @pytest.fixture def _summarize_tmp_paths(tmp_path): """将 data 目录重定向到 tmp_path(供 summarizer 测试使用)。""" with ( patch( "app.services.summary_persister.paper_dir", lambda aid: tmp_path / "papers" / aid, ), patch( "app.services.summary_generator.paper_dir", lambda aid: tmp_path / "papers" / aid, ), patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"), patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"), patch("app.utils.PAPERS_DIR", tmp_path / "papers"), patch("app.utils.TMP_DIR", tmp_path / "tmp"), ): yield # ═══════════════════════════════════════════════════════════════════════ # DB 更新测试 # ═══════════════════════════════════════════════════════════════════════ class TestDbUpdate: """_update_summary_in_db 验证。""" def test_summary_written(self, db_session, sample_paper, sample_summary_dict): schema = SummarySchema.model_validate(sample_summary_dict) _update_summary_in_db(db_session, sample_paper, schema, "normal", "raw") summary = db_session.get(PaperSummary, sample_paper.id) assert summary is not None assert summary.one_line == schema.one_line assert summary.motivation_problem == schema.motivation.problem assert json.loads(summary.full_json)["title_zh"] == schema.title_zh def test_paper_title_zh_updated( self, db_session, sample_paper, sample_summary_dict ): schema = SummarySchema.model_validate(sample_summary_dict) _update_summary_in_db(db_session, sample_paper, schema, "normal", "raw") db_session.refresh(sample_paper) assert sample_paper.title_zh == "测试论文中文标题" assert sample_paper.summary_quality == "normal" def test_fts_updated(self, db_session, sample_paper, sample_summary_dict): schema = SummarySchema.model_validate(sample_summary_dict) _update_summary_in_db(db_session, sample_paper, schema, "normal", "raw") row = db_session.execute( text("SELECT title_zh, summary_text FROM papers_fts WHERE rowid = :id"), {"id": sample_paper.id}, ).fetchone() assert row is not None assert row[0] == "测试论文中文标题" assert schema.one_line in row[1] def test_ai_tags_added(self, db_session, sample_paper, sample_summary_dict): schema = SummarySchema.model_validate(sample_summary_dict) _update_summary_in_db(db_session, sample_paper, schema, "normal", "raw") tags = ( db_session.query(PaperTag) .filter(PaperTag.paper_id == sample_paper.id, PaperTag.source == "ai") .all() ) tag_names = {t.tag for t in tags} # AI tags 来自 schema.tags assert "自然语言处理" in tag_names assert "大语言模型" in tag_names def test_existing_tags_not_duplicated( self, db_session, sample_paper, sample_summary_dict ): """已存在的标签名(同 name)不会被 AI source 重复插入。""" # sample_paper 已有 NLP (hf)、LLM (hf) # 让 AI 输出包含 NLP(与 HF 重复)和 "新标签"(新的) sample_summary_dict["tags"] = ["NLP", "新标签"] schema = SummarySchema.model_validate(sample_summary_dict) _update_summary_in_db(db_session, sample_paper, schema, "normal", "raw") all_tags = ( db_session.query(PaperTag) .filter(PaperTag.paper_id == sample_paper.id) .all() ) tag_names = [t.tag for t in all_tags] # NLP 只出现一次(HF 原有的),AI 不会重复加 assert tag_names.count("NLP") == 1 # "新标签" 是 AI 新加的 assert "新标签" in tag_names # ═══════════════════════════════════════════════════════════════════════ # 文件操作测试 # ═══════════════════════════════════════════════════════════════════════ class TestFileOperations: """文件保存和清理。""" def test_save_files(self, tmp_path, sample_summary_dict): schema = SummarySchema.model_validate(sample_summary_dict) with patch( "app.services.summary_persister.paper_dir", lambda aid: tmp_path / aid ): _save_files("2401.12345", schema, "raw output text") paper_dir = tmp_path / "2401.12345" assert (paper_dir / "summary.json").exists() assert (paper_dir / "raw_output.txt").exists() saved = json.loads((paper_dir / "summary.json").read_text()) assert saved["title_zh"] == "测试论文中文标题" def test_save_raw_output_only(self, tmp_path): with patch( "app.services.summary_persister.paper_dir", lambda aid: tmp_path / aid ): _save_files("2401.12345", None, "raw output") paper_dir = tmp_path / "2401.12345" assert (paper_dir / "raw_output.txt").exists() assert not (paper_dir / "summary.json").exists() def test_cleanup_tmp(self, tmp_path): tmp_paper = tmp_path / "2401.12345" tmp_paper.mkdir() (tmp_paper / "paper.pdf").write_bytes(b"%PDF-fake") with patch("app.services.pdf_downloader.TMP_DIR", tmp_path): _cleanup_tmp("2401.12345") assert not tmp_paper.exists() def test_cleanup_tmp_nonexistent(self, tmp_path): """清理不存在的目录不报错。""" with patch("app.services.pdf_downloader.TMP_DIR", tmp_path): _cleanup_tmp("nonexistent") # 不抛异常 # ═══════════════════════════════════════════════════════════════════════ # 全流程状态流转测试 # ═══════════════════════════════════════════════════════════════════════ class TestSummarizeOneFlow: """summarize_one 的状态流转(mock pi 和 PDF)。""" @pytest.mark.asyncio async def test_full_success_path( self, db_session, sample_paper, mock_pi_output, _summarize_tmp_paths ): """pending → processing → done 全流程。""" with ( patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), patch( "app.services.summary_generator.call_pi", new_callable=AsyncMock, return_value=(mock_pi_output, "test-session-id"), ), ): result = await summarize_one(db_session, sample_paper) assert result["status"] == "done" assert result["quality"] == "normal" # 验证 DB 状态 db_session.refresh(sample_paper) assert sample_paper.summary_status.status == "done" assert sample_paper.summary_status.quality == "normal" assert sample_paper.title_zh == "测试论文中文标题" # 验证 summary 已写入 summary = db_session.get(PaperSummary, sample_paper.id) assert summary is not None assert summary.one_line # 验证 FTS 已更新 fts_row = db_session.execute( text("SELECT title_zh FROM papers_fts WHERE rowid = :id"), {"id": sample_paper.id}, ).fetchone() assert fts_row[0] == "测试论文中文标题" @pytest.mark.asyncio async def test_pdf_download_failure( self, db_session, sample_paper, _summarize_tmp_paths ): """PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。""" with ( patch( "app.services.summarizer.download_pdf", new_callable=AsyncMock, side_effect=PdfDownloadError("network error"), ), ): result = await summarize_one(db_session, sample_paper) assert result["status"] == "failed" assert result["error_type"] == "pdf_download_failed" db_session.refresh(sample_paper) status = sample_paper.summary_status assert status.error_type == "pdf_download_failed" @pytest.mark.asyncio async def test_pi_timeout(self, db_session, sample_paper, _summarize_tmp_paths): """pi 超时 → timeout 错误,retry_count 递增。""" with ( patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), patch( "app.services.summary_generator.call_pi", new_callable=AsyncMock, side_effect=PiTimeoutError("timeout after 300s"), ), ): result = await summarize_one(db_session, sample_paper) assert result["status"] == "failed" assert result["error_type"] == "timeout" assert result["retry_count"] == 1 @pytest.mark.asyncio async def test_json_not_found(self, db_session, sample_paper, _summarize_tmp_paths): """pi 输出无 JSON → 验证循环重试 4 次后 ValueError (unknown)。""" with ( patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), patch( "app.services.summary_generator.call_pi", new_callable=AsyncMock, return_value=("No JSON in this output at all.", "test-session-id"), ), ): result = await summarize_one(db_session, sample_paper) assert result["status"] == "failed" assert result["error_type"] == "unknown" @pytest.mark.asyncio async def test_validation_fails_and_retries( self, db_session, sample_paper, _summarize_tmp_paths ): """验证失败(字段不符合要求)→ 重试多次后失败。""" bad_json = json.dumps( { "arxiv_id": sample_paper.arxiv_id, "title_zh": "", # 空的必填字段 "one_line": "valid line", "tags": ["tag1"], "motivation": {"problem": "valid problem"}, "method": {"key_idea": "valid idea"}, }, ensure_ascii=False, ) bad_output = f"```json\n{bad_json}\n```" with ( patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), patch( "app.services.summary_generator.call_pi", new_callable=AsyncMock, return_value=(bad_output, "test-session-id"), ), ): # _validate_summary 先拦截,4 轮都失败后 ValueError → unknown result = await summarize_one(db_session, sample_paper) assert result["status"] == "failed" assert result["error_type"] == "unknown" assert result["retry_count"] == 1 @pytest.mark.asyncio async def test_raw_output_saved_on_failure( self, db_session, sample_paper, tmp_path, _summarize_tmp_paths ): """失败时仍保存 raw_output.txt。""" with ( patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), patch( "app.services.summary_generator.call_pi", new_callable=AsyncMock, return_value=("Some output without JSON", "test-session-id"), ), ): await summarize_one(db_session, sample_paper) raw_file = tmp_path / "papers" / sample_paper.arxiv_id / "raw_output.txt" assert raw_file.exists() assert "Some output without JSON" in raw_file.read_text() @pytest.mark.asyncio async def test_tmp_cleaned_on_success( self, db_session, sample_paper, mock_pi_output, tmp_path, _summarize_tmp_paths ): """成功后清理 tmp 目录。""" with ( patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), patch( "app.services.summary_generator.call_pi", new_callable=AsyncMock, return_value=(mock_pi_output, "test-session-id"), ), ): await summarize_one(db_session, sample_paper) tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id assert not tmp_paper.exists() @pytest.mark.asyncio async def test_tmp_cleaned_on_failure( self, db_session, sample_paper, tmp_path, _summarize_tmp_paths ): """失败后也清理 tmp 目录。""" with ( patch( "app.services.summarizer.download_pdf", new_callable=AsyncMock, side_effect=PdfDownloadError("fail"), ), ): await summarize_one(db_session, sample_paper) tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id assert not tmp_paper.exists() @pytest.mark.asyncio async def test_skips_done_paper( self, db_session, sample_paper, _summarize_tmp_paths ): """已完成的论文跳过。""" sample_paper.summary_status.status = "done" db_session.commit() result = await summarize_one(db_session, sample_paper) assert result["status"] == "skipped" @pytest.mark.asyncio async def test_post_processing_runs_in_thread( self, db_session, sample_paper, mock_pi_output, _summarize_tmp_paths ): """后处理(图片提取/ChromaDB)在工作线程而非事件循环线程执行。""" import threading seen_threads: list[int] = [] main_thread = threading.current_thread().ident def spy_extract(arxiv_id, schema): seen_threads.append(threading.current_thread().ident) with ( patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), patch( "app.services.summary_generator.call_pi", new_callable=AsyncMock, return_value=(mock_pi_output, "test-session-id"), ), patch( "app.services.summary_persister._maybe_extract_images", side_effect=spy_extract, ), patch("app.services.summary_persister._maybe_index_chroma"), ): result = await summarize_one(db_session, sample_paper) assert result["status"] == "done" assert seen_threads, "后处理未被调用" assert seen_threads[0] != main_thread, "后处理应在工作线程执行,不阻塞事件循环" # ═══════════════════════════════════════════════════════════════════════ # 批量操作测试 # ═══════════════════════════════════════════════════════════════════════ class TestBatchSummarize: """批量总结测试。""" @pytest.mark.asyncio async def test_batch_multiple_papers( self, db_session, db_engine, mock_pi_output, _summarize_tmp_paths ): """批量处理多篇论文。""" now = utc_now() for i in range(3): p = Paper( arxiv_id=f"2401.1234{i}", title_en=f"Test Paper {i}", abstract=f"Abstract {i}", paper_date=date(2024, 1, 15), crawled_at=now, pdf_url=f"https://arxiv.org/pdf/2401.1234{i}.pdf", ) db_session.add(p) db_session.flush() db_session.add(SummaryStatus(paper_id=p.id, status="pending")) db_session.commit() # 每个 worker 用独立 session(同一个内存引擎) from sqlalchemy.orm import sessionmaker as _sm _TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False) with ( patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), patch( "app.services.summary_generator.call_pi", new_callable=AsyncMock, return_value=(mock_pi_output, "test-session-id"), ), ): result = await summarize_batch(db_session, _session_factory=_TestSession) assert result["status"] == "success" assert result["done"] == 3 assert result["failed"] == 0 # 验证 CrawlLog log = db_session.query(CrawlLog).filter(CrawlLog.task == "summarize").first() assert log is not None assert log.status == "success" assert log.papers_found == 3 @pytest.mark.asyncio async def test_single_failure_no_block( self, db_session, db_engine, mock_pi_output, _summarize_tmp_paths ): """一篇失败不阻塞其他。""" now = utc_now() for i in range(2): p = Paper( arxiv_id=f"2401.5678{i}", title_en=f"Paper {i}", abstract=f"Abstract {i}", paper_date=date(2024, 1, 15), crawled_at=now, pdf_url=f"https://arxiv.org/pdf/2401.5678{i}.pdf", ) db_session.add(p) db_session.flush() db_session.add(SummaryStatus(paper_id=p.id, status="pending")) db_session.commit() from sqlalchemy.orm import sessionmaker as _sm _TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False) call_count = 0 async def _mock_call_pi(meta_path, pdf_path, **kwargs): nonlocal call_count call_count += 1 if call_count == 1: raise PiTimeoutError("timeout") return mock_pi_output, "test-session-id" with ( patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), patch("app.services.summary_generator.call_pi", side_effect=_mock_call_pi), ): result = await summarize_batch(db_session, _session_factory=_TestSession) assert result["done"] == 1 assert result["failed"] == 1 @pytest.mark.asyncio async def test_task_lock_conflict(self, db_session, _summarize_tmp_paths): """TaskLock 防止并发 batch。""" from app.exceptions import ConflictError # 先插入一个 running 锁 db_session.add( TaskLock( task="summarize", lock_key="batch", status="running", acquired_at=utc_now(), ) ) db_session.commit() with pytest.raises(ConflictError): await summarize_batch(db_session) @pytest.mark.asyncio async def test_task_lock_released( self, db_session, db_engine, mock_pi_output, _summarize_tmp_paths ): """完成后释放 TaskLock。""" from sqlalchemy.orm import sessionmaker as _sm _TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False) with ( patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), patch( "app.services.summary_generator.call_pi", new_callable=AsyncMock, return_value=(mock_pi_output, "test-session-id"), ), ): await summarize_batch(db_session, _session_factory=_TestSession) locks = ( db_session.query(TaskLock) .filter( TaskLock.task == "summarize", TaskLock.lock_key == "batch", ) .all() ) for lock in locks: assert lock.status == "finished" assert lock.released_at is not None @pytest.mark.asyncio async def test_batch_empty(self, db_session, _summarize_tmp_paths): """无 pending 论文时返回空结果。""" result = await summarize_batch(db_session) assert result["status"] == "success" assert result["total"] == 0