21f16e6756
- Split summarizer into summary_generator and summary_persister modules - Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection - Add layout_detector service for PicoDet-S_layout_3cls integration - Add exceptions module with ConflictError and NotFoundError - Improve admin dashboard with better statistics and task management - Add design review document with system optimization suggestions - Add new tests for crawler, pdf_downloader, pipeline, and summary_utils - Update dependencies and configuration - Clean up dead code and improve error handling
276 lines
9.8 KiB
Python
276 lines
9.8 KiB
Python
"""AI 总结生成器 — AI 后端调用、重试循环、JSON 验证、错误分类。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
from pathlib import Path
|
||
|
||
from pydantic import ValidationError
|
||
|
||
from app.config import settings
|
||
from app.services.pdf_downloader import (
|
||
PdfDownloadError,
|
||
paper_dir,
|
||
)
|
||
from app.services.summary_utils import (
|
||
JsonNotFoundError,
|
||
build_prompt,
|
||
extract_json,
|
||
extract_pdf_text,
|
||
)
|
||
from app.services.pi_client import (
|
||
PiProcessError,
|
||
PiTimeoutError,
|
||
call_pi,
|
||
)
|
||
from app.services import claude_backend
|
||
from app.services.schemas import classify_validation_error
|
||
from app.utils import truncate_error
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ── 错误分类 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
def _classify_error(exc: Exception) -> str:
|
||
"""将异常映射到 error_type 枚举值。"""
|
||
if isinstance(exc, PdfDownloadError):
|
||
return "pdf_download_failed"
|
||
if isinstance(exc, PiTimeoutError):
|
||
return "timeout"
|
||
if isinstance(exc, PiProcessError):
|
||
return "process_error"
|
||
if isinstance(exc, JsonNotFoundError):
|
||
return "json_not_found"
|
||
if isinstance(exc, json.JSONDecodeError):
|
||
return "json_invalid"
|
||
if isinstance(exc, ValidationError):
|
||
return classify_validation_error(exc)
|
||
return "unknown"
|
||
|
||
|
||
# ── JSON 验证 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
|
||
"""验证 JSON 数据是否符合要求,返回错误列表(空=通过)。"""
|
||
errors: list[str] = []
|
||
|
||
if not isinstance(json_data, dict):
|
||
return ["顶层必须是 JSON 对象"]
|
||
|
||
# 必填字段
|
||
for f in ["arxiv_id", "title_zh", "one_line", "tags"]:
|
||
if f not in json_data or not json_data[f]:
|
||
errors.append(f"缺少必填字段: {f}")
|
||
|
||
# tags 必须是非空数组
|
||
tags = json_data.get("tags")
|
||
if not isinstance(tags, list) or len(tags) == 0:
|
||
errors.append("tags 必须是非空数组")
|
||
|
||
# 字符串段落字段(必须是 str 且 ≥50 字)
|
||
string_fields = [
|
||
("motivation", "problem"),
|
||
("motivation", "goal"),
|
||
("motivation", "gap"),
|
||
("method", "overview"),
|
||
("method", "key_idea"),
|
||
("method", "steps"),
|
||
("method", "novelty"),
|
||
("results", "main_findings"),
|
||
("results", "limitations"),
|
||
("improvements", "weaknesses"),
|
||
("improvements", "future_work"),
|
||
("improvements", "reproducibility"),
|
||
]
|
||
for section, field in string_fields:
|
||
val = json_data.get(section, {}).get(field)
|
||
if isinstance(val, list):
|
||
errors.append(f"{section}.{field} 应该是字符串段落,不能是数组")
|
||
elif not isinstance(val, str) or len(val.strip()) < 50:
|
||
errors.append(
|
||
f"{section}.{field} 必须是详细段落(≥50字),"
|
||
f"当前: {type(val).__name__} ({len(str(val))}字)"
|
||
)
|
||
|
||
# benchmarks 必须是数组
|
||
benchmarks = json_data.get("results", {}).get("benchmarks")
|
||
if benchmarks is not None and not isinstance(benchmarks, list):
|
||
errors.append("results.benchmarks 必须是数组")
|
||
|
||
# prerequisites.concepts 必须是对象数组,每个有 term
|
||
concepts = json_data.get("prerequisites", {}).get("concepts")
|
||
if concepts is not None:
|
||
if not isinstance(concepts, list):
|
||
errors.append("prerequisites.concepts 必须是数组")
|
||
elif len(concepts) == 0:
|
||
errors.append("prerequisites.concepts 不能为空")
|
||
else:
|
||
for i, c in enumerate(concepts):
|
||
if isinstance(c, str):
|
||
errors.append(
|
||
f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串"
|
||
)
|
||
elif isinstance(c, dict) and not c.get("term"):
|
||
errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段")
|
||
|
||
# figures 必须是数组,每个元素应有 id
|
||
figures = json_data.get("figures")
|
||
if figures is not None:
|
||
if not isinstance(figures, list):
|
||
errors.append("figures 必须是数组")
|
||
else:
|
||
for i, fig in enumerate(figures):
|
||
if isinstance(fig, dict) and not fig.get("id"):
|
||
errors.append(f"figures[{i}] 缺少 id 字段")
|
||
|
||
return errors
|
||
|
||
|
||
# ── AI 调用 + 重试 ──────────────────────────────────────────────────────
|
||
|
||
|
||
async def _generate_with_retry(
|
||
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
|
||
) -> tuple[dict, str]:
|
||
"""调用 AI 后端生成总结,最多 4 轮验证循环。
|
||
|
||
根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。
|
||
|
||
Returns:
|
||
(json_data, raw_output)
|
||
Raises:
|
||
ValueError: 4 轮验证仍未通过
|
||
"""
|
||
import time as _time
|
||
|
||
backend = settings.SUMMARY_BACKEND
|
||
validation_errors: list[str] = []
|
||
json_data: dict | None = None
|
||
raw_output = ""
|
||
session_id = None
|
||
|
||
summary_file = paper_dir(arxiv_id) / "summary.json"
|
||
|
||
# claude 后端需要预构建 prompt(pi 后端在 call_pi 内部构建)
|
||
claude_prompt: str | None = None
|
||
if backend == "claude":
|
||
_t0 = _time.monotonic()
|
||
txt_path = extract_pdf_text(pdf_path, max_chars=None)
|
||
body = txt_path.read_text(encoding="utf-8")
|
||
if len(body) > 80_000:
|
||
trimmed = body[:80_000].rstrip()
|
||
txt_path.write_text(trimmed, encoding="utf-8")
|
||
claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None)
|
||
logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0)
|
||
|
||
for attempt in range(1, settings.SUMMARY_MAX_RETRIES + 1):
|
||
# 清理上一轮写入的不完整文件
|
||
if summary_file.exists():
|
||
summary_file.unlink()
|
||
|
||
# 记录 AI 调用开始时间
|
||
_t_call_start = _time.monotonic()
|
||
|
||
if backend == "claude":
|
||
if attempt == 1:
|
||
raw_output, session_id = await claude_backend.call_claude(
|
||
claude_prompt,
|
||
session_id=None,
|
||
)
|
||
else:
|
||
retry_prompt = build_prompt(
|
||
arxiv_id,
|
||
meta_path,
|
||
extract_pdf_text(pdf_path, max_chars=80000),
|
||
"inject",
|
||
fix_errors=validation_errors,
|
||
)
|
||
raw_output, session_id = await claude_backend.call_claude(
|
||
retry_prompt,
|
||
session_id=session_id,
|
||
fix_errors=validation_errors,
|
||
)
|
||
else:
|
||
if attempt == 1:
|
||
raw_output, session_id = await call_pi(
|
||
meta_path, pdf_path, pdf_mode=pdf_mode
|
||
)
|
||
else:
|
||
raw_output, session_id = await call_pi(
|
||
meta_path,
|
||
pdf_path,
|
||
fix_errors=validation_errors,
|
||
session_id=session_id,
|
||
pdf_mode=pdf_mode,
|
||
)
|
||
|
||
_t_call_end = _time.monotonic()
|
||
|
||
# 检查 summary.json 是否由 AI 子进程写入
|
||
file_written_by_ai = summary_file.exists()
|
||
file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None
|
||
file_size = summary_file.stat().st_size if file_written_by_ai else 0
|
||
|
||
logger.info(
|
||
" [%s] attempt %d AI调用: %.2fs summary.json=%s%s",
|
||
arxiv_id,
|
||
attempt,
|
||
_t_call_end - _t_call_start,
|
||
f"已写入({file_size}B)" if file_written_by_ai else "未写入",
|
||
f" mtime={file_mtime:.2f}" if file_mtime else "",
|
||
)
|
||
|
||
# 提取 JSON
|
||
_t_json_start = _time.monotonic()
|
||
try:
|
||
if file_written_by_ai:
|
||
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
|
||
logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id)
|
||
else:
|
||
json_data = extract_json(raw_output)
|
||
except (json.JSONDecodeError, JsonNotFoundError) as exc:
|
||
_t_json_end = _time.monotonic()
|
||
logger.warning(
|
||
" [%s] JSON提取失败: %.2fs %s",
|
||
arxiv_id,
|
||
_t_json_end - _t_json_start,
|
||
str(exc)[:200],
|
||
)
|
||
validation_errors = [f"无法提取有效 JSON: {truncate_error(exc)}"]
|
||
continue
|
||
_t_json_end = _time.monotonic()
|
||
|
||
# 验证
|
||
_t_val_start = _time.monotonic()
|
||
validation_errors = _validate_summary(json_data, arxiv_id)
|
||
_t_val_end = _time.monotonic()
|
||
|
||
if not validation_errors:
|
||
logger.info(
|
||
" [%s] JSON提取: %.2fs 验证: %.2fs ✅",
|
||
arxiv_id,
|
||
_t_json_end - _t_json_start,
|
||
_t_val_end - _t_val_start,
|
||
)
|
||
break
|
||
logger.warning(
|
||
" [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s",
|
||
arxiv_id,
|
||
_t_json_end - _t_json_start,
|
||
_t_val_end - _t_val_start,
|
||
"; ".join(validation_errors)[:200],
|
||
)
|
||
|
||
if validation_errors:
|
||
exc = ValueError(
|
||
f"Summary validation failed after {settings.SUMMARY_MAX_RETRIES} attempts: {'; '.join(validation_errors)}"
|
||
)
|
||
exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用
|
||
raise exc
|
||
|
||
return json_data, raw_output
|