"""AI 总结生成器 — AI 后端调用、重试循环、JSON 验证、错误分类。""" from __future__ import annotations import json import logging from pathlib import Path from pydantic import ValidationError from app.config import settings from app.services.pdf_downloader import ( PdfDownloadError, paper_dir, ) from app.services.summary_utils import ( JsonNotFoundError, build_prompt, extract_json, extract_pdf_text, ) from app.services.pi_client import ( PiProcessError, PiTimeoutError, call_pi, ) from app.services import claude_backend from app.services.schemas import classify_validation_error from app.utils import truncate_error logger = logging.getLogger(__name__) # ── 错误分类 ──────────────────────────────────────────────────────────── def _classify_error(exc: Exception) -> str: """将异常映射到 error_type 枚举值。""" if isinstance(exc, PdfDownloadError): return "pdf_download_failed" if isinstance(exc, PiTimeoutError): return "timeout" if isinstance(exc, PiProcessError): return "process_error" if isinstance(exc, JsonNotFoundError): return "json_not_found" if isinstance(exc, json.JSONDecodeError): return "json_invalid" if isinstance(exc, ValidationError): return classify_validation_error(exc) return "unknown" # ── JSON 验证 ────────────────────────────────────────────────────────── def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]: """验证 JSON 数据是否符合要求,返回错误列表(空=通过)。""" errors: list[str] = [] if not isinstance(json_data, dict): return ["顶层必须是 JSON 对象"] # 必填字段 for f in ["arxiv_id", "title_zh", "one_line", "tags"]: if f not in json_data or not json_data[f]: errors.append(f"缺少必填字段: {f}") # tags 必须是非空数组 tags = json_data.get("tags") if not isinstance(tags, list) or len(tags) == 0: errors.append("tags 必须是非空数组") # 字符串段落字段(必须是 str 且 ≥50 字) string_fields = [ ("motivation", "problem"), ("motivation", "goal"), ("motivation", "gap"), ("method", "overview"), ("method", "key_idea"), ("method", "steps"), ("method", "novelty"), ("results", "main_findings"), ("results", "limitations"), ("improvements", "weaknesses"), ("improvements", "future_work"), ("improvements", "reproducibility"), ] for section, field in string_fields: val = json_data.get(section, {}).get(field) if isinstance(val, list): errors.append(f"{section}.{field} 应该是字符串段落,不能是数组") elif not isinstance(val, str) or len(val.strip()) < 50: errors.append( f"{section}.{field} 必须是详细段落(≥50字)," f"当前: {type(val).__name__} ({len(str(val))}字)" ) # benchmarks 必须是数组 benchmarks = json_data.get("results", {}).get("benchmarks") if benchmarks is not None and not isinstance(benchmarks, list): errors.append("results.benchmarks 必须是数组") # prerequisites.concepts 必须是对象数组,每个有 term concepts = json_data.get("prerequisites", {}).get("concepts") if concepts is not None: if not isinstance(concepts, list): errors.append("prerequisites.concepts 必须是数组") elif len(concepts) == 0: errors.append("prerequisites.concepts 不能为空") else: for i, c in enumerate(concepts): if isinstance(c, str): errors.append( f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串" ) elif isinstance(c, dict) and not c.get("term"): errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段") # figures 必须是数组,每个元素应有 id figures = json_data.get("figures") if figures is not None: if not isinstance(figures, list): errors.append("figures 必须是数组") else: for i, fig in enumerate(figures): if isinstance(fig, dict) and not fig.get("id"): errors.append(f"figures[{i}] 缺少 id 字段") return errors # ── AI 调用 + 重试 ────────────────────────────────────────────────────── async def _generate_with_retry( arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto" ) -> tuple[dict, str]: """调用 AI 后端生成总结,最多 4 轮验证循环。 根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。 Returns: (json_data, raw_output) Raises: ValueError: 4 轮验证仍未通过 """ import time as _time backend = settings.SUMMARY_BACKEND validation_errors: list[str] = [] json_data: dict | None = None raw_output = "" session_id = None summary_file = paper_dir(arxiv_id) / "summary.json" # claude 后端需要预构建 prompt(pi 后端在 call_pi 内部构建) claude_prompt: str | None = None if backend == "claude": _t0 = _time.monotonic() txt_path = extract_pdf_text(pdf_path, max_chars=None) body = txt_path.read_text(encoding="utf-8") if len(body) > 80_000: trimmed = body[:80_000].rstrip() txt_path.write_text(trimmed, encoding="utf-8") claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None) logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0) for attempt in range(1, settings.SUMMARY_MAX_RETRIES + 1): # 清理上一轮写入的不完整文件 if summary_file.exists(): summary_file.unlink() # 记录 AI 调用开始时间 _t_call_start = _time.monotonic() if backend == "claude": if attempt == 1: raw_output, session_id = await claude_backend.call_claude( claude_prompt, session_id=None, ) else: retry_prompt = build_prompt( arxiv_id, meta_path, extract_pdf_text(pdf_path, max_chars=80000), "inject", fix_errors=validation_errors, ) raw_output, session_id = await claude_backend.call_claude( retry_prompt, session_id=session_id, fix_errors=validation_errors, ) else: if attempt == 1: raw_output, session_id = await call_pi( meta_path, pdf_path, pdf_mode=pdf_mode ) else: raw_output, session_id = await call_pi( meta_path, pdf_path, fix_errors=validation_errors, session_id=session_id, pdf_mode=pdf_mode, ) _t_call_end = _time.monotonic() # 检查 summary.json 是否由 AI 子进程写入 file_written_by_ai = summary_file.exists() file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None file_size = summary_file.stat().st_size if file_written_by_ai else 0 logger.info( " [%s] attempt %d AI调用: %.2fs summary.json=%s%s", arxiv_id, attempt, _t_call_end - _t_call_start, f"已写入({file_size}B)" if file_written_by_ai else "未写入", f" mtime={file_mtime:.2f}" if file_mtime else "", ) # 提取 JSON _t_json_start = _time.monotonic() try: if file_written_by_ai: json_data = json.loads(summary_file.read_text(encoding="utf-8")) logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id) else: json_data = extract_json(raw_output) except (json.JSONDecodeError, JsonNotFoundError) as exc: _t_json_end = _time.monotonic() logger.warning( " [%s] JSON提取失败: %.2fs %s", arxiv_id, _t_json_end - _t_json_start, str(exc)[:200], ) validation_errors = [f"无法提取有效 JSON: {truncate_error(exc)}"] continue _t_json_end = _time.monotonic() # 验证 _t_val_start = _time.monotonic() validation_errors = _validate_summary(json_data, arxiv_id) _t_val_end = _time.monotonic() if not validation_errors: logger.info( " [%s] JSON提取: %.2fs 验证: %.2fs ✅", arxiv_id, _t_json_end - _t_json_start, _t_val_end - _t_val_start, ) break logger.warning( " [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s", arxiv_id, _t_json_end - _t_json_start, _t_val_end - _t_val_start, "; ".join(validation_errors)[:200], ) if validation_errors: exc = ValueError( f"Summary validation failed after {settings.SUMMARY_MAX_RETRIES} attempts: {'; '.join(validation_errors)}" ) exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用 raise exc return json_data, raw_output