feat: add claude backend, refactor summary utilities, improve batch worker pattern, add pymupdf4llm
This commit is contained in:
+200
-64
@@ -29,14 +29,19 @@ from app.services.pdf_downloader import (
|
||||
download_pdf,
|
||||
paper_dir,
|
||||
)
|
||||
from app.services.pi_client import (
|
||||
from app.services.summary_utils import (
|
||||
JsonNotFoundError,
|
||||
build_prompt,
|
||||
extract_json,
|
||||
write_meta_json,
|
||||
extract_pdf_text,
|
||||
)
|
||||
from app.services.pi_client import (
|
||||
PiProcessError,
|
||||
PiTimeoutError,
|
||||
call_pi,
|
||||
extract_json,
|
||||
write_meta_json,
|
||||
)
|
||||
from app.services import claude_backend
|
||||
from app.services.schemas import (
|
||||
SummarySchema,
|
||||
assess_quality,
|
||||
@@ -229,7 +234,6 @@ def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) ->
|
||||
async def summarize_one(
|
||||
db: Session,
|
||||
paper: Paper,
|
||||
semaphore: asyncio.Semaphore | None = None,
|
||||
*,
|
||||
force: bool = False,
|
||||
pdf_mode: str = "auto",
|
||||
@@ -257,68 +261,128 @@ async def summarize_one(
|
||||
"reason": "permanent_failure",
|
||||
}
|
||||
|
||||
if semaphore:
|
||||
await semaphore.acquire()
|
||||
try:
|
||||
return await _do_summarize_one(db, paper, pdf_mode=pdf_mode)
|
||||
finally:
|
||||
if semaphore:
|
||||
semaphore.release()
|
||||
return await _do_summarize_one(db, paper, pdf_mode=pdf_mode)
|
||||
|
||||
|
||||
async def _generate_with_retry(
|
||||
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
|
||||
) -> tuple[dict, str]:
|
||||
"""调用 pi CLI 生成总结,最多 4 轮验证循环。
|
||||
"""调用 AI 后端生成总结,最多 4 轮验证循环。
|
||||
|
||||
根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。
|
||||
|
||||
Returns:
|
||||
(json_data, raw_output)
|
||||
Raises:
|
||||
ValueError: 4 轮验证仍未通过
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
backend = settings.SUMMARY_BACKEND
|
||||
validation_errors: list[str] = []
|
||||
json_data: dict | None = None
|
||||
raw_output = ""
|
||||
session_id = None
|
||||
|
||||
summary_file = paper_dir(arxiv_id) / "summary.json"
|
||||
|
||||
# claude 后端需要预构建 prompt(pi 后端在 call_pi 内部构建)
|
||||
claude_prompt: str | None = None
|
||||
if backend == "claude":
|
||||
_t0 = _time.monotonic()
|
||||
txt_path = extract_pdf_text(pdf_path, max_chars=None)
|
||||
body = txt_path.read_text(encoding="utf-8")
|
||||
if len(body) > 80_000:
|
||||
trimmed = body[:80_000].rstrip()
|
||||
txt_path.write_text(trimmed, encoding="utf-8")
|
||||
claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None)
|
||||
logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0)
|
||||
|
||||
for attempt in range(1, 5):
|
||||
# 清理上一轮 pi 写的不完整文件
|
||||
stale = paper_dir(arxiv_id) / "summary.json"
|
||||
if stale.exists():
|
||||
stale.unlink()
|
||||
# 清理上一轮写入的不完整文件
|
||||
if summary_file.exists():
|
||||
summary_file.unlink()
|
||||
|
||||
if attempt == 1:
|
||||
raw_output, session_id = await call_pi(meta_path, pdf_path, pdf_mode=pdf_mode)
|
||||
# 记录 AI 调用开始时间
|
||||
_t_call_start = _time.monotonic()
|
||||
|
||||
if backend == "claude":
|
||||
if attempt == 1:
|
||||
raw_output, session_id = await claude_backend.call_claude(
|
||||
claude_prompt, session_id=None,
|
||||
)
|
||||
else:
|
||||
retry_prompt = build_prompt(
|
||||
arxiv_id, meta_path,
|
||||
extract_pdf_text(pdf_path, max_chars=80000),
|
||||
"inject", fix_errors=validation_errors,
|
||||
)
|
||||
raw_output, session_id = await claude_backend.call_claude(
|
||||
retry_prompt, session_id=session_id, fix_errors=validation_errors,
|
||||
)
|
||||
else:
|
||||
raw_output, session_id = await call_pi(
|
||||
meta_path, pdf_path,
|
||||
fix_errors=validation_errors,
|
||||
session_id=session_id,
|
||||
pdf_mode=pdf_mode,
|
||||
)
|
||||
if attempt == 1:
|
||||
raw_output, session_id = await call_pi(meta_path, pdf_path, pdf_mode=pdf_mode)
|
||||
else:
|
||||
raw_output, session_id = await call_pi(
|
||||
meta_path, pdf_path,
|
||||
fix_errors=validation_errors,
|
||||
session_id=session_id,
|
||||
pdf_mode=pdf_mode,
|
||||
)
|
||||
|
||||
# 优先读取 pi 写入的 summary.json,否则从 stdout 提取
|
||||
summary_file = paper_dir(arxiv_id) / "summary.json"
|
||||
_t_call_end = _time.monotonic()
|
||||
|
||||
# 检查 summary.json 是否由 AI 子进程写入
|
||||
file_written_by_ai = summary_file.exists()
|
||||
file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None
|
||||
file_size = summary_file.stat().st_size if file_written_by_ai else 0
|
||||
|
||||
logger.info(
|
||||
" [%s] attempt %d AI调用: %.2fs summary.json=%s%s",
|
||||
arxiv_id, attempt,
|
||||
_t_call_end - _t_call_start,
|
||||
f"已写入({file_size}B)" if file_written_by_ai else "未写入",
|
||||
f" mtime={file_mtime:.2f}" if file_mtime else "",
|
||||
)
|
||||
|
||||
# 提取 JSON
|
||||
_t_json_start = _time.monotonic()
|
||||
try:
|
||||
if summary_file.exists():
|
||||
if file_written_by_ai:
|
||||
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
|
||||
logger.info("Read summary.json written by pi for %s", arxiv_id)
|
||||
logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id)
|
||||
else:
|
||||
json_data = extract_json(raw_output)
|
||||
except (json.JSONDecodeError, JsonNotFoundError) as exc:
|
||||
_t_json_end = _time.monotonic()
|
||||
logger.warning(
|
||||
"JSON extraction failed for %s (attempt %d): %s",
|
||||
arxiv_id, attempt, str(exc)[:200],
|
||||
" [%s] JSON提取失败: %.2fs %s",
|
||||
arxiv_id, _t_json_end - _t_json_start, str(exc)[:200],
|
||||
)
|
||||
validation_errors = [f"无法提取有效 JSON: {str(exc)[:100]}"]
|
||||
continue
|
||||
_t_json_end = _time.monotonic()
|
||||
|
||||
# 验证
|
||||
_t_val_start = _time.monotonic()
|
||||
validation_errors = _validate_summary(json_data, arxiv_id)
|
||||
_t_val_end = _time.monotonic()
|
||||
|
||||
if not validation_errors:
|
||||
logger.info(
|
||||
" [%s] JSON提取: %.2fs 验证: %.2fs ✅",
|
||||
arxiv_id,
|
||||
_t_json_end - _t_json_start,
|
||||
_t_val_end - _t_val_start,
|
||||
)
|
||||
break
|
||||
logger.warning(
|
||||
"Validation failed for %s (attempt %d): %s",
|
||||
arxiv_id, attempt, "; ".join(validation_errors),
|
||||
" [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s",
|
||||
arxiv_id,
|
||||
_t_json_end - _t_json_start,
|
||||
_t_val_end - _t_val_start,
|
||||
"; ".join(validation_errors)[:200],
|
||||
)
|
||||
|
||||
if validation_errors:
|
||||
@@ -335,11 +399,19 @@ def _persist_summary(
|
||||
db: Session, paper: Paper, json_data: dict, raw_output: str
|
||||
) -> str:
|
||||
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
|
||||
import time as _time
|
||||
arxiv_id = paper.arxiv_id
|
||||
|
||||
_t0 = _time.monotonic()
|
||||
schema = SummarySchema.model_validate(json_data)
|
||||
quality = assess_quality(schema)
|
||||
_t1 = _time.monotonic()
|
||||
|
||||
_save_files(arxiv_id, schema, raw_output)
|
||||
_t2 = _time.monotonic()
|
||||
|
||||
_save_files(paper.arxiv_id, schema, raw_output)
|
||||
_update_summary_in_db(db, paper, schema, quality, raw_output)
|
||||
_t3 = _time.monotonic()
|
||||
|
||||
# 状态 → done
|
||||
paper.summary_status.status = SummaryState.DONE
|
||||
@@ -347,10 +419,30 @@ def _persist_summary(
|
||||
paper.summary_status.completed_at = utc_now()
|
||||
paper.summary_status.raw_output_saved = True
|
||||
db.commit()
|
||||
_t4 = _time.monotonic()
|
||||
|
||||
logger.info(
|
||||
" [%s] persist: pydantic=%.2fs 文件=%.2fs DB写入=%.2fs 状态commit=%.2fs",
|
||||
arxiv_id,
|
||||
_t1 - _t0,
|
||||
_t2 - _t1,
|
||||
_t3 - _t2,
|
||||
_t4 - _t3,
|
||||
)
|
||||
|
||||
# 触发性增强(失败不影响总结)
|
||||
_maybe_extract_images(paper.arxiv_id, schema)
|
||||
_maybe_index_chroma(paper.arxiv_id, paper, schema)
|
||||
_t5 = _time.monotonic()
|
||||
_maybe_extract_images(arxiv_id, schema)
|
||||
_t6 = _time.monotonic()
|
||||
_maybe_index_chroma(arxiv_id, paper, schema)
|
||||
_t7 = _time.monotonic()
|
||||
|
||||
logger.info(
|
||||
" [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs",
|
||||
arxiv_id,
|
||||
_t6 - _t5,
|
||||
_t7 - _t6,
|
||||
)
|
||||
|
||||
return quality
|
||||
|
||||
@@ -445,28 +537,47 @@ async def _do_summarize_one(
|
||||
) -> dict:
|
||||
"""实际的单篇总结执行(在 semaphore 保护下)。"""
|
||||
arxiv_id = paper.arxiv_id
|
||||
title_short = (paper.title_en or "")[:50]
|
||||
|
||||
# 状态 → processing
|
||||
paper.summary_status.status = SummaryState.PROCESSING
|
||||
paper.summary_status.started_at = utc_now()
|
||||
db.commit()
|
||||
|
||||
logger.info("▶ [%s] 开始总结: %s", arxiv_id, title_short)
|
||||
|
||||
# 清理旧的图片文件和 figures_json,避免重新总结时残留
|
||||
import time as _time
|
||||
_t_cleanup_start = _time.monotonic()
|
||||
_cleanup_old_images(db, paper)
|
||||
_t_cleanup_end = _time.monotonic()
|
||||
logger.info(" [%s] 清理旧数据: %.2fs", arxiv_id, _t_cleanup_end - _t_cleanup_start)
|
||||
|
||||
raw_output = ""
|
||||
try:
|
||||
meta_path = write_meta_json(paper)
|
||||
await download_pdf(arxiv_id, paper.pdf_url)
|
||||
_t0 = _time.monotonic()
|
||||
|
||||
meta_path = write_meta_json(paper)
|
||||
_t1 = _time.monotonic()
|
||||
logger.info(" [%s] meta.json: %.2fs", arxiv_id, _t1 - _t0)
|
||||
|
||||
await download_pdf(arxiv_id, paper.pdf_url)
|
||||
_t2 = _time.monotonic()
|
||||
logger.info(" [%s] 下载PDF: %.2fs", arxiv_id, _t2 - _t1)
|
||||
|
||||
logger.info(" [%s] 调用 pi 生成总结...", arxiv_id)
|
||||
json_data, raw_output = await _generate_with_retry(
|
||||
arxiv_id, meta_path, TMP_DIR / arxiv_id / "paper.pdf",
|
||||
pdf_mode=pdf_mode,
|
||||
)
|
||||
_t3 = _time.monotonic()
|
||||
logger.info(" [%s] pi生成: %.2fs", arxiv_id, _t3 - _t2)
|
||||
|
||||
quality = _persist_summary(db, paper, json_data, raw_output)
|
||||
_t4 = _time.monotonic()
|
||||
logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3)
|
||||
|
||||
logger.info("Summarize done: %s quality=%s", arxiv_id, quality)
|
||||
logger.info("✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0)
|
||||
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
|
||||
|
||||
except Exception as exc:
|
||||
@@ -588,42 +699,67 @@ async def summarize_batch(
|
||||
"total": 0,
|
||||
}
|
||||
|
||||
# 并发控制
|
||||
semaphore = asyncio.Semaphore(settings.SUMMARY_CONCURRENCY)
|
||||
# 并发控制:worker 模式,避免 573 个协程同时打开 DB 连接耗尽连接池
|
||||
concurrency = settings.SUMMARY_CONCURRENCY
|
||||
make_session = _session_factory or SessionLocal
|
||||
|
||||
async def _process_paper(paper: Paper) -> dict:
|
||||
paper_db = make_session()
|
||||
try:
|
||||
p = paper_db.execute(
|
||||
select(Paper)
|
||||
.where(Paper.id == paper.id)
|
||||
.options(*PAPER_DEFAULT_LOAD)
|
||||
).unique().scalar_one_or_none()
|
||||
return await summarize_one(paper_db, p, semaphore, pdf_mode=pdf_mode)
|
||||
finally:
|
||||
paper_db.close()
|
||||
# 进度追踪
|
||||
progress = {"done": 0, "failed": 0, "skipped": 0}
|
||||
paper_queue: asyncio.Queue[Paper | None] = asyncio.Queue()
|
||||
for p in papers:
|
||||
paper_queue.put_nowait(p)
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[_process_paper(p) for p in papers],
|
||||
async def _worker() -> list[dict]:
|
||||
results: list[dict] = []
|
||||
while True:
|
||||
paper = paper_queue.get_nowait() if not paper_queue.empty() else None
|
||||
if paper is None:
|
||||
break
|
||||
paper_db = make_session()
|
||||
try:
|
||||
p = paper_db.execute(
|
||||
select(Paper)
|
||||
.where(Paper.id == paper.id)
|
||||
.options(*PAPER_DEFAULT_LOAD)
|
||||
).unique().scalar_one_or_none()
|
||||
result = await summarize_one(paper_db, p, pdf_mode=pdf_mode)
|
||||
status = result.get("status", "failed")
|
||||
progress[status] = progress.get(status, 0) + 1
|
||||
finished = sum(progress.values())
|
||||
logger.info(
|
||||
"📊 进度: %d/%d (✅%d ❌%d ⏭️%d) — %s",
|
||||
finished, total,
|
||||
progress["done"], progress["failed"], progress["skipped"],
|
||||
paper.arxiv_id,
|
||||
)
|
||||
results.append(result)
|
||||
except Exception as exc:
|
||||
logger.error("Worker error: %s", exc)
|
||||
results.append({"status": "failed", "error": str(exc)})
|
||||
finally:
|
||||
paper_db.close()
|
||||
return results
|
||||
|
||||
worker_results = await asyncio.gather(
|
||||
*[_worker() for _ in range(concurrency)],
|
||||
return_exceptions=True,
|
||||
)
|
||||
results = []
|
||||
for r in worker_results:
|
||||
if isinstance(r, Exception):
|
||||
logger.error("Unexpected error in batch: %s", r)
|
||||
results.append(r)
|
||||
elif isinstance(r, list):
|
||||
results.extend(r)
|
||||
|
||||
# 统计结果
|
||||
done = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
# 统计结果(progress 已在 worker 中实时更新)
|
||||
done = progress["done"]
|
||||
failed = progress["failed"]
|
||||
skipped = progress["skipped"]
|
||||
for r in results:
|
||||
if isinstance(r, Exception):
|
||||
logger.error("Unexpected error in batch: %s", r)
|
||||
failed += 1
|
||||
elif isinstance(r, dict):
|
||||
if r.get("status") == "done":
|
||||
done += 1
|
||||
elif r.get("status") == "skipped":
|
||||
skipped += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
log_entry.status = "success" if failed == 0 else "failed"
|
||||
log_entry.papers_found = total
|
||||
|
||||
Reference in New Issue
Block a user