Files

191 lines
8.3 KiB
Python

"""SummarySchema 校验、quality 分级、JSON 提取、错误分类测试。"""
from __future__ import annotations
import json
import pytest
from pydantic import ValidationError
from app.services.summary_utils import (
JsonNotFoundError,
extract_json as _extract_json,
)
from app.services.pi_client import (
PiProcessError,
PiTimeoutError,
)
from app.services.pdf_downloader import PdfDownloadError
from app.services.schemas import (
SummarySchema,
assess_quality,
classify_validation_error,
flatten_for_db,
)
from app.services.summary_generator import _classify_error
# ═══════════════════════════════════════════════════════════════════════
# SummarySchema 校验
# ═══════════════════════════════════════════════════════════════════════
class TestSummarySchema:
"""Pydantic schema 校验。"""
def test_valid_summary(self, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
assert schema.title_zh == "测试论文中文标题"
assert len(schema.tags) == 3
assert schema.motivation.problem
def test_missing_title_zh(self, sample_summary_dict):
del sample_summary_dict["title_zh"]
with pytest.raises(ValidationError) as exc_info:
SummarySchema.model_validate(sample_summary_dict)
assert classify_validation_error(exc_info.value) == "field_missing"
def test_empty_one_line(self, sample_summary_dict):
sample_summary_dict["one_line"] = ""
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_empty_tags(self, sample_summary_dict):
sample_summary_dict["tags"] = []
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_empty_motivation_problem(self, sample_summary_dict):
sample_summary_dict["motivation"]["problem"] = ""
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_empty_method_key_idea(self, sample_summary_dict):
sample_summary_dict["method"]["key_idea"] = ""
with pytest.raises(ValidationError):
SummarySchema.model_validate(sample_summary_dict)
def test_extra_fields_ignored(self, sample_summary_dict):
sample_summary_dict["takeaway"] = "important paper"
schema = SummarySchema.model_validate(sample_summary_dict)
assert not hasattr(schema, "takeaway")
assert schema.title_zh # 正常解析
def test_flatten_for_db(self, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
flat = flatten_for_db(schema)
assert flat["one_line"] == schema.one_line
assert flat["motivation_problem"] == schema.motivation.problem
assert flat["method_key_idea"] == schema.method.key_idea
assert "full_json" in flat
assert "updated_at" in flat
# JSON 字段可解析
assert isinstance(json.loads(flat["prerequisites_json"]), dict)
assert isinstance(flat["figures_json"], str) # figures 序列化为 JSON
# ═══════════════════════════════════════════════════════════════════════
# Quality 分级
# ═══════════════════════════════════════════════════════════════════════
class TestQualityAssessment:
"""质量分级测试。"""
def test_quality_normal(self, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "normal"
def test_quality_degraded_missing_goal(self, sample_summary_dict):
sample_summary_dict["motivation"]["goal"] = ""
sample_summary_dict["motivation"]["gap"] = ""
sample_summary_dict["method"]["overview"] = ""
sample_summary_dict["results"]["main_findings"] = ""
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "degraded"
def test_quality_low_short_one_line(self, sample_summary_dict):
sample_summary_dict["one_line"] = ""
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "low"
def test_quality_low_short_key_idea(self, sample_summary_dict):
sample_summary_dict["method"]["key_idea"] = ""
schema = SummarySchema.model_validate(sample_summary_dict)
assert assess_quality(schema) == "low"
# ═══════════════════════════════════════════════════════════════════════
# JSON 提取
# ═══════════════════════════════════════════════════════════════════════
class TestJsonExtraction:
"""pi 输出的 JSON 提取。"""
def test_direct_json(self, sample_summary_json):
result = _extract_json(sample_summary_json)
assert result["title_zh"] == "测试论文中文标题"
def test_fenced_code_block(self, sample_summary_json):
raw = f"一些文字\n```json\n{sample_summary_json}\n```\n更多文字"
result = _extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_fenced_without_lang(self, sample_summary_json):
raw = f"文字\n```\n{sample_summary_json}\n```"
result = _extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_embedded_braces(self, sample_summary_dict):
json_str = json.dumps(sample_summary_dict, ensure_ascii=False)
raw = f"Here is the summary:\n{json_str}\nEnd."
result = _extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_no_json_raises(self):
with pytest.raises(JsonNotFoundError):
_extract_json("No JSON here at all.")
def test_json_without_title_zh_falls_through(self):
"""不含 title_zh 的 JSON 不是我们要的。"""
raw = json.dumps({"other": "data"})
# 如果有其他合法 JSON 块也能返回,但没有就直接找最大块
# 此场景 raw 本身就是一个 JSON dict,但没有 title_zh
# 策略 1 会跳过(无 title_zh),策略 2 无代码块,策略 3 找到最大块
result = _extract_json(raw)
assert result == {"other": "data"} # 最大块兜底
# ═══════════════════════════════════════════════════════════════════════
# 错误分类
# ═══════════════════════════════════════════════════════════════════════
class TestErrorClassification:
"""异常 → error_type 映射。"""
def test_pdf_download_error(self):
assert _classify_error(PdfDownloadError("fail")) == "pdf_download_failed"
def test_timeout_error(self):
assert _classify_error(PiTimeoutError("timeout")) == "timeout"
def test_process_error(self):
assert _classify_error(PiProcessError(1, "stderr")) == "process_error"
def test_json_not_found(self):
assert _classify_error(JsonNotFoundError("not found")) == "json_not_found"
def test_json_invalid(self):
assert _classify_error(json.JSONDecodeError("bad", "", 0)) == "json_invalid"
def test_field_missing(self):
try:
SummarySchema.model_validate({"title_zh": ""}) # type: ignore
except ValidationError as exc:
assert _classify_error(exc) == "field_missing"
def test_unknown_error(self):
assert _classify_error(RuntimeError("boom")) == "unknown"