Files
daily-paper/app/services/schemas.py
T

179 lines
6.3 KiB
Python

"""AI 总结 schema — Pydantic 校验模型、质量评估、DB 展平。"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from pydantic import BaseModel, Field, ValidationError, field_validator
# ── 子模型 ──────────────────────────────────────────────────────────────
class PrerequisitesSchema(BaseModel):
concepts: list[str] = Field(default_factory=list)
level: str = ""
class MotivationSchema(BaseModel):
problem: str
goal: str = ""
gap: str = ""
@field_validator("problem")
@classmethod
def non_empty_problem(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("motivation.problem cannot be empty")
return v.strip()
class MethodSchema(BaseModel):
overview: str = ""
key_idea: str
steps: list[str] = Field(default_factory=list)
novelty: str = ""
@field_validator("key_idea")
@classmethod
def non_empty_key_idea(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("method.key_idea cannot be empty")
return v.strip()
class ResultsSchema(BaseModel):
main_findings: list[str] = Field(default_factory=list)
benchmarks: list[dict] = Field(default_factory=list)
limitations: list[str] = Field(default_factory=list)
class ImprovementsSchema(BaseModel):
weaknesses: list[str] = Field(default_factory=list)
future_work: list[str] = Field(default_factory=list)
reproducibility: str = ""
# ── 顶层 schema ─────────────────────────────────────────────────────────
class SummarySchema(BaseModel):
model_config = {"extra": "ignore"}
title_zh: str
one_line: str
tags: list[str]
difficulty: str = ""
paper_date: str | None = None
prerequisites: PrerequisitesSchema = Field(default_factory=PrerequisitesSchema)
motivation: MotivationSchema
method: MethodSchema
results: ResultsSchema = Field(default_factory=ResultsSchema)
improvements: ImprovementsSchema = Field(default_factory=ImprovementsSchema)
@field_validator("title_zh", "one_line")
@classmethod
def non_empty_text(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("field cannot be empty")
return v.strip()
@field_validator("tags")
@classmethod
def non_empty_tags(cls, v: list[str]) -> list[str]:
tags = [tag.strip() for tag in v if tag and tag.strip()]
if not tags:
raise ValueError("tags cannot be empty")
return tags
# ── 质量评估 ────────────────────────────────────────────────────────────
# 必填字段:title_zh, one_line, tags, motivation.problem, method.key_idea
# — 缺失时 Pydantic 校验就会报错,不会走到 assess_quality
# 重要字段:motivation.goal, motivation.gap, method.overview, results.main_findings
# — 缺失可入库,标记 degraded
_OPTIONAL_BUT_IMPORTANT_FIELDS = [
"motivation.goal",
"motivation.gap",
"method.overview",
"results.main_findings",
]
def assess_quality(schema: SummarySchema) -> str:
"""评估总结质量:normal / degraded / low。"""
# low:内容空洞的启发式判断
if len(schema.one_line.strip()) < 10 or len(schema.method.key_idea.strip()) < 10:
return "low"
# 检查重要字段是否缺失
missing_important = 0
if not schema.motivation.goal.strip():
missing_important += 1
if not schema.motivation.gap.strip():
missing_important += 1
if not schema.method.overview.strip():
missing_important += 1
if not schema.results.main_findings:
missing_important += 1
if missing_important == 0:
return "normal"
return "degraded"
# ── DB 展平 ─────────────────────────────────────────────────────────────
def flatten_for_db(schema: SummarySchema) -> dict:
"""将 SummarySchema 展平为 paper_summaries 表的列值 dict。"""
return {
"one_line": schema.one_line,
"difficulty": schema.difficulty,
"prerequisites_json": json.dumps(
schema.prerequisites.model_dump(), ensure_ascii=False
),
"motivation_problem": schema.motivation.problem,
"motivation_goal": schema.motivation.goal,
"motivation_gap": schema.motivation.gap,
"method_overview": schema.method.overview,
"method_key_idea": schema.method.key_idea,
"method_steps_json": json.dumps(schema.method.steps, ensure_ascii=False),
"method_novelty": schema.method.novelty,
"results_main_json": json.dumps(
schema.results.main_findings, ensure_ascii=False
),
"results_benchmarks_json": json.dumps(
schema.results.benchmarks, ensure_ascii=False
),
"limitations_json": json.dumps(schema.results.limitations, ensure_ascii=False),
"weaknesses_json": json.dumps(
schema.improvements.weaknesses, ensure_ascii=False
),
"future_work_json": json.dumps(
schema.improvements.future_work, ensure_ascii=False
),
"reproducibility": schema.improvements.reproducibility,
"full_json": schema.model_dump_json(ensure_ascii=False),
"updated_at": datetime.now(timezone.utc),
}
# ── 错误分类 ────────────────────────────────────────────────────────────
_REQUIRED_FIELDS = {"title_zh", "one_line", "tags", "problem", "key_idea"}
def classify_validation_error(exc: ValidationError) -> str:
"""区分 field_missing(必填缺失)和 schema_error(类型不合法等)。"""
for err in exc.errors():
field_name = err["loc"][-1] if err["loc"] else ""
if field_name in _REQUIRED_FIELDS and err["type"] in (
"missing",
"value_error",
):
return "field_missing"
return "schema_error"