feat: add claude backend, refactor summary utilities, improve batch worker pattern, add pymupdf4llm

This commit is contained in:
2026-06-12 22:25:57 +08:00
parent b42e9149e5
commit e2f0e1a8be
13 changed files with 1350 additions and 1010 deletions
+200 -64
View File
@@ -29,14 +29,19 @@ from app.services.pdf_downloader import (
download_pdf,
paper_dir,
)
from app.services.pi_client import (
from app.services.summary_utils import (
JsonNotFoundError,
build_prompt,
extract_json,
write_meta_json,
extract_pdf_text,
)
from app.services.pi_client import (
PiProcessError,
PiTimeoutError,
call_pi,
extract_json,
write_meta_json,
)
from app.services import claude_backend
from app.services.schemas import (
SummarySchema,
assess_quality,
@@ -229,7 +234,6 @@ def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) ->
async def summarize_one(
db: Session,
paper: Paper,
semaphore: asyncio.Semaphore | None = None,
*,
force: bool = False,
pdf_mode: str = "auto",
@@ -257,68 +261,128 @@ async def summarize_one(
"reason": "permanent_failure",
}
if semaphore:
await semaphore.acquire()
try:
return await _do_summarize_one(db, paper, pdf_mode=pdf_mode)
finally:
if semaphore:
semaphore.release()
return await _do_summarize_one(db, paper, pdf_mode=pdf_mode)
async def _generate_with_retry(
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
) -> tuple[dict, str]:
"""调用 pi CLI 生成总结,最多 4 轮验证循环。
"""调用 AI 后端生成总结,最多 4 轮验证循环。
根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。
Returns:
(json_data, raw_output)
Raises:
ValueError: 4 轮验证仍未通过
"""
import time as _time
backend = settings.SUMMARY_BACKEND
validation_errors: list[str] = []
json_data: dict | None = None
raw_output = ""
session_id = None
summary_file = paper_dir(arxiv_id) / "summary.json"
# claude 后端需要预构建 promptpi 后端在 call_pi 内部构建)
claude_prompt: str | None = None
if backend == "claude":
_t0 = _time.monotonic()
txt_path = extract_pdf_text(pdf_path, max_chars=None)
body = txt_path.read_text(encoding="utf-8")
if len(body) > 80_000:
trimmed = body[:80_000].rstrip()
txt_path.write_text(trimmed, encoding="utf-8")
claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None)
logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0)
for attempt in range(1, 5):
# 清理上一轮 pi 写的不完整文件
stale = paper_dir(arxiv_id) / "summary.json"
if stale.exists():
stale.unlink()
# 清理上一轮写的不完整文件
if summary_file.exists():
summary_file.unlink()
if attempt == 1:
raw_output, session_id = await call_pi(meta_path, pdf_path, pdf_mode=pdf_mode)
# 记录 AI 调用开始时间
_t_call_start = _time.monotonic()
if backend == "claude":
if attempt == 1:
raw_output, session_id = await claude_backend.call_claude(
claude_prompt, session_id=None,
)
else:
retry_prompt = build_prompt(
arxiv_id, meta_path,
extract_pdf_text(pdf_path, max_chars=80000),
"inject", fix_errors=validation_errors,
)
raw_output, session_id = await claude_backend.call_claude(
retry_prompt, session_id=session_id, fix_errors=validation_errors,
)
else:
raw_output, session_id = await call_pi(
meta_path, pdf_path,
fix_errors=validation_errors,
session_id=session_id,
pdf_mode=pdf_mode,
)
if attempt == 1:
raw_output, session_id = await call_pi(meta_path, pdf_path, pdf_mode=pdf_mode)
else:
raw_output, session_id = await call_pi(
meta_path, pdf_path,
fix_errors=validation_errors,
session_id=session_id,
pdf_mode=pdf_mode,
)
# 优先读取 pi 写入的 summary.json,否则从 stdout 提取
summary_file = paper_dir(arxiv_id) / "summary.json"
_t_call_end = _time.monotonic()
# 检查 summary.json 是否由 AI 子进程写入
file_written_by_ai = summary_file.exists()
file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None
file_size = summary_file.stat().st_size if file_written_by_ai else 0
logger.info(
" [%s] attempt %d AI调用: %.2fs summary.json=%s%s",
arxiv_id, attempt,
_t_call_end - _t_call_start,
f"已写入({file_size}B)" if file_written_by_ai else "未写入",
f" mtime={file_mtime:.2f}" if file_mtime else "",
)
# 提取 JSON
_t_json_start = _time.monotonic()
try:
if summary_file.exists():
if file_written_by_ai:
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
logger.info("Read summary.json written by pi for %s", arxiv_id)
logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id)
else:
json_data = extract_json(raw_output)
except (json.JSONDecodeError, JsonNotFoundError) as exc:
_t_json_end = _time.monotonic()
logger.warning(
"JSON extraction failed for %s (attempt %d): %s",
arxiv_id, attempt, str(exc)[:200],
" [%s] JSON提取失败: %.2fs %s",
arxiv_id, _t_json_end - _t_json_start, str(exc)[:200],
)
validation_errors = [f"无法提取有效 JSON: {str(exc)[:100]}"]
continue
_t_json_end = _time.monotonic()
# 验证
_t_val_start = _time.monotonic()
validation_errors = _validate_summary(json_data, arxiv_id)
_t_val_end = _time.monotonic()
if not validation_errors:
logger.info(
" [%s] JSON提取: %.2fs 验证: %.2fs ✅",
arxiv_id,
_t_json_end - _t_json_start,
_t_val_end - _t_val_start,
)
break
logger.warning(
"Validation failed for %s (attempt %d): %s",
arxiv_id, attempt, "; ".join(validation_errors),
" [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s",
arxiv_id,
_t_json_end - _t_json_start,
_t_val_end - _t_val_start,
"; ".join(validation_errors)[:200],
)
if validation_errors:
@@ -335,11 +399,19 @@ def _persist_summary(
db: Session, paper: Paper, json_data: dict, raw_output: str
) -> str:
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
import time as _time
arxiv_id = paper.arxiv_id
_t0 = _time.monotonic()
schema = SummarySchema.model_validate(json_data)
quality = assess_quality(schema)
_t1 = _time.monotonic()
_save_files(arxiv_id, schema, raw_output)
_t2 = _time.monotonic()
_save_files(paper.arxiv_id, schema, raw_output)
_update_summary_in_db(db, paper, schema, quality, raw_output)
_t3 = _time.monotonic()
# 状态 → done
paper.summary_status.status = SummaryState.DONE
@@ -347,10 +419,30 @@ def _persist_summary(
paper.summary_status.completed_at = utc_now()
paper.summary_status.raw_output_saved = True
db.commit()
_t4 = _time.monotonic()
logger.info(
" [%s] persist: pydantic=%.2fs 文件=%.2fs DB写入=%.2fs 状态commit=%.2fs",
arxiv_id,
_t1 - _t0,
_t2 - _t1,
_t3 - _t2,
_t4 - _t3,
)
# 触发性增强(失败不影响总结)
_maybe_extract_images(paper.arxiv_id, schema)
_maybe_index_chroma(paper.arxiv_id, paper, schema)
_t5 = _time.monotonic()
_maybe_extract_images(arxiv_id, schema)
_t6 = _time.monotonic()
_maybe_index_chroma(arxiv_id, paper, schema)
_t7 = _time.monotonic()
logger.info(
" [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs",
arxiv_id,
_t6 - _t5,
_t7 - _t6,
)
return quality
@@ -445,28 +537,47 @@ async def _do_summarize_one(
) -> dict:
"""实际的单篇总结执行(在 semaphore 保护下)。"""
arxiv_id = paper.arxiv_id
title_short = (paper.title_en or "")[:50]
# 状态 → processing
paper.summary_status.status = SummaryState.PROCESSING
paper.summary_status.started_at = utc_now()
db.commit()
logger.info("▶ [%s] 开始总结: %s", arxiv_id, title_short)
# 清理旧的图片文件和 figures_json,避免重新总结时残留
import time as _time
_t_cleanup_start = _time.monotonic()
_cleanup_old_images(db, paper)
_t_cleanup_end = _time.monotonic()
logger.info(" [%s] 清理旧数据: %.2fs", arxiv_id, _t_cleanup_end - _t_cleanup_start)
raw_output = ""
try:
meta_path = write_meta_json(paper)
await download_pdf(arxiv_id, paper.pdf_url)
_t0 = _time.monotonic()
meta_path = write_meta_json(paper)
_t1 = _time.monotonic()
logger.info(" [%s] meta.json: %.2fs", arxiv_id, _t1 - _t0)
await download_pdf(arxiv_id, paper.pdf_url)
_t2 = _time.monotonic()
logger.info(" [%s] 下载PDF: %.2fs", arxiv_id, _t2 - _t1)
logger.info(" [%s] 调用 pi 生成总结...", arxiv_id)
json_data, raw_output = await _generate_with_retry(
arxiv_id, meta_path, TMP_DIR / arxiv_id / "paper.pdf",
pdf_mode=pdf_mode,
)
_t3 = _time.monotonic()
logger.info(" [%s] pi生成: %.2fs", arxiv_id, _t3 - _t2)
quality = _persist_summary(db, paper, json_data, raw_output)
_t4 = _time.monotonic()
logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3)
logger.info("Summarize done: %s quality=%s", arxiv_id, quality)
logger.info("✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0)
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
except Exception as exc:
@@ -588,42 +699,67 @@ async def summarize_batch(
"total": 0,
}
# 并发控制
semaphore = asyncio.Semaphore(settings.SUMMARY_CONCURRENCY)
# 并发控制worker 模式,避免 573 个协程同时打开 DB 连接耗尽连接池
concurrency = settings.SUMMARY_CONCURRENCY
make_session = _session_factory or SessionLocal
async def _process_paper(paper: Paper) -> dict:
paper_db = make_session()
try:
p = paper_db.execute(
select(Paper)
.where(Paper.id == paper.id)
.options(*PAPER_DEFAULT_LOAD)
).unique().scalar_one_or_none()
return await summarize_one(paper_db, p, semaphore, pdf_mode=pdf_mode)
finally:
paper_db.close()
# 进度追踪
progress = {"done": 0, "failed": 0, "skipped": 0}
paper_queue: asyncio.Queue[Paper | None] = asyncio.Queue()
for p in papers:
paper_queue.put_nowait(p)
results = await asyncio.gather(
*[_process_paper(p) for p in papers],
async def _worker() -> list[dict]:
results: list[dict] = []
while True:
paper = paper_queue.get_nowait() if not paper_queue.empty() else None
if paper is None:
break
paper_db = make_session()
try:
p = paper_db.execute(
select(Paper)
.where(Paper.id == paper.id)
.options(*PAPER_DEFAULT_LOAD)
).unique().scalar_one_or_none()
result = await summarize_one(paper_db, p, pdf_mode=pdf_mode)
status = result.get("status", "failed")
progress[status] = progress.get(status, 0) + 1
finished = sum(progress.values())
logger.info(
"📊 进度: %d/%d (✅%d%d ⏭️%d) — %s",
finished, total,
progress["done"], progress["failed"], progress["skipped"],
paper.arxiv_id,
)
results.append(result)
except Exception as exc:
logger.error("Worker error: %s", exc)
results.append({"status": "failed", "error": str(exc)})
finally:
paper_db.close()
return results
worker_results = await asyncio.gather(
*[_worker() for _ in range(concurrency)],
return_exceptions=True,
)
results = []
for r in worker_results:
if isinstance(r, Exception):
logger.error("Unexpected error in batch: %s", r)
results.append(r)
elif isinstance(r, list):
results.extend(r)
# 统计结果
done = 0
failed = 0
skipped = 0
# 统计结果progress 已在 worker 中实时更新)
done = progress["done"]
failed = progress["failed"]
skipped = progress["skipped"]
for r in results:
if isinstance(r, Exception):
logger.error("Unexpected error in batch: %s", r)
failed += 1
elif isinstance(r, dict):
if r.get("status") == "done":
done += 1
elif r.get("status") == "skipped":
skipped += 1
else:
failed += 1
log_entry.status = "success" if failed == 0 else "failed"
log_entry.papers_found = total