Files
daily-paper/app/services/summarizer.py
T

520 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI 总结编排服务 — 协调 PDF 下载、pi CLI 调用、JSON 校验、DB 写入、语义索引。"""
from __future__ import annotations
import json
import logging
import shutil
from datetime import datetime, timezone
from pathlib import Path
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.database import SessionLocal
from app.models import (
CrawlLog,
Paper,
PaperSummary,
PaperTag,
SummaryStatus,
TaskLock,
)
from app.services.image_extractor import extract_images_from_source
from app.services.pdf_downloader import (
PdfDownloadError,
cleanup_tmp,
download_pdf,
paper_dir,
)
from app.services.pi_client import (
JsonNotFoundError,
PiProcessError,
PiTimeoutError,
call_pi,
extract_json,
write_meta_json,
)
from app.services.schemas import (
SummarySchema,
assess_quality,
classify_validation_error,
flatten_for_db,
)
from app.utils import PAPERS_DIR, release_lock
logger = logging.getLogger(__name__)
# ── 错误分类 ────────────────────────────────────────────────────────────
def _classify_error(exc: Exception) -> str:
"""将异常映射到 error_type 枚举值。"""
if isinstance(exc, PdfDownloadError):
return "pdf_download_failed"
if isinstance(exc, PiTimeoutError):
return "timeout"
if isinstance(exc, PiProcessError):
return "process_error"
if isinstance(exc, JsonNotFoundError):
return "json_not_found"
if isinstance(exc, json.JSONDecodeError):
return "json_invalid"
if isinstance(exc, ValidationError):
return classify_validation_error(exc)
return "unknown"
# ── FTS5 文本构建 ───────────────────────────────────────────────────────
def _build_fts_summary_text(schema: SummarySchema) -> str:
"""拼接用于 FTS5 索引的总结文本。"""
parts = [
schema.one_line or "",
schema.motivation.problem or "",
schema.motivation.goal or "",
schema.method_overview if hasattr(schema, "method_overview") else "",
schema.method.overview or "",
schema.method.key_idea or "",
" ".join(schema.results.main_findings or []),
]
return " ".join(p for p in parts if p)
# ── DB 更新 ─────────────────────────────────────────────────────────────
def _update_summary_in_db(
db: Session,
paper: Paper,
schema: SummarySchema,
quality: str,
raw_output: str,
) -> None:
"""将校验后的总结写入 DBpaper_summaries + papers + paper_tags + FTS5。"""
from sqlalchemy import text
now = datetime.now(timezone.utc)
# 1. paper_summariesupsert
existing = db.get(PaperSummary, paper.id)
flat = flatten_for_db(schema)
if existing:
for k, v in flat.items():
setattr(existing, k, v)
else:
db.add(PaperSummary(paper_id=paper.id, **flat))
# 2. papers 表
paper.title_zh = schema.title_zh
paper.summary_quality = quality
p_dir = paper_dir(paper.arxiv_id)
paper.summary_path = str(p_dir / "summary.json")
paper.raw_output_path = str(p_dir / "raw_output.txt")
# 3. AI 标签
existing_tag_names = {t.tag for t in paper.tags}
for tag_name in schema.tags:
if tag_name not in existing_tag_names:
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
existing_tag_names.add(tag_name)
# 4. FTS5 更新
summary_text = _build_fts_summary_text(schema)
db.execute(
text(
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
"WHERE rowid=:paper_id"
),
{
"title_zh": schema.title_zh,
"summary_text": summary_text,
"paper_id": paper.id,
},
)
db.commit()
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
# ── 文件操作 ────────────────────────────────────────────────────────────
def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
"""保存 summary.json 和 raw_output.txt。"""
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
(d / "summary.json").write_text(
schema.model_dump_json(ensure_ascii=False, indent=2),
encoding="utf-8",
)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None:
"""仅保存 raw_output.txt(失败时)。"""
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
# ── 单篇总结 ────────────────────────────────────────────────────────────
async def summarize_one(
db: Session,
paper: Paper,
semaphore: asyncio.Semaphore | None = None,
*,
force: bool = False,
) -> dict:
"""总结单篇论文的完整流程。"""
import asyncio
arxiv_id = paper.arxiv_id
# 获取或创建 summary_status
if not paper.summary_status:
db.add(SummaryStatus(paper_id=paper.id, status="pending"))
db.commit()
db.refresh(paper)
status = paper.summary_status
# 跳过已完成的(除非 force
if status.status == "done" and not force:
return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"}
# 跳过 permanent_failure(除非 force
if status.status == "permanent_failure" and not force:
return {
"arxiv_id": arxiv_id,
"status": "skipped",
"reason": "permanent_failure",
}
if semaphore:
await semaphore.acquire()
try:
return await _do_summarize_one(db, paper)
finally:
if semaphore:
semaphore.release()
async def _do_summarize_one(db: Session, paper: Paper) -> dict:
"""实际的单篇总结执行(在 semaphore 保护下)。"""
import asyncio
arxiv_id = paper.arxiv_id
status = paper.summary_status
now = datetime.now(timezone.utc)
# 状态 → processing
status.status = "processing"
status.started_at = now
db.commit()
raw_output = ""
try:
# 写 meta.json
meta_path = write_meta_json(paper)
# 下载 PDF
await download_pdf(arxiv_id, paper.pdf_url)
# 调用 pi
raw_output = await call_pi(meta_path, Path("data/tmp") / arxiv_id / "paper.pdf")
# 提取 JSON
json_data = extract_json(raw_output)
# Pydantic 校验
schema = SummarySchema.model_validate(json_data)
# 质量评估
quality = assess_quality(schema)
# 保存文件
_save_files(arxiv_id, schema, raw_output)
# 更新 DB
_update_summary_in_db(db, paper, schema, quality, raw_output)
# 状态 → done
status.status = "done"
status.quality = quality
status.completed_at = datetime.now(timezone.utc)
status.raw_output_saved = True
db.commit()
# Phase 5: LaTeX 图片提取(可选增强,失败不影响总结)
try:
await extract_images_from_source(arxiv_id)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
# Phase 5: 同步写入语义索引(失败仅 log)
try:
from app.services.embedder import index_paper
texts_dict = {
"arxiv_id": arxiv_id,
"title_zh": schema.title_zh or "",
"title_en": paper.title_en or "",
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
"one_line": schema.one_line or "",
"motivation_problem": schema.motivation_problem or "",
"method_key_idea": schema.method_key_idea or "",
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
}
index_paper(arxiv_id, texts_dict)
except Exception:
logger.warning(
"Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True
)
logger.info("Summarize done: %s quality=%s", arxiv_id, quality)
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
except Exception as exc:
error_type = _classify_error(exc)
logger.error(
"Summarize failed: %s error_type=%s %s",
arxiv_id,
error_type,
str(exc)[:200],
)
# 保存 raw_output(如果有)
if raw_output:
_save_raw_output_only(arxiv_id, raw_output)
status.raw_output_saved = True
# 重试逻辑
status.retry_count = (status.retry_count or 0) + 1
status.error_type = error_type
status.error = str(exc)[:2000]
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
status.status = "permanent_failure"
else:
status.status = "pending"
status.completed_at = datetime.now(timezone.utc)
db.commit()
return {
"arxiv_id": arxiv_id,
"status": "failed",
"error_type": error_type,
"error": str(exc)[:200],
"retry_count": status.retry_count,
}
finally:
cleanup_tmp(arxiv_id)
# ── 单篇入口 ────────────────────────────────────────────────────────────
async def summarize_single(
db: Session,
arxiv_id: str,
*,
force: bool = True,
_session_factory=None,
) -> dict:
"""单篇总结入口(供 admin 路由和 CLI 调用)。
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
paper = (
db.query(Paper)
.filter(Paper.arxiv_id == arxiv_id)
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
.first()
)
if not paper:
return {"status": "not_found", "arxiv_id": arxiv_id}
make_session = _session_factory or SessionLocal
# 每篇用独立 session 避免并发问题
paper_db = make_session()
try:
paper_in_new_session = (
paper_db.query(Paper)
.filter(Paper.arxiv_id == arxiv_id)
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
.first()
)
result = await summarize_one(paper_db, paper_in_new_session, force=force)
finally:
paper_db.close()
return result
# ── 批量总结 ────────────────────────────────────────────────────────────
async def summarize_batch(
db: Session,
arxiv_ids: list[str] | None = None,
*,
_session_factory=None,
) -> dict:
"""批量总结入口。arxiv_ids=None 时处理所有 pending 论文。
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
import asyncio
now = datetime.now(timezone.utc)
# TaskLock 防重入
lock = TaskLock(
task="summarize",
lock_key="batch",
status="running",
owner="summarize_batch",
acquired_at=now,
)
try:
db.add(lock)
db.commit()
except Exception:
db.rollback()
logger.warning("Summarize batch already running (lock conflict)")
return {"status": "conflict", "error": "summarize batch already running"}
# CrawlLog
log_entry = CrawlLog(
task="summarize",
status="running",
started_at=now,
)
db.add(log_entry)
db.commit()
try:
# 查询待总结论文
query = db.query(Paper).options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
if arxiv_ids:
query = query.filter(Paper.arxiv_id.in_(arxiv_ids))
else:
# 只处理 pending 或 failed(可重试的)
query = query.join(SummaryStatus).filter(
SummaryStatus.status.in_(["pending", "failed"])
)
papers = query.all()
total = len(papers)
logger.info("Summarize batch: %d papers to process", total)
if total == 0:
log_entry.status = "success"
log_entry.papers_found = 0
log_entry.papers_new = 0
log_entry.completed_at = datetime.now(timezone.utc)
release_lock(db, lock)
return {
"status": "success",
"done": 0,
"failed": 0,
"skipped": 0,
"total": 0,
}
# 并发控制
semaphore = asyncio.Semaphore(settings.SUMMARY_CONCURRENCY)
make_session = _session_factory or SessionLocal
async def _process_paper(paper: Paper) -> dict:
paper_db = make_session()
try:
p = (
paper_db.query(Paper)
.filter(Paper.id == paper.id)
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
)
.first()
)
return await summarize_one(paper_db, p, semaphore)
finally:
paper_db.close()
results = await asyncio.gather(
*[_process_paper(p) for p in papers],
return_exceptions=True,
)
# 统计结果
done = 0
failed = 0
skipped = 0
for r in results:
if isinstance(r, Exception):
logger.error("Unexpected error in batch: %s", r)
failed += 1
elif isinstance(r, dict):
if r.get("status") == "done":
done += 1
elif r.get("status") == "skipped":
skipped += 1
else:
failed += 1
log_entry.status = "success" if failed == 0 else "failed"
log_entry.papers_found = total
log_entry.papers_new = done
log_entry.completed_at = datetime.now(timezone.utc)
db.commit()
logger.info(
"Summarize batch done: total=%d done=%d failed=%d skipped=%d",
total,
done,
failed,
skipped,
)
return {
"status": "success" if failed == 0 else "partial",
"total": total,
"done": done,
"failed": failed,
"skipped": skipped,
}
except Exception as exc:
logger.exception("Summarize batch failed")
log_entry.status = "failed"
log_entry.error = str(exc)[:2000]
log_entry.completed_at = datetime.now(timezone.utc)
db.commit()
return {"status": "failed", "error": str(exc)}
finally:
release_lock(db, lock)