21f16e6756
- Split summarizer into summary_generator and summary_persister modules - Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection - Add layout_detector service for PicoDet-S_layout_3cls integration - Add exceptions module with ConflictError and NotFoundError - Improve admin dashboard with better statistics and task management - Add design review document with system optimization suggestions - Add new tests for crawler, pdf_downloader, pipeline, and summary_utils - Update dependencies and configuration - Clean up dead code and improve error handling
332 lines
11 KiB
Python
332 lines
11 KiB
Python
"""AI 总结编排服务 — 协调生成器、持久化、批量处理的顶层入口。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
|
||
from sqlalchemy import select
|
||
from sqlalchemy.exc import IntegrityError
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.config import settings
|
||
from app.database import SessionLocal
|
||
from app.exceptions import ConflictError, NotFoundError
|
||
from app.models import (
|
||
PAPER_DEFAULT_LOAD,
|
||
CrawlLog,
|
||
Paper,
|
||
SummaryState,
|
||
SummaryStatus,
|
||
TaskLock,
|
||
get_paper_by_arxiv_id,
|
||
get_paper_by_id,
|
||
)
|
||
from app.services.pdf_downloader import download_pdf
|
||
from app.services.summary_utils import write_meta_json
|
||
from app.services.summary_generator import (
|
||
_generate_with_retry,
|
||
)
|
||
from app.services.summary_persister import (
|
||
_cleanup_old_images,
|
||
_handle_summary_failure,
|
||
_persist_summary,
|
||
)
|
||
from app.utils import TMP_DIR, release_lock, truncate_error, utc_now
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ── 单篇总结 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
async def summarize_one(
|
||
db: Session,
|
||
paper: Paper,
|
||
*,
|
||
force: bool = False,
|
||
pdf_mode: str = "auto",
|
||
) -> dict:
|
||
"""总结单篇论文的完整流程。"""
|
||
arxiv_id = paper.arxiv_id
|
||
|
||
# 获取或创建 summary_status
|
||
if not paper.summary_status:
|
||
db.add(SummaryStatus(paper_id=paper.id, status=SummaryState.PENDING))
|
||
db.commit()
|
||
db.refresh(paper)
|
||
|
||
status = paper.summary_status
|
||
|
||
# 跳过已完成的(除非 force)
|
||
if status.status == SummaryState.DONE and not force:
|
||
return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"}
|
||
|
||
# 跳过 permanent_failure(除非 force)
|
||
if status.status == SummaryState.PERMANENT_FAILURE and not force:
|
||
return {
|
||
"arxiv_id": arxiv_id,
|
||
"status": "skipped",
|
||
"reason": "permanent_failure",
|
||
}
|
||
|
||
return await _do_summarize_one(db, paper, pdf_mode=pdf_mode)
|
||
|
||
|
||
async def _do_summarize_one(db: Session, paper: Paper, pdf_mode: str = "auto") -> dict:
|
||
"""实际的单篇总结执行(在 semaphore 保护下)。"""
|
||
arxiv_id = paper.arxiv_id
|
||
title_short = (paper.title_en or "")[:50]
|
||
|
||
# 状态 → processing
|
||
paper.summary_status.status = SummaryState.PROCESSING
|
||
paper.summary_status.started_at = utc_now()
|
||
db.commit()
|
||
|
||
logger.info("▶ [%s] 开始总结: %s", arxiv_id, title_short)
|
||
|
||
# 清理旧的图片文件和 figures_json,避免重新总结时残留
|
||
import time as _time
|
||
|
||
_t_cleanup_start = _time.monotonic()
|
||
_cleanup_old_images(db, paper)
|
||
_t_cleanup_end = _time.monotonic()
|
||
logger.info(" [%s] 清理旧数据: %.2fs", arxiv_id, _t_cleanup_end - _t_cleanup_start)
|
||
|
||
raw_output = ""
|
||
try:
|
||
_t0 = _time.monotonic()
|
||
|
||
meta_path = write_meta_json(paper)
|
||
_t1 = _time.monotonic()
|
||
logger.info(" [%s] meta.json: %.2fs", arxiv_id, _t1 - _t0)
|
||
|
||
await download_pdf(arxiv_id, paper.pdf_url)
|
||
_t2 = _time.monotonic()
|
||
logger.info(" [%s] 下载PDF: %.2fs", arxiv_id, _t2 - _t1)
|
||
|
||
logger.info(" [%s] 调用 pi 生成总结...", arxiv_id)
|
||
json_data, raw_output = await _generate_with_retry(
|
||
arxiv_id,
|
||
meta_path,
|
||
TMP_DIR / arxiv_id / "paper.pdf",
|
||
pdf_mode=pdf_mode,
|
||
)
|
||
_t3 = _time.monotonic()
|
||
logger.info(" [%s] pi生成: %.2fs", arxiv_id, _t3 - _t2)
|
||
|
||
quality = _persist_summary(db, paper, json_data, raw_output)
|
||
_t4 = _time.monotonic()
|
||
logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3)
|
||
|
||
logger.info(
|
||
"✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0
|
||
)
|
||
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
|
||
|
||
except Exception as exc:
|
||
# 从异常对象获取 raw_output(_generate_with_retry 失败时仍有输出)
|
||
fail_output = getattr(exc, "raw_output", raw_output)
|
||
return _handle_summary_failure(db, paper, exc, fail_output)
|
||
|
||
finally:
|
||
pass # cleanup_tmp(arxiv_id) # 暂时禁用,保留 PDF 用于调试图片提取
|
||
|
||
|
||
# ── 单篇入口 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
async def summarize_single(
|
||
db: Session,
|
||
arxiv_id: str,
|
||
*,
|
||
force: bool = True,
|
||
pdf_mode: str = "auto",
|
||
_session_factory=None,
|
||
) -> dict:
|
||
"""单篇总结入口(供 admin 路由和 CLI 调用)。
|
||
|
||
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
|
||
"""
|
||
paper = get_paper_by_arxiv_id(db, arxiv_id)
|
||
if not paper:
|
||
raise NotFoundError(f"Paper not found: {arxiv_id}")
|
||
|
||
make_session = _session_factory or SessionLocal
|
||
|
||
# 每篇用独立 session 避免并发问题
|
||
paper_db = make_session()
|
||
try:
|
||
paper_in_new_session = get_paper_by_arxiv_id(paper_db, arxiv_id)
|
||
result = await summarize_one(
|
||
paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode
|
||
)
|
||
finally:
|
||
paper_db.close()
|
||
|
||
return result
|
||
|
||
|
||
# ── 批量总结 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
async def summarize_batch(
|
||
db: Session,
|
||
arxiv_ids: list[str] | None = None,
|
||
*,
|
||
pdf_mode: str = "auto",
|
||
_session_factory=None,
|
||
) -> dict:
|
||
"""批量总结入口。arxiv_ids=None 时处理所有 pending 论文。
|
||
|
||
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
|
||
"""
|
||
now = utc_now()
|
||
|
||
# TaskLock 防重入
|
||
lock = TaskLock(
|
||
task="summarize",
|
||
lock_key="batch",
|
||
status="running",
|
||
owner="summarize_batch",
|
||
acquired_at=now,
|
||
)
|
||
try:
|
||
db.add(lock)
|
||
db.commit()
|
||
except IntegrityError:
|
||
db.rollback()
|
||
logger.warning("Summarize batch already running (lock conflict)")
|
||
raise ConflictError("summarize batch already running")
|
||
|
||
# CrawlLog
|
||
log_entry = CrawlLog(
|
||
task="summarize",
|
||
status="running",
|
||
started_at=now,
|
||
)
|
||
db.add(log_entry)
|
||
db.commit()
|
||
|
||
try:
|
||
# 查询待总结论文
|
||
stmt = select(Paper).options(*PAPER_DEFAULT_LOAD)
|
||
if arxiv_ids:
|
||
stmt = stmt.where(Paper.arxiv_id.in_(arxiv_ids))
|
||
else:
|
||
# 只处理 pending 或 failed(可重试的)
|
||
stmt = stmt.join(SummaryStatus).where(
|
||
SummaryStatus.status.in_([SummaryState.PENDING, SummaryState.FAILED])
|
||
)
|
||
|
||
papers = db.execute(stmt).unique().scalars().all()
|
||
total = len(papers)
|
||
logger.info("Summarize batch: %d papers to process", total)
|
||
|
||
if total == 0:
|
||
log_entry.status = "success"
|
||
log_entry.papers_found = 0
|
||
log_entry.papers_new = 0
|
||
log_entry.completed_at = utc_now()
|
||
release_lock(db, lock)
|
||
return {
|
||
"status": "success",
|
||
"done": 0,
|
||
"failed": 0,
|
||
"skipped": 0,
|
||
"total": 0,
|
||
}
|
||
|
||
# 并发控制:worker 模式,避免 573 个协程同时打开 DB 连接耗尽连接池
|
||
concurrency = settings.SUMMARY_CONCURRENCY
|
||
make_session = _session_factory or SessionLocal
|
||
|
||
# 进度追踪
|
||
progress = {"done": 0, "failed": 0, "skipped": 0}
|
||
paper_queue: asyncio.Queue[Paper | None] = asyncio.Queue()
|
||
for p in papers:
|
||
paper_queue.put_nowait(p)
|
||
|
||
async def _worker() -> list[dict]:
|
||
results: list[dict] = []
|
||
while True:
|
||
paper = paper_queue.get_nowait() if not paper_queue.empty() else None
|
||
if paper is None:
|
||
break
|
||
paper_db = make_session()
|
||
try:
|
||
p = get_paper_by_id(paper_db, paper.id)
|
||
result = await summarize_one(paper_db, p, pdf_mode=pdf_mode)
|
||
status = result.get("status", "failed")
|
||
progress[status] = progress.get(status, 0) + 1
|
||
finished = sum(progress.values())
|
||
logger.info(
|
||
"📊 进度: %d/%d (✅%d ❌%d ⏭️%d) — %s",
|
||
finished,
|
||
total,
|
||
progress["done"],
|
||
progress["failed"],
|
||
progress["skipped"],
|
||
paper.arxiv_id,
|
||
)
|
||
results.append(result)
|
||
except Exception as exc:
|
||
logger.error("Worker error: %s", exc)
|
||
results.append({"status": "failed", "error": str(exc)})
|
||
finally:
|
||
paper_db.close()
|
||
return results
|
||
|
||
worker_results = await asyncio.gather(
|
||
*[_worker() for _ in range(concurrency)],
|
||
return_exceptions=True,
|
||
)
|
||
results = []
|
||
for r in worker_results:
|
||
if isinstance(r, Exception):
|
||
logger.error("Unexpected error in batch: %s", r)
|
||
results.append(r)
|
||
elif isinstance(r, list):
|
||
results.extend(r)
|
||
|
||
# 统计结果(progress 已在 worker 中实时更新)
|
||
done = progress["done"]
|
||
failed = progress["failed"]
|
||
skipped = progress["skipped"]
|
||
for r in results:
|
||
if isinstance(r, Exception):
|
||
logger.error("Unexpected error in batch: %s", r)
|
||
failed += 1
|
||
|
||
log_entry.status = "success" if failed == 0 else "failed"
|
||
log_entry.papers_found = total
|
||
log_entry.papers_new = done
|
||
log_entry.completed_at = utc_now()
|
||
db.commit()
|
||
|
||
logger.info(
|
||
"Summarize batch done: total=%d done=%d failed=%d skipped=%d",
|
||
total,
|
||
done,
|
||
failed,
|
||
skipped,
|
||
)
|
||
return {
|
||
"status": "success" if failed == 0 else "partial",
|
||
"total": total,
|
||
"done": done,
|
||
"failed": failed,
|
||
"skipped": skipped,
|
||
}
|
||
|
||
except Exception as exc:
|
||
logger.exception("Summarize batch failed")
|
||
log_entry.status = "failed"
|
||
log_entry.error = truncate_error(exc, limit=2000)
|
||
log_entry.completed_at = utc_now()
|
||
db.commit()
|
||
return {"status": "failed", "error": truncate_error(exc)}
|
||
|
||
finally:
|
||
release_lock(db, lock)
|