Files
daily-paper/app/services/summary_generator.py
T
Rain-Bus 21f16e6756 feat: refactor summarizer and PDF extraction pipeline
- Split summarizer into summary_generator and summary_persister modules
- Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection
- Add layout_detector service for PicoDet-S_layout_3cls integration
- Add exceptions module with ConflictError and NotFoundError
- Improve admin dashboard with better statistics and task management
- Add design review document with system optimization suggestions
- Add new tests for crawler, pdf_downloader, pipeline, and summary_utils
- Update dependencies and configuration
- Clean up dead code and improve error handling
2026-06-13 13:16:47 +08:00

276 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI 总结生成器 — AI 后端调用、重试循环、JSON 验证、错误分类。"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from pydantic import ValidationError
from app.config import settings
from app.services.pdf_downloader import (
PdfDownloadError,
paper_dir,
)
from app.services.summary_utils import (
JsonNotFoundError,
build_prompt,
extract_json,
extract_pdf_text,
)
from app.services.pi_client import (
PiProcessError,
PiTimeoutError,
call_pi,
)
from app.services import claude_backend
from app.services.schemas import classify_validation_error
from app.utils import truncate_error
logger = logging.getLogger(__name__)
# ── 错误分类 ────────────────────────────────────────────────────────────
def _classify_error(exc: Exception) -> str:
"""将异常映射到 error_type 枚举值。"""
if isinstance(exc, PdfDownloadError):
return "pdf_download_failed"
if isinstance(exc, PiTimeoutError):
return "timeout"
if isinstance(exc, PiProcessError):
return "process_error"
if isinstance(exc, JsonNotFoundError):
return "json_not_found"
if isinstance(exc, json.JSONDecodeError):
return "json_invalid"
if isinstance(exc, ValidationError):
return classify_validation_error(exc)
return "unknown"
# ── JSON 验证 ──────────────────────────────────────────────────────────
def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
"""验证 JSON 数据是否符合要求,返回错误列表(空=通过)。"""
errors: list[str] = []
if not isinstance(json_data, dict):
return ["顶层必须是 JSON 对象"]
# 必填字段
for f in ["arxiv_id", "title_zh", "one_line", "tags"]:
if f not in json_data or not json_data[f]:
errors.append(f"缺少必填字段: {f}")
# tags 必须是非空数组
tags = json_data.get("tags")
if not isinstance(tags, list) or len(tags) == 0:
errors.append("tags 必须是非空数组")
# 字符串段落字段(必须是 str 且 ≥50 字)
string_fields = [
("motivation", "problem"),
("motivation", "goal"),
("motivation", "gap"),
("method", "overview"),
("method", "key_idea"),
("method", "steps"),
("method", "novelty"),
("results", "main_findings"),
("results", "limitations"),
("improvements", "weaknesses"),
("improvements", "future_work"),
("improvements", "reproducibility"),
]
for section, field in string_fields:
val = json_data.get(section, {}).get(field)
if isinstance(val, list):
errors.append(f"{section}.{field} 应该是字符串段落,不能是数组")
elif not isinstance(val, str) or len(val.strip()) < 50:
errors.append(
f"{section}.{field} 必须是详细段落(≥50字),"
f"当前: {type(val).__name__} ({len(str(val))}字)"
)
# benchmarks 必须是数组
benchmarks = json_data.get("results", {}).get("benchmarks")
if benchmarks is not None and not isinstance(benchmarks, list):
errors.append("results.benchmarks 必须是数组")
# prerequisites.concepts 必须是对象数组,每个有 term
concepts = json_data.get("prerequisites", {}).get("concepts")
if concepts is not None:
if not isinstance(concepts, list):
errors.append("prerequisites.concepts 必须是数组")
elif len(concepts) == 0:
errors.append("prerequisites.concepts 不能为空")
else:
for i, c in enumerate(concepts):
if isinstance(c, str):
errors.append(
f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串"
)
elif isinstance(c, dict) and not c.get("term"):
errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段")
# figures 必须是数组,每个元素应有 id
figures = json_data.get("figures")
if figures is not None:
if not isinstance(figures, list):
errors.append("figures 必须是数组")
else:
for i, fig in enumerate(figures):
if isinstance(fig, dict) and not fig.get("id"):
errors.append(f"figures[{i}] 缺少 id 字段")
return errors
# ── AI 调用 + 重试 ──────────────────────────────────────────────────────
async def _generate_with_retry(
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
) -> tuple[dict, str]:
"""调用 AI 后端生成总结,最多 4 轮验证循环。
根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。
Returns:
(json_data, raw_output)
Raises:
ValueError: 4 轮验证仍未通过
"""
import time as _time
backend = settings.SUMMARY_BACKEND
validation_errors: list[str] = []
json_data: dict | None = None
raw_output = ""
session_id = None
summary_file = paper_dir(arxiv_id) / "summary.json"
# claude 后端需要预构建 promptpi 后端在 call_pi 内部构建)
claude_prompt: str | None = None
if backend == "claude":
_t0 = _time.monotonic()
txt_path = extract_pdf_text(pdf_path, max_chars=None)
body = txt_path.read_text(encoding="utf-8")
if len(body) > 80_000:
trimmed = body[:80_000].rstrip()
txt_path.write_text(trimmed, encoding="utf-8")
claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None)
logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0)
for attempt in range(1, settings.SUMMARY_MAX_RETRIES + 1):
# 清理上一轮写入的不完整文件
if summary_file.exists():
summary_file.unlink()
# 记录 AI 调用开始时间
_t_call_start = _time.monotonic()
if backend == "claude":
if attempt == 1:
raw_output, session_id = await claude_backend.call_claude(
claude_prompt,
session_id=None,
)
else:
retry_prompt = build_prompt(
arxiv_id,
meta_path,
extract_pdf_text(pdf_path, max_chars=80000),
"inject",
fix_errors=validation_errors,
)
raw_output, session_id = await claude_backend.call_claude(
retry_prompt,
session_id=session_id,
fix_errors=validation_errors,
)
else:
if attempt == 1:
raw_output, session_id = await call_pi(
meta_path, pdf_path, pdf_mode=pdf_mode
)
else:
raw_output, session_id = await call_pi(
meta_path,
pdf_path,
fix_errors=validation_errors,
session_id=session_id,
pdf_mode=pdf_mode,
)
_t_call_end = _time.monotonic()
# 检查 summary.json 是否由 AI 子进程写入
file_written_by_ai = summary_file.exists()
file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None
file_size = summary_file.stat().st_size if file_written_by_ai else 0
logger.info(
" [%s] attempt %d AI调用: %.2fs summary.json=%s%s",
arxiv_id,
attempt,
_t_call_end - _t_call_start,
f"已写入({file_size}B)" if file_written_by_ai else "未写入",
f" mtime={file_mtime:.2f}" if file_mtime else "",
)
# 提取 JSON
_t_json_start = _time.monotonic()
try:
if file_written_by_ai:
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id)
else:
json_data = extract_json(raw_output)
except (json.JSONDecodeError, JsonNotFoundError) as exc:
_t_json_end = _time.monotonic()
logger.warning(
" [%s] JSON提取失败: %.2fs %s",
arxiv_id,
_t_json_end - _t_json_start,
str(exc)[:200],
)
validation_errors = [f"无法提取有效 JSON: {truncate_error(exc)}"]
continue
_t_json_end = _time.monotonic()
# 验证
_t_val_start = _time.monotonic()
validation_errors = _validate_summary(json_data, arxiv_id)
_t_val_end = _time.monotonic()
if not validation_errors:
logger.info(
" [%s] JSON提取: %.2fs 验证: %.2fs ✅",
arxiv_id,
_t_json_end - _t_json_start,
_t_val_end - _t_val_start,
)
break
logger.warning(
" [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s",
arxiv_id,
_t_json_end - _t_json_start,
_t_val_end - _t_val_start,
"; ".join(validation_errors)[:200],
)
if validation_errors:
exc = ValueError(
f"Summary validation failed after {settings.SUMMARY_MAX_RETRIES} attempts: {'; '.join(validation_errors)}"
)
exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用
raise exc
return json_data, raw_output