Files
daily-paper/app/services/summarizer.py
T
Rain-Bus 21f16e6756 feat: refactor summarizer and PDF extraction pipeline
- Split summarizer into summary_generator and summary_persister modules
- Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection
- Add layout_detector service for PicoDet-S_layout_3cls integration
- Add exceptions module with ConflictError and NotFoundError
- Improve admin dashboard with better statistics and task management
- Add design review document with system optimization suggestions
- Add new tests for crawler, pdf_downloader, pipeline, and summary_utils
- Update dependencies and configuration
- Clean up dead code and improve error handling
2026-06-13 13:16:47 +08:00

332 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI 总结编排服务 — 协调生成器、持久化、批量处理的顶层入口。"""
from __future__ import annotations
import asyncio
import logging
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.config import settings
from app.database import SessionLocal
from app.exceptions import ConflictError, NotFoundError
from app.models import (
PAPER_DEFAULT_LOAD,
CrawlLog,
Paper,
SummaryState,
SummaryStatus,
TaskLock,
get_paper_by_arxiv_id,
get_paper_by_id,
)
from app.services.pdf_downloader import download_pdf
from app.services.summary_utils import write_meta_json
from app.services.summary_generator import (
_generate_with_retry,
)
from app.services.summary_persister import (
_cleanup_old_images,
_handle_summary_failure,
_persist_summary,
)
from app.utils import TMP_DIR, release_lock, truncate_error, utc_now
logger = logging.getLogger(__name__)
# ── 单篇总结 ────────────────────────────────────────────────────────────
async def summarize_one(
db: Session,
paper: Paper,
*,
force: bool = False,
pdf_mode: str = "auto",
) -> dict:
"""总结单篇论文的完整流程。"""
arxiv_id = paper.arxiv_id
# 获取或创建 summary_status
if not paper.summary_status:
db.add(SummaryStatus(paper_id=paper.id, status=SummaryState.PENDING))
db.commit()
db.refresh(paper)
status = paper.summary_status
# 跳过已完成的(除非 force
if status.status == SummaryState.DONE and not force:
return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"}
# 跳过 permanent_failure(除非 force
if status.status == SummaryState.PERMANENT_FAILURE and not force:
return {
"arxiv_id": arxiv_id,
"status": "skipped",
"reason": "permanent_failure",
}
return await _do_summarize_one(db, paper, pdf_mode=pdf_mode)
async def _do_summarize_one(db: Session, paper: Paper, pdf_mode: str = "auto") -> dict:
"""实际的单篇总结执行(在 semaphore 保护下)。"""
arxiv_id = paper.arxiv_id
title_short = (paper.title_en or "")[:50]
# 状态 → processing
paper.summary_status.status = SummaryState.PROCESSING
paper.summary_status.started_at = utc_now()
db.commit()
logger.info("▶ [%s] 开始总结: %s", arxiv_id, title_short)
# 清理旧的图片文件和 figures_json,避免重新总结时残留
import time as _time
_t_cleanup_start = _time.monotonic()
_cleanup_old_images(db, paper)
_t_cleanup_end = _time.monotonic()
logger.info(" [%s] 清理旧数据: %.2fs", arxiv_id, _t_cleanup_end - _t_cleanup_start)
raw_output = ""
try:
_t0 = _time.monotonic()
meta_path = write_meta_json(paper)
_t1 = _time.monotonic()
logger.info(" [%s] meta.json: %.2fs", arxiv_id, _t1 - _t0)
await download_pdf(arxiv_id, paper.pdf_url)
_t2 = _time.monotonic()
logger.info(" [%s] 下载PDF: %.2fs", arxiv_id, _t2 - _t1)
logger.info(" [%s] 调用 pi 生成总结...", arxiv_id)
json_data, raw_output = await _generate_with_retry(
arxiv_id,
meta_path,
TMP_DIR / arxiv_id / "paper.pdf",
pdf_mode=pdf_mode,
)
_t3 = _time.monotonic()
logger.info(" [%s] pi生成: %.2fs", arxiv_id, _t3 - _t2)
quality = _persist_summary(db, paper, json_data, raw_output)
_t4 = _time.monotonic()
logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3)
logger.info(
"✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0
)
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
except Exception as exc:
# 从异常对象获取 raw_output_generate_with_retry 失败时仍有输出)
fail_output = getattr(exc, "raw_output", raw_output)
return _handle_summary_failure(db, paper, exc, fail_output)
finally:
pass # cleanup_tmp(arxiv_id) # 暂时禁用,保留 PDF 用于调试图片提取
# ── 单篇入口 ────────────────────────────────────────────────────────────
async def summarize_single(
db: Session,
arxiv_id: str,
*,
force: bool = True,
pdf_mode: str = "auto",
_session_factory=None,
) -> dict:
"""单篇总结入口(供 admin 路由和 CLI 调用)。
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
paper = get_paper_by_arxiv_id(db, arxiv_id)
if not paper:
raise NotFoundError(f"Paper not found: {arxiv_id}")
make_session = _session_factory or SessionLocal
# 每篇用独立 session 避免并发问题
paper_db = make_session()
try:
paper_in_new_session = get_paper_by_arxiv_id(paper_db, arxiv_id)
result = await summarize_one(
paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode
)
finally:
paper_db.close()
return result
# ── 批量总结 ────────────────────────────────────────────────────────────
async def summarize_batch(
db: Session,
arxiv_ids: list[str] | None = None,
*,
pdf_mode: str = "auto",
_session_factory=None,
) -> dict:
"""批量总结入口。arxiv_ids=None 时处理所有 pending 论文。
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
now = utc_now()
# TaskLock 防重入
lock = TaskLock(
task="summarize",
lock_key="batch",
status="running",
owner="summarize_batch",
acquired_at=now,
)
try:
db.add(lock)
db.commit()
except IntegrityError:
db.rollback()
logger.warning("Summarize batch already running (lock conflict)")
raise ConflictError("summarize batch already running")
# CrawlLog
log_entry = CrawlLog(
task="summarize",
status="running",
started_at=now,
)
db.add(log_entry)
db.commit()
try:
# 查询待总结论文
stmt = select(Paper).options(*PAPER_DEFAULT_LOAD)
if arxiv_ids:
stmt = stmt.where(Paper.arxiv_id.in_(arxiv_ids))
else:
# 只处理 pending 或 failed(可重试的)
stmt = stmt.join(SummaryStatus).where(
SummaryStatus.status.in_([SummaryState.PENDING, SummaryState.FAILED])
)
papers = db.execute(stmt).unique().scalars().all()
total = len(papers)
logger.info("Summarize batch: %d papers to process", total)
if total == 0:
log_entry.status = "success"
log_entry.papers_found = 0
log_entry.papers_new = 0
log_entry.completed_at = utc_now()
release_lock(db, lock)
return {
"status": "success",
"done": 0,
"failed": 0,
"skipped": 0,
"total": 0,
}
# 并发控制:worker 模式,避免 573 个协程同时打开 DB 连接耗尽连接池
concurrency = settings.SUMMARY_CONCURRENCY
make_session = _session_factory or SessionLocal
# 进度追踪
progress = {"done": 0, "failed": 0, "skipped": 0}
paper_queue: asyncio.Queue[Paper | None] = asyncio.Queue()
for p in papers:
paper_queue.put_nowait(p)
async def _worker() -> list[dict]:
results: list[dict] = []
while True:
paper = paper_queue.get_nowait() if not paper_queue.empty() else None
if paper is None:
break
paper_db = make_session()
try:
p = get_paper_by_id(paper_db, paper.id)
result = await summarize_one(paper_db, p, pdf_mode=pdf_mode)
status = result.get("status", "failed")
progress[status] = progress.get(status, 0) + 1
finished = sum(progress.values())
logger.info(
"📊 进度: %d/%d (✅%d%d ⏭️%d) — %s",
finished,
total,
progress["done"],
progress["failed"],
progress["skipped"],
paper.arxiv_id,
)
results.append(result)
except Exception as exc:
logger.error("Worker error: %s", exc)
results.append({"status": "failed", "error": str(exc)})
finally:
paper_db.close()
return results
worker_results = await asyncio.gather(
*[_worker() for _ in range(concurrency)],
return_exceptions=True,
)
results = []
for r in worker_results:
if isinstance(r, Exception):
logger.error("Unexpected error in batch: %s", r)
results.append(r)
elif isinstance(r, list):
results.extend(r)
# 统计结果(progress 已在 worker 中实时更新)
done = progress["done"]
failed = progress["failed"]
skipped = progress["skipped"]
for r in results:
if isinstance(r, Exception):
logger.error("Unexpected error in batch: %s", r)
failed += 1
log_entry.status = "success" if failed == 0 else "failed"
log_entry.papers_found = total
log_entry.papers_new = done
log_entry.completed_at = utc_now()
db.commit()
logger.info(
"Summarize batch done: total=%d done=%d failed=%d skipped=%d",
total,
done,
failed,
skipped,
)
return {
"status": "success" if failed == 0 else "partial",
"total": total,
"done": done,
"failed": failed,
"skipped": skipped,
}
except Exception as exc:
logger.exception("Summarize batch failed")
log_entry.status = "failed"
log_entry.error = truncate_error(exc, limit=2000)
log_entry.completed_at = utc_now()
db.commit()
return {"status": "failed", "error": truncate_error(exc)}
finally:
release_lock(db, lock)