feat: add concurrency safety, caption detection, admin enhancements, and performance improvements
This commit is contained in:
+129
-2
@@ -7,7 +7,7 @@ from datetime import date
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy import func, select, text, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
@@ -100,7 +100,9 @@ def get_admin_stats(db: Session) -> dict:
|
||||
|
||||
# ── 活跃锁 ────────────────────────────────────────────────────────
|
||||
active_locks = (
|
||||
db.execute(select(TaskLock).where(TaskLock.status == "running")).scalars().all()
|
||||
db.execute(select(TaskLock).where(TaskLock.status.in_(["running", "stale"])))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -124,6 +126,7 @@ def get_admin_stats(db: Session) -> dict:
|
||||
"recent_logs": recent_logs,
|
||||
"active_locks": active_locks,
|
||||
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
|
||||
"config_overview": get_config_overview(),
|
||||
}
|
||||
|
||||
|
||||
@@ -370,6 +373,7 @@ def get_logs_context(db: Session, *, page: int, per_page: int) -> dict:
|
||||
"summary_done": summary_done,
|
||||
"summary_pending": summary_pending,
|
||||
"summary_failed": summary_failed,
|
||||
"failure_breakdown": get_failure_breakdown(db),
|
||||
}
|
||||
|
||||
|
||||
@@ -511,3 +515,126 @@ def reset_summaries_pending(db: Session, arxiv_ids: list[str]) -> int:
|
||||
db.add(SummaryStatus(paper_id=paper_id, status=SummaryState.PENDING))
|
||||
db.commit()
|
||||
return len(paper_ids)
|
||||
|
||||
|
||||
# ── 任务监控 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def query_jobs(
|
||||
db: Session,
|
||||
*,
|
||||
status: str | None = None,
|
||||
job_type: str | None = None,
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""后台任务列表查询 — 支持 status/type 过滤 + 分页,返回已 enrich 的 dict 列表。"""
|
||||
query = select(Job)
|
||||
if status and status != "all":
|
||||
query = query.where(Job.status == status)
|
||||
if job_type and job_type != "all":
|
||||
query = query.where(Job.type == job_type)
|
||||
|
||||
total = db.scalar(select(func.count()).select_from(query.subquery())) or 0
|
||||
jobs = (
|
||||
db.execute(
|
||||
query.order_by(Job.created_at.desc())
|
||||
.offset((page - 1) * per_page)
|
||||
.limit(per_page)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
return [serialize_job(j) for j in jobs], total
|
||||
|
||||
|
||||
def _as_naive(dt):
|
||||
"""去掉 tzinfo — SQLite 读回的 datetime 是 naive UTC,与 utc_now() 运算前需统一。"""
|
||||
if dt is not None and getattr(dt, "tzinfo", None) is not None:
|
||||
return dt.replace(tzinfo=None)
|
||||
return dt
|
||||
|
||||
|
||||
def serialize_job(job: Job) -> dict:
|
||||
"""单条 job 序列化为展示用 dict(含耗时)。"""
|
||||
duration = None
|
||||
started = _as_naive(job.started_at)
|
||||
if started:
|
||||
end = _as_naive(job.completed_at) or _as_naive(utc_now())
|
||||
duration = round((end - started).total_seconds(), 1)
|
||||
return {
|
||||
"id": job.id,
|
||||
"type": job.type,
|
||||
"status": job.status,
|
||||
"owner": job.owner,
|
||||
"created_at": job.created_at,
|
||||
"started_at": job.started_at,
|
||||
"completed_at": job.completed_at,
|
||||
"duration_seconds": duration,
|
||||
"error": job.error,
|
||||
}
|
||||
|
||||
|
||||
def get_job_status_counts(db: Session) -> dict:
|
||||
"""按 status 聚合 job 计数,供任务页顶部小统计行用。"""
|
||||
rows = db.execute(
|
||||
select(Job.status, func.count(Job.id)).group_by(Job.status)
|
||||
).fetchall()
|
||||
return {row[0]: row[1] for row in rows}
|
||||
|
||||
|
||||
# ── 锁管理 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def force_release_lock(db: Session, lock_id: int) -> bool:
|
||||
"""强制释放一个卡死的 TaskLock(仅对 running/stale 生效)。"""
|
||||
result = db.execute(
|
||||
update(TaskLock)
|
||||
.where(TaskLock.id == lock_id, TaskLock.status.in_(["running", "stale"]))
|
||||
.values(status="finished", released_at=utc_now())
|
||||
)
|
||||
db.commit()
|
||||
return (result.rowcount or 0) > 0
|
||||
|
||||
|
||||
# ── 失败原因分布 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_failure_breakdown(db: Session) -> list[dict]:
|
||||
"""按 error_type 聚合失败/永久失败的总结,按数量降序。NULL 归 unknown。"""
|
||||
error_expr = func.coalesce(SummaryStatus.error_type, "unknown")
|
||||
rows = db.execute(
|
||||
select(error_expr, func.count(SummaryStatus.id))
|
||||
.where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||
)
|
||||
)
|
||||
.group_by(error_expr)
|
||||
.order_by(func.count(SummaryStatus.id).desc())
|
||||
).fetchall()
|
||||
return [{"error_type": row[0], "count": row[1]} for row in rows]
|
||||
|
||||
|
||||
# ── 运行配置概览 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_config_overview() -> dict:
|
||||
"""聚合非敏感配置,供仪表盘展示。敏感字段只标是否已配置,不显示值。"""
|
||||
return {
|
||||
"summary_backend": settings.SUMMARY_BACKEND,
|
||||
"summary_pdf_mode": settings.SUMMARY_PDF_MODE,
|
||||
"summary_concurrency": settings.SUMMARY_CONCURRENCY,
|
||||
"summary_timeout_seconds": settings.SUMMARY_TIMEOUT_SECONDS,
|
||||
"summary_max_retries": settings.SUMMARY_MAX_RETRIES,
|
||||
"scheduler_enabled": settings.SCHEDULER_ENABLED,
|
||||
"schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}",
|
||||
"chroma_enabled": settings.CHROMA_ENABLED,
|
||||
"embed_model": settings.EMBED_MODEL or "(未配置)",
|
||||
"top_n": settings.TOP_N,
|
||||
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
|
||||
"app_workers": settings.APP_WORKERS,
|
||||
"layout_model": Path(settings.LAYOUT_MODEL_PATH).name,
|
||||
"database_url": settings.DATABASE_URL,
|
||||
"api_key_configured": bool(settings.EMBED_API_KEY),
|
||||
}
|
||||
|
||||
@@ -270,3 +270,67 @@ def _update_upvotes_only(db: Session, papers_raw: list[dict]) -> int:
|
||||
|
||||
db.commit()
|
||||
return updated
|
||||
|
||||
|
||||
async def recrawl_single(db: Session, arxiv_id: str) -> dict:
|
||||
"""重新抓取一篇已存在论文的完整元数据。
|
||||
|
||||
基于 paper.paper_date 重新拉取 HF Daily 列表,命中后全字段刷新
|
||||
(标题/摘要/作者/标签/链接/upvotes)并重建 FTS。若该论文不在其收录日的
|
||||
列表中则无法重抓。
|
||||
"""
|
||||
paper = db.execute(
|
||||
select(Paper).where(Paper.arxiv_id == arxiv_id)
|
||||
).scalar_one_or_none()
|
||||
if not paper:
|
||||
return {"updated": False, "reason": "not_found", "arxiv_id": arxiv_id}
|
||||
|
||||
target_date = paper.paper_date.isoformat()
|
||||
raw_papers = await fetch_daily(target_date)
|
||||
|
||||
target = None
|
||||
for item in raw_papers:
|
||||
if _parse_paper(item)["arxiv_id"] == arxiv_id:
|
||||
target = item
|
||||
break
|
||||
|
||||
if target is None:
|
||||
return {
|
||||
"updated": False,
|
||||
"reason": "not_in_daily",
|
||||
"arxiv_id": arxiv_id,
|
||||
"date": target_date,
|
||||
}
|
||||
|
||||
meta = _parse_paper(target)
|
||||
now = utc_now()
|
||||
|
||||
# 全字段刷新
|
||||
paper.title_en = meta["title_en"]
|
||||
paper.abstract = meta["abstract"]
|
||||
paper.published_at = meta["published_at"]
|
||||
paper.hf_url = meta["hf_url"]
|
||||
paper.arxiv_url = meta["arxiv_url"]
|
||||
paper.pdf_url = meta["pdf_url"]
|
||||
paper.upvotes = meta["upvotes"]
|
||||
paper.crawled_at = now
|
||||
|
||||
# 重建 authors(删旧再加新)
|
||||
paper.authors.clear()
|
||||
seen_authors: set[str] = set()
|
||||
for idx, name in enumerate(meta["authors"]):
|
||||
if name and name not in seen_authors:
|
||||
seen_authors.add(name)
|
||||
db.add(PaperAuthor(paper_id=paper.id, name=name, position=idx))
|
||||
|
||||
# 重建 tags
|
||||
paper.tags.clear()
|
||||
for tag_name in meta["tags"]:
|
||||
if tag_name:
|
||||
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="hf"))
|
||||
|
||||
db.flush()
|
||||
reindex_paper_fts(db, paper)
|
||||
db.commit()
|
||||
logger.info("Re-crawled paper %s (full metadata refresh)", arxiv_id)
|
||||
return {"updated": True, "arxiv_id": arxiv_id, "date": target_date}
|
||||
|
||||
+78
-40
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select
|
||||
@@ -18,14 +19,27 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChromaManager:
|
||||
"""封装 ChromaDB 客户端和 collection 的生命周期。"""
|
||||
"""封装 ChromaDB 客户端和 collection 的生命周期。
|
||||
|
||||
所有客户端/集合访问经 ``self._lock`` 串行化:后处理经 ``asyncio.to_thread``
|
||||
多 worker 并发调 ``index_paper``,若不串行化会并发建连,触发 chromadb 1.5.x
|
||||
``SharedSystemClient`` 类级缓存的并发竞争(``_create_system_if_not_exists``
|
||||
无锁 + refcount release 弹 key)→ ``KeyError: '<persist_dir>'``。
|
||||
锁用 RLock:``index_paper`` 持锁后经 ``get_collection()`` 间接再调 ``init()``,
|
||||
同线程可重入,不死锁。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.RLock()
|
||||
self._client = None
|
||||
self._collection = None
|
||||
|
||||
def init(self) -> None:
|
||||
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
|
||||
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。
|
||||
|
||||
双重检查锁串行化首次建连 —— 外层快判已建好就直接返回(快路径),抢到锁后再
|
||||
判一次(防并发下另一线程已建好),确保 ``PersistentClient`` 全进程只调一次。
|
||||
"""
|
||||
if not settings.CHROMA_ENABLED:
|
||||
logger.debug("ChromaDB disabled, skip init")
|
||||
return
|
||||
@@ -33,19 +47,23 @@ class ChromaManager:
|
||||
if self._client is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
with self._lock:
|
||||
if self._client is not None: # 双重检查:抢到锁后可能已被别的线程建好
|
||||
return
|
||||
|
||||
chroma_path = Path(settings.CHROMA_DIR)
|
||||
chroma_path.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
import chromadb
|
||||
|
||||
self._client = chromadb.PersistentClient(path=str(chroma_path))
|
||||
self._collection = self._get_or_create_collection()
|
||||
logger.info("ChromaDB initialized at %s", chroma_path)
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize ChromaDB")
|
||||
self._client = None
|
||||
self._collection = None
|
||||
chroma_path = Path(settings.CHROMA_DIR)
|
||||
chroma_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._client = chromadb.PersistentClient(path=str(chroma_path))
|
||||
self._collection = self._get_or_create_collection()
|
||||
logger.info("ChromaDB initialized at %s", chroma_path)
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize ChromaDB")
|
||||
self._client = None
|
||||
self._collection = None
|
||||
|
||||
def _get_or_create_collection(self):
|
||||
"""获取或创建 papers_embeddings collection。"""
|
||||
@@ -102,6 +120,8 @@ def _get_embedding(text: str) -> list[float] | None:
|
||||
POST /v1/embeddings, model=EMBED_MODEL
|
||||
校验返回向量长度 == EMBED_DIMENSIONS
|
||||
失败时返回 None 并记录日志。
|
||||
|
||||
纯远程 HTTP 调用、线程安全 —— 留在锁外,让多 worker 并行调。
|
||||
"""
|
||||
if not settings.EMBED_API_BASE or not settings.EMBED_MODEL:
|
||||
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
|
||||
@@ -177,9 +197,11 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
|
||||
|
||||
Returns:
|
||||
True 表示成功,False 表示失败或跳过。
|
||||
|
||||
并发设计:远程 embedding 调用在锁外(多 worker 并行),chroma 集合访问
|
||||
(含首次 init)在 ``_chroma._lock`` 内串行化。
|
||||
"""
|
||||
col = get_collection()
|
||||
if col is None:
|
||||
if not settings.CHROMA_ENABLED:
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -227,17 +249,25 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
|
||||
logger.warning("Empty index text for %s, skip", paper_id)
|
||||
return False
|
||||
|
||||
vec = _get_embedding(index_text)
|
||||
vec = _get_embedding(index_text) # 远程 HTTP,锁外并行
|
||||
if vec is None:
|
||||
return False
|
||||
|
||||
col.upsert(
|
||||
ids=[arxiv_id],
|
||||
embeddings=[vec],
|
||||
metadatas=[
|
||||
{"arxiv_id": arxiv_id, "title_zh": title_zh, "paper_date": paper_date}
|
||||
],
|
||||
)
|
||||
with _chroma._lock: # 串行化集合访问(首次含 init)
|
||||
col = _chroma.get_collection()
|
||||
if col is None:
|
||||
return False
|
||||
col.upsert(
|
||||
ids=[arxiv_id],
|
||||
embeddings=[vec],
|
||||
metadatas=[
|
||||
{
|
||||
"arxiv_id": arxiv_id,
|
||||
"title_zh": title_zh,
|
||||
"paper_date": paper_date,
|
||||
}
|
||||
],
|
||||
)
|
||||
logger.info("Indexed paper %s in ChromaDB", arxiv_id)
|
||||
return True
|
||||
|
||||
@@ -255,17 +285,20 @@ def delete_paper(paper_id: str) -> bool:
|
||||
Args:
|
||||
paper_id: arxiv_id
|
||||
"""
|
||||
col = get_collection()
|
||||
if col is None:
|
||||
if not settings.CHROMA_ENABLED:
|
||||
return False
|
||||
|
||||
try:
|
||||
col.delete(ids=[paper_id])
|
||||
logger.info("Deleted paper %s from ChromaDB", paper_id)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to delete paper %s from ChromaDB", paper_id)
|
||||
return False
|
||||
with _chroma._lock:
|
||||
col = _chroma.get_collection()
|
||||
if col is None:
|
||||
return False
|
||||
try:
|
||||
col.delete(ids=[paper_id])
|
||||
logger.info("Deleted paper %s from ChromaDB", paper_id)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to delete paper %s from ChromaDB", paper_id)
|
||||
return False
|
||||
|
||||
|
||||
# ── 相似查询 ────────────────────────────────────────────────────────────
|
||||
@@ -280,21 +313,26 @@ def search_similar(query_text: str, top_k: int = 20) -> list[dict]:
|
||||
|
||||
Returns:
|
||||
[{"arxiv_id": str, "distance": float}, ...]
|
||||
|
||||
并发设计:远程 embedding 在锁外,集合查询在 ``_chroma._lock`` 内。
|
||||
"""
|
||||
col = get_collection()
|
||||
if col is None:
|
||||
if not settings.CHROMA_ENABLED:
|
||||
return []
|
||||
|
||||
try:
|
||||
vec = _get_embedding(query_text)
|
||||
vec = _get_embedding(query_text) # 远程 HTTP,锁外
|
||||
if vec is None:
|
||||
return []
|
||||
|
||||
results = col.query(
|
||||
query_embeddings=[vec],
|
||||
n_results=min(top_k, col.count()) if col.count() > 0 else top_k,
|
||||
include=["metadatas", "distances"],
|
||||
)
|
||||
with _chroma._lock:
|
||||
col = _chroma.get_collection()
|
||||
if col is None:
|
||||
return []
|
||||
results = col.query(
|
||||
query_embeddings=[vec],
|
||||
n_results=min(top_k, col.count()) if col.count() > 0 else top_k,
|
||||
include=["metadatas", "distances"],
|
||||
)
|
||||
|
||||
if not results["ids"] or not results["ids"][0]:
|
||||
return []
|
||||
|
||||
+15
-1
@@ -148,7 +148,7 @@ async def run_job(db: Session, job_id: int) -> dict:
|
||||
|
||||
async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
|
||||
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
|
||||
from app.services.crawler import refresh_upvotes
|
||||
from app.services.crawler import recrawl_single, refresh_upvotes
|
||||
from app.services.derived import reindex_chroma, reindex_fts
|
||||
from app.services.pipeline import run_crawl, run_pipeline
|
||||
from app.services.summarizer import summarize_batch, summarize_single
|
||||
@@ -193,6 +193,20 @@ async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
|
||||
return reindex_fts(db)
|
||||
if job.type == "reindex_chroma":
|
||||
return reindex_chroma(db)
|
||||
if job.type == "recrawl_one":
|
||||
return await recrawl_single(db, payload["arxiv_id"])
|
||||
if job.type == "recrawl_batch":
|
||||
updated = 0
|
||||
skipped = 0
|
||||
results = []
|
||||
for arxiv_id in payload.get("arxiv_ids", []):
|
||||
res = await recrawl_single(db, arxiv_id)
|
||||
results.append(res)
|
||||
if res.get("updated"):
|
||||
updated += 1
|
||||
else:
|
||||
skipped += 1
|
||||
return {"updated": updated, "skipped": skipped, "results": results}
|
||||
|
||||
raise ValueError(f"Unsupported job type: {job.type}")
|
||||
|
||||
|
||||
@@ -23,8 +23,10 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
@@ -47,14 +49,18 @@ _FALLBACK_NAMES: dict[int, str] = {
|
||||
8: "isolate_formula",
|
||||
9: "formula_caption",
|
||||
}
|
||||
# 下游只需 picture/table —— 按 class name 字符串动态匹配(不依赖 class index,
|
||||
# 下游需要 picture/table 及其 caption —— 按 class name 字符串动态匹配(不依赖 class index,
|
||||
# 规避 DocStructBench 不同发布的类别顺序差异)
|
||||
_PICTURE_NAMES = {"figure", "figure_group"}
|
||||
_TABLE_NAMES = {"table", "table_group"}
|
||||
_FIGURE_CAPTION_NAMES = {"figure_caption"}
|
||||
_TABLE_CAPTION_NAMES = {"table_caption"}
|
||||
# letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降)
|
||||
_PAD_VALUE = 114
|
||||
# 最小 bbox 尺寸(PDF 点)
|
||||
_MIN_BOX_SIZE = 20
|
||||
_MIN_CAPTION_BOX_WIDTH = 30
|
||||
_MIN_CAPTION_BOX_HEIGHT = 6
|
||||
|
||||
# device → ExecutionProvider 映射
|
||||
_PROVIDER_MAP: dict[str, str] = {
|
||||
@@ -72,7 +78,7 @@ _AUTO_PRIORITY = ["cuda", "directml", "openvino", "cann", "tensorrt", "qnn"]
|
||||
|
||||
@dataclass
|
||||
class LayoutBox:
|
||||
"""检测到的布局区域,坐标为 PDF 点,boxclass ∈ {"picture", "table"}。"""
|
||||
"""检测到的布局区域,坐标为 PDF 点。"""
|
||||
|
||||
x0: float
|
||||
y0: float
|
||||
@@ -191,13 +197,17 @@ def _postprocess_output(
|
||||
|
||||
|
||||
def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None:
|
||||
"""按 class name 匹配 figure→picture / table→table,其余返回 None。"""
|
||||
"""按 class name 匹配下游关心的布局类别,其余返回 None。"""
|
||||
name = names.get(cls_id, "")
|
||||
n = name.strip().lower()
|
||||
if n in _PICTURE_NAMES:
|
||||
return "picture"
|
||||
if n in _TABLE_NAMES:
|
||||
return "table"
|
||||
if n in _FIGURE_CAPTION_NAMES:
|
||||
return "figure_caption"
|
||||
if n in _TABLE_CAPTION_NAMES:
|
||||
return "table_caption"
|
||||
return None
|
||||
|
||||
|
||||
@@ -220,15 +230,50 @@ def _parse_names_from_meta(session: ort.InferenceSession) -> dict[int, str]:
|
||||
# ── 检测器单例 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _LayoutDetector:
|
||||
"""单例:管理 ONNX InferenceSession 生命周期。"""
|
||||
class _Singleton(type):
|
||||
"""元类单例:``cls()`` 永远返回同一实例;``reset_instance()`` 清缓存以便重建。
|
||||
|
||||
生产代码只应在模块级 ``_detector = _LayoutDetector()`` 创建一次。任何第二处
|
||||
``_LayoutDetector()`` 都会拿到同一实例(含同一 ONNX session + 同一锁),杜绝
|
||||
并发推理时各建一份 session 导致内存峰值翻倍(8GB 机器崩溃根因)。双检锁保证
|
||||
首次实例化线程安全。
|
||||
"""
|
||||
|
||||
_instances: dict[type, Any] = {}
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls in _Singleton._instances:
|
||||
return _Singleton._instances[cls]
|
||||
with _Singleton._lock:
|
||||
if cls not in _Singleton._instances:
|
||||
_Singleton._instances[cls] = super().__call__(*args, **kwargs)
|
||||
return _Singleton._instances[cls]
|
||||
|
||||
|
||||
class _LayoutDetector(metaclass=_Singleton):
|
||||
"""强约束单例:管理 ONNX InferenceSession 生命周期。
|
||||
|
||||
由 ``_Singleton`` 元类保证全进程唯一实例 —— 重复 ``_LayoutDetector()`` 只会返回
|
||||
已有实例(含已加载的 session 和锁),不会新建。``reset_instance()`` 清缓存,仅供
|
||||
测试隔离用。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._session: ort.InferenceSession | None = None
|
||||
self._names: dict[int, str] = {}
|
||||
self._input_name: str = ""
|
||||
self._imgsz: int = settings.LAYOUT_IMGSZ
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""清空单例缓存,下次 ``_LayoutDetector()`` 重建新实例(含新锁 + 空 session)。
|
||||
|
||||
仅用于测试隔离 —— 生产代码永远不该调用(否则会丢掉已加载的模型 session)。
|
||||
"""
|
||||
_Singleton._instances.pop(cls, None)
|
||||
|
||||
def _init_session(self) -> ort.InferenceSession:
|
||||
if self._session is not None:
|
||||
return self._session
|
||||
@@ -275,7 +320,7 @@ class _LayoutDetector:
|
||||
self._imgsz = settings.LAYOUT_IMGSZ
|
||||
return self._session
|
||||
|
||||
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||||
def _detect_page_impl(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||||
"""检测单页 PDF 的 figure / table 区域。
|
||||
|
||||
流程:
|
||||
@@ -323,21 +368,38 @@ class _LayoutDetector:
|
||||
y0 = max(0.0, min(y0, page_h))
|
||||
x1 = max(0.0, min(x1, page_w))
|
||||
y1 = max(0.0, min(y1, page_h))
|
||||
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
|
||||
continue
|
||||
if boxclass in ("figure_caption", "table_caption"):
|
||||
if (x1 - x0) < _MIN_CAPTION_BOX_WIDTH or (
|
||||
y1 - y0
|
||||
) < _MIN_CAPTION_BOX_HEIGHT:
|
||||
continue
|
||||
else:
|
||||
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
|
||||
continue
|
||||
result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass))
|
||||
|
||||
return result
|
||||
|
||||
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||||
"""公共入口:加锁串行化推理。
|
||||
|
||||
# 模块级单例
|
||||
包裹整段 _detect_page_impl(含 pixmap 渲染 + tensor 构造 + session.run),
|
||||
保证同一时刻只有一个推理在跑——避免 SUMMARY_CONCURRENCY>1 时多个 to_thread
|
||||
线程并发推理导致内存峰值翻倍(8GB 机器崩溃根因)。锁由 _detect_page_impl
|
||||
间接保护 _init_session,首次加载也串行,杜绝并发各建一份 session。
|
||||
"""
|
||||
with self._lock:
|
||||
return self._detect_page_impl(page)
|
||||
|
||||
|
||||
# 模块级单例 —— 生产代码唯一的实例化点(_Singleton 元类保证不会再有第二个)
|
||||
_detector = _LayoutDetector()
|
||||
|
||||
|
||||
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
|
||||
"""检测 PDF 页面中的 figure / table 区域。
|
||||
"""检测 PDF 页面中的 figure / table / caption 区域。
|
||||
|
||||
Returns:
|
||||
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table。
|
||||
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table 及其 caption。
|
||||
"""
|
||||
return _detector.detect_page(page)
|
||||
|
||||
@@ -34,6 +34,8 @@ _CLUSTER_GAP = 15
|
||||
_MIN_BOX_AREA = 2000
|
||||
# Phase 2: 搜索文本到 box 的最大匹配距离(单位: pt)
|
||||
_LABEL_MATCH_DISTANCE = 100
|
||||
# DocLayout caption 与 figure/table 匹配的最大距离(单位: pt)
|
||||
_CAPTION_MATCH_DISTANCE = 120
|
||||
|
||||
|
||||
# ── Box 聚类 ─────────────────────────────────────────────────────────
|
||||
@@ -53,6 +55,15 @@ class _BoxCluster:
|
||||
self.boxclass = "table" if raw == "table-fallback" else raw
|
||||
|
||||
|
||||
def _cluster_to_box(cluster: _BoxCluster) -> list[float]:
|
||||
return [
|
||||
round(float(cluster.x0), 1),
|
||||
round(float(cluster.y0), 1),
|
||||
round(float(cluster.x1), 1),
|
||||
round(float(cluster.y1), 1),
|
||||
]
|
||||
|
||||
|
||||
def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
|
||||
"""将相邻的同类型 box 合并为聚类。"""
|
||||
if not boxes:
|
||||
@@ -92,6 +103,67 @@ def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
|
||||
return [_BoxCluster(members) for members in groups.values()]
|
||||
|
||||
|
||||
def _caption_class_for_content(boxclass: str) -> str:
|
||||
return "figure_caption" if boxclass == "picture" else "table_caption"
|
||||
|
||||
|
||||
def _caption_distance(content: _BoxCluster, caption: _BoxCluster) -> float | None:
|
||||
"""Return a spatial score for pairing a caption with a content box."""
|
||||
h_overlap = min(content.x1, caption.x1) - max(content.x0, caption.x0)
|
||||
min_width = min(content.x1 - content.x0, caption.x1 - caption.x0)
|
||||
if min_width <= 0 or h_overlap < min_width * 0.25:
|
||||
return None
|
||||
|
||||
if caption.y1 < content.y0:
|
||||
v_gap = content.y0 - caption.y1
|
||||
elif caption.y0 > content.y1:
|
||||
v_gap = caption.y0 - content.y1
|
||||
else:
|
||||
v_gap = 0.0
|
||||
|
||||
return v_gap if v_gap <= _CAPTION_MATCH_DISTANCE else None
|
||||
|
||||
|
||||
def _extract_caption_text(page, caption: _BoxCluster) -> str:
|
||||
rect = pymupdf.Rect(caption.x0, caption.y0, caption.x1, caption.y1)
|
||||
try:
|
||||
text = page.get_text("text", clip=rect)
|
||||
except Exception:
|
||||
return ""
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
def _match_captions(
|
||||
page,
|
||||
content_clusters: list[_BoxCluster],
|
||||
caption_clusters: list[_BoxCluster],
|
||||
) -> dict[int, tuple[_BoxCluster, str]]:
|
||||
"""Match each content cluster to its nearest same-type DocLayout caption."""
|
||||
matches: dict[int, tuple[_BoxCluster, str]] = {}
|
||||
used_captions: set[int] = set()
|
||||
candidates: list[tuple[float, int, int]] = []
|
||||
|
||||
for content_idx, content in enumerate(content_clusters):
|
||||
wanted_caption_class = _caption_class_for_content(content.boxclass)
|
||||
for caption_idx, caption in enumerate(caption_clusters):
|
||||
if caption.boxclass != wanted_caption_class:
|
||||
continue
|
||||
dist = _caption_distance(content, caption)
|
||||
if dist is not None:
|
||||
candidates.append((dist, content_idx, caption_idx))
|
||||
|
||||
for _dist, content_idx, caption_idx in sorted(candidates):
|
||||
if content_idx in matches or caption_idx in used_captions:
|
||||
continue
|
||||
text = _extract_caption_text(page, caption_clusters[caption_idx])
|
||||
if not text:
|
||||
continue
|
||||
matches[content_idx] = (caption_clusters[caption_idx], text)
|
||||
used_captions.add(caption_idx)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
# ── Phase 1: 检测 + 渲染 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -102,14 +174,25 @@ def _render_box(
|
||||
filename: str,
|
||||
cap_type: str,
|
||||
page_num: int,
|
||||
caption: _BoxCluster | None = None,
|
||||
) -> bool:
|
||||
"""渲染单个 box 区域并保存 JPEG,成功返回 True。"""
|
||||
"""渲染单个 box 区域并保存 JPEG,成功返回 True。
|
||||
|
||||
若提供 caption,则将内容与 caption 区域合并后一起截取,
|
||||
使同一张截图同时包含图/表及其标题文字。
|
||||
"""
|
||||
page_width = page.rect.width
|
||||
x0, y0, x1, y1 = box.x0, box.y0, box.x1, box.y1
|
||||
if caption is not None:
|
||||
x0 = min(x0, caption.x0)
|
||||
y0 = min(y0, caption.y0)
|
||||
x1 = max(x1, caption.x1)
|
||||
y1 = max(y1, caption.y1)
|
||||
clip = pymupdf.Rect(
|
||||
max(0, box.x0 - _REGION_PADDING),
|
||||
max(0, box.y0 - _REGION_PADDING),
|
||||
min(page_width, box.x1 + _REGION_PADDING),
|
||||
box.y1 + _REGION_PADDING,
|
||||
max(0, x0 - _REGION_PADDING),
|
||||
max(0, y0 - _REGION_PADDING),
|
||||
min(page_width, x1 + _REGION_PADDING),
|
||||
y1 + _REGION_PADDING,
|
||||
)
|
||||
mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM)
|
||||
try:
|
||||
@@ -136,25 +219,31 @@ def _process_page(
|
||||
fig_counter = 0
|
||||
tbl_counter = 0
|
||||
|
||||
# 收集本页的 table/picture box(跳过极小区域)
|
||||
# 收集本页的 table/picture box 与 caption box(跳过极小区域)
|
||||
raw_boxes = []
|
||||
raw_caption_boxes = []
|
||||
for box in page_boxes:
|
||||
if box.boxclass not in ("table", "table-fallback", "picture"):
|
||||
continue
|
||||
w = box.x1 - box.x0
|
||||
h = box.y1 - box.y0
|
||||
if w < 20 or h < 20 or w * h < _MIN_BOX_AREA:
|
||||
continue
|
||||
raw_boxes.append(box)
|
||||
if box.boxclass in ("table", "table-fallback", "picture"):
|
||||
if w < 20 or h < 20 or w * h < _MIN_BOX_AREA:
|
||||
continue
|
||||
raw_boxes.append(box)
|
||||
elif box.boxclass in ("figure_caption", "table_caption"):
|
||||
if w < 30 or h < 6:
|
||||
continue
|
||||
raw_caption_boxes.append(box)
|
||||
|
||||
if not raw_boxes:
|
||||
return 0
|
||||
|
||||
# 聚类:将同一 figure/table 的碎片 box 合并
|
||||
clusters = _cluster_boxes(raw_boxes)
|
||||
caption_clusters = _cluster_boxes(raw_caption_boxes)
|
||||
caption_matches = _match_captions(page, clusters, caption_clusters)
|
||||
|
||||
extracted = 0
|
||||
for cluster in clusters:
|
||||
for cluster_idx, cluster in enumerate(clusters):
|
||||
cap_type = "figure" if cluster.boxclass == "picture" else "table"
|
||||
|
||||
if cap_type == "figure":
|
||||
@@ -168,21 +257,33 @@ def _process_page(
|
||||
continue
|
||||
seen_labels.add(label)
|
||||
|
||||
caption_match = caption_matches.get(cluster_idx)
|
||||
caption_cluster = caption_match[0] if caption_match else None
|
||||
|
||||
filename = f"{label.replace(' ', '_').lower()}.jpg"
|
||||
if not _render_box(page, cluster, images_dest, filename, cap_type, page_num):
|
||||
if not _render_box(
|
||||
page,
|
||||
cluster,
|
||||
images_dest,
|
||||
filename,
|
||||
cap_type,
|
||||
page_num,
|
||||
caption=caption_cluster,
|
||||
):
|
||||
continue
|
||||
|
||||
manifest[filename] = {
|
||||
info = {
|
||||
"page": page_num,
|
||||
"type": cap_type,
|
||||
"label": label,
|
||||
"box": [
|
||||
round(float(cluster.x0), 1),
|
||||
round(float(cluster.y0), 1),
|
||||
round(float(cluster.x1), 1),
|
||||
round(float(cluster.y1), 1),
|
||||
],
|
||||
"box": _cluster_to_box(cluster),
|
||||
}
|
||||
if caption_match:
|
||||
info["caption_text"] = caption_match[1][:500]
|
||||
info["caption_box"] = _cluster_to_box(caption_cluster)
|
||||
info["caption_source"] = "doclayout"
|
||||
|
||||
manifest[filename] = info
|
||||
extracted += 1
|
||||
|
||||
return extracted
|
||||
@@ -446,14 +547,20 @@ def label_images_by_summary(
|
||||
cap_type = info.get("type", "figure")
|
||||
|
||||
# 读取 caption 文本(从 figures 列表)
|
||||
caption_text = ""
|
||||
summary_caption_text = ""
|
||||
for fig in figures:
|
||||
if fig.get("id") == fig_id:
|
||||
caption_text = fig.get("caption", "")
|
||||
summary_caption_text = fig.get("caption", "")
|
||||
break
|
||||
|
||||
info["label"] = fig_id
|
||||
info["caption_text"] = caption_text[:200] if caption_text else ""
|
||||
existing_caption_text = info.get("caption_text", "")
|
||||
if existing_caption_text and summary_caption_text:
|
||||
info["summary_caption_text"] = summary_caption_text[:500]
|
||||
else:
|
||||
info["caption_text"] = (
|
||||
summary_caption_text[:500] if summary_caption_text else ""
|
||||
)
|
||||
info.setdefault("figures" if cap_type == "figure" else "tables", []).append(
|
||||
fig_id
|
||||
)
|
||||
|
||||
@@ -31,6 +31,7 @@ from app.services.summary_persister import (
|
||||
_cleanup_old_images,
|
||||
_handle_summary_failure,
|
||||
_persist_summary,
|
||||
_run_post_processing,
|
||||
)
|
||||
from app.utils import TMP_DIR, release_lock, truncate_error, utc_now
|
||||
|
||||
@@ -115,12 +116,31 @@ async def _do_summarize_one(db: Session, paper: Paper, pdf_mode: str = "auto") -
|
||||
_t3 = _time.monotonic()
|
||||
logger.info(" [%s] pi生成: %.2fs", arxiv_id, _t3 - _t2)
|
||||
|
||||
quality = _persist_summary(db, paper, json_data, raw_output)
|
||||
quality, schema = _persist_summary(db, paper, json_data, raw_output)
|
||||
_t4 = _time.monotonic()
|
||||
logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3)
|
||||
|
||||
# 后处理(图片提取 + ChromaDB 索引)搬到线程池跑,避免 CPU 密集推理冻结
|
||||
# 事件循环。paper 字段在此(事件循环线程)提取成纯值再传入,规避 worker
|
||||
# 线程跨线程访问 ORM 的 DetachedInstanceError。DocLayout 推理由单例的
|
||||
# threading.Lock 串行化,并发 worker 不会同时压模型。
|
||||
paper_meta = {
|
||||
"title_en": paper.title_en or "",
|
||||
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
|
||||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
|
||||
}
|
||||
_t5 = _time.monotonic()
|
||||
try:
|
||||
await asyncio.to_thread(_run_post_processing, arxiv_id, schema, paper_meta)
|
||||
except Exception:
|
||||
# 双保险:_run_post_processing 内部已 try/except,此处兜底,
|
||||
# 确保后处理失败绝不影响已 DONE 的总结。
|
||||
logger.warning("Post-processing error for %s", arxiv_id, exc_info=True)
|
||||
_t6 = _time.monotonic()
|
||||
logger.info(" [%s] 后处理(线程池): %.2fs", arxiv_id, _t6 - _t5)
|
||||
|
||||
logger.info(
|
||||
"✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0
|
||||
"✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t6 - _t0
|
||||
)
|
||||
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
|
||||
|
||||
|
||||
@@ -131,8 +131,12 @@ def _handle_summary_failure(
|
||||
|
||||
def _persist_summary(
|
||||
db: Session, paper: Paper, json_data: dict, raw_output: str
|
||||
) -> str:
|
||||
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
|
||||
) -> tuple[str, SummarySchema]:
|
||||
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 (quality, schema)。
|
||||
|
||||
后处理(图片提取/ChromaDB)不再在此函数内执行,由调用方搬到线程池,
|
||||
以免阻塞事件循环。返回 schema 供调用方在线程池里跑后处理。
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
arxiv_id = paper.arxiv_id
|
||||
@@ -165,21 +169,10 @@ def _persist_summary(
|
||||
_t4 - _t3,
|
||||
)
|
||||
|
||||
# 触发性增强(失败不影响总结)
|
||||
_t5 = _time.monotonic()
|
||||
_maybe_extract_images(arxiv_id, schema)
|
||||
_t6 = _time.monotonic()
|
||||
_maybe_index_chroma(arxiv_id, paper, schema)
|
||||
_t7 = _time.monotonic()
|
||||
|
||||
logger.info(
|
||||
" [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs",
|
||||
arxiv_id,
|
||||
_t6 - _t5,
|
||||
_t7 - _t6,
|
||||
)
|
||||
|
||||
return quality
|
||||
# 后处理(图片提取 + ChromaDB 索引)已上移到调用方 _do_summarize_one,
|
||||
# 经 asyncio.to_thread 在线程池跑——DB session 必须留在事件循环线程,
|
||||
# 而 CPU/IO 密集的后处理搬走才不冻结事件循环。
|
||||
return quality, schema
|
||||
|
||||
|
||||
# ── 清理 ────────────────────────────────────────────────────────────────
|
||||
@@ -226,21 +219,44 @@ def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
|
||||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||||
|
||||
|
||||
def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None:
|
||||
"""写入 ChromaDB 语义索引(失败不影响总结)。"""
|
||||
def _maybe_index_chroma(arxiv_id: str, schema: SummarySchema, paper_meta: dict) -> None:
|
||||
"""写入 ChromaDB 语义索引(失败不影响总结)。
|
||||
|
||||
paper_meta 是调用方在事件循环线程从 ORM 提取的纯值(title_en/tags/paper_date),
|
||||
规避此函数在线程池跑时跨线程访问 ORM 的 DetachedInstanceError 风险。
|
||||
"""
|
||||
try:
|
||||
from app.services.embedder import index_paper
|
||||
|
||||
texts_dict = {
|
||||
"arxiv_id": arxiv_id,
|
||||
"title_zh": schema.title_zh or "",
|
||||
"title_en": paper.title_en or "",
|
||||
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
|
||||
"title_en": paper_meta.get("title_en", ""),
|
||||
"tags": paper_meta.get("tags", ""),
|
||||
"one_line": schema.one_line or "",
|
||||
"motivation_problem": schema.motivation.problem or "",
|
||||
"method_key_idea": schema.method.key_idea or "",
|
||||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
|
||||
"paper_date": paper_meta.get("paper_date", ""),
|
||||
}
|
||||
index_paper(arxiv_id, texts_dict)
|
||||
except Exception:
|
||||
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
|
||||
|
||||
|
||||
def _run_post_processing(
|
||||
arxiv_id: str, schema: SummarySchema, paper_meta: dict
|
||||
) -> None:
|
||||
"""线程池里跑的 CPU/IO 密集后处理(由 _do_summarize_one 经 asyncio.to_thread 调用)。
|
||||
|
||||
顺序与原 _persist_summary 内部一致:图片提取 → ChromaDB 索引。两者各自
|
||||
try/except(失败不影响已成功的总结),此处再包一层做双保险。
|
||||
"""
|
||||
try:
|
||||
_maybe_extract_images(arxiv_id, schema)
|
||||
_maybe_index_chroma(arxiv_id, schema, paper_meta)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Post-processing failed for %s (summary already persisted)",
|
||||
arxiv_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user