Files
daily-paper/app/services/summarizer.py
T

659 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI 总结编排服务 — 协调 PDF 下载、pi CLI 调用、JSON 校验、DB 写入、语义索引。"""
from __future__ import annotations
import asyncio
import json
import logging
from pathlib import Path
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.config import settings
from app.database import SessionLocal
from app.models import (
PAPER_DEFAULT_LOAD,
CrawlLog,
Paper,
PaperSummary,
PaperTag,
SummaryState,
SummaryStatus,
TaskLock,
)
from app.services.pdf_downloader import (
PdfDownloadError,
cleanup_tmp,
download_pdf,
paper_dir,
)
from app.services.pi_client import (
JsonNotFoundError,
PiProcessError,
PiTimeoutError,
call_pi,
extract_json,
write_meta_json,
)
from app.services.schemas import (
SummarySchema,
assess_quality,
classify_validation_error,
flatten_for_db,
)
from app.utils import TMP_DIR, release_lock, utc_now
logger = logging.getLogger(__name__)
# ── 错误分类 ────────────────────────────────────────────────────────────
def _classify_error(exc: Exception) -> str:
"""将异常映射到 error_type 枚举值。"""
if isinstance(exc, PdfDownloadError):
return "pdf_download_failed"
if isinstance(exc, PiTimeoutError):
return "timeout"
if isinstance(exc, PiProcessError):
return "process_error"
if isinstance(exc, JsonNotFoundError):
return "json_not_found"
if isinstance(exc, json.JSONDecodeError):
return "json_invalid"
if isinstance(exc, ValidationError):
return classify_validation_error(exc)
return "unknown"
# ── FTS5 文本构建 ───────────────────────────────────────────────────────
def _build_fts_summary_text(schema: SummarySchema) -> str:
"""拼接用于 FTS5 索引的总结文本。"""
parts = [
schema.one_line or "",
schema.motivation.problem or "",
schema.motivation.goal or "",
schema.method.overview or "",
schema.method.key_idea or "",
schema.results.main_findings or "",
]
return " ".join(p for p in parts if p)
# ── DB 更新 ─────────────────────────────────────────────────────────────
def _update_summary_in_db(
db: Session,
paper: Paper,
schema: SummarySchema,
quality: str,
raw_output: str,
) -> None:
"""将校验后的总结写入 DBpaper_summaries + papers + paper_tags + FTS5。"""
from sqlalchemy import text
# 1. paper_summariesupsert
existing = db.get(PaperSummary, paper.id)
flat = flatten_for_db(schema)
if existing:
for k, v in flat.items():
setattr(existing, k, v)
else:
db.add(PaperSummary(paper_id=paper.id, **flat))
# 2. papers 表
paper.title_zh = schema.title_zh
paper.summary_quality = quality
p_dir = paper_dir(paper.arxiv_id)
paper.summary_path = str(p_dir / "summary.json")
paper.raw_output_path = str(p_dir / "raw_output.txt")
# 3. AI 标签
existing_tag_names = {t.tag for t in paper.tags}
for tag_name in schema.tags:
if tag_name not in existing_tag_names:
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
existing_tag_names.add(tag_name)
# 4. FTS5 更新
summary_text = _build_fts_summary_text(schema)
db.execute(
text(
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
"WHERE rowid=:paper_id"
),
{
"title_zh": schema.title_zh,
"summary_text": summary_text,
"paper_id": paper.id,
},
)
db.commit()
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
# ── JSON 验证 ──────────────────────────────────────────────────────────
def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
"""验证 JSON 数据是否符合要求,返回错误列表(空=通过)。"""
errors: list[str] = []
if not isinstance(json_data, dict):
return ["顶层必须是 JSON 对象"]
# 必填字段
for f in ["arxiv_id", "title_zh", "one_line", "tags"]:
if f not in json_data or not json_data[f]:
errors.append(f"缺少必填字段: {f}")
# tags 必须是非空数组
tags = json_data.get("tags")
if not isinstance(tags, list) or len(tags) == 0:
errors.append("tags 必须是非空数组")
# 字符串段落字段(必须是 str 且 ≥50 字)
string_fields = [
("motivation", "problem"), ("motivation", "goal"), ("motivation", "gap"),
("method", "overview"), ("method", "key_idea"), ("method", "steps"),
("method", "novelty"),
("results", "main_findings"), ("results", "limitations"),
("improvements", "weaknesses"), ("improvements", "future_work"),
("improvements", "reproducibility"),
]
for section, field in string_fields:
val = json_data.get(section, {}).get(field)
if isinstance(val, list):
errors.append(f"{section}.{field} 应该是字符串段落,不能是数组")
elif not isinstance(val, str) or len(val.strip()) < 50:
errors.append(
f"{section}.{field} 必须是详细段落(≥50字),"
f"当前: {type(val).__name__} ({len(str(val))}字)"
)
# benchmarks 必须是数组
benchmarks = json_data.get("results", {}).get("benchmarks")
if benchmarks is not None and not isinstance(benchmarks, list):
errors.append("results.benchmarks 必须是数组")
# prerequisites.concepts 必须是对象数组,每个有 term
concepts = json_data.get("prerequisites", {}).get("concepts")
if concepts is not None:
if not isinstance(concepts, list):
errors.append("prerequisites.concepts 必须是数组")
elif len(concepts) == 0:
errors.append("prerequisites.concepts 不能为空")
else:
for i, c in enumerate(concepts):
if isinstance(c, str):
errors.append(f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串")
elif isinstance(c, dict) and not c.get("term"):
errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段")
# figures 必须是数组,每个元素应有 id
figures = json_data.get("figures")
if figures is not None:
if not isinstance(figures, list):
errors.append("figures 必须是数组")
else:
for i, fig in enumerate(figures):
if isinstance(fig, dict) and not fig.get("id"):
errors.append(f"figures[{i}] 缺少 id 字段")
return errors
# ── 文件操作 ────────────────────────────────────────────────────────────
def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) -> None:
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
if schema:
(d / "summary.json").write_text(
schema.model_dump_json(ensure_ascii=False, indent=2),
encoding="utf-8",
)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
# ── 单篇总结 ────────────────────────────────────────────────────────────
async def summarize_one(
db: Session,
paper: Paper,
semaphore: asyncio.Semaphore | None = None,
*,
force: bool = False,
pdf_mode: str = "auto",
) -> dict:
"""总结单篇论文的完整流程。"""
arxiv_id = paper.arxiv_id
# 获取或创建 summary_status
if not paper.summary_status:
db.add(SummaryStatus(paper_id=paper.id, status=SummaryState.PENDING))
db.commit()
db.refresh(paper)
status = paper.summary_status
# 跳过已完成的(除非 force
if status.status == SummaryState.DONE and not force:
return {"arxiv_id": arxiv_id, "status": "skipped", "reason": "already_done"}
# 跳过 permanent_failure(除非 force
if status.status == SummaryState.PERMANENT_FAILURE and not force:
return {
"arxiv_id": arxiv_id,
"status": "skipped",
"reason": "permanent_failure",
}
if semaphore:
await semaphore.acquire()
try:
return await _do_summarize_one(db, paper, pdf_mode=pdf_mode)
finally:
if semaphore:
semaphore.release()
async def _generate_with_retry(
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
) -> tuple[dict, str]:
"""调用 pi CLI 生成总结,最多 4 轮验证循环。
Returns:
(json_data, raw_output)
Raises:
ValueError: 4 轮验证仍未通过
"""
validation_errors: list[str] = []
json_data: dict | None = None
raw_output = ""
session_id = None
for attempt in range(1, 5):
# 清理上一轮 pi 写的不完整文件
stale = paper_dir(arxiv_id) / "summary.json"
if stale.exists():
stale.unlink()
if attempt == 1:
raw_output, session_id = await call_pi(meta_path, pdf_path, pdf_mode=pdf_mode)
else:
raw_output, session_id = await call_pi(
meta_path, pdf_path,
fix_errors=validation_errors,
session_id=session_id,
pdf_mode=pdf_mode,
)
# 优先读取 pi 写入的 summary.json,否则从 stdout 提取
summary_file = paper_dir(arxiv_id) / "summary.json"
try:
if summary_file.exists():
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
logger.info("Read summary.json written by pi for %s", arxiv_id)
else:
json_data = extract_json(raw_output)
except (json.JSONDecodeError, JsonNotFoundError) as exc:
logger.warning(
"JSON extraction failed for %s (attempt %d): %s",
arxiv_id, attempt, str(exc)[:200],
)
validation_errors = [f"无法提取有效 JSON: {str(exc)[:100]}"]
continue
validation_errors = _validate_summary(json_data, arxiv_id)
if not validation_errors:
break
logger.warning(
"Validation failed for %s (attempt %d): %s",
arxiv_id, attempt, "; ".join(validation_errors),
)
if validation_errors:
exc = ValueError(
f"Summary validation failed after 4 attempts: {'; '.join(validation_errors)}"
)
exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用
raise exc
return json_data, raw_output
def _persist_summary(
db: Session, paper: Paper, json_data: dict, raw_output: str
) -> str:
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
schema = SummarySchema.model_validate(json_data)
quality = assess_quality(schema)
_save_files(paper.arxiv_id, schema, raw_output)
_update_summary_in_db(db, paper, schema, quality, raw_output)
# 状态 → done
paper.summary_status.status = SummaryState.DONE
paper.summary_status.quality = quality
paper.summary_status.completed_at = utc_now()
paper.summary_status.raw_output_saved = True
db.commit()
# 触发性增强(失败不影响总结)
_maybe_extract_images(paper.arxiv_id, schema)
_maybe_index_chroma(paper.arxiv_id, paper, schema)
return quality
def _handle_summary_failure(
db: Session, paper: Paper, exc: Exception, raw_output: str,
) -> dict:
"""记录失败:保存 raw_output、重试计数、错误分类。"""
error_type = _classify_error(exc)
logger.error(
"Summarize failed: %s error_type=%s %s",
paper.arxiv_id, error_type, str(exc)[:200],
)
status = paper.summary_status
if raw_output:
_save_files(paper.arxiv_id, None, raw_output)
status.raw_output_saved = True
status.retry_count = (status.retry_count or 0) + 1
status.error_type = error_type
status.error = str(exc)[:2000]
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
status.status = SummaryState.PERMANENT_FAILURE
else:
status.status = SummaryState.PENDING
status.completed_at = utc_now()
db.commit()
return {
"arxiv_id": paper.arxiv_id,
"status": "failed",
"error_type": error_type,
"error": str(exc)[:200],
"retry_count": status.retry_count,
}
def _cleanup_old_images(db: Session, paper: Paper) -> None:
"""清理旧的图片文件和 figures_json,避免重新总结时残留。"""
arxiv_id = paper.arxiv_id
images_dir = paper_dir(arxiv_id) / "images"
if images_dir.exists():
for old_file in images_dir.iterdir():
if old_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg") or old_file.name == "manifest.json":
old_file.unlink(missing_ok=True)
# 清除数据库中的 figures_json
if paper.summary and paper.summary.figures_json:
paper.summary.figures_json = None
db.commit()
def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
"""从 PDF 提取图片和表格(失败不影响总结)。"""
try:
from app.services.pdf_image_extractor import (
extract_images_from_pdf,
filter_images_by_summary,
)
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
extract_images_from_pdf(arxiv_id, pdf_path)
if schema.figures:
filter_images_by_summary(arxiv_id, schema.figures)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None:
"""写入 ChromaDB 语义索引(失败不影响总结)。"""
try:
from app.services.embedder import index_paper
texts_dict = {
"arxiv_id": arxiv_id,
"title_zh": schema.title_zh or "",
"title_en": paper.title_en or "",
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
"one_line": schema.one_line or "",
"motivation_problem": schema.motivation.problem or "",
"method_key_idea": schema.method.key_idea or "",
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
}
index_paper(arxiv_id, texts_dict)
except Exception:
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
async def _do_summarize_one(
db: Session, paper: Paper, pdf_mode: str = "auto"
) -> dict:
"""实际的单篇总结执行(在 semaphore 保护下)。"""
arxiv_id = paper.arxiv_id
# 状态 → processing
paper.summary_status.status = SummaryState.PROCESSING
paper.summary_status.started_at = utc_now()
db.commit()
# 清理旧的图片文件和 figures_json,避免重新总结时残留
_cleanup_old_images(db, paper)
raw_output = ""
try:
meta_path = write_meta_json(paper)
await download_pdf(arxiv_id, paper.pdf_url)
json_data, raw_output = await _generate_with_retry(
arxiv_id, meta_path, TMP_DIR / arxiv_id / "paper.pdf",
pdf_mode=pdf_mode,
)
quality = _persist_summary(db, paper, json_data, raw_output)
logger.info("Summarize done: %s quality=%s", arxiv_id, quality)
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
except Exception as exc:
# 从异常对象获取 raw_output_generate_with_retry 失败时仍有输出)
fail_output = getattr(exc, "raw_output", raw_output)
return _handle_summary_failure(db, paper, exc, fail_output)
finally:
cleanup_tmp(arxiv_id)
# ── 单篇入口 ────────────────────────────────────────────────────────────
async def summarize_single(
db: Session,
arxiv_id: str,
*,
force: bool = True,
pdf_mode: str = "auto",
_session_factory=None,
) -> dict:
"""单篇总结入口(供 admin 路由和 CLI 调用)。
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
paper = db.execute(
select(Paper)
.where(Paper.arxiv_id == arxiv_id)
.options(*PAPER_DEFAULT_LOAD)
).unique().scalar_one_or_none()
if not paper:
return {"status": "not_found", "arxiv_id": arxiv_id}
make_session = _session_factory or SessionLocal
# 每篇用独立 session 避免并发问题
paper_db = make_session()
try:
paper_in_new_session = paper_db.execute(
select(Paper)
.where(Paper.arxiv_id == arxiv_id)
.options(*PAPER_DEFAULT_LOAD)
).unique().scalar_one_or_none()
result = await summarize_one(paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode)
finally:
paper_db.close()
return result
# ── 批量总结 ────────────────────────────────────────────────────────────
async def summarize_batch(
db: Session,
arxiv_ids: list[str] | None = None,
*,
pdf_mode: str = "auto",
_session_factory=None,
) -> dict:
"""批量总结入口。arxiv_ids=None 时处理所有 pending 论文。
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
now = utc_now()
# TaskLock 防重入
lock = TaskLock(
task="summarize",
lock_key="batch",
status="running",
owner="summarize_batch",
acquired_at=now,
)
try:
db.add(lock)
db.commit()
except Exception:
db.rollback()
logger.warning("Summarize batch already running (lock conflict)")
return {"status": "conflict", "error": "summarize batch already running"}
# CrawlLog
log_entry = CrawlLog(
task="summarize",
status="running",
started_at=now,
)
db.add(log_entry)
db.commit()
try:
# 查询待总结论文
stmt = select(Paper).options(*PAPER_DEFAULT_LOAD)
if arxiv_ids:
stmt = stmt.where(Paper.arxiv_id.in_(arxiv_ids))
else:
# 只处理 pending 或 failed(可重试的)
stmt = stmt.join(SummaryStatus).where(
SummaryStatus.status.in_([SummaryState.PENDING, SummaryState.FAILED])
)
papers = db.execute(stmt).unique().scalars().all()
total = len(papers)
logger.info("Summarize batch: %d papers to process", total)
if total == 0:
log_entry.status = "success"
log_entry.papers_found = 0
log_entry.papers_new = 0
log_entry.completed_at = utc_now()
release_lock(db, lock)
return {
"status": "success",
"done": 0,
"failed": 0,
"skipped": 0,
"total": 0,
}
# 并发控制
semaphore = asyncio.Semaphore(settings.SUMMARY_CONCURRENCY)
make_session = _session_factory or SessionLocal
async def _process_paper(paper: Paper) -> dict:
paper_db = make_session()
try:
p = paper_db.execute(
select(Paper)
.where(Paper.id == paper.id)
.options(*PAPER_DEFAULT_LOAD)
).unique().scalar_one_or_none()
return await summarize_one(paper_db, p, semaphore, pdf_mode=pdf_mode)
finally:
paper_db.close()
results = await asyncio.gather(
*[_process_paper(p) for p in papers],
return_exceptions=True,
)
# 统计结果
done = 0
failed = 0
skipped = 0
for r in results:
if isinstance(r, Exception):
logger.error("Unexpected error in batch: %s", r)
failed += 1
elif isinstance(r, dict):
if r.get("status") == "done":
done += 1
elif r.get("status") == "skipped":
skipped += 1
else:
failed += 1
log_entry.status = "success" if failed == 0 else "failed"
log_entry.papers_found = total
log_entry.papers_new = done
log_entry.completed_at = utc_now()
db.commit()
logger.info(
"Summarize batch done: total=%d done=%d failed=%d skipped=%d",
total,
done,
failed,
skipped,
)
return {
"status": "success" if failed == 0 else "partial",
"total": total,
"done": done,
"failed": failed,
"skipped": skipped,
}
except Exception as exc:
logger.exception("Summarize batch failed")
log_entry.status = "failed"
log_entry.error = str(exc)[:2000]
log_entry.completed_at = utc_now()
db.commit()
return {"status": "failed", "error": str(exc)}
finally:
release_lock(db, lock)