Files
Rain-Bus 1ccac1f29a refactor: replace Phase 2 label matching with PDF text-stream caption pairing
- Extract captions from PDF text dict instead of DocLayout caption boxes
- Use _CaptionBlock dataclass to carry authoritative ID, kind, text, bbox
- Pair captions to content boxes with directional preference (figure below, table above)
- Filter out uncaptioned boxes (Algorithm pseudo-code, unnumbered appendix tables, false positives)
- Remove label_images_by_summary and Phase 2 rename pipeline entirely
- Update tests to cover text-based caption pairing and filtering
2026-06-15 01:09:29 +08:00

258 lines
9.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI 总结持久化 — DB 写入、文件保存、FTS 索引、图片提取、ChromaDB 索引。"""
from __future__ import annotations
import logging
from sqlalchemy.orm import Session
from app.models import (
Paper,
PaperSummary,
PaperTag,
SummaryState,
)
from app.services.derived import reindex_paper_fts
from app.services.pdf_downloader import paper_dir
from app.services.schemas import (
SummarySchema,
assess_quality,
flatten_for_db,
)
from app.services.summary_generator import _classify_error
from app.utils import TMP_DIR, truncate_error, utc_now
logger = logging.getLogger(__name__)
# ── DB 更新 ─────────────────────────────────────────────────────────────
def _update_summary_in_db(
db: Session,
paper: Paper,
schema: SummarySchema,
quality: str,
raw_output: str,
) -> None:
"""将校验后的总结写入 DBpaper_summaries + papers + paper_tags + FTS5。"""
# 1. paper_summariesupsert
existing = db.get(PaperSummary, paper.id)
flat = flatten_for_db(schema)
if existing:
for k, v in flat.items():
setattr(existing, k, v)
else:
db.add(PaperSummary(paper_id=paper.id, **flat))
# 2. papers 表
paper.title_zh = schema.title_zh
paper.summary_quality = quality
p_dir = paper_dir(paper.arxiv_id)
paper.summary_path = str(p_dir / "summary.json")
paper.raw_output_path = str(p_dir / "raw_output.txt")
# 3. AI 标签
existing_tag_names = {t.tag for t in paper.tags}
for tag_name in schema.tags:
if tag_name not in existing_tag_names:
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
existing_tag_names.add(tag_name)
# 4. FTS5 派生索引
db.flush()
reindex_paper_fts(db, paper)
db.commit()
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
# ── 文件操作 ────────────────────────────────────────────────────────────
def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) -> None:
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
if schema:
(d / "summary.json").write_text(
schema.model_dump_json(ensure_ascii=False, indent=2),
encoding="utf-8",
)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
# ── 失败处理 ────────────────────────────────────────────────────────────
def _handle_summary_failure(
db: Session,
paper: Paper,
exc: Exception,
raw_output: str,
) -> dict:
"""记录失败:保存 raw_output、重试计数、错误分类。"""
from app.config import settings
error_type = _classify_error(exc)
logger.error(
"Summarize failed: %s error_type=%s %s",
paper.arxiv_id,
error_type,
truncate_error(exc),
)
status = paper.summary_status
if raw_output:
_save_files(paper.arxiv_id, None, raw_output)
status.raw_output_saved = True
status.retry_count = (status.retry_count or 0) + 1
status.error_type = error_type
status.error = truncate_error(exc, limit=2000)
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
status.status = SummaryState.PERMANENT_FAILURE
else:
status.status = SummaryState.PENDING
status.completed_at = utc_now()
db.commit()
return {
"arxiv_id": paper.arxiv_id,
"status": "failed",
"error_type": error_type,
"error": truncate_error(exc),
"retry_count": status.retry_count,
}
# ── 持久化 ──────────────────────────────────────────────────────────────
def _persist_summary(
db: Session, paper: Paper, json_data: dict, raw_output: str
) -> tuple[str, SummarySchema]:
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 (quality, schema)。
后处理(图片提取/ChromaDB)不再在此函数内执行,由调用方搬到线程池,
以免阻塞事件循环。返回 schema 供调用方在线程池里跑后处理。
"""
import time as _time
arxiv_id = paper.arxiv_id
_t0 = _time.monotonic()
schema = SummarySchema.model_validate(json_data)
quality = assess_quality(schema)
_t1 = _time.monotonic()
_save_files(arxiv_id, schema, raw_output)
_t2 = _time.monotonic()
_update_summary_in_db(db, paper, schema, quality, raw_output)
_t3 = _time.monotonic()
# 状态 → done
paper.summary_status.status = SummaryState.DONE
paper.summary_status.quality = quality
paper.summary_status.completed_at = utc_now()
paper.summary_status.raw_output_saved = True
db.commit()
_t4 = _time.monotonic()
logger.info(
" [%s] persist: pydantic=%.2fs 文件=%.2fs DB写入=%.2fs 状态commit=%.2fs",
arxiv_id,
_t1 - _t0,
_t2 - _t1,
_t3 - _t2,
_t4 - _t3,
)
# 后处理(图片提取 + ChromaDB 索引)已上移到调用方 _do_summarize_one
# 经 asyncio.to_thread 在线程池跑——DB session 必须留在事件循环线程,
# 而 CPU/IO 密集的后处理搬走才不冻结事件循环。
return quality, schema
# ── 清理 ────────────────────────────────────────────────────────────────
def _cleanup_old_images(db: Session, paper: Paper) -> None:
"""清理旧的图片文件和 figures_json,避免重新总结时残留。"""
arxiv_id = paper.arxiv_id
images_dir = paper_dir(arxiv_id) / "images"
if images_dir.exists():
for old_file in images_dir.iterdir():
if (
old_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg")
or old_file.name == "manifest.json"
):
old_file.unlink(missing_ok=True)
# 清除数据库中的 figures_json
if paper.summary and paper.summary.figures_json:
paper.summary.figures_json = None
db.commit()
# ── 触发性增强 ──────────────────────────────────────────────────────────
def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
"""从 PDF 提取图片和表格(失败不影响总结)。
DocLayout-YOLO 检测 figure/table 内容区域 → PDF 文本定位 caption → 只渲染
配到 Figure/Table 标题的(Algorithm、无编号附录表、误检碎片一律过滤)。
标题来源已切换为 PDF 文本,schema.figures 不再参与命名,参数保留备用。
"""
try:
from app.services.pdf_image_extractor import extract_images_from_pdf
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
extract_images_from_pdf(arxiv_id, pdf_path)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
def _maybe_index_chroma(arxiv_id: str, schema: SummarySchema, paper_meta: dict) -> None:
"""写入 ChromaDB 语义索引(失败不影响总结)。
paper_meta 是调用方在事件循环线程从 ORM 提取的纯值(title_en/tags/paper_date),
规避此函数在线程池跑时跨线程访问 ORM 的 DetachedInstanceError 风险。
"""
try:
from app.services.embedder import index_paper
texts_dict = {
"arxiv_id": arxiv_id,
"title_zh": schema.title_zh or "",
"title_en": paper_meta.get("title_en", ""),
"tags": paper_meta.get("tags", ""),
"one_line": schema.one_line or "",
"motivation_problem": schema.motivation.problem or "",
"method_key_idea": schema.method.key_idea or "",
"paper_date": paper_meta.get("paper_date", ""),
}
index_paper(arxiv_id, texts_dict)
except Exception:
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
def _run_post_processing(
arxiv_id: str, schema: SummarySchema, paper_meta: dict
) -> None:
"""线程池里跑的 CPU/IO 密集后处理(由 _do_summarize_one 经 asyncio.to_thread 调用)。
顺序与原 _persist_summary 内部一致:图片提取 → ChromaDB 索引。两者各自
try/except(失败不影响已成功的总结),此处再包一层做双保险。
"""
try:
_maybe_extract_images(arxiv_id, schema)
_maybe_index_chroma(arxiv_id, schema, paper_meta)
except Exception:
logger.warning(
"Post-processing failed for %s (summary already persisted)",
arxiv_id,
exc_info=True,
)