"""AI 总结编排服务 — 协调 PDF 下载、pi CLI 调用、JSON 校验、DB 写入、语义索引。""" from __future__ import annotations import asyncio import json import logging from pathlib import Path from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session from app.config import settings from app.database import SessionLocal from app.models import ( PAPER_DEFAULT_LOAD, CrawlLog, Paper, PaperSummary, PaperTag, SummaryState, SummaryStatus, TaskLock, ) from app.services.pdf_downloader import ( PdfDownloadError, cleanup_tmp, download_pdf, paper_dir, ) from app.services.pi_client import ( JsonNotFoundError, PiProcessError, PiTimeoutError, call_pi, extract_json, write_meta_json, ) from app.services.schemas import ( SummarySchema, assess_quality, classify_validation_error, flatten_for_db, ) from app.utils import TMP_DIR, release_lock, utc_now logger = logging.getLogger(__name__) # ── 错误分类 ──────────────────────────────────────────────────────────── def _classify_error(exc: Exception) -> str: """将异常映射到 error_type 枚举值。""" if isinstance(exc, PdfDownloadError): return "pdf_download_failed" if isinstance(exc, PiTimeoutError): return "timeout" if isinstance(exc, PiProcessError): return "process_error" if isinstance(exc, JsonNotFoundError): return "json_not_found" if isinstance(exc, json.JSONDecodeError): return "json_invalid" if isinstance(exc, ValidationError): return classify_validation_error(exc) return "unknown" # ── FTS5 文本构建 ─────────────────────────────────────────────────────── def _build_fts_summary_text(schema: SummarySchema) -> str: """拼接用于 FTS5 索引的总结文本。""" parts = [ schema.one_line or "", schema.motivation.problem or "", schema.motivation.goal or "", schema.method.overview or "", schema.method.key_idea or "", schema.results.main_findings or "", ] return " ".join(p for p in parts if p) # ── DB 更新 ───────────────────────────────────────────────────────────── def _update_summary_in_db( db: Session, paper: Paper, schema: SummarySchema, quality: str, raw_output: str, ) -> None: """将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。""" from sqlalchemy import text # 1. paper_summaries:upsert existing = db.get(PaperSummary, paper.id) flat = flatten_for_db(schema) if existing: for k, v in flat.items(): setattr(existing, k, v) else: db.add(PaperSummary(paper_id=paper.id, **flat)) # 2. papers 表 paper.title_zh = schema.title_zh paper.summary_quality = quality p_dir = paper_dir(paper.arxiv_id) paper.summary_path = str(p_dir / "summary.json") paper.raw_output_path = str(p_dir / "raw_output.txt") # 3. AI 标签 existing_tag_names = {t.tag for t in paper.tags} for tag_name in schema.tags: if tag_name not in existing_tag_names: db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai")) existing_tag_names.add(tag_name) # 4. FTS5 更新 summary_text = _build_fts_summary_text(schema) db.execute( text( "UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text " "WHERE rowid=:paper_id" ), { "title_zh": schema.title_zh, "summary_text": summary_text, "paper_id": paper.id, }, ) db.commit() logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality) # ── JSON 验证 ────────────────────────────────────────────────────────── def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]: """验证 JSON 数据是否符合要求,返回错误列表(空=通过)。""" errors: list[str] = [] if not isinstance(json_data, dict): return ["顶层必须是 JSON 对象"] # 必填字段 for f in ["arxiv_id", "title_zh", "one_line", "tags"]: if f not in json_data or not json_data[f]: errors.append(f"缺少必填字段: {f}") # tags 必须是非空数组 tags = json_data.get("tags") if not isinstance(tags, list) or len(tags) == 0: errors.append("tags 必须是非空数组") # 字符串段落字段(必须是 str 且 ≥50 字) string_fields = [ ("motivation", "problem"), ("motivation", "goal"), ("motivation", "gap"), ("method", "overview"), ("method", "key_idea"), ("method", "steps"), ("method", "novelty"), ("results", "main_findings"), ("results", "limitations"), ("improvements", "weaknesses"), ("improvements", "future_work"), ("improvements", "reproducibility"), ] for section, field in string_fields: val = json_data.get(section, {}).get(field) if isinstance(val, list): errors.append(f"{section}.{field} 应该是字符串段落,不能是数组") elif not isinstance(val, str) or len(val.strip()) < 50: errors.append( f"{section}.{field} 必须是详细段落(≥50字)," f"当前: {type(val).__name__} ({len(str(val))}字)" ) # benchmarks 必须是数组 benchmarks = json_data.get("results", {}).get("benchmarks") if benchmarks is not None and not isinstance(benchmarks, list): errors.append("results.benchmarks 必须是数组") # prerequisites.concepts 必须是对象数组,每个有 term concepts = json_data.get("prerequisites", {}).get("concepts") if concepts is not None: if not isinstance(concepts, list): errors.append("prerequisites.concepts 必须是数组") elif len(concepts) == 0: errors.append("prerequisites.concepts 不能为空") else: for i, c in enumerate(concepts): if isinstance(c, str): errors.append(f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串") elif isinstance(c, dict) and not c.get("term"): errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段") # figures 必须是数组,每个元素应有 id figures = json_data.get("figures") if figures is not None: if not isinstance(figures, list): errors.append("figures 必须是数组") else: for i, fig in enumerate(figures): if isinstance(fig, dict) and not fig.get("id"): errors.append(f"figures[{i}] 缺少 id 字段") return errors # ── 文件操作 ──────────────────────────────────────────────────────────── def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) -> None: d = paper_dir(arxiv_id) d.mkdir(parents=True, exist_ok=True) if schema: (d / "summary.json").write_text( schema.model_dump_json(ensure_ascii=False, indent=2), encoding="utf-8", ) (d / "raw_output.txt").write_text(raw_output, encoding="utf-8") # ── 单篇总结 ──────────────────────────────────────────────────────────── async def summarize_one( db: Session, paper: Paper, semaphore: asyncio.Semaphore | None = None, *, force: bool = False, pdf_mode: str = "auto", ) -> dict: """总结单篇论文的完整流程。""" arxiv_id = paper.arxiv_id # 获取或创建 summary_status if not paper.summary_status: db.add(SummaryStatus(paper_id=paper.id, status=SummaryState.PENDING)) db.commit() db.refresh(paper) status = paper.summary_status # 跳过已完成的(除非 force) if status.status == SummaryState.DONE and not force: return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"} # 跳过 permanent_failure(除非 force) if status.status == SummaryState.PERMANENT_FAILURE and not force: return { "arxiv_id": arxiv_id, "status": "skipped", "reason": "permanent_failure", } if semaphore: await semaphore.acquire() try: return await _do_summarize_one(db, paper, pdf_mode=pdf_mode) finally: if semaphore: semaphore.release() async def _generate_with_retry( arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto" ) -> tuple[dict, str]: """调用 pi CLI 生成总结,最多 4 轮验证循环。 Returns: (json_data, raw_output) Raises: ValueError: 4 轮验证仍未通过 """ validation_errors: list[str] = [] json_data: dict | None = None raw_output = "" session_id = None for attempt in range(1, 5): # 清理上一轮 pi 写的不完整文件 stale = paper_dir(arxiv_id) / "summary.json" if stale.exists(): stale.unlink() if attempt == 1: raw_output, session_id = await call_pi(meta_path, pdf_path, pdf_mode=pdf_mode) else: raw_output, session_id = await call_pi( meta_path, pdf_path, fix_errors=validation_errors, session_id=session_id, pdf_mode=pdf_mode, ) # 优先读取 pi 写入的 summary.json,否则从 stdout 提取 summary_file = paper_dir(arxiv_id) / "summary.json" try: if summary_file.exists(): json_data = json.loads(summary_file.read_text(encoding="utf-8")) logger.info("Read summary.json written by pi for %s", arxiv_id) else: json_data = extract_json(raw_output) except (json.JSONDecodeError, JsonNotFoundError) as exc: logger.warning( "JSON extraction failed for %s (attempt %d): %s", arxiv_id, attempt, str(exc)[:200], ) validation_errors = [f"无法提取有效 JSON: {str(exc)[:100]}"] continue validation_errors = _validate_summary(json_data, arxiv_id) if not validation_errors: break logger.warning( "Validation failed for %s (attempt %d): %s", arxiv_id, attempt, "; ".join(validation_errors), ) if validation_errors: exc = ValueError( f"Summary validation failed after 4 attempts: {'; '.join(validation_errors)}" ) exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用 raise exc return json_data, raw_output def _persist_summary( db: Session, paper: Paper, json_data: dict, raw_output: str ) -> str: """Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。""" schema = SummarySchema.model_validate(json_data) quality = assess_quality(schema) _save_files(paper.arxiv_id, schema, raw_output) _update_summary_in_db(db, paper, schema, quality, raw_output) # 状态 → done paper.summary_status.status = SummaryState.DONE paper.summary_status.quality = quality paper.summary_status.completed_at = utc_now() paper.summary_status.raw_output_saved = True db.commit() # 触发性增强(失败不影响总结) _maybe_extract_images(paper.arxiv_id, schema) _maybe_index_chroma(paper.arxiv_id, paper, schema) return quality def _handle_summary_failure( db: Session, paper: Paper, exc: Exception, raw_output: str, ) -> dict: """记录失败:保存 raw_output、重试计数、错误分类。""" error_type = _classify_error(exc) logger.error( "Summarize failed: %s error_type=%s %s", paper.arxiv_id, error_type, str(exc)[:200], ) status = paper.summary_status if raw_output: _save_files(paper.arxiv_id, None, raw_output) status.raw_output_saved = True status.retry_count = (status.retry_count or 0) + 1 status.error_type = error_type status.error = str(exc)[:2000] if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1: status.status = SummaryState.PERMANENT_FAILURE else: status.status = SummaryState.PENDING status.completed_at = utc_now() db.commit() return { "arxiv_id": paper.arxiv_id, "status": "failed", "error_type": error_type, "error": str(exc)[:200], "retry_count": status.retry_count, } def _cleanup_old_images(db: Session, paper: Paper) -> None: """清理旧的图片文件和 figures_json,避免重新总结时残留。""" arxiv_id = paper.arxiv_id images_dir = paper_dir(arxiv_id) / "images" if images_dir.exists(): for old_file in images_dir.iterdir(): if old_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg") or old_file.name == "manifest.json": old_file.unlink(missing_ok=True) # 清除数据库中的 figures_json if paper.summary and paper.summary.figures_json: paper.summary.figures_json = None db.commit() def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None: """从 PDF 提取图片和表格(失败不影响总结)。""" try: from app.services.pdf_image_extractor import ( extract_images_from_pdf, filter_images_by_summary, ) pdf_path = TMP_DIR / arxiv_id / "paper.pdf" extract_images_from_pdf(arxiv_id, pdf_path) if schema.figures: filter_images_by_summary(arxiv_id, schema.figures) except Exception: logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True) def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None: """写入 ChromaDB 语义索引(失败不影响总结)。""" try: from app.services.embedder import index_paper texts_dict = { "arxiv_id": arxiv_id, "title_zh": schema.title_zh or "", "title_en": paper.title_en or "", "tags": " ".join(t.tag for t in paper.tags) if paper.tags else "", "one_line": schema.one_line or "", "motivation_problem": schema.motivation.problem or "", "method_key_idea": schema.method.key_idea or "", "paper_date": paper.paper_date.isoformat() if paper.paper_date else "", } index_paper(arxiv_id, texts_dict) except Exception: logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True) async def _do_summarize_one( db: Session, paper: Paper, pdf_mode: str = "auto" ) -> dict: """实际的单篇总结执行(在 semaphore 保护下)。""" arxiv_id = paper.arxiv_id # 状态 → processing paper.summary_status.status = SummaryState.PROCESSING paper.summary_status.started_at = utc_now() db.commit() # 清理旧的图片文件和 figures_json,避免重新总结时残留 _cleanup_old_images(db, paper) raw_output = "" try: meta_path = write_meta_json(paper) await download_pdf(arxiv_id, paper.pdf_url) json_data, raw_output = await _generate_with_retry( arxiv_id, meta_path, TMP_DIR / arxiv_id / "paper.pdf", pdf_mode=pdf_mode, ) quality = _persist_summary(db, paper, json_data, raw_output) logger.info("Summarize done: %s quality=%s", arxiv_id, quality) return {"arxiv_id": arxiv_id, "status": "done", "quality": quality} except Exception as exc: # 从异常对象获取 raw_output(_generate_with_retry 失败时仍有输出) fail_output = getattr(exc, "raw_output", raw_output) return _handle_summary_failure(db, paper, exc, fail_output) finally: cleanup_tmp(arxiv_id) # ── 单篇入口 ──────────────────────────────────────────────────────────── async def summarize_single( db: Session, arxiv_id: str, *, force: bool = True, pdf_mode: str = "auto", _session_factory=None, ) -> dict: """单篇总结入口(供 admin 路由和 CLI 调用)。 _session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。 """ paper = db.execute( select(Paper) .where(Paper.arxiv_id == arxiv_id) .options(*PAPER_DEFAULT_LOAD) ).unique().scalar_one_or_none() if not paper: return {"status": "not_found", "arxiv_id": arxiv_id} make_session = _session_factory or SessionLocal # 每篇用独立 session 避免并发问题 paper_db = make_session() try: paper_in_new_session = paper_db.execute( select(Paper) .where(Paper.arxiv_id == arxiv_id) .options(*PAPER_DEFAULT_LOAD) ).unique().scalar_one_or_none() result = await summarize_one(paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode) finally: paper_db.close() return result # ── 批量总结 ──────────────────────────────────────────────────────────── async def summarize_batch( db: Session, arxiv_ids: list[str] | None = None, *, pdf_mode: str = "auto", _session_factory=None, ) -> dict: """批量总结入口。arxiv_ids=None 时处理所有 pending 论文。 _session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。 """ now = utc_now() # TaskLock 防重入 lock = TaskLock( task="summarize", lock_key="batch", status="running", owner="summarize_batch", acquired_at=now, ) try: db.add(lock) db.commit() except Exception: db.rollback() logger.warning("Summarize batch already running (lock conflict)") return {"status": "conflict", "error": "summarize batch already running"} # CrawlLog log_entry = CrawlLog( task="summarize", status="running", started_at=now, ) db.add(log_entry) db.commit() try: # 查询待总结论文 stmt = select(Paper).options(*PAPER_DEFAULT_LOAD) if arxiv_ids: stmt = stmt.where(Paper.arxiv_id.in_(arxiv_ids)) else: # 只处理 pending 或 failed(可重试的) stmt = stmt.join(SummaryStatus).where( SummaryStatus.status.in_([SummaryState.PENDING, SummaryState.FAILED]) ) papers = db.execute(stmt).unique().scalars().all() total = len(papers) logger.info("Summarize batch: %d papers to process", total) if total == 0: log_entry.status = "success" log_entry.papers_found = 0 log_entry.papers_new = 0 log_entry.completed_at = utc_now() release_lock(db, lock) return { "status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0, } # 并发控制 semaphore = asyncio.Semaphore(settings.SUMMARY_CONCURRENCY) make_session = _session_factory or SessionLocal async def _process_paper(paper: Paper) -> dict: paper_db = make_session() try: p = paper_db.execute( select(Paper) .where(Paper.id == paper.id) .options(*PAPER_DEFAULT_LOAD) ).unique().scalar_one_or_none() return await summarize_one(paper_db, p, semaphore, pdf_mode=pdf_mode) finally: paper_db.close() results = await asyncio.gather( *[_process_paper(p) for p in papers], return_exceptions=True, ) # 统计结果 done = 0 failed = 0 skipped = 0 for r in results: if isinstance(r, Exception): logger.error("Unexpected error in batch: %s", r) failed += 1 elif isinstance(r, dict): if r.get("status") == "done": done += 1 elif r.get("status") == "skipped": skipped += 1 else: failed += 1 log_entry.status = "success" if failed == 0 else "failed" log_entry.papers_found = total log_entry.papers_new = done log_entry.completed_at = utc_now() db.commit() logger.info( "Summarize batch done: total=%d done=%d failed=%d skipped=%d", total, done, failed, skipped, ) return { "status": "success" if failed == 0 else "partial", "total": total, "done": done, "failed": failed, "skipped": skipped, } except Exception as exc: logger.exception("Summarize batch failed") log_entry.status = "failed" log_entry.error = str(exc)[:2000] log_entry.completed_at = utc_now() db.commit() return {"status": "failed", "error": str(exc)} finally: release_lock(db, lock)