"""AI 总结 schema — Pydantic 校验模型、质量评估、DB 展平。""" from __future__ import annotations import json from datetime import datetime, timezone from pydantic import BaseModel, Field, ValidationError, field_validator # ── 子模型 ────────────────────────────────────────────────────────────── class PrerequisitesSchema(BaseModel): concepts: list[str] = Field(default_factory=list) level: str = "" class MotivationSchema(BaseModel): problem: str goal: str = "" gap: str = "" @field_validator("problem") @classmethod def non_empty_problem(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("motivation.problem cannot be empty") return v.strip() class MethodSchema(BaseModel): overview: str = "" key_idea: str steps: list[str] = Field(default_factory=list) novelty: str = "" @field_validator("key_idea") @classmethod def non_empty_key_idea(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("method.key_idea cannot be empty") return v.strip() class ResultsSchema(BaseModel): main_findings: list[str] = Field(default_factory=list) benchmarks: list[dict] = Field(default_factory=list) limitations: list[str] = Field(default_factory=list) class ImprovementsSchema(BaseModel): weaknesses: list[str] = Field(default_factory=list) future_work: list[str] = Field(default_factory=list) reproducibility: str = "" # ── 顶层 schema ───────────────────────────────────────────────────────── class SummarySchema(BaseModel): model_config = {"extra": "ignore"} title_zh: str one_line: str tags: list[str] difficulty: str = "" paper_date: str | None = None prerequisites: PrerequisitesSchema = Field(default_factory=PrerequisitesSchema) motivation: MotivationSchema method: MethodSchema results: ResultsSchema = Field(default_factory=ResultsSchema) improvements: ImprovementsSchema = Field(default_factory=ImprovementsSchema) @field_validator("title_zh", "one_line") @classmethod def non_empty_text(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("field cannot be empty") return v.strip() @field_validator("tags") @classmethod def non_empty_tags(cls, v: list[str]) -> list[str]: tags = [tag.strip() for tag in v if tag and tag.strip()] if not tags: raise ValueError("tags cannot be empty") return tags # ── 质量评估 ──────────────────────────────────────────────────────────── # 必填字段:title_zh, one_line, tags, motivation.problem, method.key_idea # — 缺失时 Pydantic 校验就会报错,不会走到 assess_quality # 重要字段:motivation.goal, motivation.gap, method.overview, results.main_findings # — 缺失可入库,标记 degraded _OPTIONAL_BUT_IMPORTANT_FIELDS = [ "motivation.goal", "motivation.gap", "method.overview", "results.main_findings", ] def assess_quality(schema: SummarySchema) -> str: """评估总结质量:normal / degraded / low。""" # low:内容空洞的启发式判断 if len(schema.one_line.strip()) < 10 or len(schema.method.key_idea.strip()) < 10: return "low" # 检查重要字段是否缺失 missing_important = 0 if not schema.motivation.goal.strip(): missing_important += 1 if not schema.motivation.gap.strip(): missing_important += 1 if not schema.method.overview.strip(): missing_important += 1 if not schema.results.main_findings: missing_important += 1 if missing_important == 0: return "normal" return "degraded" # ── DB 展平 ───────────────────────────────────────────────────────────── def flatten_for_db(schema: SummarySchema) -> dict: """将 SummarySchema 展平为 paper_summaries 表的列值 dict。""" return { "one_line": schema.one_line, "difficulty": schema.difficulty, "prerequisites_json": json.dumps(schema.prerequisites.model_dump(), ensure_ascii=False), "motivation_problem": schema.motivation.problem, "motivation_goal": schema.motivation.goal, "motivation_gap": schema.motivation.gap, "method_overview": schema.method.overview, "method_key_idea": schema.method.key_idea, "method_steps_json": json.dumps(schema.method.steps, ensure_ascii=False), "method_novelty": schema.method.novelty, "results_main_json": json.dumps(schema.results.main_findings, ensure_ascii=False), "results_benchmarks_json": json.dumps(schema.results.benchmarks, ensure_ascii=False), "limitations_json": json.dumps(schema.results.limitations, ensure_ascii=False), "weaknesses_json": json.dumps(schema.improvements.weaknesses, ensure_ascii=False), "future_work_json": json.dumps(schema.improvements.future_work, ensure_ascii=False), "reproducibility": schema.improvements.reproducibility, "full_json": schema.model_dump_json(ensure_ascii=False), "updated_at": datetime.now(timezone.utc), } # ── 错误分类 ──────────────────────────────────────────────────────────── _REQUIRED_FIELDS = {"title_zh", "one_line", "tags", "problem", "key_idea"} def classify_validation_error(exc: ValidationError) -> str: """区分 field_missing(必填缺失)和 schema_error(类型不合法等)。""" for err in exc.errors(): field_name = err["loc"][-1] if err["loc"] else "" if field_name in _REQUIRED_FIELDS and err["type"] in ( "missing", "value_error", ): return "field_missing" return "schema_error"