"""AI 总结持久化 — DB 写入、文件保存、FTS 索引、图片提取、ChromaDB 索引。""" from __future__ import annotations import logging from sqlalchemy.orm import Session from app.models import ( Paper, PaperSummary, PaperTag, SummaryState, ) from app.services.derived import reindex_paper_fts from app.services.pdf_downloader import paper_dir from app.services.schemas import ( SummarySchema, assess_quality, flatten_for_db, ) from app.services.summary_generator import _classify_error from app.utils import TMP_DIR, truncate_error, utc_now logger = logging.getLogger(__name__) # ── FTS5 文本构建 ─────────────────────────────────────────────────────── def _build_fts_summary_text(schema: SummarySchema) -> str: """拼接用于 FTS5 索引的总结文本。""" parts = [ schema.one_line or "", schema.motivation.problem or "", schema.motivation.goal or "", schema.method.overview or "", schema.method.key_idea or "", schema.results.main_findings or "", ] return " ".join(p for p in parts if p) # ── DB 更新 ───────────────────────────────────────────────────────────── def _update_summary_in_db( db: Session, paper: Paper, schema: SummarySchema, quality: str, raw_output: str, ) -> None: """将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。""" # 1. paper_summaries:upsert existing = db.get(PaperSummary, paper.id) flat = flatten_for_db(schema) if existing: for k, v in flat.items(): setattr(existing, k, v) else: db.add(PaperSummary(paper_id=paper.id, **flat)) # 2. papers 表 paper.title_zh = schema.title_zh paper.summary_quality = quality p_dir = paper_dir(paper.arxiv_id) paper.summary_path = str(p_dir / "summary.json") paper.raw_output_path = str(p_dir / "raw_output.txt") # 3. AI 标签 existing_tag_names = {t.tag for t in paper.tags} for tag_name in schema.tags: if tag_name not in existing_tag_names: db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai")) existing_tag_names.add(tag_name) # 4. FTS5 派生索引 db.flush() reindex_paper_fts(db, paper) db.commit() logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality) # ── 文件操作 ──────────────────────────────────────────────────────────── def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) -> None: d = paper_dir(arxiv_id) d.mkdir(parents=True, exist_ok=True) if schema: (d / "summary.json").write_text( schema.model_dump_json(ensure_ascii=False, indent=2), encoding="utf-8", ) (d / "raw_output.txt").write_text(raw_output, encoding="utf-8") # ── 失败处理 ──────────────────────────────────────────────────────────── def _handle_summary_failure( db: Session, paper: Paper, exc: Exception, raw_output: str, ) -> dict: """记录失败:保存 raw_output、重试计数、错误分类。""" from app.config import settings error_type = _classify_error(exc) logger.error( "Summarize failed: %s error_type=%s %s", paper.arxiv_id, error_type, truncate_error(exc), ) status = paper.summary_status if raw_output: _save_files(paper.arxiv_id, None, raw_output) status.raw_output_saved = True status.retry_count = (status.retry_count or 0) + 1 status.error_type = error_type status.error = truncate_error(exc, limit=2000) if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1: status.status = SummaryState.PERMANENT_FAILURE else: status.status = SummaryState.PENDING status.completed_at = utc_now() db.commit() return { "arxiv_id": paper.arxiv_id, "status": "failed", "error_type": error_type, "error": truncate_error(exc), "retry_count": status.retry_count, } # ── 持久化 ────────────────────────────────────────────────────────────── def _persist_summary( db: Session, paper: Paper, json_data: dict, raw_output: str ) -> str: """Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。""" import time as _time arxiv_id = paper.arxiv_id _t0 = _time.monotonic() schema = SummarySchema.model_validate(json_data) quality = assess_quality(schema) _t1 = _time.monotonic() _save_files(arxiv_id, schema, raw_output) _t2 = _time.monotonic() _update_summary_in_db(db, paper, schema, quality, raw_output) _t3 = _time.monotonic() # 状态 → done paper.summary_status.status = SummaryState.DONE paper.summary_status.quality = quality paper.summary_status.completed_at = utc_now() paper.summary_status.raw_output_saved = True db.commit() _t4 = _time.monotonic() logger.info( " [%s] persist: pydantic=%.2fs 文件=%.2fs DB写入=%.2fs 状态commit=%.2fs", arxiv_id, _t1 - _t0, _t2 - _t1, _t3 - _t2, _t4 - _t3, ) # 触发性增强(失败不影响总结) _t5 = _time.monotonic() _maybe_extract_images(arxiv_id, schema) _t6 = _time.monotonic() _maybe_index_chroma(arxiv_id, paper, schema) _t7 = _time.monotonic() logger.info( " [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs", arxiv_id, _t6 - _t5, _t7 - _t6, ) return quality # ── 清理 ──────────────────────────────────────────────────────────────── def _cleanup_old_images(db: Session, paper: Paper) -> None: """清理旧的图片文件和 figures_json,避免重新总结时残留。""" arxiv_id = paper.arxiv_id images_dir = paper_dir(arxiv_id) / "images" if images_dir.exists(): for old_file in images_dir.iterdir(): if ( old_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg") or old_file.name == "manifest.json" ): old_file.unlink(missing_ok=True) # 清除数据库中的 figures_json if paper.summary and paper.summary.figures_json: paper.summary.figures_json = None db.commit() # ── 触发性增强 ────────────────────────────────────────────────────────── def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None: """从 PDF 提取图片和表格(失败不影响总结)。 两阶段流水线: 1. DocLayout-YOLO 检测 + 渲染截图(通用标签) 2. 用 summary 的 figures ID 在 PDF 中搜索定位 → 重命名 """ try: from app.services.pdf_image_extractor import ( extract_images_from_pdf, label_images_by_summary, ) pdf_path = TMP_DIR / arxiv_id / "paper.pdf" extract_images_from_pdf(arxiv_id, pdf_path) if schema.figures: label_images_by_summary(arxiv_id, schema.figures, pdf_path) except Exception: logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True) def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None: """写入 ChromaDB 语义索引(失败不影响总结)。""" try: from app.services.embedder import index_paper texts_dict = { "arxiv_id": arxiv_id, "title_zh": schema.title_zh or "", "title_en": paper.title_en or "", "tags": " ".join(t.tag for t in paper.tags) if paper.tags else "", "one_line": schema.one_line or "", "motivation_problem": schema.motivation.problem or "", "method_key_idea": schema.method.key_idea or "", "paper_date": paper.paper_date.isoformat() if paper.paper_date else "", } index_paper(arxiv_id, texts_dict) except Exception: logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)