"""AI 总结编排服务 — 协调 PDF 下载、pi CLI 调用、JSON 校验、DB 写入、语义索引。""" from __future__ import annotations import asyncio import json import logging from pathlib import Path from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session from app.config import settings from app.database import SessionLocal from app.models import ( PAPER_DEFAULT_LOAD, CrawlLog, Paper, PaperSummary, PaperTag, SummaryState, SummaryStatus, TaskLock, ) from app.services.pdf_downloader import ( PdfDownloadError, cleanup_tmp, download_pdf, paper_dir, ) from app.services.summary_utils import ( JsonNotFoundError, build_prompt, extract_json, write_meta_json, extract_pdf_text, ) from app.services.pi_client import ( PiProcessError, PiTimeoutError, call_pi, ) from app.services import claude_backend from app.services.schemas import ( SummarySchema, assess_quality, classify_validation_error, flatten_for_db, ) from app.utils import TMP_DIR, release_lock, utc_now logger = logging.getLogger(__name__) # ── 错误分类 ──────────────────────────────────────────────────────────── def _classify_error(exc: Exception) -> str: """将异常映射到 error_type 枚举值。""" if isinstance(exc, PdfDownloadError): return "pdf_download_failed" if isinstance(exc, PiTimeoutError): return "timeout" if isinstance(exc, PiProcessError): return "process_error" if isinstance(exc, JsonNotFoundError): return "json_not_found" if isinstance(exc, json.JSONDecodeError): return "json_invalid" if isinstance(exc, ValidationError): return classify_validation_error(exc) return "unknown" # ── FTS5 文本构建 ─────────────────────────────────────────────────────── def _build_fts_summary_text(schema: SummarySchema) -> str: """拼接用于 FTS5 索引的总结文本。""" parts = [ schema.one_line or "", schema.motivation.problem or "", schema.motivation.goal or "", schema.method.overview or "", schema.method.key_idea or "", schema.results.main_findings or "", ] return " ".join(p for p in parts if p) # ── DB 更新 ───────────────────────────────────────────────────────────── def _update_summary_in_db( db: Session, paper: Paper, schema: SummarySchema, quality: str, raw_output: str, ) -> None: """将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。""" from sqlalchemy import text # 1. paper_summaries:upsert existing = db.get(PaperSummary, paper.id) flat = flatten_for_db(schema) if existing: for k, v in flat.items(): setattr(existing, k, v) else: db.add(PaperSummary(paper_id=paper.id, **flat)) # 2. papers 表 paper.title_zh = schema.title_zh paper.summary_quality = quality p_dir = paper_dir(paper.arxiv_id) paper.summary_path = str(p_dir / "summary.json") paper.raw_output_path = str(p_dir / "raw_output.txt") # 3. AI 标签 existing_tag_names = {t.tag for t in paper.tags} for tag_name in schema.tags: if tag_name not in existing_tag_names: db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai")) existing_tag_names.add(tag_name) # 4. FTS5 更新 summary_text = _build_fts_summary_text(schema) db.execute( text( "UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text " "WHERE rowid=:paper_id" ), { "title_zh": schema.title_zh, "summary_text": summary_text, "paper_id": paper.id, }, ) db.commit() logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality) # ── JSON 验证 ────────────────────────────────────────────────────────── def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]: """验证 JSON 数据是否符合要求,返回错误列表(空=通过)。""" errors: list[str] = [] if not isinstance(json_data, dict): return ["顶层必须是 JSON 对象"] # 必填字段 for f in ["arxiv_id", "title_zh", "one_line", "tags"]: if f not in json_data or not json_data[f]: errors.append(f"缺少必填字段: {f}") # tags 必须是非空数组 tags = json_data.get("tags") if not isinstance(tags, list) or len(tags) == 0: errors.append("tags 必须是非空数组") # 字符串段落字段(必须是 str 且 ≥50 字) string_fields = [ ("motivation", "problem"), ("motivation", "goal"), ("motivation", "gap"), ("method", "overview"), ("method", "key_idea"), ("method", "steps"), ("method", "novelty"), ("results", "main_findings"), ("results", "limitations"), ("improvements", "weaknesses"), ("improvements", "future_work"), ("improvements", "reproducibility"), ] for section, field in string_fields: val = json_data.get(section, {}).get(field) if isinstance(val, list): errors.append(f"{section}.{field} 应该是字符串段落,不能是数组") elif not isinstance(val, str) or len(val.strip()) < 50: errors.append( f"{section}.{field} 必须是详细段落(≥50字)," f"当前: {type(val).__name__} ({len(str(val))}字)" ) # benchmarks 必须是数组 benchmarks = json_data.get("results", {}).get("benchmarks") if benchmarks is not None and not isinstance(benchmarks, list): errors.append("results.benchmarks 必须是数组") # prerequisites.concepts 必须是对象数组,每个有 term concepts = json_data.get("prerequisites", {}).get("concepts") if concepts is not None: if not isinstance(concepts, list): errors.append("prerequisites.concepts 必须是数组") elif len(concepts) == 0: errors.append("prerequisites.concepts 不能为空") else: for i, c in enumerate(concepts): if isinstance(c, str): errors.append(f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串") elif isinstance(c, dict) and not c.get("term"): errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段") # figures 必须是数组,每个元素应有 id figures = json_data.get("figures") if figures is not None: if not isinstance(figures, list): errors.append("figures 必须是数组") else: for i, fig in enumerate(figures): if isinstance(fig, dict) and not fig.get("id"): errors.append(f"figures[{i}] 缺少 id 字段") return errors # ── 文件操作 ──────────────────────────────────────────────────────────── def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) -> None: d = paper_dir(arxiv_id) d.mkdir(parents=True, exist_ok=True) if schema: (d / "summary.json").write_text( schema.model_dump_json(ensure_ascii=False, indent=2), encoding="utf-8", ) (d / "raw_output.txt").write_text(raw_output, encoding="utf-8") # ── 单篇总结 ──────────────────────────────────────────────────────────── async def summarize_one( db: Session, paper: Paper, *, force: bool = False, pdf_mode: str = "auto", ) -> dict: """总结单篇论文的完整流程。""" arxiv_id = paper.arxiv_id # 获取或创建 summary_status if not paper.summary_status: db.add(SummaryStatus(paper_id=paper.id, status=SummaryState.PENDING)) db.commit() db.refresh(paper) status = paper.summary_status # 跳过已完成的(除非 force) if status.status == SummaryState.DONE and not force: return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"} # 跳过 permanent_failure(除非 force) if status.status == SummaryState.PERMANENT_FAILURE and not force: return { "arxiv_id": arxiv_id, "status": "skipped", "reason": "permanent_failure", } return await _do_summarize_one(db, paper, pdf_mode=pdf_mode) async def _generate_with_retry( arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto" ) -> tuple[dict, str]: """调用 AI 后端生成总结,最多 4 轮验证循环。 根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。 Returns: (json_data, raw_output) Raises: ValueError: 4 轮验证仍未通过 """ import time as _time backend = settings.SUMMARY_BACKEND validation_errors: list[str] = [] json_data: dict | None = None raw_output = "" session_id = None summary_file = paper_dir(arxiv_id) / "summary.json" # claude 后端需要预构建 prompt(pi 后端在 call_pi 内部构建) claude_prompt: str | None = None if backend == "claude": _t0 = _time.monotonic() txt_path = extract_pdf_text(pdf_path, max_chars=None) body = txt_path.read_text(encoding="utf-8") if len(body) > 80_000: trimmed = body[:80_000].rstrip() txt_path.write_text(trimmed, encoding="utf-8") claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None) logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0) for attempt in range(1, 5): # 清理上一轮写入的不完整文件 if summary_file.exists(): summary_file.unlink() # 记录 AI 调用开始时间 _t_call_start = _time.monotonic() if backend == "claude": if attempt == 1: raw_output, session_id = await claude_backend.call_claude( claude_prompt, session_id=None, ) else: retry_prompt = build_prompt( arxiv_id, meta_path, extract_pdf_text(pdf_path, max_chars=80000), "inject", fix_errors=validation_errors, ) raw_output, session_id = await claude_backend.call_claude( retry_prompt, session_id=session_id, fix_errors=validation_errors, ) else: if attempt == 1: raw_output, session_id = await call_pi(meta_path, pdf_path, pdf_mode=pdf_mode) else: raw_output, session_id = await call_pi( meta_path, pdf_path, fix_errors=validation_errors, session_id=session_id, pdf_mode=pdf_mode, ) _t_call_end = _time.monotonic() # 检查 summary.json 是否由 AI 子进程写入 file_written_by_ai = summary_file.exists() file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None file_size = summary_file.stat().st_size if file_written_by_ai else 0 logger.info( " [%s] attempt %d AI调用: %.2fs summary.json=%s%s", arxiv_id, attempt, _t_call_end - _t_call_start, f"已写入({file_size}B)" if file_written_by_ai else "未写入", f" mtime={file_mtime:.2f}" if file_mtime else "", ) # 提取 JSON _t_json_start = _time.monotonic() try: if file_written_by_ai: json_data = json.loads(summary_file.read_text(encoding="utf-8")) logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id) else: json_data = extract_json(raw_output) except (json.JSONDecodeError, JsonNotFoundError) as exc: _t_json_end = _time.monotonic() logger.warning( " [%s] JSON提取失败: %.2fs %s", arxiv_id, _t_json_end - _t_json_start, str(exc)[:200], ) validation_errors = [f"无法提取有效 JSON: {str(exc)[:100]}"] continue _t_json_end = _time.monotonic() # 验证 _t_val_start = _time.monotonic() validation_errors = _validate_summary(json_data, arxiv_id) _t_val_end = _time.monotonic() if not validation_errors: logger.info( " [%s] JSON提取: %.2fs 验证: %.2fs ✅", arxiv_id, _t_json_end - _t_json_start, _t_val_end - _t_val_start, ) break logger.warning( " [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s", arxiv_id, _t_json_end - _t_json_start, _t_val_end - _t_val_start, "; ".join(validation_errors)[:200], ) if validation_errors: exc = ValueError( f"Summary validation failed after 4 attempts: {'; '.join(validation_errors)}" ) exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用 raise exc return json_data, raw_output def _persist_summary( db: Session, paper: Paper, json_data: dict, raw_output: str ) -> str: """Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。""" import time as _time arxiv_id = paper.arxiv_id _t0 = _time.monotonic() schema = SummarySchema.model_validate(json_data) quality = assess_quality(schema) _t1 = _time.monotonic() _save_files(arxiv_id, schema, raw_output) _t2 = _time.monotonic() _update_summary_in_db(db, paper, schema, quality, raw_output) _t3 = _time.monotonic() # 状态 → done paper.summary_status.status = SummaryState.DONE paper.summary_status.quality = quality paper.summary_status.completed_at = utc_now() paper.summary_status.raw_output_saved = True db.commit() _t4 = _time.monotonic() logger.info( " [%s] persist: pydantic=%.2fs 文件=%.2fs DB写入=%.2fs 状态commit=%.2fs", arxiv_id, _t1 - _t0, _t2 - _t1, _t3 - _t2, _t4 - _t3, ) # 触发性增强(失败不影响总结) _t5 = _time.monotonic() _maybe_extract_images(arxiv_id, schema) _t6 = _time.monotonic() _maybe_index_chroma(arxiv_id, paper, schema) _t7 = _time.monotonic() logger.info( " [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs", arxiv_id, _t6 - _t5, _t7 - _t6, ) return quality def _handle_summary_failure( db: Session, paper: Paper, exc: Exception, raw_output: str, ) -> dict: """记录失败:保存 raw_output、重试计数、错误分类。""" error_type = _classify_error(exc) logger.error( "Summarize failed: %s error_type=%s %s", paper.arxiv_id, error_type, str(exc)[:200], ) status = paper.summary_status if raw_output: _save_files(paper.arxiv_id, None, raw_output) status.raw_output_saved = True status.retry_count = (status.retry_count or 0) + 1 status.error_type = error_type status.error = str(exc)[:2000] if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1: status.status = SummaryState.PERMANENT_FAILURE else: status.status = SummaryState.PENDING status.completed_at = utc_now() db.commit() return { "arxiv_id": paper.arxiv_id, "status": "failed", "error_type": error_type, "error": str(exc)[:200], "retry_count": status.retry_count, } def _cleanup_old_images(db: Session, paper: Paper) -> None: """清理旧的图片文件和 figures_json,避免重新总结时残留。""" arxiv_id = paper.arxiv_id images_dir = paper_dir(arxiv_id) / "images" if images_dir.exists(): for old_file in images_dir.iterdir(): if old_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg") or old_file.name == "manifest.json": old_file.unlink(missing_ok=True) # 清除数据库中的 figures_json if paper.summary and paper.summary.figures_json: paper.summary.figures_json = None db.commit() def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None: """从 PDF 提取图片和表格(失败不影响总结)。""" try: from app.services.pdf_image_extractor import ( extract_images_from_pdf, filter_images_by_summary, ) pdf_path = TMP_DIR / arxiv_id / "paper.pdf" extract_images_from_pdf(arxiv_id, pdf_path) if schema.figures: filter_images_by_summary(arxiv_id, schema.figures) except Exception: logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True) def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None: """写入 ChromaDB 语义索引(失败不影响总结)。""" try: from app.services.embedder import index_paper texts_dict = { "arxiv_id": arxiv_id, "title_zh": schema.title_zh or "", "title_en": paper.title_en or "", "tags": " ".join(t.tag for t in paper.tags) if paper.tags else "", "one_line": schema.one_line or "", "motivation_problem": schema.motivation.problem or "", "method_key_idea": schema.method.key_idea or "", "paper_date": paper.paper_date.isoformat() if paper.paper_date else "", } index_paper(arxiv_id, texts_dict) except Exception: logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True) async def _do_summarize_one( db: Session, paper: Paper, pdf_mode: str = "auto" ) -> dict: """实际的单篇总结执行(在 semaphore 保护下)。""" arxiv_id = paper.arxiv_id title_short = (paper.title_en or "")[:50] # 状态 → processing paper.summary_status.status = SummaryState.PROCESSING paper.summary_status.started_at = utc_now() db.commit() logger.info("▶ [%s] 开始总结: %s", arxiv_id, title_short) # 清理旧的图片文件和 figures_json,避免重新总结时残留 import time as _time _t_cleanup_start = _time.monotonic() _cleanup_old_images(db, paper) _t_cleanup_end = _time.monotonic() logger.info(" [%s] 清理旧数据: %.2fs", arxiv_id, _t_cleanup_end - _t_cleanup_start) raw_output = "" try: _t0 = _time.monotonic() meta_path = write_meta_json(paper) _t1 = _time.monotonic() logger.info(" [%s] meta.json: %.2fs", arxiv_id, _t1 - _t0) await download_pdf(arxiv_id, paper.pdf_url) _t2 = _time.monotonic() logger.info(" [%s] 下载PDF: %.2fs", arxiv_id, _t2 - _t1) logger.info(" [%s] 调用 pi 生成总结...", arxiv_id) json_data, raw_output = await _generate_with_retry( arxiv_id, meta_path, TMP_DIR / arxiv_id / "paper.pdf", pdf_mode=pdf_mode, ) _t3 = _time.monotonic() logger.info(" [%s] pi生成: %.2fs", arxiv_id, _t3 - _t2) quality = _persist_summary(db, paper, json_data, raw_output) _t4 = _time.monotonic() logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3) logger.info("✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0) return {"arxiv_id": arxiv_id, "status": "done", "quality": quality} except Exception as exc: # 从异常对象获取 raw_output(_generate_with_retry 失败时仍有输出) fail_output = getattr(exc, "raw_output", raw_output) return _handle_summary_failure(db, paper, exc, fail_output) finally: cleanup_tmp(arxiv_id) # ── 单篇入口 ──────────────────────────────────────────────────────────── async def summarize_single( db: Session, arxiv_id: str, *, force: bool = True, pdf_mode: str = "auto", _session_factory=None, ) -> dict: """单篇总结入口(供 admin 路由和 CLI 调用)。 _session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。 """ paper = db.execute( select(Paper) .where(Paper.arxiv_id == arxiv_id) .options(*PAPER_DEFAULT_LOAD) ).unique().scalar_one_or_none() if not paper: return {"status": "not_found", "arxiv_id": arxiv_id} make_session = _session_factory or SessionLocal # 每篇用独立 session 避免并发问题 paper_db = make_session() try: paper_in_new_session = paper_db.execute( select(Paper) .where(Paper.arxiv_id == arxiv_id) .options(*PAPER_DEFAULT_LOAD) ).unique().scalar_one_or_none() result = await summarize_one(paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode) finally: paper_db.close() return result # ── 批量总结 ──────────────────────────────────────────────────────────── async def summarize_batch( db: Session, arxiv_ids: list[str] | None = None, *, pdf_mode: str = "auto", _session_factory=None, ) -> dict: """批量总结入口。arxiv_ids=None 时处理所有 pending 论文。 _session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。 """ now = utc_now() # TaskLock 防重入 lock = TaskLock( task="summarize", lock_key="batch", status="running", owner="summarize_batch", acquired_at=now, ) try: db.add(lock) db.commit() except Exception: db.rollback() logger.warning("Summarize batch already running (lock conflict)") return {"status": "conflict", "error": "summarize batch already running"} # CrawlLog log_entry = CrawlLog( task="summarize", status="running", started_at=now, ) db.add(log_entry) db.commit() try: # 查询待总结论文 stmt = select(Paper).options(*PAPER_DEFAULT_LOAD) if arxiv_ids: stmt = stmt.where(Paper.arxiv_id.in_(arxiv_ids)) else: # 只处理 pending 或 failed(可重试的) stmt = stmt.join(SummaryStatus).where( SummaryStatus.status.in_([SummaryState.PENDING, SummaryState.FAILED]) ) papers = db.execute(stmt).unique().scalars().all() total = len(papers) logger.info("Summarize batch: %d papers to process", total) if total == 0: log_entry.status = "success" log_entry.papers_found = 0 log_entry.papers_new = 0 log_entry.completed_at = utc_now() release_lock(db, lock) return { "status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0, } # 并发控制:worker 模式,避免 573 个协程同时打开 DB 连接耗尽连接池 concurrency = settings.SUMMARY_CONCURRENCY make_session = _session_factory or SessionLocal # 进度追踪 progress = {"done": 0, "failed": 0, "skipped": 0} paper_queue: asyncio.Queue[Paper | None] = asyncio.Queue() for p in papers: paper_queue.put_nowait(p) async def _worker() -> list[dict]: results: list[dict] = [] while True: paper = paper_queue.get_nowait() if not paper_queue.empty() else None if paper is None: break paper_db = make_session() try: p = paper_db.execute( select(Paper) .where(Paper.id == paper.id) .options(*PAPER_DEFAULT_LOAD) ).unique().scalar_one_or_none() result = await summarize_one(paper_db, p, pdf_mode=pdf_mode) status = result.get("status", "failed") progress[status] = progress.get(status, 0) + 1 finished = sum(progress.values()) logger.info( "📊 进度: %d/%d (✅%d ❌%d ⏭️%d) — %s", finished, total, progress["done"], progress["failed"], progress["skipped"], paper.arxiv_id, ) results.append(result) except Exception as exc: logger.error("Worker error: %s", exc) results.append({"status": "failed", "error": str(exc)}) finally: paper_db.close() return results worker_results = await asyncio.gather( *[_worker() for _ in range(concurrency)], return_exceptions=True, ) results = [] for r in worker_results: if isinstance(r, Exception): logger.error("Unexpected error in batch: %s", r) results.append(r) elif isinstance(r, list): results.extend(r) # 统计结果(progress 已在 worker 中实时更新) done = progress["done"] failed = progress["failed"] skipped = progress["skipped"] for r in results: if isinstance(r, Exception): logger.error("Unexpected error in batch: %s", r) failed += 1 log_entry.status = "success" if failed == 0 else "failed" log_entry.papers_found = total log_entry.papers_new = done log_entry.completed_at = utc_now() db.commit() logger.info( "Summarize batch done: total=%d done=%d failed=%d skipped=%d", total, done, failed, skipped, ) return { "status": "success" if failed == 0 else "partial", "total": total, "done": done, "failed": failed, "skipped": skipped, } except Exception as exc: logger.exception("Summarize batch failed") log_entry.status = "failed" log_entry.error = str(exc)[:2000] log_entry.completed_at = utc_now() db.commit() return {"status": "failed", "error": str(exc)} finally: release_lock(db, lock)