Files

352 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI 总结编排服务 — 协调生成器、持久化、批量处理的顶层入口。"""
from __future__ import annotations
import asyncio
import logging
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.config import settings
from app.database import SessionLocal
from app.exceptions import ConflictError, NotFoundError
from app.models import (
PAPER_DEFAULT_LOAD,
CrawlLog,
Paper,
SummaryState,
SummaryStatus,
TaskLock,
get_paper_by_arxiv_id,
get_paper_by_id,
)
from app.services.pdf_downloader import download_pdf
from app.services.summary_utils import write_meta_json
from app.services.summary_generator import (
_generate_with_retry,
)
from app.services.summary_persister import (
_cleanup_old_images,
_handle_summary_failure,
_persist_summary,
_run_post_processing,
)
from app.utils import TMP_DIR, release_lock, truncate_error, utc_now
logger = logging.getLogger(__name__)
# ── 单篇总结 ────────────────────────────────────────────────────────────
async def summarize_one(
db: Session,
paper: Paper,
*,
force: bool = False,
pdf_mode: str = "auto",
) -> dict:
"""总结单篇论文的完整流程。"""
arxiv_id = paper.arxiv_id
# 获取或创建 summary_status
if not paper.summary_status:
db.add(SummaryStatus(paper_id=paper.id, status=SummaryState.PENDING))
db.commit()
db.refresh(paper)
status = paper.summary_status
# 跳过已完成的(除非 force
if status.status == SummaryState.DONE and not force:
return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"}
# 跳过 permanent_failure(除非 force
if status.status == SummaryState.PERMANENT_FAILURE and not force:
return {
"arxiv_id": arxiv_id,
"status": "skipped",
"reason": "permanent_failure",
}
return await _do_summarize_one(db, paper, pdf_mode=pdf_mode)
async def _do_summarize_one(db: Session, paper: Paper, pdf_mode: str = "auto") -> dict:
"""实际的单篇总结执行(在 semaphore 保护下)。"""
arxiv_id = paper.arxiv_id
title_short = (paper.title_en or "")[:50]
# 状态 → processing
paper.summary_status.status = SummaryState.PROCESSING
paper.summary_status.started_at = utc_now()
db.commit()
logger.info("▶ [%s] 开始总结: %s", arxiv_id, title_short)
# 清理旧的图片文件和 figures_json,避免重新总结时残留
import time as _time
_t_cleanup_start = _time.monotonic()
_cleanup_old_images(db, paper)
_t_cleanup_end = _time.monotonic()
logger.info(" [%s] 清理旧数据: %.2fs", arxiv_id, _t_cleanup_end - _t_cleanup_start)
raw_output = ""
try:
_t0 = _time.monotonic()
meta_path = write_meta_json(paper)
_t1 = _time.monotonic()
logger.info(" [%s] meta.json: %.2fs", arxiv_id, _t1 - _t0)
await download_pdf(arxiv_id, paper.pdf_url)
_t2 = _time.monotonic()
logger.info(" [%s] 下载PDF: %.2fs", arxiv_id, _t2 - _t1)
logger.info(" [%s] 调用 pi 生成总结...", arxiv_id)
json_data, raw_output = await _generate_with_retry(
arxiv_id,
meta_path,
TMP_DIR / arxiv_id / "paper.pdf",
pdf_mode=pdf_mode,
)
_t3 = _time.monotonic()
logger.info(" [%s] pi生成: %.2fs", arxiv_id, _t3 - _t2)
quality, schema = _persist_summary(db, paper, json_data, raw_output)
_t4 = _time.monotonic()
logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3)
# 后处理(图片提取 + ChromaDB 索引)搬到线程池跑,避免 CPU 密集推理冻结
# 事件循环。paper 字段在此(事件循环线程)提取成纯值再传入,规避 worker
# 线程跨线程访问 ORM 的 DetachedInstanceError。DocLayout 推理由单例的
# threading.Lock 串行化,并发 worker 不会同时压模型。
paper_meta = {
"title_en": paper.title_en or "",
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
}
_t5 = _time.monotonic()
try:
await asyncio.to_thread(_run_post_processing, arxiv_id, schema, paper_meta)
except Exception:
# 双保险:_run_post_processing 内部已 try/except,此处兜底,
# 确保后处理失败绝不影响已 DONE 的总结。
logger.warning("Post-processing error for %s", arxiv_id, exc_info=True)
_t6 = _time.monotonic()
logger.info(" [%s] 后处理(线程池): %.2fs", arxiv_id, _t6 - _t5)
logger.info(
"✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t6 - _t0
)
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
except Exception as exc:
# 从异常对象获取 raw_output_generate_with_retry 失败时仍有输出)
fail_output = getattr(exc, "raw_output", raw_output)
return _handle_summary_failure(db, paper, exc, fail_output)
finally:
pass # cleanup_tmp(arxiv_id) # 暂时禁用,保留 PDF 用于调试图片提取
# ── 单篇入口 ────────────────────────────────────────────────────────────
async def summarize_single(
db: Session,
arxiv_id: str,
*,
force: bool = True,
pdf_mode: str = "auto",
_session_factory=None,
) -> dict:
"""单篇总结入口(供 admin 路由和 CLI 调用)。
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
paper = get_paper_by_arxiv_id(db, arxiv_id)
if not paper:
raise NotFoundError(f"Paper not found: {arxiv_id}")
make_session = _session_factory or SessionLocal
# 每篇用独立 session 避免并发问题
paper_db = make_session()
try:
paper_in_new_session = get_paper_by_arxiv_id(paper_db, arxiv_id)
result = await summarize_one(
paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode
)
finally:
paper_db.close()
return result
# ── 批量总结 ────────────────────────────────────────────────────────────
async def summarize_batch(
db: Session,
arxiv_ids: list[str] | None = None,
*,
pdf_mode: str = "auto",
_session_factory=None,
) -> dict:
"""批量总结入口。arxiv_ids=None 时处理所有 pending 论文。
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
now = utc_now()
# TaskLock 防重入
lock = TaskLock(
task="summarize",
lock_key="batch",
status="running",
owner="summarize_batch",
acquired_at=now,
)
try:
db.add(lock)
db.commit()
except IntegrityError:
db.rollback()
logger.warning("Summarize batch already running (lock conflict)")
raise ConflictError("summarize batch already running")
# CrawlLog
log_entry = CrawlLog(
task="summarize",
status="running",
started_at=now,
)
db.add(log_entry)
db.commit()
try:
# 查询待总结论文
stmt = select(Paper).options(*PAPER_DEFAULT_LOAD)
if arxiv_ids:
stmt = stmt.where(Paper.arxiv_id.in_(arxiv_ids))
else:
# 只处理 pending 或 failed(可重试的)
stmt = stmt.join(SummaryStatus).where(
SummaryStatus.status.in_([SummaryState.PENDING, SummaryState.FAILED])
)
papers = db.execute(stmt).unique().scalars().all()
total = len(papers)
logger.info("Summarize batch: %d papers to process", total)
if total == 0:
log_entry.status = "success"
log_entry.papers_found = 0
log_entry.papers_new = 0
log_entry.completed_at = utc_now()
release_lock(db, lock)
return {
"status": "success",
"done": 0,
"failed": 0,
"skipped": 0,
"total": 0,
}
# 并发控制:worker 模式,避免 573 个协程同时打开 DB 连接耗尽连接池
concurrency = settings.SUMMARY_CONCURRENCY
make_session = _session_factory or SessionLocal
# 进度追踪
progress = {"done": 0, "failed": 0, "skipped": 0}
paper_queue: asyncio.Queue[Paper | None] = asyncio.Queue()
for p in papers:
paper_queue.put_nowait(p)
async def _worker() -> list[dict]:
results: list[dict] = []
while True:
paper = paper_queue.get_nowait() if not paper_queue.empty() else None
if paper is None:
break
paper_db = make_session()
try:
p = get_paper_by_id(paper_db, paper.id)
result = await summarize_one(paper_db, p, pdf_mode=pdf_mode)
status = result.get("status", "failed")
progress[status] = progress.get(status, 0) + 1
finished = sum(progress.values())
logger.info(
"📊 进度: %d/%d (✅%d%d ⏭️%d) — %s",
finished,
total,
progress["done"],
progress["failed"],
progress["skipped"],
paper.arxiv_id,
)
results.append(result)
except Exception as exc:
logger.error("Worker error: %s", exc)
results.append({"status": "failed", "error": str(exc)})
finally:
paper_db.close()
return results
worker_results = await asyncio.gather(
*[_worker() for _ in range(concurrency)],
return_exceptions=True,
)
results = []
for r in worker_results:
if isinstance(r, Exception):
logger.error("Unexpected error in batch: %s", r)
results.append(r)
elif isinstance(r, list):
results.extend(r)
# 统计结果(progress 已在 worker 中实时更新)
done = progress["done"]
failed = progress["failed"]
skipped = progress["skipped"]
for r in results:
if isinstance(r, Exception):
logger.error("Unexpected error in batch: %s", r)
failed += 1
log_entry.status = "success" if failed == 0 else "failed"
log_entry.papers_found = total
log_entry.papers_new = done
log_entry.completed_at = utc_now()
db.commit()
logger.info(
"Summarize batch done: total=%d done=%d failed=%d skipped=%d",
total,
done,
failed,
skipped,
)
return {
"status": "success" if failed == 0 else "partial",
"total": total,
"done": done,
"failed": failed,
"skipped": skipped,
}
except Exception as exc:
logger.exception("Summarize batch failed")
log_entry.status = "failed"
log_entry.error = truncate_error(exc, limit=2000)
log_entry.completed_at = utc_now()
db.commit()
return {"status": "failed", "error": truncate_error(exc)}
finally:
release_lock(db, lock)