feat: refactor summarizer and PDF extraction pipeline
- Split summarizer into summary_generator and summary_persister modules - Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection - Add layout_detector service for PicoDet-S_layout_3cls integration - Add exceptions module with ConflictError and NotFoundError - Improve admin dashboard with better statistics and task management - Add design review document with system optimization suggestions - Add new tests for crawler, pdf_downloader, pipeline, and summary_utils - Update dependencies and configuration - Clean up dead code and improve error handling
This commit is contained in:
@@ -0,0 +1,275 @@
|
||||
"""AI 总结生成器 — AI 后端调用、重试循环、JSON 验证、错误分类。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.config import settings
|
||||
from app.services.pdf_downloader import (
|
||||
PdfDownloadError,
|
||||
paper_dir,
|
||||
)
|
||||
from app.services.summary_utils import (
|
||||
JsonNotFoundError,
|
||||
build_prompt,
|
||||
extract_json,
|
||||
extract_pdf_text,
|
||||
)
|
||||
from app.services.pi_client import (
|
||||
PiProcessError,
|
||||
PiTimeoutError,
|
||||
call_pi,
|
||||
)
|
||||
from app.services import claude_backend
|
||||
from app.services.schemas import classify_validation_error
|
||||
from app.utils import truncate_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── 错误分类 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _classify_error(exc: Exception) -> str:
|
||||
"""将异常映射到 error_type 枚举值。"""
|
||||
if isinstance(exc, PdfDownloadError):
|
||||
return "pdf_download_failed"
|
||||
if isinstance(exc, PiTimeoutError):
|
||||
return "timeout"
|
||||
if isinstance(exc, PiProcessError):
|
||||
return "process_error"
|
||||
if isinstance(exc, JsonNotFoundError):
|
||||
return "json_not_found"
|
||||
if isinstance(exc, json.JSONDecodeError):
|
||||
return "json_invalid"
|
||||
if isinstance(exc, ValidationError):
|
||||
return classify_validation_error(exc)
|
||||
return "unknown"
|
||||
|
||||
|
||||
# ── JSON 验证 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
|
||||
"""验证 JSON 数据是否符合要求,返回错误列表(空=通过)。"""
|
||||
errors: list[str] = []
|
||||
|
||||
if not isinstance(json_data, dict):
|
||||
return ["顶层必须是 JSON 对象"]
|
||||
|
||||
# 必填字段
|
||||
for f in ["arxiv_id", "title_zh", "one_line", "tags"]:
|
||||
if f not in json_data or not json_data[f]:
|
||||
errors.append(f"缺少必填字段: {f}")
|
||||
|
||||
# tags 必须是非空数组
|
||||
tags = json_data.get("tags")
|
||||
if not isinstance(tags, list) or len(tags) == 0:
|
||||
errors.append("tags 必须是非空数组")
|
||||
|
||||
# 字符串段落字段(必须是 str 且 ≥50 字)
|
||||
string_fields = [
|
||||
("motivation", "problem"),
|
||||
("motivation", "goal"),
|
||||
("motivation", "gap"),
|
||||
("method", "overview"),
|
||||
("method", "key_idea"),
|
||||
("method", "steps"),
|
||||
("method", "novelty"),
|
||||
("results", "main_findings"),
|
||||
("results", "limitations"),
|
||||
("improvements", "weaknesses"),
|
||||
("improvements", "future_work"),
|
||||
("improvements", "reproducibility"),
|
||||
]
|
||||
for section, field in string_fields:
|
||||
val = json_data.get(section, {}).get(field)
|
||||
if isinstance(val, list):
|
||||
errors.append(f"{section}.{field} 应该是字符串段落,不能是数组")
|
||||
elif not isinstance(val, str) or len(val.strip()) < 50:
|
||||
errors.append(
|
||||
f"{section}.{field} 必须是详细段落(≥50字),"
|
||||
f"当前: {type(val).__name__} ({len(str(val))}字)"
|
||||
)
|
||||
|
||||
# benchmarks 必须是数组
|
||||
benchmarks = json_data.get("results", {}).get("benchmarks")
|
||||
if benchmarks is not None and not isinstance(benchmarks, list):
|
||||
errors.append("results.benchmarks 必须是数组")
|
||||
|
||||
# prerequisites.concepts 必须是对象数组,每个有 term
|
||||
concepts = json_data.get("prerequisites", {}).get("concepts")
|
||||
if concepts is not None:
|
||||
if not isinstance(concepts, list):
|
||||
errors.append("prerequisites.concepts 必须是数组")
|
||||
elif len(concepts) == 0:
|
||||
errors.append("prerequisites.concepts 不能为空")
|
||||
else:
|
||||
for i, c in enumerate(concepts):
|
||||
if isinstance(c, str):
|
||||
errors.append(
|
||||
f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串"
|
||||
)
|
||||
elif isinstance(c, dict) and not c.get("term"):
|
||||
errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段")
|
||||
|
||||
# figures 必须是数组,每个元素应有 id
|
||||
figures = json_data.get("figures")
|
||||
if figures is not None:
|
||||
if not isinstance(figures, list):
|
||||
errors.append("figures 必须是数组")
|
||||
else:
|
||||
for i, fig in enumerate(figures):
|
||||
if isinstance(fig, dict) and not fig.get("id"):
|
||||
errors.append(f"figures[{i}] 缺少 id 字段")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# ── AI 调用 + 重试 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _generate_with_retry(
|
||||
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
|
||||
) -> tuple[dict, str]:
|
||||
"""调用 AI 后端生成总结,最多 4 轮验证循环。
|
||||
|
||||
根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。
|
||||
|
||||
Returns:
|
||||
(json_data, raw_output)
|
||||
Raises:
|
||||
ValueError: 4 轮验证仍未通过
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
backend = settings.SUMMARY_BACKEND
|
||||
validation_errors: list[str] = []
|
||||
json_data: dict | None = None
|
||||
raw_output = ""
|
||||
session_id = None
|
||||
|
||||
summary_file = paper_dir(arxiv_id) / "summary.json"
|
||||
|
||||
# claude 后端需要预构建 prompt(pi 后端在 call_pi 内部构建)
|
||||
claude_prompt: str | None = None
|
||||
if backend == "claude":
|
||||
_t0 = _time.monotonic()
|
||||
txt_path = extract_pdf_text(pdf_path, max_chars=None)
|
||||
body = txt_path.read_text(encoding="utf-8")
|
||||
if len(body) > 80_000:
|
||||
trimmed = body[:80_000].rstrip()
|
||||
txt_path.write_text(trimmed, encoding="utf-8")
|
||||
claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None)
|
||||
logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0)
|
||||
|
||||
for attempt in range(1, settings.SUMMARY_MAX_RETRIES + 1):
|
||||
# 清理上一轮写入的不完整文件
|
||||
if summary_file.exists():
|
||||
summary_file.unlink()
|
||||
|
||||
# 记录 AI 调用开始时间
|
||||
_t_call_start = _time.monotonic()
|
||||
|
||||
if backend == "claude":
|
||||
if attempt == 1:
|
||||
raw_output, session_id = await claude_backend.call_claude(
|
||||
claude_prompt,
|
||||
session_id=None,
|
||||
)
|
||||
else:
|
||||
retry_prompt = build_prompt(
|
||||
arxiv_id,
|
||||
meta_path,
|
||||
extract_pdf_text(pdf_path, max_chars=80000),
|
||||
"inject",
|
||||
fix_errors=validation_errors,
|
||||
)
|
||||
raw_output, session_id = await claude_backend.call_claude(
|
||||
retry_prompt,
|
||||
session_id=session_id,
|
||||
fix_errors=validation_errors,
|
||||
)
|
||||
else:
|
||||
if attempt == 1:
|
||||
raw_output, session_id = await call_pi(
|
||||
meta_path, pdf_path, pdf_mode=pdf_mode
|
||||
)
|
||||
else:
|
||||
raw_output, session_id = await call_pi(
|
||||
meta_path,
|
||||
pdf_path,
|
||||
fix_errors=validation_errors,
|
||||
session_id=session_id,
|
||||
pdf_mode=pdf_mode,
|
||||
)
|
||||
|
||||
_t_call_end = _time.monotonic()
|
||||
|
||||
# 检查 summary.json 是否由 AI 子进程写入
|
||||
file_written_by_ai = summary_file.exists()
|
||||
file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None
|
||||
file_size = summary_file.stat().st_size if file_written_by_ai else 0
|
||||
|
||||
logger.info(
|
||||
" [%s] attempt %d AI调用: %.2fs summary.json=%s%s",
|
||||
arxiv_id,
|
||||
attempt,
|
||||
_t_call_end - _t_call_start,
|
||||
f"已写入({file_size}B)" if file_written_by_ai else "未写入",
|
||||
f" mtime={file_mtime:.2f}" if file_mtime else "",
|
||||
)
|
||||
|
||||
# 提取 JSON
|
||||
_t_json_start = _time.monotonic()
|
||||
try:
|
||||
if file_written_by_ai:
|
||||
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
|
||||
logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id)
|
||||
else:
|
||||
json_data = extract_json(raw_output)
|
||||
except (json.JSONDecodeError, JsonNotFoundError) as exc:
|
||||
_t_json_end = _time.monotonic()
|
||||
logger.warning(
|
||||
" [%s] JSON提取失败: %.2fs %s",
|
||||
arxiv_id,
|
||||
_t_json_end - _t_json_start,
|
||||
str(exc)[:200],
|
||||
)
|
||||
validation_errors = [f"无法提取有效 JSON: {truncate_error(exc)}"]
|
||||
continue
|
||||
_t_json_end = _time.monotonic()
|
||||
|
||||
# 验证
|
||||
_t_val_start = _time.monotonic()
|
||||
validation_errors = _validate_summary(json_data, arxiv_id)
|
||||
_t_val_end = _time.monotonic()
|
||||
|
||||
if not validation_errors:
|
||||
logger.info(
|
||||
" [%s] JSON提取: %.2fs 验证: %.2fs ✅",
|
||||
arxiv_id,
|
||||
_t_json_end - _t_json_start,
|
||||
_t_val_end - _t_val_start,
|
||||
)
|
||||
break
|
||||
logger.warning(
|
||||
" [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s",
|
||||
arxiv_id,
|
||||
_t_json_end - _t_json_start,
|
||||
_t_val_end - _t_val_start,
|
||||
"; ".join(validation_errors)[:200],
|
||||
)
|
||||
|
||||
if validation_errors:
|
||||
exc = ValueError(
|
||||
f"Summary validation failed after {settings.SUMMARY_MAX_RETRIES} attempts: {'; '.join(validation_errors)}"
|
||||
)
|
||||
exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用
|
||||
raise exc
|
||||
|
||||
return json_data, raw_output
|
||||
Reference in New Issue
Block a user