"""AI 总结编排服务 — 协调生成器、持久化、批量处理的顶层入口。""" from __future__ import annotations import asyncio import logging from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from app.config import settings from app.database import SessionLocal from app.exceptions import ConflictError, NotFoundError from app.models import ( PAPER_DEFAULT_LOAD, CrawlLog, Paper, SummaryState, SummaryStatus, TaskLock, get_paper_by_arxiv_id, get_paper_by_id, ) from app.services.pdf_downloader import download_pdf from app.services.summary_utils import write_meta_json from app.services.summary_generator import ( _generate_with_retry, ) from app.services.summary_persister import ( _cleanup_old_images, _handle_summary_failure, _persist_summary, _run_post_processing, ) from app.utils import TMP_DIR, release_lock, truncate_error, utc_now logger = logging.getLogger(__name__) # ── 单篇总结 ──────────────────────────────────────────────────────────── async def summarize_one( db: Session, paper: Paper, *, force: bool = False, pdf_mode: str = "auto", ) -> dict: """总结单篇论文的完整流程。""" arxiv_id = paper.arxiv_id # 获取或创建 summary_status if not paper.summary_status: db.add(SummaryStatus(paper_id=paper.id, status=SummaryState.PENDING)) db.commit() db.refresh(paper) status = paper.summary_status # 跳过已完成的(除非 force) if status.status == SummaryState.DONE and not force: return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"} # 跳过 permanent_failure(除非 force) if status.status == SummaryState.PERMANENT_FAILURE and not force: return { "arxiv_id": arxiv_id, "status": "skipped", "reason": "permanent_failure", } return await _do_summarize_one(db, paper, pdf_mode=pdf_mode) async def _do_summarize_one(db: Session, paper: Paper, pdf_mode: str = "auto") -> dict: """实际的单篇总结执行(在 semaphore 保护下)。""" arxiv_id = paper.arxiv_id title_short = (paper.title_en or "")[:50] # 状态 → processing paper.summary_status.status = SummaryState.PROCESSING paper.summary_status.started_at = utc_now() db.commit() logger.info("▶ [%s] 开始总结: %s", arxiv_id, title_short) # 清理旧的图片文件和 figures_json,避免重新总结时残留 import time as _time _t_cleanup_start = _time.monotonic() _cleanup_old_images(db, paper) _t_cleanup_end = _time.monotonic() logger.info(" [%s] 清理旧数据: %.2fs", arxiv_id, _t_cleanup_end - _t_cleanup_start) raw_output = "" try: _t0 = _time.monotonic() meta_path = write_meta_json(paper) _t1 = _time.monotonic() logger.info(" [%s] meta.json: %.2fs", arxiv_id, _t1 - _t0) await download_pdf(arxiv_id, paper.pdf_url) _t2 = _time.monotonic() logger.info(" [%s] 下载PDF: %.2fs", arxiv_id, _t2 - _t1) logger.info(" [%s] 调用 pi 生成总结...", arxiv_id) json_data, raw_output = await _generate_with_retry( arxiv_id, meta_path, TMP_DIR / arxiv_id / "paper.pdf", pdf_mode=pdf_mode, ) _t3 = _time.monotonic() logger.info(" [%s] pi生成: %.2fs", arxiv_id, _t3 - _t2) quality, schema = _persist_summary(db, paper, json_data, raw_output) _t4 = _time.monotonic() logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3) # 后处理(图片提取 + ChromaDB 索引)搬到线程池跑,避免 CPU 密集推理冻结 # 事件循环。paper 字段在此(事件循环线程)提取成纯值再传入,规避 worker # 线程跨线程访问 ORM 的 DetachedInstanceError。DocLayout 推理由单例的 # threading.Lock 串行化,并发 worker 不会同时压模型。 paper_meta = { "title_en": paper.title_en or "", "tags": " ".join(t.tag for t in paper.tags) if paper.tags else "", "paper_date": paper.paper_date.isoformat() if paper.paper_date else "", } _t5 = _time.monotonic() try: await asyncio.to_thread(_run_post_processing, arxiv_id, schema, paper_meta) except Exception: # 双保险:_run_post_processing 内部已 try/except,此处兜底, # 确保后处理失败绝不影响已 DONE 的总结。 logger.warning("Post-processing error for %s", arxiv_id, exc_info=True) _t6 = _time.monotonic() logger.info(" [%s] 后处理(线程池): %.2fs", arxiv_id, _t6 - _t5) logger.info( "✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t6 - _t0 ) return {"arxiv_id": arxiv_id, "status": "done", "quality": quality} except Exception as exc: # 从异常对象获取 raw_output(_generate_with_retry 失败时仍有输出) fail_output = getattr(exc, "raw_output", raw_output) return _handle_summary_failure(db, paper, exc, fail_output) finally: pass # cleanup_tmp(arxiv_id) # 暂时禁用,保留 PDF 用于调试图片提取 # ── 单篇入口 ──────────────────────────────────────────────────────────── async def summarize_single( db: Session, arxiv_id: str, *, force: bool = True, pdf_mode: str = "auto", _session_factory=None, ) -> dict: """单篇总结入口(供 admin 路由和 CLI 调用)。 _session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。 """ paper = get_paper_by_arxiv_id(db, arxiv_id) if not paper: raise NotFoundError(f"Paper not found: {arxiv_id}") make_session = _session_factory or SessionLocal # 每篇用独立 session 避免并发问题 paper_db = make_session() try: paper_in_new_session = get_paper_by_arxiv_id(paper_db, arxiv_id) result = await summarize_one( paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode ) finally: paper_db.close() return result # ── 批量总结 ──────────────────────────────────────────────────────────── async def summarize_batch( db: Session, arxiv_ids: list[str] | None = None, *, pdf_mode: str = "auto", _session_factory=None, ) -> dict: """批量总结入口。arxiv_ids=None 时处理所有 pending 论文。 _session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。 """ now = utc_now() # TaskLock 防重入 lock = TaskLock( task="summarize", lock_key="batch", status="running", owner="summarize_batch", acquired_at=now, ) try: db.add(lock) db.commit() except IntegrityError: db.rollback() logger.warning("Summarize batch already running (lock conflict)") raise ConflictError("summarize batch already running") # CrawlLog log_entry = CrawlLog( task="summarize", status="running", started_at=now, ) db.add(log_entry) db.commit() try: # 查询待总结论文 stmt = select(Paper).options(*PAPER_DEFAULT_LOAD) if arxiv_ids: stmt = stmt.where(Paper.arxiv_id.in_(arxiv_ids)) else: # 只处理 pending 或 failed(可重试的) stmt = stmt.join(SummaryStatus).where( SummaryStatus.status.in_([SummaryState.PENDING, SummaryState.FAILED]) ) papers = db.execute(stmt).unique().scalars().all() total = len(papers) logger.info("Summarize batch: %d papers to process", total) if total == 0: log_entry.status = "success" log_entry.papers_found = 0 log_entry.papers_new = 0 log_entry.completed_at = utc_now() release_lock(db, lock) return { "status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0, } # 并发控制:worker 模式,避免 573 个协程同时打开 DB 连接耗尽连接池 concurrency = settings.SUMMARY_CONCURRENCY make_session = _session_factory or SessionLocal # 进度追踪 progress = {"done": 0, "failed": 0, "skipped": 0} paper_queue: asyncio.Queue[Paper | None] = asyncio.Queue() for p in papers: paper_queue.put_nowait(p) async def _worker() -> list[dict]: results: list[dict] = [] while True: paper = paper_queue.get_nowait() if not paper_queue.empty() else None if paper is None: break paper_db = make_session() try: p = get_paper_by_id(paper_db, paper.id) result = await summarize_one(paper_db, p, pdf_mode=pdf_mode) status = result.get("status", "failed") progress[status] = progress.get(status, 0) + 1 finished = sum(progress.values()) logger.info( "📊 进度: %d/%d (✅%d ❌%d ⏭️%d) — %s", finished, total, progress["done"], progress["failed"], progress["skipped"], paper.arxiv_id, ) results.append(result) except Exception as exc: logger.error("Worker error: %s", exc) results.append({"status": "failed", "error": str(exc)}) finally: paper_db.close() return results worker_results = await asyncio.gather( *[_worker() for _ in range(concurrency)], return_exceptions=True, ) results = [] for r in worker_results: if isinstance(r, Exception): logger.error("Unexpected error in batch: %s", r) results.append(r) elif isinstance(r, list): results.extend(r) # 统计结果(progress 已在 worker 中实时更新) done = progress["done"] failed = progress["failed"] skipped = progress["skipped"] for r in results: if isinstance(r, Exception): logger.error("Unexpected error in batch: %s", r) failed += 1 log_entry.status = "success" if failed == 0 else "failed" log_entry.papers_found = total log_entry.papers_new = done log_entry.completed_at = utc_now() db.commit() logger.info( "Summarize batch done: total=%d done=%d failed=%d skipped=%d", total, done, failed, skipped, ) return { "status": "success" if failed == 0 else "partial", "total": total, "done": done, "failed": failed, "skipped": skipped, } except Exception as exc: logger.exception("Summarize batch failed") log_entry.status = "failed" log_entry.error = truncate_error(exc, limit=2000) log_entry.completed_at = utc_now() db.commit() return {"status": "failed", "error": truncate_error(exc)} finally: release_lock(db, lock)