29e6797c12
- Add /admin routes for manual trigger and status inspection - Add summarizer service with batch/single summary support - Add summarize CLI command (single arxiv_id or batch pending) - Register admin router in main app - Add tests for summarizer
726 lines
29 KiB
Python
726 lines
29 KiB
Python
"""AI 总结服务测试 — Mock 全链路,不调用真实 pi。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
from datetime import date, datetime, timezone
|
||
from pathlib import Path
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
from pydantic import ValidationError
|
||
from sqlalchemy import text
|
||
|
||
from app.models import (
|
||
CrawlLog,
|
||
Paper,
|
||
PaperSummary,
|
||
PaperTag,
|
||
SummaryStatus,
|
||
TaskLock,
|
||
)
|
||
from app.services.schemas import (
|
||
SummarySchema,
|
||
assess_quality,
|
||
classify_validation_error,
|
||
flatten_for_db,
|
||
)
|
||
from app.services.summarizer import (
|
||
JsonNotFoundError,
|
||
PdfDownloadError,
|
||
PiProcessError,
|
||
PiTimeoutError,
|
||
_call_pi,
|
||
_classify_error,
|
||
_cleanup_tmp,
|
||
_extract_json,
|
||
_save_files,
|
||
_save_raw_output_only,
|
||
_update_summary_in_db,
|
||
summarize_batch,
|
||
summarize_one,
|
||
summarize_single,
|
||
)
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# Schema 校验测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestSummarySchema:
|
||
"""Pydantic schema 校验。"""
|
||
|
||
def test_valid_summary(self, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
assert schema.title_zh == "测试论文中文标题"
|
||
assert len(schema.tags) == 3
|
||
assert schema.motivation.problem
|
||
|
||
def test_missing_title_zh(self, sample_summary_dict):
|
||
del sample_summary_dict["title_zh"]
|
||
with pytest.raises(ValidationError) as exc_info:
|
||
SummarySchema.model_validate(sample_summary_dict)
|
||
assert classify_validation_error(exc_info.value) == "field_missing"
|
||
|
||
def test_empty_one_line(self, sample_summary_dict):
|
||
sample_summary_dict["one_line"] = ""
|
||
with pytest.raises(ValidationError):
|
||
SummarySchema.model_validate(sample_summary_dict)
|
||
|
||
def test_empty_tags(self, sample_summary_dict):
|
||
sample_summary_dict["tags"] = []
|
||
with pytest.raises(ValidationError):
|
||
SummarySchema.model_validate(sample_summary_dict)
|
||
|
||
def test_empty_motivation_problem(self, sample_summary_dict):
|
||
sample_summary_dict["motivation"]["problem"] = ""
|
||
with pytest.raises(ValidationError):
|
||
SummarySchema.model_validate(sample_summary_dict)
|
||
|
||
def test_empty_method_key_idea(self, sample_summary_dict):
|
||
sample_summary_dict["method"]["key_idea"] = ""
|
||
with pytest.raises(ValidationError):
|
||
SummarySchema.model_validate(sample_summary_dict)
|
||
|
||
def test_extra_fields_ignored(self, sample_summary_dict):
|
||
sample_summary_dict["figures"] = ["fig1.png"]
|
||
sample_summary_dict["takeaway"] = "important paper"
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
assert not hasattr(schema, "figures")
|
||
assert schema.title_zh # 正常解析
|
||
|
||
def test_flatten_for_db(self, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
flat = flatten_for_db(schema)
|
||
assert flat["one_line"] == schema.one_line
|
||
assert flat["motivation_problem"] == schema.motivation.problem
|
||
assert flat["method_key_idea"] == schema.method.key_idea
|
||
assert "full_json" in flat
|
||
assert "updated_at" in flat
|
||
# JSON 字段可解析
|
||
assert isinstance(json.loads(flat["prerequisites_json"]), dict)
|
||
assert isinstance(json.loads(flat["method_steps_json"]), list)
|
||
|
||
|
||
class TestQualityAssessment:
|
||
"""质量分级测试。"""
|
||
|
||
def test_quality_normal(self, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
assert assess_quality(schema) == "normal"
|
||
|
||
def test_quality_degraded_missing_goal(self, sample_summary_dict):
|
||
sample_summary_dict["motivation"]["goal"] = ""
|
||
sample_summary_dict["motivation"]["gap"] = ""
|
||
sample_summary_dict["method"]["overview"] = ""
|
||
sample_summary_dict["results"]["main_findings"] = []
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
assert assess_quality(schema) == "degraded"
|
||
|
||
def test_quality_low_short_one_line(self, sample_summary_dict):
|
||
sample_summary_dict["one_line"] = "短"
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
assert assess_quality(schema) == "low"
|
||
|
||
def test_quality_low_short_key_idea(self, sample_summary_dict):
|
||
sample_summary_dict["method"]["key_idea"] = "短"
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
assert assess_quality(schema) == "low"
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# JSON 提取测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestJsonExtraction:
|
||
"""pi 输出的 JSON 提取。"""
|
||
|
||
def test_direct_json(self, sample_summary_json):
|
||
result = _extract_json(sample_summary_json)
|
||
assert result["title_zh"] == "测试论文中文标题"
|
||
|
||
def test_fenced_code_block(self, sample_summary_json):
|
||
raw = f"一些文字\n```json\n{sample_summary_json}\n```\n更多文字"
|
||
result = _extract_json(raw)
|
||
assert result["title_zh"] == "测试论文中文标题"
|
||
|
||
def test_fenced_without_lang(self, sample_summary_json):
|
||
raw = f"文字\n```\n{sample_summary_json}\n```"
|
||
result = _extract_json(raw)
|
||
assert result["title_zh"] == "测试论文中文标题"
|
||
|
||
def test_embedded_braces(self, sample_summary_dict):
|
||
json_str = json.dumps(sample_summary_dict, ensure_ascii=False)
|
||
raw = f"Here is the summary:\n{json_str}\nEnd."
|
||
result = _extract_json(raw)
|
||
assert result["title_zh"] == "测试论文中文标题"
|
||
|
||
def test_no_json_raises(self):
|
||
with pytest.raises(JsonNotFoundError):
|
||
_extract_json("No JSON here at all.")
|
||
|
||
def test_json_without_title_zh_falls_through(self):
|
||
"""不含 title_zh 的 JSON 不是我们要的。"""
|
||
raw = json.dumps({"other": "data"})
|
||
# 如果有其他合法 JSON 块也能返回,但没有就直接找最大块
|
||
# 此场景 raw 本身就是一个 JSON dict,但没有 title_zh
|
||
# 策略 1 会跳过(无 title_zh),策略 2 无代码块,策略 3 找到最大块
|
||
result = _extract_json(raw)
|
||
assert result == {"other": "data"} # 最大块兜底
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 错误分类测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestErrorClassification:
|
||
"""异常 → error_type 映射。"""
|
||
|
||
def test_pdf_download_error(self):
|
||
assert _classify_error(PdfDownloadError("fail")) == "pdf_download_failed"
|
||
|
||
def test_timeout_error(self):
|
||
assert _classify_error(PiTimeoutError("timeout")) == "timeout"
|
||
|
||
def test_process_error(self):
|
||
assert _classify_error(PiProcessError(1, "stderr")) == "process_error"
|
||
|
||
def test_json_not_found(self):
|
||
assert _classify_error(JsonNotFoundError("not found")) == "json_not_found"
|
||
|
||
def test_json_invalid(self):
|
||
assert _classify_error(json.JSONDecodeError("bad", "", 0)) == "json_invalid"
|
||
|
||
def test_field_missing(self):
|
||
try:
|
||
SummarySchema.model_validate({"title_zh": ""}) # type: ignore
|
||
except ValidationError as exc:
|
||
assert _classify_error(exc) == "field_missing"
|
||
|
||
def test_unknown_error(self):
|
||
assert _classify_error(RuntimeError("boom")) == "unknown"
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# DB 更新测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestDbUpdate:
|
||
"""_update_summary_in_db 验证。"""
|
||
|
||
def test_summary_written(self, db_session, sample_paper, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||
|
||
summary = db_session.get(PaperSummary, sample_paper.id)
|
||
assert summary is not None
|
||
assert summary.one_line == schema.one_line
|
||
assert summary.motivation_problem == schema.motivation.problem
|
||
assert json.loads(summary.full_json)["title_zh"] == schema.title_zh
|
||
|
||
def test_paper_title_zh_updated(self, db_session, sample_paper, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||
|
||
db_session.refresh(sample_paper)
|
||
assert sample_paper.title_zh == "测试论文中文标题"
|
||
assert sample_paper.summary_quality == "normal"
|
||
|
||
def test_fts_updated(self, db_session, sample_paper, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||
|
||
row = db_session.execute(
|
||
text("SELECT title_zh, summary_text FROM papers_fts WHERE rowid = :id"),
|
||
{"id": sample_paper.id},
|
||
).fetchone()
|
||
assert row is not None
|
||
assert row[0] == "测试论文中文标题"
|
||
assert schema.one_line in row[1]
|
||
|
||
def test_ai_tags_added(self, db_session, sample_paper, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||
|
||
tags = (
|
||
db_session.query(PaperTag)
|
||
.filter(PaperTag.paper_id == sample_paper.id, PaperTag.source == "ai")
|
||
.all()
|
||
)
|
||
tag_names = {t.tag for t in tags}
|
||
# AI tags 来自 schema.tags
|
||
assert "自然语言处理" in tag_names
|
||
assert "大语言模型" in tag_names
|
||
|
||
def test_existing_tags_not_duplicated(self, db_session, sample_paper, sample_summary_dict):
|
||
"""已存在的标签名(同 name)不会被 AI source 重复插入。"""
|
||
# sample_paper 已有 NLP (hf)、LLM (hf)
|
||
# 让 AI 输出包含 NLP(与 HF 重复)和 "新标签"(新的)
|
||
sample_summary_dict["tags"] = ["NLP", "新标签"]
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
_update_summary_in_db(db_session, sample_paper, schema, "normal", "raw")
|
||
|
||
all_tags = (
|
||
db_session.query(PaperTag)
|
||
.filter(PaperTag.paper_id == sample_paper.id)
|
||
.all()
|
||
)
|
||
tag_names = [t.tag for t in all_tags]
|
||
# NLP 只出现一次(HF 原有的),AI 不会重复加
|
||
assert tag_names.count("NLP") == 1
|
||
# "新标签" 是 AI 新加的
|
||
assert "新标签" in tag_names
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 文件操作测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestFileOperations:
|
||
"""文件保存和清理。"""
|
||
|
||
def test_save_files(self, tmp_path, sample_summary_dict):
|
||
schema = SummarySchema.model_validate(sample_summary_dict)
|
||
with patch("app.services.summarizer._PAPERS_DIR", tmp_path):
|
||
_save_files("2401.12345", schema, "raw output text")
|
||
|
||
paper_dir = tmp_path / "2401.12345"
|
||
assert (paper_dir / "summary.json").exists()
|
||
assert (paper_dir / "raw_output.txt").exists()
|
||
saved = json.loads((paper_dir / "summary.json").read_text())
|
||
assert saved["title_zh"] == "测试论文中文标题"
|
||
|
||
def test_save_raw_output_only(self, tmp_path):
|
||
with patch("app.services.summarizer._PAPERS_DIR", tmp_path):
|
||
_save_raw_output_only("2401.12345", "raw output")
|
||
paper_dir = tmp_path / "2401.12345"
|
||
assert (paper_dir / "raw_output.txt").exists()
|
||
assert not (paper_dir / "summary.json").exists()
|
||
|
||
def test_cleanup_tmp(self, tmp_path):
|
||
tmp_paper = tmp_path / "2401.12345"
|
||
tmp_paper.mkdir()
|
||
(tmp_paper / "paper.pdf").write_bytes(b"%PDF-fake")
|
||
with patch("app.services.summarizer._TMP_DIR", tmp_path):
|
||
_cleanup_tmp("2401.12345")
|
||
assert not tmp_paper.exists()
|
||
|
||
def test_cleanup_tmp_nonexistent(self, tmp_path):
|
||
"""清理不存在的目录不报错。"""
|
||
with patch("app.services.summarizer._TMP_DIR", tmp_path):
|
||
_cleanup_tmp("nonexistent") # 不抛异常
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 全流程状态流转测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestSummarizeOneFlow:
|
||
"""summarize_one 的状态流转(mock pi 和 PDF)。"""
|
||
|
||
@pytest.fixture
|
||
def _patch_paths(self, tmp_path):
|
||
"""将 data 目录重定向到 tmp_path。"""
|
||
with (
|
||
patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"),
|
||
patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"),
|
||
patch("app.services.summarizer._DATA_DIR", tmp_path),
|
||
):
|
||
yield
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_full_success_path(
|
||
self, db_session, sample_paper, mock_pi_output, _patch_paths
|
||
):
|
||
"""pending → processing → done 全流程。"""
|
||
with (
|
||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||
):
|
||
result = await summarize_one(db_session, sample_paper)
|
||
|
||
assert result["status"] == "done"
|
||
assert result["quality"] == "normal"
|
||
|
||
# 验证 DB 状态
|
||
db_session.refresh(sample_paper)
|
||
assert sample_paper.summary_status.status == "done"
|
||
assert sample_paper.summary_status.quality == "normal"
|
||
assert sample_paper.title_zh == "测试论文中文标题"
|
||
|
||
# 验证 summary 已写入
|
||
summary = db_session.get(PaperSummary, sample_paper.id)
|
||
assert summary is not None
|
||
assert summary.one_line
|
||
|
||
# 验证 FTS 已更新
|
||
fts_row = db_session.execute(
|
||
text("SELECT title_zh FROM papers_fts WHERE rowid = :id"),
|
||
{"id": sample_paper.id},
|
||
).fetchone()
|
||
assert fts_row[0] == "测试论文中文标题"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pdf_download_failure(
|
||
self, db_session, sample_paper, _patch_paths
|
||
):
|
||
"""PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。"""
|
||
with (
|
||
patch(
|
||
"app.services.summarizer._download_pdf",
|
||
new_callable=AsyncMock,
|
||
side_effect=PdfDownloadError("network error"),
|
||
),
|
||
):
|
||
result = await summarize_one(db_session, sample_paper)
|
||
|
||
assert result["status"] == "failed"
|
||
assert result["error_type"] == "pdf_download_failed"
|
||
|
||
db_session.refresh(sample_paper)
|
||
status = sample_paper.summary_status
|
||
assert status.error_type == "pdf_download_failed"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pi_timeout(self, db_session, sample_paper, _patch_paths):
|
||
"""pi 超时 → timeout 错误,retry_count 递增。"""
|
||
with (
|
||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summarizer._call_pi",
|
||
new_callable=AsyncMock,
|
||
side_effect=PiTimeoutError("timeout after 300s"),
|
||
),
|
||
):
|
||
result = await summarize_one(db_session, sample_paper)
|
||
|
||
assert result["status"] == "failed"
|
||
assert result["error_type"] == "timeout"
|
||
assert result["retry_count"] == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_json_not_found(self, db_session, sample_paper, _patch_paths):
|
||
"""pi 输出无 JSON → json_not_found。"""
|
||
with (
|
||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summarizer._call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value="No JSON in this output at all.",
|
||
),
|
||
):
|
||
result = await summarize_one(db_session, sample_paper)
|
||
|
||
assert result["status"] == "failed"
|
||
assert result["error_type"] == "json_not_found"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_field_missing_and_retry(
|
||
self, db_session, sample_paper, _patch_paths
|
||
):
|
||
"""必填字段缺失 → field_missing → retry → permanent_failure。"""
|
||
bad_json = json.dumps({
|
||
"title_zh": "", # 空的必填字段
|
||
"one_line": "valid line",
|
||
"tags": ["tag1"],
|
||
"motivation": {"problem": "valid problem"},
|
||
"method": {"key_idea": "valid idea"},
|
||
}, ensure_ascii=False)
|
||
bad_output = f"```json\n{bad_json}\n```"
|
||
|
||
with (
|
||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summarizer._call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value=bad_output,
|
||
),
|
||
):
|
||
# 第一次失败 → pending (retry)
|
||
result1 = await summarize_one(db_session, sample_paper)
|
||
assert result1["status"] == "failed"
|
||
assert result1["error_type"] == "field_missing"
|
||
assert result1["retry_count"] == 1
|
||
|
||
# 第二次失败 → permanent_failure (SUMMARY_MAX_RETRIES=1, 所以 2 次 > 1+1)
|
||
db_session.refresh(sample_paper)
|
||
result2 = await summarize_one(db_session, sample_paper)
|
||
assert result2["status"] == "failed"
|
||
assert result2["retry_count"] == 2
|
||
|
||
db_session.refresh(sample_paper)
|
||
assert sample_paper.summary_status.status == "permanent_failure"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_raw_output_saved_on_failure(
|
||
self, db_session, sample_paper, tmp_path, _patch_paths
|
||
):
|
||
"""失败时仍保存 raw_output.txt。"""
|
||
with (
|
||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||
patch(
|
||
"app.services.summarizer._call_pi",
|
||
new_callable=AsyncMock,
|
||
return_value="Some output without JSON",
|
||
),
|
||
):
|
||
await summarize_one(db_session, sample_paper)
|
||
|
||
raw_file = tmp_path / "papers" / sample_paper.arxiv_id / "raw_output.txt"
|
||
assert raw_file.exists()
|
||
assert "Some output without JSON" in raw_file.read_text()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_tmp_cleaned_on_success(
|
||
self, db_session, sample_paper, mock_pi_output, tmp_path, _patch_paths
|
||
):
|
||
"""成功后清理 tmp 目录。"""
|
||
with (
|
||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||
):
|
||
await summarize_one(db_session, sample_paper)
|
||
|
||
tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id
|
||
assert not tmp_paper.exists()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_tmp_cleaned_on_failure(
|
||
self, db_session, sample_paper, tmp_path, _patch_paths
|
||
):
|
||
"""失败后也清理 tmp 目录。"""
|
||
with (
|
||
patch(
|
||
"app.services.summarizer._download_pdf",
|
||
new_callable=AsyncMock,
|
||
side_effect=PdfDownloadError("fail"),
|
||
),
|
||
):
|
||
await summarize_one(db_session, sample_paper)
|
||
|
||
tmp_paper = tmp_path / "tmp" / sample_paper.arxiv_id
|
||
assert not tmp_paper.exists()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_skips_done_paper(self, db_session, sample_paper, _patch_paths):
|
||
"""已完成的论文跳过。"""
|
||
sample_paper.summary_status.status = "done"
|
||
db_session.commit()
|
||
|
||
result = await summarize_one(db_session, sample_paper)
|
||
assert result["status"] == "skipped"
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 批量操作测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestBatchSummarize:
|
||
"""批量总结测试。"""
|
||
|
||
@pytest.fixture
|
||
def _patch_paths(self, tmp_path):
|
||
with (
|
||
patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"),
|
||
patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"),
|
||
patch("app.services.summarizer._DATA_DIR", tmp_path),
|
||
):
|
||
yield
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_batch_multiple_papers(
|
||
self, db_session, db_engine, mock_pi_output, _patch_paths
|
||
):
|
||
"""批量处理多篇论文。"""
|
||
now = datetime.now(timezone.utc)
|
||
for i in range(3):
|
||
p = Paper(
|
||
arxiv_id=f"2401.1234{i}",
|
||
title_en=f"Test Paper {i}",
|
||
abstract=f"Abstract {i}",
|
||
paper_date=date(2024, 1, 15),
|
||
crawled_at=now,
|
||
pdf_url=f"https://arxiv.org/pdf/2401.1234{i}.pdf",
|
||
)
|
||
db_session.add(p)
|
||
db_session.flush()
|
||
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
|
||
|
||
db_session.commit()
|
||
|
||
# 每个 worker 用独立 session(同一个内存引擎)
|
||
from sqlalchemy.orm import sessionmaker as _sm
|
||
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||
|
||
with (
|
||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||
):
|
||
result = await summarize_batch(
|
||
db_session, _session_factory=_TestSession
|
||
)
|
||
|
||
assert result["status"] == "success"
|
||
assert result["done"] == 3
|
||
assert result["failed"] == 0
|
||
|
||
# 验证 CrawlLog
|
||
log = db_session.query(CrawlLog).filter(CrawlLog.task == "summarize").first()
|
||
assert log is not None
|
||
assert log.status == "success"
|
||
assert log.papers_found == 3
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_single_failure_no_block(
|
||
self, db_session, db_engine, mock_pi_output, _patch_paths
|
||
):
|
||
"""一篇失败不阻塞其他。"""
|
||
now = datetime.now(timezone.utc)
|
||
for i in range(2):
|
||
p = Paper(
|
||
arxiv_id=f"2401.5678{i}",
|
||
title_en=f"Paper {i}",
|
||
abstract=f"Abstract {i}",
|
||
paper_date=date(2024, 1, 15),
|
||
crawled_at=now,
|
||
pdf_url=f"https://arxiv.org/pdf/2401.5678{i}.pdf",
|
||
)
|
||
db_session.add(p)
|
||
db_session.flush()
|
||
db_session.add(SummaryStatus(paper_id=p.id, status="pending"))
|
||
|
||
db_session.commit()
|
||
|
||
from sqlalchemy.orm import sessionmaker as _sm
|
||
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||
|
||
call_count = 0
|
||
|
||
async def _mock_call_pi(meta_path, pdf_path):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
if call_count == 1:
|
||
raise PiTimeoutError("timeout")
|
||
return mock_pi_output
|
||
|
||
with (
|
||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||
patch("app.services.summarizer._call_pi", side_effect=_mock_call_pi),
|
||
):
|
||
result = await summarize_batch(
|
||
db_session, _session_factory=_TestSession
|
||
)
|
||
|
||
assert result["done"] == 1
|
||
assert result["failed"] == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_task_lock_conflict(self, db_session, _patch_paths):
|
||
"""TaskLock 防止并发 batch。"""
|
||
# 先插入一个 running 锁
|
||
db_session.add(
|
||
TaskLock(
|
||
task="summarize",
|
||
lock_key="batch",
|
||
status="running",
|
||
acquired_at=datetime.now(timezone.utc),
|
||
)
|
||
)
|
||
db_session.commit()
|
||
|
||
result = await summarize_batch(db_session)
|
||
assert result["status"] == "conflict"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_task_lock_released(self, db_session, db_engine, mock_pi_output, _patch_paths):
|
||
"""完成后释放 TaskLock。"""
|
||
from sqlalchemy.orm import sessionmaker as _sm
|
||
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||
|
||
with (
|
||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
||
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||
):
|
||
await summarize_batch(
|
||
db_session, _session_factory=_TestSession
|
||
)
|
||
|
||
locks = db_session.query(TaskLock).filter(
|
||
TaskLock.task == "summarize",
|
||
TaskLock.lock_key == "batch",
|
||
).all()
|
||
for lock in locks:
|
||
assert lock.status == "finished"
|
||
assert lock.released_at is not None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_batch_empty(self, db_session, _patch_paths):
|
||
"""无 pending 论文时返回空结果。"""
|
||
result = await summarize_batch(db_session)
|
||
assert result["status"] == "success"
|
||
assert result["total"] == 0
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# Admin 路由鉴权测试
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestAdminAuth:
|
||
"""管理接口鉴权 — 只测 HTTP 层,mock 掉实际服务调用。"""
|
||
|
||
def test_no_token_returns_401(self, client):
|
||
"""无 Bearer token 返回 401。"""
|
||
resp = client.post("/admin/summarize")
|
||
assert resp.status_code in (401, 403)
|
||
|
||
def test_wrong_token_returns_401(self, client):
|
||
resp = client.post(
|
||
"/admin/summarize",
|
||
headers={"Authorization": "Bearer wrong-token"},
|
||
)
|
||
assert resp.status_code == 401
|
||
|
||
def test_correct_token_batch(self, client, admin_headers):
|
||
"""正确 token 调用 batch summarize,mock 掉服务层。"""
|
||
import app.config as config_mod
|
||
|
||
original = config_mod.settings.ADMIN_TOKEN
|
||
config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345"
|
||
try:
|
||
with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock:
|
||
mock.return_value = {"status": "success", "done": 0, "failed": 0, "total": 0}
|
||
resp = client.post("/admin/summarize", headers=admin_headers)
|
||
assert resp.status_code == 200
|
||
assert resp.json()["status"] == "success"
|
||
finally:
|
||
config_mod.settings.ADMIN_TOKEN = original
|
||
|
||
def test_single_paper_not_found(self, client, admin_headers):
|
||
"""单篇总结不存在的论文返回 404。"""
|
||
import app.config as config_mod
|
||
|
||
original = config_mod.settings.ADMIN_TOKEN
|
||
config_mod.settings.ADMIN_TOKEN = "test-admin-token-12345"
|
||
try:
|
||
with patch(
|
||
"app.routes.admin.summarize_single",
|
||
new_callable=AsyncMock,
|
||
return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"},
|
||
):
|
||
resp = client.post(
|
||
"/admin/summarize/nonexistent.99999",
|
||
headers=admin_headers,
|
||
)
|
||
assert resp.status_code == 404
|
||
finally:
|
||
config_mod.settings.ADMIN_TOKEN = original
|