"""AI 总结服务测试 — Mock 全链路,不调用真实 pi。""" from __future__ import annotations import asyncio import json from datetime import date, datetime, timezone from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest from pydantic import ValidationError from sqlalchemy import text from app.models import ( CrawlLog, Paper, PaperSummary, PaperTag, SummaryStatus, TaskLock, ) from app.services.schemas import ( SummarySchema, assess_quality, classify_validation_error, flatten_for_db, ) from app.services.summarizer import ( JsonNotFoundError, PdfDownloadError, PiProcessError, PiTimeoutError, _call_pi, _classify_error, _cleanup_tmp, _extract_json, _save_files, _save_raw_output_only, _update_summary_in_db, summarize_batch, summarize_one, summarize_single, ) # ═══════════════════════════════════════════════════════════════════════ # Schema 校验测试 # ═══════════════════════════════════════════════════════════════════════ class TestSummarySchema: """Pydantic schema 校验。""" def test_valid_summary(self, sample_summary_dict): schema = SummarySchema.model_validate(sample_summary_dict) assert schema.title_zh == "测试论文中文标题" assert len(schema.tags) == 3 assert schema.motivation.problem def test_missing_title_zh(self, sample_summary_dict): del sample_summary_dict["title_zh"] with pytest.raises(ValidationError) as exc_info: SummarySchema.model_validate(sample_summary_dict) assert classify_validation_error(exc_info.value) == "field_missing" def test_empty_one_line(self, sample_summary_dict): sample_summary_dict["one_line"] = "" with pytest.raises(ValidationError): SummarySchema.model_validate(sample_summary_dict) def test_empty_tags(self, sample_summary_dict): sample_summary_dict["tags"] = [] with pytest.raises(ValidationError): SummarySchema.model_validate(sample_summary_dict) def test_empty_motivation_problem(self, sample_summary_dict): sample_summary_dict["motivation"]["problem"] = "" with pytest.raises(ValidationError): SummarySchema.model_validate(sample_summary_dict) def test_empty_method_key_idea(self, sample_summary_dict): sample_summary_dict["method"]["key_idea"] = "" with pytest.raises(ValidationError): SummarySchema.model_validate(sample_summary_dict) def test_extra_fields_ignored(self, sample_summary_dict): sample_summary_dict["figures"] = ["fig1.png"] sample_summary_dict["takeaway"] = "important paper" schema = SummarySchema.model_validate(sample_summary_dict) assert not hasattr(schema, "figures") assert schema.title_zh # 正常解析 def test_flatten_for_db(self, sample_summary_dict): schema = SummarySchema.model_validate(sample_summary_dict) flat = flatten_for_db(schema) assert flat["one_line"] == schema.one_line assert flat["motivation_problem"] == schema.motivation.problem assert flat["method_key_idea"] == schema.method.key_idea assert "full_json" in flat assert "updated_at" in flat # JSON 字段可解析 assert isinstance(json.loads(flat["prerequisites_json"]), dict) assert isinstance(json.loads(flat["method_steps_json"]), list) class TestQualityAssessment: """质量分级测试。""" def test_quality_normal(self, sample_summary_dict): schema = SummarySchema.model_validate(sample_summary_dict) assert assess_quality(schema) == "normal" def test_quality_degraded_missing_goal(self, sample_summary_dict): sample_summary_dict["motivation"]["goal"] = "" sample_summary_dict["motivation"]["gap"] = "" sample_summary_dict["method"]["overview"] = "" sample_summary_dict["results"]["main_findings"] = [] schema = SummarySchema.model_validate(sample_summary_dict) assert assess_quality(schema) == "degraded" def test_quality_low_short_one_line(self, sample_summary_dict): sample_summary_dict["one_line"] = "短" schema = SummarySchema.model_validate(sample_summary_dict) assert assess_quality(schema) == "low" def test_quality_low_short_key_idea(self, sample_summary_dict): sample_summary_dict["method"]["key_idea"] = "短" schema = SummarySchema.model_validate(sample_summary_dict) assert assess_quality(schema) == "low" # ═══════════════════════════════════════════════════════════════════════ # JSON 提取测试 # ═══════════════════════════════════════════════════════════════════════ class TestJsonExtraction: """pi 输出的 JSON 提取。""" def test_direct_json(self, sample_summary_json): result = _extract_json(sample_summary_json) assert result["title_zh"] == "测试论文中文标题" def test_fenced_code_block(self, sample_summary_json): raw = f"一些文字\n```json\n{sample_summary_json}\n```\n更多文字" result = _extract_json(raw) assert result["title_zh"] == "测试论文中文标题" def test_fenced_without_lang(self, sample_summary_json): raw = f"文字\n```\n{sample_summary_json}\n```" result = _extract_json(raw) assert result["title_zh"] == "测试论文中文标题" def test_embedded_braces(self, sample_summary_dict): json_str = json.dumps(sample_summary_dict, ensure_ascii=False) raw = f"Here is the summary:\n{json_str}\nEnd." result = _extract_json(raw) assert result["title_zh"] == "测试论文中文标题" def test_no_json_raises(self): with pytest.raises(JsonNotFoundError): _extract_json("No JSON here at all.") def test_json_without_title_zh_falls_through(self): """不含 title_zh 的 JSON 不是我们要的。""" raw = json.dumps({"other": "data"}) # 如果有其他合法 JSON 块也能返回,但没有就直接找最大块 # 此场景 raw 本身就是一个 JSON dict,但没有 title_zh # 策略 1 会跳过(无 title_zh),策略 2 无代码块,策略 3 找到最大块 result = _extract_json(raw) assert result == {"other": "data"} # 最大块兜底 # ═══════════════════════════════════════════════════════════════════════ # 错误分类测试 # ═══════════════════════════════════════════════════════════════════════ class TestErrorClassification: """异常 → error_type 映射。""" def test_pdf_download_error(self): assert _classify_error(PdfDownloadError("fail")) == "pdf_download_failed" def test_timeout_error(self): assert _classify_error(PiTimeoutError("timeout")) == "timeout" def test_process_error(self): assert _classify_error(PiProcessError(1, "stderr")) == "process_error" def test_json_not_found(self): assert _classify_error(JsonNotFoundError("not found")) == "json_not_found" def test_json_invalid(self): assert _classify_error(json.JSONDecodeError("bad", "", 0)) == "json_invalid" def test_field_missing(self): try: SummarySchema.model_validate({"title_zh": ""}) # type: ignore except ValidationError as exc: assert _classify_error(exc) == "field_missing" def test_unknown_error(self): assert _classify_error(RuntimeError("boom")) == "unknown" # ═══════════════════════════════════════════════════════════════════════ # DB 更新测试 # ═══════════════════════════════════════════════════════════════════════ class TestDbUpdate: """_update_summary_in_db 验证。""" def test_summary_written(self, db_session, sample_paper, sample_summary_dict): schema = SummarySchema.model_validate(sample_summary_dict) _update_summary_in_db(db_session, sample_paper, schema, "normal", "raw") summary = db_session.get(PaperSummary, sample_paper.id) assert summary is not None assert summary.one_line == schema.one_line assert summary.motivation_problem == schema.motivation.problem assert json.loads(summary.full_json)["title_zh"] == schema.title_zh def test_paper_title_zh_updated(self, db_session, sample_paper, sample_summary_dict): schema = SummarySchema.model_validate(sample_summary_dict) _update_summary_in_db(db_session, sample_paper, schema, "normal", "raw") db_session.refresh(sample_paper) assert sample_paper.title_zh == "测试论文中文标题" assert sample_paper.summary_quality == "normal" def test_fts_updated(self, db_session, sample_paper, sample_summary_dict): schema = SummarySchema.model_validate(sample_summary_dict) _update_summary_in_db(db_session, sample_paper, schema, "normal", "raw") row = db_session.execute( text("SELECT title_zh, summary_text FROM papers_fts WHERE rowid = :id"), {"id": sample_paper.id}, ).fetchone() assert row is not None assert row[0] == "测试论文中文标题" assert schema.one_line in row[1] def test_ai_tags_added(self, db_session, sample_paper, sample_summary_dict): schema = SummarySchema.model_validate(sample_summary_dict) _update_summary_in_db(db_session, sample_paper, schema, "normal", "raw") tags = ( db_session.query(PaperTag) .filter(PaperTag.paper_id == sample_paper.id, PaperTag.source == "ai") .all() ) tag_names = {t.tag for t in tags} # AI tags 来自 schema.tags assert "自然语言处理" in tag_names assert "大语言模型" in tag_names def test_existing_tags_not_duplicated(self, db_session, sample_paper, sample_summary_dict): """已存在的标签名(同 name)不会被 AI source 重复插入。""" # sample_paper 已有 NLP (hf)、LLM (hf) # 让 AI 输出包含 NLP(与 HF 重复)和 "新标签"(新的) sample_summary_dict["tags"] = ["NLP", "新标签"] schema = SummarySchema.model_validate(sample_summary_dict) _update_summary_in_db(db_session, sample_paper, schema, "normal", "raw") all_tags = ( db_session.query(PaperTag) .filter(PaperTag.paper_id == sample_paper.id) .all() ) tag_names = [t.tag for t in all_tags] # NLP 只出现一次(HF 原有的),AI 不会重复加 assert tag_names.count("NLP") == 1 # "新标签" 是 AI 新加的 assert "新标签" in tag_names # ═══════════════════════════════════════════════════════════════════════ # 文件操作测试 # ═══════════════════════════════════════════════════════════════════════ class TestFileOperations: """文件保存和清理。""" def test_save_files(self, tmp_path, sample_summary_dict): schema = SummarySchema.model_validate(sample_summary_dict) with patch("app.services.summarizer._PAPERS_DIR", tmp_path): _save_files("2401.12345", schema, "raw output text") paper_dir = tmp_path / "2401.12345" assert (paper_dir / "summary.json").exists() assert (paper_dir / "raw_output.txt").exists() saved = json.loads((paper_dir / "summary.json").read_text()) assert saved["title_zh"] == "测试论文中文标题" def test_save_raw_output_only(self, tmp_path): with patch("app.services.summarizer._PAPERS_DIR", tmp_path): _save_raw_output_only("2401.12345", "raw output") paper_dir = tmp_path / "2401.12345" assert (paper_dir / "raw_output.txt").exists() assert not (paper_dir / "summary.json").exists() def test_cleanup_tmp(self, tmp_path): tmp_paper = tmp_path / "2401.12345" tmp_paper.mkdir() (tmp_paper / "paper.pdf").write_bytes(b"%PDF-fake") with patch("app.services.summarizer._TMP_DIR", tmp_path): _cleanup_tmp("2401.12345") assert not tmp_paper.exists() def test_cleanup_tmp_nonexistent(self, tmp_path): """清理不存在的目录不报错。""" with patch("app.services.summarizer._TMP_DIR", tmp_path): _cleanup_tmp("nonexistent") # 不抛异常 # ═══════════════════════════════════════════════════════════════════════ # 全流程状态流转测试 # ═══════════════════════════════════════════════════════════════════════ class TestSummarizeOneFlow: """summarize_one 的状态流转(mock pi 和 PDF)。""" @pytest.fixture def _patch_paths(self, tmp_path): """将 data 目录重定向到 tmp_path。""" with ( patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"), patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"), patch("app.services.summarizer._DATA_DIR", tmp_path), ): yield @pytest.mark.asyncio async def test_full_success_path( self, db_session, sample_paper, mock_pi_output, _patch_paths ): """pending → processing → done 全流程。""" with ( patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output), ): result = await summarize_one(db_session, sample_paper) assert result["status"] == "done" assert result["quality"] == "normal" # 验证 DB 状态 db_session.refresh(sample_paper) assert sample_paper.summary_status.status == "done" assert sample_paper.summary_status.quality == "normal" assert sample_paper.title_zh == "测试论文中文标题" # 验证 summary 已写入 summary = db_session.get(PaperSummary, sample_paper.id) assert summary is not None assert summary.one_line # 验证 FTS 已更新 fts_row = db_session.execute( text("SELECT title_zh FROM papers_fts WHERE rowid = :id"), {"id": sample_paper.id}, ).fetchone() assert fts_row[0] == "测试论文中文标题" @pytest.mark.asyncio async def test_pdf_download_failure( self, db_session, sample_paper, _patch_paths ): """PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。""" with ( patch( "app.services.summarizer._download_pdf", new_callable=AsyncMock, side_effect=PdfDownloadError("network error"), ), ): result = await summarize_one(db_session, sample_paper) assert result["status"] == "failed" assert result["error_type"] == "pdf_download_failed" db_session.refresh(sample_paper) status = sample_paper.summary_status assert status.error_type == "pdf_download_failed" @pytest.mark.asyncio async def test_pi_timeout(self, db_session, sample_paper, _patch_paths): """pi 超时 → timeout 错误,retry_count 递增。""" with ( patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), patch( "app.services.summarizer._call_pi", new_callable=AsyncMock, side_effect=PiTimeoutError("timeout after 300s"), ), ): result = await summarize_one(db_session, sample_paper) assert result["status"] == "failed" assert result["error_type"] == "timeout" assert result["retry_count"] == 1 @pytest.mark.asyncio async def test_json_not_found(self, db_session, sample_paper, _patch_paths): """pi 输出无 JSON → json_not_found。""" with ( patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), patch( "app.services.summarizer._call_pi", new_callable=AsyncMock, return_value="No JSON in this output at all.", ), ): result = await summarize_one(db_session, sample_paper) assert result["status"] == "failed" assert result["error_type"] == "json_not_found" @pytest.mark.asyncio async def test_field_missing_and_retry( self, db_session, sample_paper, _patch_paths ): """必填字段缺失 → field_missing → retry → permanent_failure。""" bad_json = json.dumps({ "title_zh": "", # 空的必填字段 "one_line": "valid line", "tags": ["tag1"], "motivation": {"problem": "valid problem"}, "method": {"key_idea": "valid idea"}, }, ensure_ascii=False) bad_output = f"```json\n{bad_json}\n```" with ( patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), patch( "app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=bad_output, ), ): # 第一次失败 → pending (retry) result1 = await summarize_one(db_session, sample_paper) assert result1["status"] == "failed" assert result1["error_type"] == "field_missing" assert result1["retry_count"] == 1 # 第二次失败 → permanent_failure (SUMMARY_MAX_RETRIES=1, 所以 2 次 > 1+1) db_session.refresh(sample_paper) result2 = await summarize_one(db_session, sample_paper) assert result2["status"] == "failed" assert result2["retry_count"] == 2 db_session.refresh(sample_paper) assert sample_paper.summary_status.status == "permanent_failure" @pytest.mark.asyncio async def test_raw_output_saved_on_failure( self, db_session, sample_paper, tmp_path, _patch_paths ): """失败时仍保存 raw_output.txt。""" with ( patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), patch( "app.services.summarizer._call_pi", new_callable=AsyncMock, return_value="Some output without JSON", ), ): await summarize_one(db_session, sample_paper) raw_file = tmp_path / "papers" / sample_paper.arxiv_id / "raw_output.txt" assert raw_file.exists() assert "Some output without JSON" in raw_file.read_text() @pytest.mark.asyncio async def test_tmp_cleaned_on_success( self, db_session, sample_paper, mock_pi_output, tmp_path, _patch_paths ): """成功后清理 tmp 目录。""" with ( patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output), ): await summarize_one(db_session, sample_paper) tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id assert not tmp_paper.exists() @pytest.mark.asyncio async def test_tmp_cleaned_on_failure( self, db_session, sample_paper, tmp_path, _patch_paths ): """失败后也清理 tmp 目录。""" with ( patch( "app.services.summarizer._download_pdf", new_callable=AsyncMock, side_effect=PdfDownloadError("fail"), ), ): await summarize_one(db_session, sample_paper) tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id assert not tmp_paper.exists() @pytest.mark.asyncio async def test_skips_done_paper(self, db_session, sample_paper, _patch_paths): """已完成的论文跳过。""" sample_paper.summary_status.status = "done" db_session.commit() result = await summarize_one(db_session, sample_paper) assert result["status"] == "skipped" # ═══════════════════════════════════════════════════════════════════════ # 批量操作测试 # ═══════════════════════════════════════════════════════════════════════ class TestBatchSummarize: """批量总结测试。""" @pytest.fixture def _patch_paths(self, tmp_path): with ( patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"), patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"), patch("app.services.summarizer._DATA_DIR", tmp_path), ): yield @pytest.mark.asyncio async def test_batch_multiple_papers( self, db_session, db_engine, mock_pi_output, _patch_paths ): """批量处理多篇论文。""" now = datetime.now(timezone.utc) for i in range(3): p = Paper( arxiv_id=f"2401.1234{i}", title_en=f"Test Paper {i}", abstract=f"Abstract {i}", paper_date=date(2024, 1, 15), crawled_at=now, pdf_url=f"https://arxiv.org/pdf/2401.1234{i}.pdf", ) db_session.add(p) db_session.flush() db_session.add(SummaryStatus(paper_id=p.id, status="pending")) db_session.commit() # 每个 worker 用独立 session(同一个内存引擎) from sqlalchemy.orm import sessionmaker as _sm _TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False) with ( patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output), ): result = await summarize_batch( db_session, _session_factory=_TestSession ) assert result["status"] == "success" assert result["done"] == 3 assert result["failed"] == 0 # 验证 CrawlLog log = db_session.query(CrawlLog).filter(CrawlLog.task == "summarize").first() assert log is not None assert log.status == "success" assert log.papers_found == 3 @pytest.mark.asyncio async def test_single_failure_no_block( self, db_session, db_engine, mock_pi_output, _patch_paths ): """一篇失败不阻塞其他。""" now = datetime.now(timezone.utc) for i in range(2): p = Paper( arxiv_id=f"2401.5678{i}", title_en=f"Paper {i}", abstract=f"Abstract {i}", paper_date=date(2024, 1, 15), crawled_at=now, pdf_url=f"https://arxiv.org/pdf/2401.5678{i}.pdf", ) db_session.add(p) db_session.flush() db_session.add(SummaryStatus(paper_id=p.id, status="pending")) db_session.commit() from sqlalchemy.orm import sessionmaker as _sm _TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False) call_count = 0 async def _mock_call_pi(meta_path, pdf_path): nonlocal call_count call_count += 1 if call_count == 1: raise PiTimeoutError("timeout") return mock_pi_output with ( patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), patch("app.services.summarizer._call_pi", side_effect=_mock_call_pi), ): result = await summarize_batch( db_session, _session_factory=_TestSession ) assert result["done"] == 1 assert result["failed"] == 1 @pytest.mark.asyncio async def test_task_lock_conflict(self, db_session, _patch_paths): """TaskLock 防止并发 batch。""" # 先插入一个 running 锁 db_session.add( TaskLock( task="summarize", lock_key="batch", status="running", acquired_at=datetime.now(timezone.utc), ) ) db_session.commit() result = await summarize_batch(db_session) assert result["status"] == "conflict" @pytest.mark.asyncio async def test_task_lock_released(self, db_session, db_engine, mock_pi_output, _patch_paths): """完成后释放 TaskLock。""" from sqlalchemy.orm import sessionmaker as _sm _TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False) with ( patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output), ): await summarize_batch( db_session, _session_factory=_TestSession ) locks = db_session.query(TaskLock).filter( TaskLock.task == "summarize", TaskLock.lock_key == "batch", ).all() for lock in locks: assert lock.status == "finished" assert lock.released_at is not None @pytest.mark.asyncio async def test_batch_empty(self, db_session, _patch_paths): """无 pending 论文时返回空结果。""" result = await summarize_batch(db_session) assert result["status"] == "success" assert result["total"] == 0 # ═══════════════════════════════════════════════════════════════════════ # Admin 路由鉴权测试 # ═══════════════════════════════════════════════════════════════════════ class TestAdminAuth: """管理接口鉴权 — 只测 HTTP 层,mock 掉实际服务调用。""" def test_no_token_returns_401(self, client): """无 Bearer token 返回 401。""" resp = client.post("/admin/summarize") assert resp.status_code in (401, 403) def test_wrong_token_returns_401(self, client): resp = client.post( "/admin/summarize", headers={"Authorization": "Bearer wrong-token"}, ) assert resp.status_code == 401 def test_correct_token_batch(self, client, admin_headers): """正确 token 调用 batch summarize,mock 掉服务层。""" import app.config as config_mod original = config_mod.settings.ADMIN_TOKEN config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345" try: with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock: mock.return_value = {"status": "success", "done": 0, "failed": 0, "total": 0} resp = client.post("/admin/summarize", headers=admin_headers) assert resp.status_code == 200 assert resp.json()["status"] == "success" finally: config_mod.settings.ADMIN_TOKEN = original def test_single_paper_not_found(self, client, admin_headers): """单篇总结不存在的论文返回 404。""" import app.config as config_mod original = config_mod.settings.ADMIN_TOKEN config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345" try: with patch( "app.routes.admin.summarize_single", new_callable=AsyncMock, return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"}, ): resp = client.post( "/admin/summarize/nonexistent.99999", headers=admin_headers, ) assert resp.status_code == 404 finally: config_mod.settings.ADMIN_TOKEN = original