"""AI 总结服务 — 调用 pi CLI 生成论文中文结构化总结。""" from __future__ import annotations import asyncio import json import logging import re import shutil from datetime import datetime, timezone from pathlib import Path import httpx from pydantic import ValidationError from sqlalchemy import select, text from sqlalchemy.orm import Session, joinedload from app.config import settings from app.database import SessionLocal from app.models import ( CrawlLog, Paper, PaperSummary, PaperTag, SummaryStatus, TaskLock, ) from app.services.schemas import ( SummarySchema, assess_quality, classify_validation_error, flatten_for_db, ) logger = logging.getLogger(__name__) # ── 自定义异常 ────────────────────────────────────────────────────────── class PdfDownloadError(Exception): pass class PiTimeoutError(Exception): pass class PiProcessError(Exception): def __init__(self, returncode: int, stderr: str): self.returncode = returncode self.stderr = stderr super().__init__(f"pi exited with code {returncode}: {stderr[:500]}") class JsonNotFoundError(Exception): pass # ── 路径工具 ──────────────────────────────────────────────────────────── _DATA_DIR = Path("data") _PAPERS_DIR = _DATA_DIR / "papers" _TMP_DIR = _DATA_DIR / "tmp" def _paper_dir(arxiv_id: str) -> Path: return _PAPERS_DIR / arxiv_id def _tmp_dir(arxiv_id: str) -> Path: return _TMP_DIR / arxiv_id # ── PDF 下载 ──────────────────────────────────────────────────────────── async def _download_pdf(arxiv_id: str, pdf_url: str) -> Path: """下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。""" if not pdf_url: raise PdfDownloadError(f"no pdf_url for {arxiv_id}") tmp = _tmp_dir(arxiv_id) tmp.mkdir(parents=True, exist_ok=True) dest = tmp / "paper.pdf" transport = None if settings.http_proxy: transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy) try: async with httpx.AsyncClient( timeout=settings.HTTP_TIMEOUT_SECONDS, headers={"User-Agent": settings.HTTP_USER_AGENT}, transport=transport, follow_redirects=True, ) as client: resp = await client.get(pdf_url) resp.raise_for_status() dest.write_bytes(resp.content) except Exception as exc: raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size) return dest # ── meta.json ─────────────────────────────────────────────────────────── def _write_meta_json(paper: Paper) -> Path: """写入 data/papers/{arxiv_id}/meta.json,返回路径。""" d = _paper_dir(paper.arxiv_id) d.mkdir(parents=True, exist_ok=True) meta_path = d / "meta.json" authors = [a.name for a in paper.authors] tags = [t.tag for t in paper.tags] meta = { "arxiv_id": paper.arxiv_id, "title_en": paper.title_en, "abstract": paper.abstract or "", "published_at": paper.published_at.isoformat() if paper.published_at else None, "authors": authors, "tags": tags, "upvotes": paper.upvotes, } meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8") return meta_path # ── pi CLI 调用 ──────────────────────────────────────────────────────── async def _call_pi(meta_path: Path, pdf_path: Path) -> str: """调用 pi CLI 非交互模式,返回 stdout 文本。""" cmd = [ settings.PI_BIN, "-p", "--no-tools", "--skill", settings.SUMMARY_SKILL, "请深度解读以下论文,并按指定 JSON schema 输出:", f"@{meta_path}", f"@{pdf_path}", ] logger.info("Calling pi: %s %s", paper_id_from_path(meta_path), " ".join(cmd[:4])) proc = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) try: stdout, stderr = await asyncio.wait_for( proc.communicate(), timeout=settings.SUMMARY_TIMEOUT_SECONDS, ) except asyncio.TimeoutError: proc.kill() await proc.wait() raise PiTimeoutError( f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s" ) if proc.returncode != 0: raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace")) return stdout.decode("utf-8", errors="replace") def paper_id_from_path(meta_path: Path) -> str: """从 meta.json 路径反推 arxiv_id。""" return meta_path.parent.name # ── JSON 提取 ────────────────────────────────────────────────────────── def _extract_json(raw_output: str) -> dict: """从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。""" # 策略 1:整体直接解析 stripped = raw_output.strip() try: result = json.loads(stripped) if isinstance(result, dict) and "title_zh" in result: return result except json.JSONDecodeError: pass # 策略 2:提取 ```json ... ``` 代码块 fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL) for match in fence_pattern.finditer(raw_output): try: result = json.loads(match.group(1).strip()) if isinstance(result, dict) and "title_zh" in result: return result except json.JSONDecodeError: continue # 策略 3:匹配包含 title_zh 的最大 {...} 块 brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL) # 先尝试一层嵌套;如果没命中再用更宽松的策略 for match in brace_pattern.finditer(raw_output): try: return json.loads(match.group(0)) except json.JSONDecodeError: continue # 更宽松:找到最大的 { ... } 平衡块 best = None best_len = 0 for i, ch in enumerate(raw_output): if ch != "{": continue depth = 0 for j in range(i, len(raw_output)): if raw_output[j] == "{": depth += 1 elif raw_output[j] == "}": depth -= 1 if depth == 0: candidate = raw_output[i : j + 1] if len(candidate) > best_len: try: parsed = json.loads(candidate) if isinstance(parsed, dict): best = parsed best_len = len(candidate) except json.JSONDecodeError: pass break if best is not None: return best raise JsonNotFoundError("no JSON object found in pi output") # ── 错误分类 ──────────────────────────────────────────────────────────── def _classify_error(exc: Exception) -> str: """将异常映射到 error_type 枚举值。""" if isinstance(exc, PdfDownloadError): return "pdf_download_failed" if isinstance(exc, PiTimeoutError): return "timeout" if isinstance(exc, PiProcessError): return "process_error" if isinstance(exc, JsonNotFoundError): return "json_not_found" if isinstance(exc, json.JSONDecodeError): return "json_invalid" if isinstance(exc, ValidationError): return classify_validation_error(exc) return "unknown" # ── FTS5 文本构建 ─────────────────────────────────────────────────────── def _build_fts_summary_text(schema: SummarySchema) -> str: """拼接用于 FTS5 索引的总结文本。""" parts = [ schema.one_line or "", schema.motivation.problem or "", schema.motivation.goal or "", schema.method_overview if hasattr(schema, "method_overview") else "", schema.method.overview or "", schema.method.key_idea or "", " ".join(schema.results.main_findings or []), ] return " ".join(p for p in parts if p) # ── DB 更新 ───────────────────────────────────────────────────────────── def _update_summary_in_db( db: Session, paper: Paper, schema: SummarySchema, quality: str, raw_output: str, ) -> None: """将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。""" now = datetime.now(timezone.utc) # 1. paper_summaries:upsert existing = db.get(PaperSummary, paper.id) flat = flatten_for_db(schema) if existing: for k, v in flat.items(): setattr(existing, k, v) else: db.add(PaperSummary(paper_id=paper.id, **flat)) # 2. papers 表 paper.title_zh = schema.title_zh paper.summary_quality = quality paper_dir = _paper_dir(paper.arxiv_id) paper.summary_path = str(paper_dir / "summary.json") paper.raw_output_path = str(paper_dir / "raw_output.txt") # 3. AI 标签 existing_tag_names = {t.tag for t in paper.tags} for tag_name in schema.tags: if tag_name not in existing_tag_names: db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai")) existing_tag_names.add(tag_name) # 4. FTS5 更新 summary_text = _build_fts_summary_text(schema) db.execute( text( "UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text " "WHERE rowid=:paper_id" ), { "title_zh": schema.title_zh, "summary_text": summary_text, "paper_id": paper.id, }, ) db.commit() logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality) # ── 文件操作 ──────────────────────────────────────────────────────────── def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None: """保存 summary.json 和 raw_output.txt。""" d = _paper_dir(arxiv_id) d.mkdir(parents=True, exist_ok=True) (d / "summary.json").write_text( schema.model_dump_json(ensure_ascii=False, indent=2), encoding="utf-8", ) (d / "raw_output.txt").write_text(raw_output, encoding="utf-8") def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None: """仅保存 raw_output.txt(失败时)。""" d = _paper_dir(arxiv_id) d.mkdir(parents=True, exist_ok=True) (d / "raw_output.txt").write_text(raw_output, encoding="utf-8") def _cleanup_tmp(arxiv_id: str) -> None: """清理 data/tmp/{arxiv_id}/ 目录。""" tmp = _tmp_dir(arxiv_id) if tmp.exists(): try: shutil.rmtree(tmp) logger.debug("Cleaned tmp: %s", arxiv_id) except Exception: logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True) # ── 单篇总结 ──────────────────────────────────────────────────────────── async def summarize_one( db: Session, paper: Paper, semaphore: asyncio.Semaphore | None = None, *, force: bool = False, ) -> dict: """总结单篇论文的完整流程。""" arxiv_id = paper.arxiv_id # 获取或创建 summary_status if not paper.summary_status: db.add(SummaryStatus(paper_id=paper.id, status="pending")) db.commit() db.refresh(paper) status = paper.summary_status # 跳过已完成的(除非 force) if status.status == "done" and not force: return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"} # 跳过 permanent_failure(除非 force) if status.status == "permanent_failure" and not force: return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "permanent_failure"} if semaphore: await semaphore.acquire() try: return await _do_summarize_one(db, paper) finally: if semaphore: semaphore.release() async def _do_summarize_one(db: Session, paper: Paper) -> dict: """实际的单篇总结执行(在 semaphore 保护下)。""" arxiv_id = paper.arxiv_id status = paper.summary_status now = datetime.now(timezone.utc) # 状态 → processing status.status = "processing" status.started_at = now db.commit() raw_output = "" try: # 写 meta.json meta_path = _write_meta_json(paper) # 下载 PDF await _download_pdf(arxiv_id, paper.pdf_url) # 调用 pi raw_output = await _call_pi(meta_path, _tmp_dir(arxiv_id) / "paper.pdf") # 提取 JSON json_data = _extract_json(raw_output) # Pydantic 校验 schema = SummarySchema.model_validate(json_data) # 质量评估 quality = assess_quality(schema) # 保存文件 _save_files(arxiv_id, schema, raw_output) # 更新 DB _update_summary_in_db(db, paper, schema, quality, raw_output) # 状态 → done status.status = "done" status.quality = quality status.completed_at = datetime.now(timezone.utc) status.raw_output_saved = True db.commit() logger.info("Summarize done: %s quality=%s", arxiv_id, quality) return {"arxiv_id": arxiv_id, "status": "done", "quality": quality} except Exception as exc: error_type = _classify_error(exc) logger.error( "Summarize failed: %s error_type=%s %s", arxiv_id, error_type, str(exc)[:200], ) # 保存 raw_output(如果有) if raw_output: _save_raw_output_only(arxiv_id, raw_output) status.raw_output_saved = True # 重试逻辑 status.retry_count = (status.retry_count or 0) + 1 status.error_type = error_type status.error = str(exc)[:2000] if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1: status.status = "permanent_failure" else: status.status = "pending" status.completed_at = datetime.now(timezone.utc) db.commit() return { "arxiv_id": arxiv_id, "status": "failed", "error_type": error_type, "error": str(exc)[:200], "retry_count": status.retry_count, } finally: _cleanup_tmp(arxiv_id) # ── 单篇入口 ──────────────────────────────────────────────────────────── async def summarize_single( db: Session, arxiv_id: str, *, force: bool = True, _session_factory=None, ) -> dict: """单篇总结入口(供 admin 路由和 CLI 调用)。 _session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。 """ paper = ( db.query(Paper) .filter(Paper.arxiv_id == arxiv_id) .options( joinedload(Paper.authors), joinedload(Paper.tags), joinedload(Paper.summary_status), ) .first() ) if not paper: return {"status": "not_found", "arxiv_id": arxiv_id} make_session = _session_factory or SessionLocal # 每篇用独立 session 避免并发问题 paper_db = make_session() try: paper_in_new_session = ( paper_db.query(Paper) .filter(Paper.arxiv_id == arxiv_id) .options( joinedload(Paper.authors), joinedload(Paper.tags), joinedload(Paper.summary_status), ) .first() ) result = await summarize_one(paper_db, paper_in_new_session, force=force) finally: paper_db.close() return result # ── 批量总结 ──────────────────────────────────────────────────────────── async def summarize_batch( db: Session, arxiv_ids: list[str] | None = None, *, _session_factory=None, ) -> dict: """批量总结入口。arxiv_ids=None 时处理所有 pending 论文。 _session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。 """ now = datetime.now(timezone.utc) # TaskLock 防重入 lock = TaskLock( task="summarize", lock_key="batch", status="running", owner="summarize_batch", acquired_at=now, ) try: db.add(lock) db.commit() except Exception: db.rollback() logger.warning("Summarize batch already running (lock conflict)") return {"status": "conflict", "error": "summarize batch already running"} # CrawlLog log_entry = CrawlLog( task="summarize", status="running", started_at=now, ) db.add(log_entry) db.commit() try: # 查询待总结论文 query = db.query(Paper).options( joinedload(Paper.authors), joinedload(Paper.tags), joinedload(Paper.summary_status), ) if arxiv_ids: query = query.filter(Paper.arxiv_id.in_(arxiv_ids)) else: # 只处理 pending 或 failed(可重试的) query = query.join(SummaryStatus).filter( SummaryStatus.status.in_(["pending", "failed"]) ) papers = query.all() total = len(papers) logger.info("Summarize batch: %d papers to process", total) if total == 0: log_entry.status = "success" log_entry.papers_found = 0 log_entry.papers_new = 0 log_entry.completed_at = datetime.now(timezone.utc) _release_lock(db, lock) return {"status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0} # 并发控制 semaphore = asyncio.Semaphore(settings.SUMMARY_CONCURRENCY) make_session = _session_factory or SessionLocal async def _process_paper(paper: Paper) -> dict: paper_db = make_session() try: p = ( paper_db.query(Paper) .filter(Paper.id == paper.id) .options( joinedload(Paper.authors), joinedload(Paper.tags), joinedload(Paper.summary_status), ) .first() ) return await summarize_one(paper_db, p, semaphore) finally: paper_db.close() results = await asyncio.gather( *[_process_paper(p) for p in papers], return_exceptions=True, ) # 统计结果 done = 0 failed = 0 skipped = 0 for r in results: if isinstance(r, Exception): logger.error("Unexpected error in batch: %s", r) failed += 1 elif isinstance(r, dict): if r.get("status") == "done": done += 1 elif r.get("status") == "skipped": skipped += 1 else: failed += 1 log_entry.status = "success" if failed == 0 else "failed" log_entry.papers_found = total log_entry.papers_new = done log_entry.completed_at = datetime.now(timezone.utc) db.commit() logger.info( "Summarize batch done: total=%d done=%d failed=%d skipped=%d", total, done, failed, skipped, ) return { "status": "success" if failed == 0 else "partial", "total": total, "done": done, "failed": failed, "skipped": skipped, } except Exception as exc: logger.exception("Summarize batch failed") log_entry.status = "failed" log_entry.error = str(exc)[:2000] log_entry.completed_at = datetime.now(timezone.utc) db.commit() return {"status": "failed", "error": str(exc)} finally: _release_lock(db, lock) def _release_lock(db: Session, lock: TaskLock) -> None: """释放 TaskLock。""" try: lock.status = "finished" lock.released_at = datetime.now(timezone.utc) db.commit() except Exception: db.rollback() logger.warning("Failed to release summarize lock", exc_info=True)