179 lines
6.3 KiB
Python
179 lines
6.3 KiB
Python
"""AI 总结 schema — Pydantic 校验模型、质量评估、DB 展平。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime, timezone
|
|
|
|
from pydantic import BaseModel, Field, ValidationError, field_validator
|
|
|
|
|
|
# ── 子模型 ──────────────────────────────────────────────────────────────
|
|
|
|
|
|
class PrerequisitesSchema(BaseModel):
|
|
concepts: list[str] = Field(default_factory=list)
|
|
level: str = ""
|
|
|
|
|
|
class MotivationSchema(BaseModel):
|
|
problem: str
|
|
goal: str = ""
|
|
gap: str = ""
|
|
|
|
@field_validator("problem")
|
|
@classmethod
|
|
def non_empty_problem(cls, v: str) -> str:
|
|
if not v or not v.strip():
|
|
raise ValueError("motivation.problem cannot be empty")
|
|
return v.strip()
|
|
|
|
|
|
class MethodSchema(BaseModel):
|
|
overview: str = ""
|
|
key_idea: str
|
|
steps: list[str] = Field(default_factory=list)
|
|
novelty: str = ""
|
|
|
|
@field_validator("key_idea")
|
|
@classmethod
|
|
def non_empty_key_idea(cls, v: str) -> str:
|
|
if not v or not v.strip():
|
|
raise ValueError("method.key_idea cannot be empty")
|
|
return v.strip()
|
|
|
|
|
|
class ResultsSchema(BaseModel):
|
|
main_findings: list[str] = Field(default_factory=list)
|
|
benchmarks: list[dict] = Field(default_factory=list)
|
|
limitations: list[str] = Field(default_factory=list)
|
|
|
|
|
|
class ImprovementsSchema(BaseModel):
|
|
weaknesses: list[str] = Field(default_factory=list)
|
|
future_work: list[str] = Field(default_factory=list)
|
|
reproducibility: str = ""
|
|
|
|
|
|
# ── 顶层 schema ─────────────────────────────────────────────────────────
|
|
|
|
|
|
class SummarySchema(BaseModel):
|
|
model_config = {"extra": "ignore"}
|
|
|
|
title_zh: str
|
|
one_line: str
|
|
tags: list[str]
|
|
difficulty: str = ""
|
|
paper_date: str | None = None
|
|
prerequisites: PrerequisitesSchema = Field(default_factory=PrerequisitesSchema)
|
|
motivation: MotivationSchema
|
|
method: MethodSchema
|
|
results: ResultsSchema = Field(default_factory=ResultsSchema)
|
|
improvements: ImprovementsSchema = Field(default_factory=ImprovementsSchema)
|
|
|
|
@field_validator("title_zh", "one_line")
|
|
@classmethod
|
|
def non_empty_text(cls, v: str) -> str:
|
|
if not v or not v.strip():
|
|
raise ValueError("field cannot be empty")
|
|
return v.strip()
|
|
|
|
@field_validator("tags")
|
|
@classmethod
|
|
def non_empty_tags(cls, v: list[str]) -> list[str]:
|
|
tags = [tag.strip() for tag in v if tag and tag.strip()]
|
|
if not tags:
|
|
raise ValueError("tags cannot be empty")
|
|
return tags
|
|
|
|
|
|
# ── 质量评估 ────────────────────────────────────────────────────────────
|
|
|
|
# 必填字段:title_zh, one_line, tags, motivation.problem, method.key_idea
|
|
# — 缺失时 Pydantic 校验就会报错,不会走到 assess_quality
|
|
# 重要字段:motivation.goal, motivation.gap, method.overview, results.main_findings
|
|
# — 缺失可入库,标记 degraded
|
|
_OPTIONAL_BUT_IMPORTANT_FIELDS = [
|
|
"motivation.goal",
|
|
"motivation.gap",
|
|
"method.overview",
|
|
"results.main_findings",
|
|
]
|
|
|
|
|
|
def assess_quality(schema: SummarySchema) -> str:
|
|
"""评估总结质量:normal / degraded / low。"""
|
|
# low:内容空洞的启发式判断
|
|
if len(schema.one_line.strip()) < 10 or len(schema.method.key_idea.strip()) < 10:
|
|
return "low"
|
|
|
|
# 检查重要字段是否缺失
|
|
missing_important = 0
|
|
if not schema.motivation.goal.strip():
|
|
missing_important += 1
|
|
if not schema.motivation.gap.strip():
|
|
missing_important += 1
|
|
if not schema.method.overview.strip():
|
|
missing_important += 1
|
|
if not schema.results.main_findings:
|
|
missing_important += 1
|
|
|
|
if missing_important == 0:
|
|
return "normal"
|
|
return "degraded"
|
|
|
|
|
|
# ── DB 展平 ─────────────────────────────────────────────────────────────
|
|
|
|
|
|
def flatten_for_db(schema: SummarySchema) -> dict:
|
|
"""将 SummarySchema 展平为 paper_summaries 表的列值 dict。"""
|
|
return {
|
|
"one_line": schema.one_line,
|
|
"difficulty": schema.difficulty,
|
|
"prerequisites_json": json.dumps(
|
|
schema.prerequisites.model_dump(), ensure_ascii=False
|
|
),
|
|
"motivation_problem": schema.motivation.problem,
|
|
"motivation_goal": schema.motivation.goal,
|
|
"motivation_gap": schema.motivation.gap,
|
|
"method_overview": schema.method.overview,
|
|
"method_key_idea": schema.method.key_idea,
|
|
"method_steps_json": json.dumps(schema.method.steps, ensure_ascii=False),
|
|
"method_novelty": schema.method.novelty,
|
|
"results_main_json": json.dumps(
|
|
schema.results.main_findings, ensure_ascii=False
|
|
),
|
|
"results_benchmarks_json": json.dumps(
|
|
schema.results.benchmarks, ensure_ascii=False
|
|
),
|
|
"limitations_json": json.dumps(schema.results.limitations, ensure_ascii=False),
|
|
"weaknesses_json": json.dumps(
|
|
schema.improvements.weaknesses, ensure_ascii=False
|
|
),
|
|
"future_work_json": json.dumps(
|
|
schema.improvements.future_work, ensure_ascii=False
|
|
),
|
|
"reproducibility": schema.improvements.reproducibility,
|
|
"full_json": schema.model_dump_json(ensure_ascii=False),
|
|
"updated_at": datetime.now(timezone.utc),
|
|
}
|
|
|
|
|
|
# ── 错误分类 ────────────────────────────────────────────────────────────
|
|
|
|
_REQUIRED_FIELDS = {"title_zh", "one_line", "tags", "problem", "key_idea"}
|
|
|
|
|
|
def classify_validation_error(exc: ValidationError) -> str:
|
|
"""区分 field_missing(必填缺失)和 schema_error(类型不合法等)。"""
|
|
for err in exc.errors():
|
|
field_name = err["loc"][-1] if err["loc"] else ""
|
|
if field_name in _REQUIRED_FIELDS and err["type"] in (
|
|
"missing",
|
|
"value_error",
|
|
):
|
|
return "field_missing"
|
|
return "schema_error"
|