"""AI 总结 schema — Pydantic 校验模型、质量评估、DB 展平。""" from __future__ import annotations import json from pydantic import BaseModel, Field, ValidationError, field_validator from app.utils import sanitize_html, utc_now # ── 子模型 ────────────────────────────────────────────────────────────── class PrerequisitesSchema(BaseModel): concepts: list[dict] = Field(default_factory=list) class MotivationSchema(BaseModel): problem: str goal: str = "" gap: str = "" @field_validator("problem") @classmethod def non_empty_problem(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("motivation.problem cannot be empty") return v.strip() class MethodSchema(BaseModel): overview: str = "" key_idea: str steps: str = "" novelty: str = "" @field_validator("key_idea") @classmethod def non_empty_key_idea(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("method.key_idea cannot be empty") return v.strip() class ResultsSchema(BaseModel): main_findings: str = "" benchmarks: list[str | dict] = Field(default_factory=list) limitations: str = "" class ImprovementsSchema(BaseModel): weaknesses: str = "" future_work: str = "" reproducibility: str = "" # ── 顶层 schema ───────────────────────────────────────────────────────── class SummarySchema(BaseModel): model_config = {"extra": "ignore"} title_zh: str one_line: str tags: list[str] difficulty: str = "" paper_date: str | None = None prerequisites: PrerequisitesSchema = Field(default_factory=PrerequisitesSchema) motivation: MotivationSchema method: MethodSchema results: ResultsSchema = Field(default_factory=ResultsSchema) improvements: ImprovementsSchema = Field(default_factory=ImprovementsSchema) figures: list[dict] = Field(default_factory=list) @field_validator("title_zh", "one_line") @classmethod def non_empty_text(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("field cannot be empty") return v.strip() @field_validator("tags") @classmethod def non_empty_tags(cls, v: list[str]) -> list[str]: tags = [tag.strip() for tag in v if tag and tag.strip()] if not tags: raise ValueError("tags cannot be empty") return tags # ── 质量评估 ──────────────────────────────────────────────────────────── def assess_quality(schema: SummarySchema) -> str: """评估总结质量:normal / degraded / low。""" # low:内容空洞的启发式判断 if len(schema.one_line.strip()) < 10 or len(schema.method.key_idea.strip()) < 10: return "low" # 检查重要字段是否缺失 missing_important = 0 if not schema.motivation.goal.strip(): missing_important += 1 if not schema.motivation.gap.strip(): missing_important += 1 if not schema.method.overview.strip(): missing_important += 1 if not schema.results.main_findings.strip(): missing_important += 1 if missing_important == 0: return "normal" return "degraded" # ── DB 展平 ───────────────────────────────────────────────────────────── def flatten_for_db(schema: SummarySchema) -> dict: """将 SummarySchema 展平为 paper_summaries 表的列值 dict。 所有供前端用 |safe 渲染的文本字段均经过 HTML 清洗。 """ # 清洗 prerequisites 嵌套文本 prereqs = schema.prerequisites.model_dump() for c in prereqs.get("concepts", []): if isinstance(c, dict): for key in ("explanation", "why_matters"): if key in c and c[key]: c[key] = sanitize_html(c[key]) return { "one_line": sanitize_html(schema.one_line), "difficulty": schema.difficulty, "prerequisites_json": json.dumps(prereqs, ensure_ascii=False), "motivation_problem": sanitize_html(schema.motivation.problem), "motivation_goal": sanitize_html(schema.motivation.goal), "motivation_gap": sanitize_html(schema.motivation.gap), "method_overview": sanitize_html(schema.method.overview), "method_key_idea": sanitize_html(schema.method.key_idea), "method_steps_json": sanitize_html(schema.method.steps), "method_novelty": sanitize_html(schema.method.novelty), "results_main_json": sanitize_html(schema.results.main_findings), "results_benchmarks_json": json.dumps( schema.results.benchmarks, ensure_ascii=False ), "limitations_json": sanitize_html(schema.results.limitations), "weaknesses_json": sanitize_html(schema.improvements.weaknesses), "future_work_json": sanitize_html(schema.improvements.future_work), "reproducibility": sanitize_html(schema.improvements.reproducibility), "figures_json": json.dumps(schema.figures, ensure_ascii=False), "full_json": schema.model_dump_json(ensure_ascii=False), "updated_at": utc_now(), } # ── 错误分类 ──────────────────────────────────────────────────────────── _REQUIRED_FIELDS = {"title_zh", "one_line", "tags", "problem", "key_idea"} def classify_validation_error(exc: ValidationError) -> str: """区分 field_missing(必填缺失)和 schema_error(类型不合法等)。""" for err in exc.errors(): field_name = err["loc"][-1] if err["loc"] else "" if field_name in _REQUIRED_FIELDS and err["type"] in ( "missing", "value_error", ): return "field_missing" return "schema_error"