Files
Rain-Bus 21f16e6756 feat: refactor summarizer and PDF extraction pipeline
- Split summarizer into summary_generator and summary_persister modules
- Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection
- Add layout_detector service for PicoDet-S_layout_3cls integration
- Add exceptions module with ConflictError and NotFoundError
- Improve admin dashboard with better statistics and task management
- Add design review document with system optimization suggestions
- Add new tests for crawler, pdf_downloader, pipeline, and summary_utils
- Update dependencies and configuration
- Clean up dead code and improve error handling
2026-06-13 13:16:47 +08:00

172 lines
6.2 KiB
Python

"""AI 总结 schema — Pydantic 校验模型、质量评估、DB 展平。"""
from __future__ import annotations
import json
from pydantic import BaseModel, Field, ValidationError, field_validator
from app.utils import sanitize_html, utc_now
# ── 子模型 ──────────────────────────────────────────────────────────────
class PrerequisitesSchema(BaseModel):
concepts: list[dict] = Field(default_factory=list)
class MotivationSchema(BaseModel):
problem: str
goal: str = ""
gap: str = ""
@field_validator("problem")
@classmethod
def non_empty_problem(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("motivation.problem cannot be empty")
return v.strip()
class MethodSchema(BaseModel):
overview: str = ""
key_idea: str
steps: str = ""
novelty: str = ""
@field_validator("key_idea")
@classmethod
def non_empty_key_idea(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("method.key_idea cannot be empty")
return v.strip()
class ResultsSchema(BaseModel):
main_findings: str = ""
benchmarks: list[str | dict] = Field(default_factory=list)
limitations: str = ""
class ImprovementsSchema(BaseModel):
weaknesses: str = ""
future_work: str = ""
reproducibility: str = ""
# ── 顶层 schema ─────────────────────────────────────────────────────────
class SummarySchema(BaseModel):
model_config = {"extra": "ignore"}
title_zh: str
one_line: str
tags: list[str]
difficulty: str = ""
paper_date: str | None = None
prerequisites: PrerequisitesSchema = Field(default_factory=PrerequisitesSchema)
motivation: MotivationSchema
method: MethodSchema
results: ResultsSchema = Field(default_factory=ResultsSchema)
improvements: ImprovementsSchema = Field(default_factory=ImprovementsSchema)
figures: list[dict] = Field(default_factory=list)
@field_validator("title_zh", "one_line")
@classmethod
def non_empty_text(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("field cannot be empty")
return v.strip()
@field_validator("tags")
@classmethod
def non_empty_tags(cls, v: list[str]) -> list[str]:
tags = [tag.strip() for tag in v if tag and tag.strip()]
if not tags:
raise ValueError("tags cannot be empty")
return tags
# ── 质量评估 ────────────────────────────────────────────────────────────
def assess_quality(schema: SummarySchema) -> str:
"""评估总结质量:normal / degraded / low。"""
# low:内容空洞的启发式判断
if len(schema.one_line.strip()) < 10 or len(schema.method.key_idea.strip()) < 10:
return "low"
# 检查重要字段是否缺失
missing_important = 0
if not schema.motivation.goal.strip():
missing_important += 1
if not schema.motivation.gap.strip():
missing_important += 1
if not schema.method.overview.strip():
missing_important += 1
if not schema.results.main_findings.strip():
missing_important += 1
if missing_important == 0:
return "normal"
return "degraded"
# ── DB 展平 ─────────────────────────────────────────────────────────────
def flatten_for_db(schema: SummarySchema) -> dict:
"""将 SummarySchema 展平为 paper_summaries 表的列值 dict。
所有供前端用 |safe 渲染的文本字段均经过 HTML 清洗。
"""
# 清洗 prerequisites 嵌套文本
prereqs = schema.prerequisites.model_dump()
for c in prereqs.get("concepts", []):
if isinstance(c, dict):
for key in ("explanation", "why_matters"):
if key in c and c[key]:
c[key] = sanitize_html(c[key])
return {
"one_line": sanitize_html(schema.one_line),
"difficulty": schema.difficulty,
"prerequisites_json": json.dumps(prereqs, ensure_ascii=False),
"motivation_problem": sanitize_html(schema.motivation.problem),
"motivation_goal": sanitize_html(schema.motivation.goal),
"motivation_gap": sanitize_html(schema.motivation.gap),
"method_overview": sanitize_html(schema.method.overview),
"method_key_idea": sanitize_html(schema.method.key_idea),
"method_steps_json": sanitize_html(schema.method.steps),
"method_novelty": sanitize_html(schema.method.novelty),
"results_main_json": sanitize_html(schema.results.main_findings),
"results_benchmarks_json": json.dumps(
schema.results.benchmarks, ensure_ascii=False
),
"limitations_json": sanitize_html(schema.results.limitations),
"weaknesses_json": sanitize_html(schema.improvements.weaknesses),
"future_work_json": sanitize_html(schema.improvements.future_work),
"reproducibility": sanitize_html(schema.improvements.reproducibility),
"figures_json": json.dumps(schema.figures, ensure_ascii=False),
"full_json": schema.model_dump_json(ensure_ascii=False),
"updated_at": utc_now(),
}
# ── 错误分类 ────────────────────────────────────────────────────────────
_REQUIRED_FIELDS = {"title_zh", "one_line", "tags", "problem", "key_idea"}
def classify_validation_error(exc: ValidationError) -> str:
"""区分 field_missing(必填缺失)和 schema_error(类型不合法等)。"""
for err in exc.errors():
field_name = err["loc"][-1] if err["loc"] else ""
if field_name in _REQUIRED_FIELDS and err["type"] in (
"missing",
"value_error",
):
return "field_missing"
return "schema_error"