feat: refactor summarizer and PDF extraction pipeline

- Split summarizer into summary_generator and summary_persister modules
- Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection
- Add layout_detector service for PicoDet-S_layout_3cls integration
- Add exceptions module with ConflictError and NotFoundError
- Improve admin dashboard with better statistics and task management
- Add design review document with system optimization suggestions
- Add new tests for crawler, pdf_downloader, pipeline, and summary_utils
- Update dependencies and configuration
- Clean up dead code and improve error handling
This commit is contained in:
2026-06-13 13:16:47 +08:00
parent e2f0e1a8be
commit 21f16e6756
43 changed files with 3304 additions and 1494 deletions
+275
View File
@@ -0,0 +1,275 @@
"""AI 总结生成器 — AI 后端调用、重试循环、JSON 验证、错误分类。"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from pydantic import ValidationError
from app.config import settings
from app.services.pdf_downloader import (
PdfDownloadError,
paper_dir,
)
from app.services.summary_utils import (
JsonNotFoundError,
build_prompt,
extract_json,
extract_pdf_text,
)
from app.services.pi_client import (
PiProcessError,
PiTimeoutError,
call_pi,
)
from app.services import claude_backend
from app.services.schemas import classify_validation_error
from app.utils import truncate_error
logger = logging.getLogger(__name__)
# ── 错误分类 ────────────────────────────────────────────────────────────
def _classify_error(exc: Exception) -> str:
"""将异常映射到 error_type 枚举值。"""
if isinstance(exc, PdfDownloadError):
return "pdf_download_failed"
if isinstance(exc, PiTimeoutError):
return "timeout"
if isinstance(exc, PiProcessError):
return "process_error"
if isinstance(exc, JsonNotFoundError):
return "json_not_found"
if isinstance(exc, json.JSONDecodeError):
return "json_invalid"
if isinstance(exc, ValidationError):
return classify_validation_error(exc)
return "unknown"
# ── JSON 验证 ──────────────────────────────────────────────────────────
def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
"""验证 JSON 数据是否符合要求,返回错误列表(空=通过)。"""
errors: list[str] = []
if not isinstance(json_data, dict):
return ["顶层必须是 JSON 对象"]
# 必填字段
for f in ["arxiv_id", "title_zh", "one_line", "tags"]:
if f not in json_data or not json_data[f]:
errors.append(f"缺少必填字段: {f}")
# tags 必须是非空数组
tags = json_data.get("tags")
if not isinstance(tags, list) or len(tags) == 0:
errors.append("tags 必须是非空数组")
# 字符串段落字段(必须是 str 且 ≥50 字)
string_fields = [
("motivation", "problem"),
("motivation", "goal"),
("motivation", "gap"),
("method", "overview"),
("method", "key_idea"),
("method", "steps"),
("method", "novelty"),
("results", "main_findings"),
("results", "limitations"),
("improvements", "weaknesses"),
("improvements", "future_work"),
("improvements", "reproducibility"),
]
for section, field in string_fields:
val = json_data.get(section, {}).get(field)
if isinstance(val, list):
errors.append(f"{section}.{field} 应该是字符串段落,不能是数组")
elif not isinstance(val, str) or len(val.strip()) < 50:
errors.append(
f"{section}.{field} 必须是详细段落(≥50字),"
f"当前: {type(val).__name__} ({len(str(val))}字)"
)
# benchmarks 必须是数组
benchmarks = json_data.get("results", {}).get("benchmarks")
if benchmarks is not None and not isinstance(benchmarks, list):
errors.append("results.benchmarks 必须是数组")
# prerequisites.concepts 必须是对象数组,每个有 term
concepts = json_data.get("prerequisites", {}).get("concepts")
if concepts is not None:
if not isinstance(concepts, list):
errors.append("prerequisites.concepts 必须是数组")
elif len(concepts) == 0:
errors.append("prerequisites.concepts 不能为空")
else:
for i, c in enumerate(concepts):
if isinstance(c, str):
errors.append(
f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串"
)
elif isinstance(c, dict) and not c.get("term"):
errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段")
# figures 必须是数组,每个元素应有 id
figures = json_data.get("figures")
if figures is not None:
if not isinstance(figures, list):
errors.append("figures 必须是数组")
else:
for i, fig in enumerate(figures):
if isinstance(fig, dict) and not fig.get("id"):
errors.append(f"figures[{i}] 缺少 id 字段")
return errors
# ── AI 调用 + 重试 ──────────────────────────────────────────────────────
async def _generate_with_retry(
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
) -> tuple[dict, str]:
"""调用 AI 后端生成总结,最多 4 轮验证循环。
根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。
Returns:
(json_data, raw_output)
Raises:
ValueError: 4 轮验证仍未通过
"""
import time as _time
backend = settings.SUMMARY_BACKEND
validation_errors: list[str] = []
json_data: dict | None = None
raw_output = ""
session_id = None
summary_file = paper_dir(arxiv_id) / "summary.json"
# claude 后端需要预构建 promptpi 后端在 call_pi 内部构建)
claude_prompt: str | None = None
if backend == "claude":
_t0 = _time.monotonic()
txt_path = extract_pdf_text(pdf_path, max_chars=None)
body = txt_path.read_text(encoding="utf-8")
if len(body) > 80_000:
trimmed = body[:80_000].rstrip()
txt_path.write_text(trimmed, encoding="utf-8")
claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None)
logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0)
for attempt in range(1, settings.SUMMARY_MAX_RETRIES + 1):
# 清理上一轮写入的不完整文件
if summary_file.exists():
summary_file.unlink()
# 记录 AI 调用开始时间
_t_call_start = _time.monotonic()
if backend == "claude":
if attempt == 1:
raw_output, session_id = await claude_backend.call_claude(
claude_prompt,
session_id=None,
)
else:
retry_prompt = build_prompt(
arxiv_id,
meta_path,
extract_pdf_text(pdf_path, max_chars=80000),
"inject",
fix_errors=validation_errors,
)
raw_output, session_id = await claude_backend.call_claude(
retry_prompt,
session_id=session_id,
fix_errors=validation_errors,
)
else:
if attempt == 1:
raw_output, session_id = await call_pi(
meta_path, pdf_path, pdf_mode=pdf_mode
)
else:
raw_output, session_id = await call_pi(
meta_path,
pdf_path,
fix_errors=validation_errors,
session_id=session_id,
pdf_mode=pdf_mode,
)
_t_call_end = _time.monotonic()
# 检查 summary.json 是否由 AI 子进程写入
file_written_by_ai = summary_file.exists()
file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None
file_size = summary_file.stat().st_size if file_written_by_ai else 0
logger.info(
" [%s] attempt %d AI调用: %.2fs summary.json=%s%s",
arxiv_id,
attempt,
_t_call_end - _t_call_start,
f"已写入({file_size}B)" if file_written_by_ai else "未写入",
f" mtime={file_mtime:.2f}" if file_mtime else "",
)
# 提取 JSON
_t_json_start = _time.monotonic()
try:
if file_written_by_ai:
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id)
else:
json_data = extract_json(raw_output)
except (json.JSONDecodeError, JsonNotFoundError) as exc:
_t_json_end = _time.monotonic()
logger.warning(
" [%s] JSON提取失败: %.2fs %s",
arxiv_id,
_t_json_end - _t_json_start,
str(exc)[:200],
)
validation_errors = [f"无法提取有效 JSON: {truncate_error(exc)}"]
continue
_t_json_end = _time.monotonic()
# 验证
_t_val_start = _time.monotonic()
validation_errors = _validate_summary(json_data, arxiv_id)
_t_val_end = _time.monotonic()
if not validation_errors:
logger.info(
" [%s] JSON提取: %.2fs 验证: %.2fs ✅",
arxiv_id,
_t_json_end - _t_json_start,
_t_val_end - _t_val_start,
)
break
logger.warning(
" [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s",
arxiv_id,
_t_json_end - _t_json_start,
_t_val_end - _t_val_start,
"; ".join(validation_errors)[:200],
)
if validation_errors:
exc = ValueError(
f"Summary validation failed after {settings.SUMMARY_MAX_RETRIES} attempts: {'; '.join(validation_errors)}"
)
exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用
raise exc
return json_data, raw_output