feat: add concurrency safety, caption detection, admin enhancements, and performance improvements

This commit is contained in:
2026-06-14 22:20:02 +08:00
parent 8f13c31991
commit 29fb20828e
23 changed files with 1782 additions and 114 deletions
+1
View File
@@ -26,6 +26,7 @@ class Settings(BaseSettings):
HTTP_TIMEOUT_SECONDS: int = 30 HTTP_TIMEOUT_SECONDS: int = 30
HTTP_MAX_RETRIES: int = 3 HTTP_MAX_RETRIES: int = 3
HTTP_USER_AGENT: str = "hf-daily-papers-local/0.1" HTTP_USER_AGENT: str = "hf-daily-papers-local/0.1"
HF_PROXY: str = ""
PDF_DOWNLOAD_TIMEOUT: int = 120 PDF_DOWNLOAD_TIMEOUT: int = 120
# AI 总结 # AI 总结
+200 -5
View File
@@ -2,8 +2,10 @@
from __future__ import annotations from __future__ import annotations
import csv
import hashlib import hashlib
import hmac import hmac
import io
from datetime import date from datetime import date
from fastapi import ( from fastapi import (
@@ -15,7 +17,7 @@ from fastapi import (
Query, Query,
Request, Request,
) )
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse, Response
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -298,6 +300,183 @@ async def admin_job_detail(
return detail return detail
# ── 任务监控 ──────────────────────────────────────────────────────────
@router.get("/jobs")
async def admin_jobs(
request: Request,
_admin: None = Depends(verify_admin),
db: Session = Depends(get_db),
status: str = Query("all"),
job_type: str = Query("all"),
page: int = Query(1, ge=1),
per_page: int = Query(20, ge=1, le=100),
):
"""后台任务监控页。"""
jobs, total = admin_svc.query_jobs(
db, status=status, job_type=job_type, page=page, per_page=per_page
)
counts = admin_svc.get_job_status_counts(db)
def pagination_url(p: int) -> str:
params = dict(request.query_params)
params["page"] = str(p)
return "/admin/jobs?" + "&".join(f"{k}={v}" for k, v in params.items())
return templates.TemplateResponse(
request,
"admin_jobs.html",
{
"jobs": jobs,
"total": total,
"page": page,
"per_page": per_page,
"current_status": status,
"current_type": job_type,
"status_counts": counts,
"pagination_url": pagination_url,
},
)
# ── 锁管理 ────────────────────────────────────────────────────────────
@router.post("/locks/{lock_id}/release")
async def admin_release_lock(
lock_id: int,
_admin: None = Depends(verify_admin),
db: Session = Depends(get_db),
):
"""强制释放一个卡死的任务锁。"""
if not admin_svc.force_release_lock(db, lock_id):
raise HTTPException(
status_code=404, detail=f"Lock not found or already released: {lock_id}"
)
return {"status": "success", "lock_id": lock_id}
# ── 重抓 ──────────────────────────────────────────────────────────────
@router.post("/paper-recrawl/{arxiv_id}")
async def admin_paper_recrawl(
arxiv_id: str,
background_tasks: BackgroundTasks,
_admin: None = Depends(verify_admin),
db: Session = Depends(get_db),
):
"""重新抓取单篇已存在论文的完整元数据。"""
job = create_job(
db, "recrawl_one", owner="admin_recrawl", payload={"arxiv_id": arxiv_id}
)
enqueue_job(background_tasks, job.id)
return {"status": "queued", "job_id": job.id, "arxiv_id": arxiv_id}
# ── 索引重建 ──────────────────────────────────────────────────────────
class RebuildIndexRequest(BaseModel):
target: str # "fts" / "chroma" / "both"
@field_validator("target")
@classmethod
def target_must_be_valid(cls, v: str) -> str:
if v not in ("fts", "chroma", "both"):
raise ValueError("target must be 'fts', 'chroma' or 'both'")
return v
@router.post("/rebuild-indexes")
async def admin_rebuild_indexes(
body: RebuildIndexRequest,
background_tasks: BackgroundTasks,
_admin: None = Depends(verify_admin),
db: Session = Depends(get_db),
):
"""重建搜索索引(FTS5 / ChromaDB)。"""
job_ids: list[int] = []
if body.target in ("fts", "both"):
job = create_job(db, "reindex_fts", owner="admin_reindex", payload={})
enqueue_job(background_tasks, job.id)
job_ids.append(job.id)
if body.target in ("chroma", "both"):
job = create_job(db, "reindex_chroma", owner="admin_reindex", payload={})
enqueue_job(background_tasks, job.id)
job_ids.append(job.id)
return {"status": "queued", "job_ids": job_ids, "target": body.target}
# ── 导出 CSV ──────────────────────────────────────────────────────────
@router.get("/papers/export.csv")
async def admin_papers_export(
_admin: None = Depends(verify_admin),
db: Session = Depends(get_db),
q: str = Query(""),
date_from: str | None = Query(None),
date_to: str | None = Query(None),
tag: str = Query(""),
summary_status: str = Query("all"),
sort: str = Query("date_desc"),
):
"""导出当前过滤条件下的论文为 CSV(含 UTF-8 BOMExcel 友好)。"""
papers, _total, statuses = admin_svc.query_papers(
db,
q=q,
date_from=date_from,
date_to=date_to,
tag=tag,
summary_status=summary_status,
sort=sort,
page=1,
per_page=10**6,
)
buf = io.StringIO()
buf.write("") # UTF-8 BOM for Excel
writer = csv.writer(buf)
writer.writerow(
[
"arxiv_id",
"title_en",
"title_zh",
"paper_date",
"upvotes",
"summary_status",
"authors",
"tags",
"pdf_url",
]
)
for paper in papers:
authors = ";".join(a.name for a in paper.authors)
tags = ";".join(t.tag for t in paper.tags)
writer.writerow(
[
paper.arxiv_id,
paper.title_en or "",
paper.title_zh or "",
str(paper.paper_date) if paper.paper_date else "",
paper.upvotes or 0,
statuses.get(paper.arxiv_id, "none"),
authors,
tags,
paper.pdf_url or "",
]
)
filename = f"papers_{today_str().replace('-', '')}.csv"
return Response(
content=buf.getvalue(),
media_type="text/csv; charset=utf-8",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
# ── 日志 ────────────────────────────────────────────────────────────── # ── 日志 ──────────────────────────────────────────────────────────────
@@ -438,24 +617,25 @@ async def admin_paper_delete(
class BatchActionRequest(BaseModel): class BatchActionRequest(BaseModel):
action: str # "delete" or "summarize" action: str # "delete" / "summarize" / "recrawl"
arxiv_ids: list[str] arxiv_ids: list[str]
@field_validator("action") @field_validator("action")
@classmethod @classmethod
def action_must_be_valid(cls, v: str) -> str: def action_must_be_valid(cls, v: str) -> str:
if v not in ("delete", "summarize"): if v not in ("delete", "summarize", "recrawl"):
raise ValueError("action must be 'delete' or 'summarize'") raise ValueError("action must be 'delete', 'summarize' or 'recrawl'")
return v return v
@router.post("/papers-batch-action") @router.post("/papers-batch-action")
async def admin_papers_batch_action( async def admin_papers_batch_action(
body: BatchActionRequest, body: BatchActionRequest,
background_tasks: BackgroundTasks,
_admin: None = Depends(verify_admin), _admin: None = Depends(verify_admin),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""批量操作论文(删除或总结)。""" """批量操作论文(删除 / 总结 / 重抓)。"""
if not body.arxiv_ids: if not body.arxiv_ids:
raise HTTPException(status_code=400, detail="arxiv_ids 不能为空") raise HTTPException(status_code=400, detail="arxiv_ids 不能为空")
@@ -475,3 +655,18 @@ async def admin_papers_batch_action(
"message": f"已将 {count} 篇论文重置为待总结", "message": f"已将 {count} 篇论文重置为待总结",
"count": count, "count": count,
} }
elif body.action == "recrawl":
job = create_job(
db,
"recrawl_batch",
owner="admin_recrawl",
payload={"arxiv_ids": body.arxiv_ids},
)
enqueue_job(background_tasks, job.id)
return {
"status": "queued",
"job_id": job.id,
"count": len(body.arxiv_ids),
"message": f"已将 {len(body.arxiv_ids)} 篇论文加入重抓队列",
}
+129 -2
View File
@@ -7,7 +7,7 @@ from datetime import date
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable
from sqlalchemy import func, select, text from sqlalchemy import func, select, text, update
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.config import settings from app.config import settings
@@ -100,7 +100,9 @@ def get_admin_stats(db: Session) -> dict:
# ── 活跃锁 ──────────────────────────────────────────────────────── # ── 活跃锁 ────────────────────────────────────────────────────────
active_locks = ( active_locks = (
db.execute(select(TaskLock).where(TaskLock.status == "running")).scalars().all() db.execute(select(TaskLock).where(TaskLock.status.in_(["running", "stale"])))
.scalars()
.all()
) )
return { return {
@@ -124,6 +126,7 @@ def get_admin_stats(db: Session) -> dict:
"recent_logs": recent_logs, "recent_logs": recent_logs,
"active_locks": active_locks, "active_locks": active_locks,
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS, "upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
"config_overview": get_config_overview(),
} }
@@ -370,6 +373,7 @@ def get_logs_context(db: Session, *, page: int, per_page: int) -> dict:
"summary_done": summary_done, "summary_done": summary_done,
"summary_pending": summary_pending, "summary_pending": summary_pending,
"summary_failed": summary_failed, "summary_failed": summary_failed,
"failure_breakdown": get_failure_breakdown(db),
} }
@@ -511,3 +515,126 @@ def reset_summaries_pending(db: Session, arxiv_ids: list[str]) -> int:
db.add(SummaryStatus(paper_id=paper_id, status=SummaryState.PENDING)) db.add(SummaryStatus(paper_id=paper_id, status=SummaryState.PENDING))
db.commit() db.commit()
return len(paper_ids) return len(paper_ids)
# ── 任务监控 ──────────────────────────────────────────────────────────
def query_jobs(
db: Session,
*,
status: str | None = None,
job_type: str | None = None,
page: int = 1,
per_page: int = 20,
) -> tuple[list[dict], int]:
"""后台任务列表查询 — 支持 status/type 过滤 + 分页,返回已 enrich 的 dict 列表。"""
query = select(Job)
if status and status != "all":
query = query.where(Job.status == status)
if job_type and job_type != "all":
query = query.where(Job.type == job_type)
total = db.scalar(select(func.count()).select_from(query.subquery())) or 0
jobs = (
db.execute(
query.order_by(Job.created_at.desc())
.offset((page - 1) * per_page)
.limit(per_page)
)
.scalars()
.all()
)
return [serialize_job(j) for j in jobs], total
def _as_naive(dt):
"""去掉 tzinfo — SQLite 读回的 datetime 是 naive UTC,与 utc_now() 运算前需统一。"""
if dt is not None and getattr(dt, "tzinfo", None) is not None:
return dt.replace(tzinfo=None)
return dt
def serialize_job(job: Job) -> dict:
"""单条 job 序列化为展示用 dict(含耗时)。"""
duration = None
started = _as_naive(job.started_at)
if started:
end = _as_naive(job.completed_at) or _as_naive(utc_now())
duration = round((end - started).total_seconds(), 1)
return {
"id": job.id,
"type": job.type,
"status": job.status,
"owner": job.owner,
"created_at": job.created_at,
"started_at": job.started_at,
"completed_at": job.completed_at,
"duration_seconds": duration,
"error": job.error,
}
def get_job_status_counts(db: Session) -> dict:
"""按 status 聚合 job 计数,供任务页顶部小统计行用。"""
rows = db.execute(
select(Job.status, func.count(Job.id)).group_by(Job.status)
).fetchall()
return {row[0]: row[1] for row in rows}
# ── 锁管理 ────────────────────────────────────────────────────────────
def force_release_lock(db: Session, lock_id: int) -> bool:
"""强制释放一个卡死的 TaskLock(仅对 running/stale 生效)。"""
result = db.execute(
update(TaskLock)
.where(TaskLock.id == lock_id, TaskLock.status.in_(["running", "stale"]))
.values(status="finished", released_at=utc_now())
)
db.commit()
return (result.rowcount or 0) > 0
# ── 失败原因分布 ──────────────────────────────────────────────────────
def get_failure_breakdown(db: Session) -> list[dict]:
"""按 error_type 聚合失败/永久失败的总结,按数量降序。NULL 归 unknown。"""
error_expr = func.coalesce(SummaryStatus.error_type, "unknown")
rows = db.execute(
select(error_expr, func.count(SummaryStatus.id))
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
.group_by(error_expr)
.order_by(func.count(SummaryStatus.id).desc())
).fetchall()
return [{"error_type": row[0], "count": row[1]} for row in rows]
# ── 运行配置概览 ──────────────────────────────────────────────────────
def get_config_overview() -> dict:
"""聚合非敏感配置,供仪表盘展示。敏感字段只标是否已配置,不显示值。"""
return {
"summary_backend": settings.SUMMARY_BACKEND,
"summary_pdf_mode": settings.SUMMARY_PDF_MODE,
"summary_concurrency": settings.SUMMARY_CONCURRENCY,
"summary_timeout_seconds": settings.SUMMARY_TIMEOUT_SECONDS,
"summary_max_retries": settings.SUMMARY_MAX_RETRIES,
"scheduler_enabled": settings.SCHEDULER_ENABLED,
"schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}",
"chroma_enabled": settings.CHROMA_ENABLED,
"embed_model": settings.EMBED_MODEL or "(未配置)",
"top_n": settings.TOP_N,
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
"app_workers": settings.APP_WORKERS,
"layout_model": Path(settings.LAYOUT_MODEL_PATH).name,
"database_url": settings.DATABASE_URL,
"api_key_configured": bool(settings.EMBED_API_KEY),
}
+64
View File
@@ -270,3 +270,67 @@ def _update_upvotes_only(db: Session, papers_raw: list[dict]) -> int:
db.commit() db.commit()
return updated return updated
async def recrawl_single(db: Session, arxiv_id: str) -> dict:
"""重新抓取一篇已存在论文的完整元数据。
基于 paper.paper_date 重新拉取 HF Daily 列表,命中后全字段刷新
(标题/摘要/作者/标签/链接/upvotes)并重建 FTS。若该论文不在其收录日的
列表中则无法重抓。
"""
paper = db.execute(
select(Paper).where(Paper.arxiv_id == arxiv_id)
).scalar_one_or_none()
if not paper:
return {"updated": False, "reason": "not_found", "arxiv_id": arxiv_id}
target_date = paper.paper_date.isoformat()
raw_papers = await fetch_daily(target_date)
target = None
for item in raw_papers:
if _parse_paper(item)["arxiv_id"] == arxiv_id:
target = item
break
if target is None:
return {
"updated": False,
"reason": "not_in_daily",
"arxiv_id": arxiv_id,
"date": target_date,
}
meta = _parse_paper(target)
now = utc_now()
# 全字段刷新
paper.title_en = meta["title_en"]
paper.abstract = meta["abstract"]
paper.published_at = meta["published_at"]
paper.hf_url = meta["hf_url"]
paper.arxiv_url = meta["arxiv_url"]
paper.pdf_url = meta["pdf_url"]
paper.upvotes = meta["upvotes"]
paper.crawled_at = now
# 重建 authors(删旧再加新)
paper.authors.clear()
seen_authors: set[str] = set()
for idx, name in enumerate(meta["authors"]):
if name and name not in seen_authors:
seen_authors.add(name)
db.add(PaperAuthor(paper_id=paper.id, name=name, position=idx))
# 重建 tags
paper.tags.clear()
for tag_name in meta["tags"]:
if tag_name:
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="hf"))
db.flush()
reindex_paper_fts(db, paper)
db.commit()
logger.info("Re-crawled paper %s (full metadata refresh)", arxiv_id)
return {"updated": True, "arxiv_id": arxiv_id, "date": target_date}
+78 -40
View File
@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import threading
from pathlib import Path from pathlib import Path
from sqlalchemy import select from sqlalchemy import select
@@ -18,14 +19,27 @@ logger = logging.getLogger(__name__)
class ChromaManager: class ChromaManager:
"""封装 ChromaDB 客户端和 collection 的生命周期。""" """封装 ChromaDB 客户端和 collection 的生命周期。
所有客户端/集合访问经 ``self._lock`` 串行化:后处理经 ``asyncio.to_thread``
多 worker 并发调 ``index_paper``,若不串行化会并发建连,触发 chromadb 1.5.x
``SharedSystemClient`` 类级缓存的并发竞争(``_create_system_if_not_exists``
无锁 + refcount release 弹 key)→ ``KeyError: '<persist_dir>'``。
锁用 RLock``index_paper`` 持锁后经 ``get_collection()`` 间接再调 ``init()``
同线程可重入,不死锁。
"""
def __init__(self) -> None: def __init__(self) -> None:
self._lock = threading.RLock()
self._client = None self._client = None
self._collection = None self._collection = None
def init(self) -> None: def init(self) -> None:
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。""" """CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。
双重检查锁串行化首次建连 —— 外层快判已建好就直接返回(快路径),抢到锁后再
判一次(防并发下另一线程已建好),确保 ``PersistentClient`` 全进程只调一次。
"""
if not settings.CHROMA_ENABLED: if not settings.CHROMA_ENABLED:
logger.debug("ChromaDB disabled, skip init") logger.debug("ChromaDB disabled, skip init")
return return
@@ -33,19 +47,23 @@ class ChromaManager:
if self._client is not None: if self._client is not None:
return return
try: with self._lock:
import chromadb if self._client is not None: # 双重检查:抢到锁后可能已被别的线程建好
return
chroma_path = Path(settings.CHROMA_DIR) try:
chroma_path.mkdir(parents=True, exist_ok=True) import chromadb
self._client = chromadb.PersistentClient(path=str(chroma_path)) chroma_path = Path(settings.CHROMA_DIR)
self._collection = self._get_or_create_collection() chroma_path.mkdir(parents=True, exist_ok=True)
logger.info("ChromaDB initialized at %s", chroma_path)
except Exception: self._client = chromadb.PersistentClient(path=str(chroma_path))
logger.exception("Failed to initialize ChromaDB") self._collection = self._get_or_create_collection()
self._client = None logger.info("ChromaDB initialized at %s", chroma_path)
self._collection = None except Exception:
logger.exception("Failed to initialize ChromaDB")
self._client = None
self._collection = None
def _get_or_create_collection(self): def _get_or_create_collection(self):
"""获取或创建 papers_embeddings collection。""" """获取或创建 papers_embeddings collection。"""
@@ -102,6 +120,8 @@ def _get_embedding(text: str) -> list[float] | None:
POST /v1/embeddings, model=EMBED_MODEL POST /v1/embeddings, model=EMBED_MODEL
校验返回向量长度 == EMBED_DIMENSIONS 校验返回向量长度 == EMBED_DIMENSIONS
失败时返回 None 并记录日志。 失败时返回 None 并记录日志。
纯远程 HTTP 调用、线程安全 —— 留在锁外,让多 worker 并行调。
""" """
if not settings.EMBED_API_BASE or not settings.EMBED_MODEL: if not settings.EMBED_API_BASE or not settings.EMBED_MODEL:
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding") logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
@@ -177,9 +197,11 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
Returns: Returns:
True 表示成功,False 表示失败或跳过。 True 表示成功,False 表示失败或跳过。
并发设计:远程 embedding 调用在锁外(多 worker 并行),chroma 集合访问
(含首次 init)在 ``_chroma._lock`` 内串行化。
""" """
col = get_collection() if not settings.CHROMA_ENABLED:
if col is None:
return False return False
try: try:
@@ -227,17 +249,25 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
logger.warning("Empty index text for %s, skip", paper_id) logger.warning("Empty index text for %s, skip", paper_id)
return False return False
vec = _get_embedding(index_text) vec = _get_embedding(index_text) # 远程 HTTP,锁外并行
if vec is None: if vec is None:
return False return False
col.upsert( with _chroma._lock: # 串行化集合访问(首次含 init
ids=[arxiv_id], col = _chroma.get_collection()
embeddings=[vec], if col is None:
metadatas=[ return False
{"arxiv_id": arxiv_id, "title_zh": title_zh, "paper_date": paper_date} col.upsert(
], ids=[arxiv_id],
) embeddings=[vec],
metadatas=[
{
"arxiv_id": arxiv_id,
"title_zh": title_zh,
"paper_date": paper_date,
}
],
)
logger.info("Indexed paper %s in ChromaDB", arxiv_id) logger.info("Indexed paper %s in ChromaDB", arxiv_id)
return True return True
@@ -255,17 +285,20 @@ def delete_paper(paper_id: str) -> bool:
Args: Args:
paper_id: arxiv_id paper_id: arxiv_id
""" """
col = get_collection() if not settings.CHROMA_ENABLED:
if col is None:
return False return False
try: with _chroma._lock:
col.delete(ids=[paper_id]) col = _chroma.get_collection()
logger.info("Deleted paper %s from ChromaDB", paper_id) if col is None:
return True return False
except Exception: try:
logger.exception("Failed to delete paper %s from ChromaDB", paper_id) col.delete(ids=[paper_id])
return False logger.info("Deleted paper %s from ChromaDB", paper_id)
return True
except Exception:
logger.exception("Failed to delete paper %s from ChromaDB", paper_id)
return False
# ── 相似查询 ──────────────────────────────────────────────────────────── # ── 相似查询 ────────────────────────────────────────────────────────────
@@ -280,21 +313,26 @@ def search_similar(query_text: str, top_k: int = 20) -> list[dict]:
Returns: Returns:
[{"arxiv_id": str, "distance": float}, ...] [{"arxiv_id": str, "distance": float}, ...]
并发设计:远程 embedding 在锁外,集合查询在 ``_chroma._lock`` 内。
""" """
col = get_collection() if not settings.CHROMA_ENABLED:
if col is None:
return [] return []
try: try:
vec = _get_embedding(query_text) vec = _get_embedding(query_text) # 远程 HTTP,锁外
if vec is None: if vec is None:
return [] return []
results = col.query( with _chroma._lock:
query_embeddings=[vec], col = _chroma.get_collection()
n_results=min(top_k, col.count()) if col.count() > 0 else top_k, if col is None:
include=["metadatas", "distances"], return []
) results = col.query(
query_embeddings=[vec],
n_results=min(top_k, col.count()) if col.count() > 0 else top_k,
include=["metadatas", "distances"],
)
if not results["ids"] or not results["ids"][0]: if not results["ids"] or not results["ids"][0]:
return [] return []
+15 -1
View File
@@ -148,7 +148,7 @@ async def run_job(db: Session, job_id: int) -> dict:
async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict: async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
from app.services.crawler import refresh_upvotes from app.services.crawler import recrawl_single, refresh_upvotes
from app.services.derived import reindex_chroma, reindex_fts from app.services.derived import reindex_chroma, reindex_fts
from app.services.pipeline import run_crawl, run_pipeline from app.services.pipeline import run_crawl, run_pipeline
from app.services.summarizer import summarize_batch, summarize_single from app.services.summarizer import summarize_batch, summarize_single
@@ -193,6 +193,20 @@ async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
return reindex_fts(db) return reindex_fts(db)
if job.type == "reindex_chroma": if job.type == "reindex_chroma":
return reindex_chroma(db) return reindex_chroma(db)
if job.type == "recrawl_one":
return await recrawl_single(db, payload["arxiv_id"])
if job.type == "recrawl_batch":
updated = 0
skipped = 0
results = []
for arxiv_id in payload.get("arxiv_ids", []):
res = await recrawl_single(db, arxiv_id)
results.append(res)
if res.get("updated"):
updated += 1
else:
skipped += 1
return {"updated": updated, "skipped": skipped, "results": results}
raise ValueError(f"Unsupported job type: {job.type}") raise ValueError(f"Unsupported job type: {job.type}")
+73 -11
View File
@@ -23,8 +23,10 @@ from __future__ import annotations
import json import json
import logging import logging
import threading
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
@@ -47,14 +49,18 @@ _FALLBACK_NAMES: dict[int, str] = {
8: "isolate_formula", 8: "isolate_formula",
9: "formula_caption", 9: "formula_caption",
} }
# 下游需 picture/table —— 按 class name 字符串动态匹配(不依赖 class index # 下游需 picture/table 及其 caption —— 按 class name 字符串动态匹配(不依赖 class index
# 规避 DocStructBench 不同发布的类别顺序差异) # 规避 DocStructBench 不同发布的类别顺序差异)
_PICTURE_NAMES = {"figure", "figure_group"} _PICTURE_NAMES = {"figure", "figure_group"}
_TABLE_NAMES = {"table", "table_group"} _TABLE_NAMES = {"table", "table_group"}
_FIGURE_CAPTION_NAMES = {"figure_caption"}
_TABLE_CAPTION_NAMES = {"table_caption"}
# letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降) # letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降)
_PAD_VALUE = 114 _PAD_VALUE = 114
# 最小 bbox 尺寸(PDF 点) # 最小 bbox 尺寸(PDF 点)
_MIN_BOX_SIZE = 20 _MIN_BOX_SIZE = 20
_MIN_CAPTION_BOX_WIDTH = 30
_MIN_CAPTION_BOX_HEIGHT = 6
# device → ExecutionProvider 映射 # device → ExecutionProvider 映射
_PROVIDER_MAP: dict[str, str] = { _PROVIDER_MAP: dict[str, str] = {
@@ -72,7 +78,7 @@ _AUTO_PRIORITY = ["cuda", "directml", "openvino", "cann", "tensorrt", "qnn"]
@dataclass @dataclass
class LayoutBox: class LayoutBox:
"""检测到的布局区域,坐标为 PDF 点boxclass ∈ {"picture", "table"}""" """检测到的布局区域,坐标为 PDF 点。"""
x0: float x0: float
y0: float y0: float
@@ -191,13 +197,17 @@ def _postprocess_output(
def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None: def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None:
"""按 class name 匹配 figure→picture / table→table,其余返回 None。""" """按 class name 匹配下游关心的布局类别,其余返回 None。"""
name = names.get(cls_id, "") name = names.get(cls_id, "")
n = name.strip().lower() n = name.strip().lower()
if n in _PICTURE_NAMES: if n in _PICTURE_NAMES:
return "picture" return "picture"
if n in _TABLE_NAMES: if n in _TABLE_NAMES:
return "table" return "table"
if n in _FIGURE_CAPTION_NAMES:
return "figure_caption"
if n in _TABLE_CAPTION_NAMES:
return "table_caption"
return None return None
@@ -220,15 +230,50 @@ def _parse_names_from_meta(session: ort.InferenceSession) -> dict[int, str]:
# ── 检测器单例 ────────────────────────────────────────────────────────── # ── 检测器单例 ──────────────────────────────────────────────────────────
class _LayoutDetector: class _Singleton(type):
"""单例:管理 ONNX InferenceSession 生命周期。""" """元类单例:``cls()`` 永远返回同一实例;``reset_instance()`` 清缓存以便重建。
生产代码只应在模块级 ``_detector = _LayoutDetector()`` 创建一次。任何第二处
``_LayoutDetector()`` 都会拿到同一实例(含同一 ONNX session + 同一锁),杜绝
并发推理时各建一份 session 导致内存峰值翻倍(8GB 机器崩溃根因)。双检锁保证
首次实例化线程安全。
"""
_instances: dict[type, Any] = {}
_lock = threading.Lock()
def __call__(cls, *args, **kwargs):
if cls in _Singleton._instances:
return _Singleton._instances[cls]
with _Singleton._lock:
if cls not in _Singleton._instances:
_Singleton._instances[cls] = super().__call__(*args, **kwargs)
return _Singleton._instances[cls]
class _LayoutDetector(metaclass=_Singleton):
"""强约束单例:管理 ONNX InferenceSession 生命周期。
由 ``_Singleton`` 元类保证全进程唯一实例 —— 重复 ``_LayoutDetector()`` 只会返回
已有实例(含已加载的 session 和锁),不会新建。``reset_instance()`` 清缓存,仅供
测试隔离用。
"""
def __init__(self) -> None: def __init__(self) -> None:
self._lock = threading.Lock()
self._session: ort.InferenceSession | None = None self._session: ort.InferenceSession | None = None
self._names: dict[int, str] = {} self._names: dict[int, str] = {}
self._input_name: str = "" self._input_name: str = ""
self._imgsz: int = settings.LAYOUT_IMGSZ self._imgsz: int = settings.LAYOUT_IMGSZ
@classmethod
def reset_instance(cls) -> None:
"""清空单例缓存,下次 ``_LayoutDetector()`` 重建新实例(含新锁 + 空 session)。
仅用于测试隔离 —— 生产代码永远不该调用(否则会丢掉已加载的模型 session)。
"""
_Singleton._instances.pop(cls, None)
def _init_session(self) -> ort.InferenceSession: def _init_session(self) -> ort.InferenceSession:
if self._session is not None: if self._session is not None:
return self._session return self._session
@@ -275,7 +320,7 @@ class _LayoutDetector:
self._imgsz = settings.LAYOUT_IMGSZ self._imgsz = settings.LAYOUT_IMGSZ
return self._session return self._session
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]: def _detect_page_impl(self, page: pymupdf.Page) -> list[LayoutBox]:
"""检测单页 PDF 的 figure / table 区域。 """检测单页 PDF 的 figure / table 区域。
流程: 流程:
@@ -323,21 +368,38 @@ class _LayoutDetector:
y0 = max(0.0, min(y0, page_h)) y0 = max(0.0, min(y0, page_h))
x1 = max(0.0, min(x1, page_w)) x1 = max(0.0, min(x1, page_w))
y1 = max(0.0, min(y1, page_h)) y1 = max(0.0, min(y1, page_h))
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE: if boxclass in ("figure_caption", "table_caption"):
continue if (x1 - x0) < _MIN_CAPTION_BOX_WIDTH or (
y1 - y0
) < _MIN_CAPTION_BOX_HEIGHT:
continue
else:
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
continue
result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass)) result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass))
return result return result
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
"""公共入口:加锁串行化推理。
# 模块级单例 包裹整段 _detect_page_impl(含 pixmap 渲染 + tensor 构造 + session.run),
保证同一时刻只有一个推理在跑——避免 SUMMARY_CONCURRENCY>1 时多个 to_thread
线程并发推理导致内存峰值翻倍(8GB 机器崩溃根因)。锁由 _detect_page_impl
间接保护 _init_session,首次加载也串行,杜绝并发各建一份 session。
"""
with self._lock:
return self._detect_page_impl(page)
# 模块级单例 —— 生产代码唯一的实例化点(_Singleton 元类保证不会再有第二个)
_detector = _LayoutDetector() _detector = _LayoutDetector()
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]: def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
"""检测 PDF 页面中的 figure / table 区域。 """检测 PDF 页面中的 figure / table / caption 区域。
Returns: Returns:
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table。 LayoutBox 列表,坐标为 PDF 点,仅含 picture/table 及其 caption
""" """
return _detector.detect_page(page) return _detector.detect_page(page)
+130 -23
View File
@@ -34,6 +34,8 @@ _CLUSTER_GAP = 15
_MIN_BOX_AREA = 2000 _MIN_BOX_AREA = 2000
# Phase 2: 搜索文本到 box 的最大匹配距离(单位: pt) # Phase 2: 搜索文本到 box 的最大匹配距离(单位: pt)
_LABEL_MATCH_DISTANCE = 100 _LABEL_MATCH_DISTANCE = 100
# DocLayout caption 与 figure/table 匹配的最大距离(单位: pt)
_CAPTION_MATCH_DISTANCE = 120
# ── Box 聚类 ───────────────────────────────────────────────────────── # ── Box 聚类 ─────────────────────────────────────────────────────────
@@ -53,6 +55,15 @@ class _BoxCluster:
self.boxclass = "table" if raw == "table-fallback" else raw self.boxclass = "table" if raw == "table-fallback" else raw
def _cluster_to_box(cluster: _BoxCluster) -> list[float]:
return [
round(float(cluster.x0), 1),
round(float(cluster.y0), 1),
round(float(cluster.x1), 1),
round(float(cluster.y1), 1),
]
def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]: def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
"""将相邻的同类型 box 合并为聚类。""" """将相邻的同类型 box 合并为聚类。"""
if not boxes: if not boxes:
@@ -92,6 +103,67 @@ def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
return [_BoxCluster(members) for members in groups.values()] return [_BoxCluster(members) for members in groups.values()]
def _caption_class_for_content(boxclass: str) -> str:
return "figure_caption" if boxclass == "picture" else "table_caption"
def _caption_distance(content: _BoxCluster, caption: _BoxCluster) -> float | None:
"""Return a spatial score for pairing a caption with a content box."""
h_overlap = min(content.x1, caption.x1) - max(content.x0, caption.x0)
min_width = min(content.x1 - content.x0, caption.x1 - caption.x0)
if min_width <= 0 or h_overlap < min_width * 0.25:
return None
if caption.y1 < content.y0:
v_gap = content.y0 - caption.y1
elif caption.y0 > content.y1:
v_gap = caption.y0 - content.y1
else:
v_gap = 0.0
return v_gap if v_gap <= _CAPTION_MATCH_DISTANCE else None
def _extract_caption_text(page, caption: _BoxCluster) -> str:
rect = pymupdf.Rect(caption.x0, caption.y0, caption.x1, caption.y1)
try:
text = page.get_text("text", clip=rect)
except Exception:
return ""
return " ".join(text.split())
def _match_captions(
page,
content_clusters: list[_BoxCluster],
caption_clusters: list[_BoxCluster],
) -> dict[int, tuple[_BoxCluster, str]]:
"""Match each content cluster to its nearest same-type DocLayout caption."""
matches: dict[int, tuple[_BoxCluster, str]] = {}
used_captions: set[int] = set()
candidates: list[tuple[float, int, int]] = []
for content_idx, content in enumerate(content_clusters):
wanted_caption_class = _caption_class_for_content(content.boxclass)
for caption_idx, caption in enumerate(caption_clusters):
if caption.boxclass != wanted_caption_class:
continue
dist = _caption_distance(content, caption)
if dist is not None:
candidates.append((dist, content_idx, caption_idx))
for _dist, content_idx, caption_idx in sorted(candidates):
if content_idx in matches or caption_idx in used_captions:
continue
text = _extract_caption_text(page, caption_clusters[caption_idx])
if not text:
continue
matches[content_idx] = (caption_clusters[caption_idx], text)
used_captions.add(caption_idx)
return matches
# ── Phase 1: 检测 + 渲染 ────────────────────────────────────────────── # ── Phase 1: 检测 + 渲染 ──────────────────────────────────────────────
@@ -102,14 +174,25 @@ def _render_box(
filename: str, filename: str,
cap_type: str, cap_type: str,
page_num: int, page_num: int,
caption: _BoxCluster | None = None,
) -> bool: ) -> bool:
"""渲染单个 box 区域并保存 JPEG,成功返回 True。""" """渲染单个 box 区域并保存 JPEG,成功返回 True。
若提供 caption,则将内容与 caption 区域合并后一起截取,
使同一张截图同时包含图/表及其标题文字。
"""
page_width = page.rect.width page_width = page.rect.width
x0, y0, x1, y1 = box.x0, box.y0, box.x1, box.y1
if caption is not None:
x0 = min(x0, caption.x0)
y0 = min(y0, caption.y0)
x1 = max(x1, caption.x1)
y1 = max(y1, caption.y1)
clip = pymupdf.Rect( clip = pymupdf.Rect(
max(0, box.x0 - _REGION_PADDING), max(0, x0 - _REGION_PADDING),
max(0, box.y0 - _REGION_PADDING), max(0, y0 - _REGION_PADDING),
min(page_width, box.x1 + _REGION_PADDING), min(page_width, x1 + _REGION_PADDING),
box.y1 + _REGION_PADDING, y1 + _REGION_PADDING,
) )
mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM) mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM)
try: try:
@@ -136,25 +219,31 @@ def _process_page(
fig_counter = 0 fig_counter = 0
tbl_counter = 0 tbl_counter = 0
# 收集本页的 table/picture box(跳过极小区域) # 收集本页的 table/picture box 与 caption box(跳过极小区域)
raw_boxes = [] raw_boxes = []
raw_caption_boxes = []
for box in page_boxes: for box in page_boxes:
if box.boxclass not in ("table", "table-fallback", "picture"):
continue
w = box.x1 - box.x0 w = box.x1 - box.x0
h = box.y1 - box.y0 h = box.y1 - box.y0
if w < 20 or h < 20 or w * h < _MIN_BOX_AREA: if box.boxclass in ("table", "table-fallback", "picture"):
continue if w < 20 or h < 20 or w * h < _MIN_BOX_AREA:
raw_boxes.append(box) continue
raw_boxes.append(box)
elif box.boxclass in ("figure_caption", "table_caption"):
if w < 30 or h < 6:
continue
raw_caption_boxes.append(box)
if not raw_boxes: if not raw_boxes:
return 0 return 0
# 聚类:将同一 figure/table 的碎片 box 合并 # 聚类:将同一 figure/table 的碎片 box 合并
clusters = _cluster_boxes(raw_boxes) clusters = _cluster_boxes(raw_boxes)
caption_clusters = _cluster_boxes(raw_caption_boxes)
caption_matches = _match_captions(page, clusters, caption_clusters)
extracted = 0 extracted = 0
for cluster in clusters: for cluster_idx, cluster in enumerate(clusters):
cap_type = "figure" if cluster.boxclass == "picture" else "table" cap_type = "figure" if cluster.boxclass == "picture" else "table"
if cap_type == "figure": if cap_type == "figure":
@@ -168,21 +257,33 @@ def _process_page(
continue continue
seen_labels.add(label) seen_labels.add(label)
caption_match = caption_matches.get(cluster_idx)
caption_cluster = caption_match[0] if caption_match else None
filename = f"{label.replace(' ', '_').lower()}.jpg" filename = f"{label.replace(' ', '_').lower()}.jpg"
if not _render_box(page, cluster, images_dest, filename, cap_type, page_num): if not _render_box(
page,
cluster,
images_dest,
filename,
cap_type,
page_num,
caption=caption_cluster,
):
continue continue
manifest[filename] = { info = {
"page": page_num, "page": page_num,
"type": cap_type, "type": cap_type,
"label": label, "label": label,
"box": [ "box": _cluster_to_box(cluster),
round(float(cluster.x0), 1),
round(float(cluster.y0), 1),
round(float(cluster.x1), 1),
round(float(cluster.y1), 1),
],
} }
if caption_match:
info["caption_text"] = caption_match[1][:500]
info["caption_box"] = _cluster_to_box(caption_cluster)
info["caption_source"] = "doclayout"
manifest[filename] = info
extracted += 1 extracted += 1
return extracted return extracted
@@ -446,14 +547,20 @@ def label_images_by_summary(
cap_type = info.get("type", "figure") cap_type = info.get("type", "figure")
# 读取 caption 文本(从 figures 列表) # 读取 caption 文本(从 figures 列表)
caption_text = "" summary_caption_text = ""
for fig in figures: for fig in figures:
if fig.get("id") == fig_id: if fig.get("id") == fig_id:
caption_text = fig.get("caption", "") summary_caption_text = fig.get("caption", "")
break break
info["label"] = fig_id info["label"] = fig_id
info["caption_text"] = caption_text[:200] if caption_text else "" existing_caption_text = info.get("caption_text", "")
if existing_caption_text and summary_caption_text:
info["summary_caption_text"] = summary_caption_text[:500]
else:
info["caption_text"] = (
summary_caption_text[:500] if summary_caption_text else ""
)
info.setdefault("figures" if cap_type == "figure" else "tables", []).append( info.setdefault("figures" if cap_type == "figure" else "tables", []).append(
fig_id fig_id
) )
+22 -2
View File
@@ -31,6 +31,7 @@ from app.services.summary_persister import (
_cleanup_old_images, _cleanup_old_images,
_handle_summary_failure, _handle_summary_failure,
_persist_summary, _persist_summary,
_run_post_processing,
) )
from app.utils import TMP_DIR, release_lock, truncate_error, utc_now from app.utils import TMP_DIR, release_lock, truncate_error, utc_now
@@ -115,12 +116,31 @@ async def _do_summarize_one(db: Session, paper: Paper, pdf_mode: str = "auto") -
_t3 = _time.monotonic() _t3 = _time.monotonic()
logger.info(" [%s] pi生成: %.2fs", arxiv_id, _t3 - _t2) logger.info(" [%s] pi生成: %.2fs", arxiv_id, _t3 - _t2)
quality = _persist_summary(db, paper, json_data, raw_output) quality, schema = _persist_summary(db, paper, json_data, raw_output)
_t4 = _time.monotonic() _t4 = _time.monotonic()
logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3) logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3)
# 后处理(图片提取 + ChromaDB 索引)搬到线程池跑,避免 CPU 密集推理冻结
# 事件循环。paper 字段在此(事件循环线程)提取成纯值再传入,规避 worker
# 线程跨线程访问 ORM 的 DetachedInstanceError。DocLayout 推理由单例的
# threading.Lock 串行化,并发 worker 不会同时压模型。
paper_meta = {
"title_en": paper.title_en or "",
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
}
_t5 = _time.monotonic()
try:
await asyncio.to_thread(_run_post_processing, arxiv_id, schema, paper_meta)
except Exception:
# 双保险:_run_post_processing 内部已 try/except,此处兜底,
# 确保后处理失败绝不影响已 DONE 的总结。
logger.warning("Post-processing error for %s", arxiv_id, exc_info=True)
_t6 = _time.monotonic()
logger.info(" [%s] 后处理(线程池): %.2fs", arxiv_id, _t6 - _t5)
logger.info( logger.info(
"✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0 "✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t6 - _t0
) )
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality} return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
+38 -22
View File
@@ -131,8 +131,12 @@ def _handle_summary_failure(
def _persist_summary( def _persist_summary(
db: Session, paper: Paper, json_data: dict, raw_output: str db: Session, paper: Paper, json_data: dict, raw_output: str
) -> str: ) -> tuple[str, SummarySchema]:
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality""" """Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 (quality, schema)。
后处理(图片提取/ChromaDB)不再在此函数内执行,由调用方搬到线程池,
以免阻塞事件循环。返回 schema 供调用方在线程池里跑后处理。
"""
import time as _time import time as _time
arxiv_id = paper.arxiv_id arxiv_id = paper.arxiv_id
@@ -165,21 +169,10 @@ def _persist_summary(
_t4 - _t3, _t4 - _t3,
) )
# 触发性增强(失败不影响总结) # 后处理(图片提取 + ChromaDB 索引)已上移到调用方 _do_summarize_one
_t5 = _time.monotonic() # 经 asyncio.to_thread 在线程池跑——DB session 必须留在事件循环线程,
_maybe_extract_images(arxiv_id, schema) # 而 CPU/IO 密集的后处理搬走才不冻结事件循环。
_t6 = _time.monotonic() return quality, schema
_maybe_index_chroma(arxiv_id, paper, schema)
_t7 = _time.monotonic()
logger.info(
" [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs",
arxiv_id,
_t6 - _t5,
_t7 - _t6,
)
return quality
# ── 清理 ──────────────────────────────────────────────────────────────── # ── 清理 ────────────────────────────────────────────────────────────────
@@ -226,21 +219,44 @@ def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True) logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None: def _maybe_index_chroma(arxiv_id: str, schema: SummarySchema, paper_meta: dict) -> None:
"""写入 ChromaDB 语义索引(失败不影响总结)。""" """写入 ChromaDB 语义索引(失败不影响总结)。
paper_meta 是调用方在事件循环线程从 ORM 提取的纯值(title_en/tags/paper_date),
规避此函数在线程池跑时跨线程访问 ORM 的 DetachedInstanceError 风险。
"""
try: try:
from app.services.embedder import index_paper from app.services.embedder import index_paper
texts_dict = { texts_dict = {
"arxiv_id": arxiv_id, "arxiv_id": arxiv_id,
"title_zh": schema.title_zh or "", "title_zh": schema.title_zh or "",
"title_en": paper.title_en or "", "title_en": paper_meta.get("title_en", ""),
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "", "tags": paper_meta.get("tags", ""),
"one_line": schema.one_line or "", "one_line": schema.one_line or "",
"motivation_problem": schema.motivation.problem or "", "motivation_problem": schema.motivation.problem or "",
"method_key_idea": schema.method.key_idea or "", "method_key_idea": schema.method.key_idea or "",
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "", "paper_date": paper_meta.get("paper_date", ""),
} }
index_paper(arxiv_id, texts_dict) index_paper(arxiv_id, texts_dict)
except Exception: except Exception:
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True) logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
def _run_post_processing(
arxiv_id: str, schema: SummarySchema, paper_meta: dict
) -> None:
"""线程池里跑的 CPU/IO 密集后处理(由 _do_summarize_one 经 asyncio.to_thread 调用)。
顺序与原 _persist_summary 内部一致:图片提取 → ChromaDB 索引。两者各自
try/except(失败不影响已成功的总结),此处再包一层做双保险。
"""
try:
_maybe_extract_images(arxiv_id, schema)
_maybe_index_chroma(arxiv_id, schema, paper_meta)
except Exception:
logger.warning(
"Post-processing failed for %s (summary already persisted)",
arxiv_id,
exc_info=True,
)
+3
View File
@@ -42,6 +42,9 @@
.status-success { background:var(--success-bg); color:#388e3c; } .status-success { background:var(--success-bg); color:#388e3c; }
.status-running { background:var(--info-bg); color:#1976d2; } .status-running { background:var(--info-bg); color:#1976d2; }
.status-failed { background:var(--danger-bg); color:var(--danger-bright); } .status-failed { background:var(--danger-bg); color:var(--danger-bright); }
.status-queued { background:#fff8e1; color:#8a6d3b; }
.status-stale { background:var(--border); color:var(--ink-muted); }
.task-reindex { background:#fff3e0; color:#e65100; }
.time-cell { white-space:nowrap; color:var(--ink-light); } .time-cell { white-space:nowrap; color:var(--ink-light); }
.error-cell { max-width:200px; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; color:var(--danger-bright); font-size:.8rem; } .error-cell { max-width:200px; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; color:var(--danger-bright); font-size:.8rem; }
+36 -1
View File
@@ -69,7 +69,8 @@
<span class="info-label">活跃任务</span> <span class="info-label">活跃任务</span>
<span class="info-value"> <span class="info-value">
{% for lock in stats.active_locks %} {% for lock in stats.active_locks %}
<span class="task-badge task-{{ lock.task }}">{{ lock.task }}</span> <span class="task-badge task-{{ lock.task }}" title="{{ lock.status }} · #{{ lock.id }}">{{ lock.task }}</span>
<button class="admin-action-btn admin-action-btn-sm admin-action-btn-danger" title="强制释放锁 #{{ lock.id }}" onclick="releaseLock({{ lock.id }})">🔓</button>
{% endfor %} {% endfor %}
</span> </span>
</div> </div>
@@ -118,6 +119,15 @@
<div class="info-row"><span class="info-label">数据库</span><span class="info-value">{{ stats.db_size }}</span></div> <div class="info-row"><span class="info-label">数据库</span><span class="info-value">{{ stats.db_size }}</span></div>
<div class="info-row"><span class="info-label">论文文件</span><span class="info-value">{{ stats.papers_size }}</span></div> <div class="info-row"><span class="info-label">论文文件</span><span class="info-value">{{ stats.papers_size }}</span></div>
<div class="info-row"><span class="info-label">临时文件</span><span class="info-value">{{ stats.tmp_size }}</span></div> <div class="info-row"><span class="info-label">临时文件</span><span class="info-value">{{ stats.tmp_size }}</span></div>
<div class="info-row">
<span class="info-label">搜索索引</span>
<span class="info-value">
<button class="admin-action-btn admin-action-btn-sm" onclick="rebuildIndexes('fts')">🔤 重建全文</button>
{% if stats.config_overview.chroma_enabled %}
<button class="admin-action-btn admin-action-btn-sm" onclick="rebuildIndexes('chroma')">🧠 重建语义</button>
{% endif %}
</span>
</div>
</div> </div>
<div class="summary-dist"> <div class="summary-dist">
<h3 class="section-subtitle">总结状态分布</h3> <h3 class="section-subtitle">总结状态分布</h3>
@@ -136,6 +146,19 @@
</div> </div>
</div> </div>
</div> </div>
<div class="admin-info-card">
<h2 class="admin-info-title">⚙️ 运行配置</h2>
<div class="admin-info-body">
<div class="info-row"><span class="info-label">总结后端</span><span class="info-value">{{ stats.config_overview.summary_backend }} · {{ stats.config_overview.summary_pdf_mode }} 模式</span></div>
<div class="info-row"><span class="info-label">并发/超时</span><span class="info-value">{{ stats.config_overview.summary_concurrency }} 并发 · {{ stats.config_overview.summary_timeout_seconds }}s · 重试 {{ stats.config_overview.summary_max_retries }}</span></div>
<div class="info-row"><span class="info-label">调度</span><span class="info-value">{{ '启用' if stats.config_overview.scheduler_enabled else '未启用' }} · {{ stats.config_overview.schedule_time }} · {{ stats.config_overview.app_workers }} worker</span></div>
<div class="info-row"><span class="info-label">语义搜索</span><span class="info-value">{{ '启用' if stats.config_overview.chroma_enabled else '未启用' }} · {{ stats.config_overview.embed_model }}</span></div>
<div class="info-row"><span class="info-label">抓取</span><span class="info-value">TOP {{ stats.config_overview.top_n }} · 投票刷新 {{ stats.config_overview.upvote_refresh_days }} 天</span></div>
<div class="info-row"><span class="info-label">布局模型</span><span class="info-value">{{ stats.config_overview.layout_model }}</span></div>
<div class="info-row"><span class="info-label">数据库</span><span class="info-value">{{ stats.config_overview.database_url }}</span></div>
<div class="info-row"><span class="info-label">嵌入密钥</span><span class="info-value">{{ '已配置' if stats.config_overview.api_key_configured else '未配置' }}</span></div>
</div>
</div>
</div> </div>
<div class="admin-section"> <div class="admin-section">
@@ -193,5 +216,17 @@
.then(data => { if (data) showToast(data.error ? "❌ " + data.error.substring(0,200) : `✅ 已刷新 ${data.updated || 0} 篇论文投票`); }) .then(data => { if (data) showToast(data.error ? "❌ " + data.error.substring(0,200) : `✅ 已刷新 ${data.updated || 0} 篇论文投票`); })
.catch(err => showToast("❌ 请求失败")); .catch(err => showToast("❌ 请求失败"));
} }
function releaseLock(lockId) {
fetch("/admin/locks/"+lockId+"/release", { method: "POST", headers: { "Content-Type": "application/json" } })
.then(r => { if (r.status===303||r.status===401) { window.location.href="/admin/login"; return; } return r.json(); })
.then(data => { if (data) showToast(data.error ? "❌ " + data.error.substring(0,200) : "✅ 已释放锁,刷新中…", {callback:function(){location.reload();}}); })
.catch(err => showToast("❌ 请求失败"));
}
function rebuildIndexes(target) {
fetch("/admin/rebuild-indexes", { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({target: target}) })
.then(r => { if (r.status===303||r.status===401) { window.location.href="/admin/login"; return; } return r.json(); })
.then(data => { if (data) showToast(data.error ? "❌ " + data.error.substring(0,200) : "✅ 重建任务已创建,可在任务页查看"); })
.catch(err => showToast("❌ 请求失败"));
}
</script> </script>
{% endblock %} {% endblock %}
+149
View File
@@ -0,0 +1,149 @@
{% extends "base.html" %}
{% block title %}任务监控 — HF Daily Papers{% endblock %}
{% set type_label = {"crawl_daily":"抓取","pipeline_daily":"流水线","summarize_batch":"批量总结","summarize_one":"单篇总结","refresh_upvotes":"刷新投票","delete_range":"删除","cleanup_tmp":"清理","reindex_fts":"重建全文","reindex_chroma":"重建语义","recrawl_one":"重抓","recrawl_batch":"批量重抓"} %}
{% set type_badge = {"crawl_daily":"task-crawl","pipeline_daily":"task-crawl","recrawl_one":"task-crawl","recrawl_batch":"task-crawl","refresh_upvotes":"task-crawl","summarize_batch":"task-summarize","summarize_one":"task-summarize","cleanup_tmp":"task-cleanup","delete_range":"task-delete","reindex_fts":"task-reindex","reindex_chroma":"task-reindex"} %}
{% set status_label = {"queued":"排队","running":"运行中","success":"成功","failed":"失败","stale":"已过期","cancelled":"已取消"} %}
{% set status_badge = {"queued":"status-queued","running":"status-running","success":"status-success","failed":"status-failed","stale":"status-stale","cancelled":"status-stale"} %}
{% macro fmt_duration(s) -%}
{%- if s is none %}-
{%- elif s < 60 %}{{ "%.0f"|format(s) }}s
{%- elif s < 3600 %}{{ (s // 60)|int }}m {{ (s % 60)|round|int }}s
{%- else %}{{ (s // 3600)|int }}h {{ ((s % 3600) // 60)|int }}m
{%- endif -%}
{%- endmacro %}
{% block content %}
<div class="admin-page">
{% set active = "jobs" %}{% include "partials/admin_subnav.html" %}
<h1 class="page-heading">🧰 任务监控</h1>
{% set _total = (status_counts.values() | sum) if status_counts else 0 %}
<div class="summary-stats-row">
<span class="summary-stat">总计 <strong>{{ _total }}</strong></span>
<span class="summary-stat summary-stat-pending">排队 <strong>{{ status_counts.get('queued', 0) }}</strong></span>
<span class="summary-stat">运行中 <strong>{{ status_counts.get('running', 0) }}</strong></span>
<span class="summary-stat summary-stat-done">成功 <strong>{{ status_counts.get('success', 0) }}</strong></span>
<span class="summary-stat summary-stat-failed">失败 <strong>{{ status_counts.get('failed', 0) + status_counts.get('stale', 0) }}</strong></span>
</div>
{% set statuses = [("all","全部"),("queued","排队"),("running","运行中"),("success","成功"),("failed","失败"),("stale","已过期")] %}
<div class="summary-filters">
<span class="summary-filter-label">状态:</span>
{% for key, label in statuses %}
<a class="filter-chip {{ 'active' if current_status == key else '' }}"
href="?status={{ key }}{% if current_type != 'all' %}&amp;type={{ current_type }}{% endif %}">{{ label }}
({% if key == 'all' %}{{ _total }}{% else %}{{ status_counts.get(key, 0) }}{% endif %})</a>
{% endfor %}
</div>
{% set types = [("crawl_daily","抓取"),("pipeline_daily","流水线"),("summarize_batch","批量总结"),("summarize_one","单篇总结"),("refresh_upvotes","刷新投票"),("recrawl_one","重抓"),("recrawl_batch","批量重抓"),("delete_range","删除"),("cleanup_tmp","清理"),("reindex_fts","重建全文"),("reindex_chroma","重建语义")] %}
<form method="get" class="summary-filters">
<span class="summary-filter-label">类型:</span>
<input type="hidden" name="status" value="{{ current_status }}" />
<select name="type" class="paper-filter-input" onchange="this.form.submit()">
<option value="all" {{ 'selected' if current_type == 'all' }}>全部类型</option>
{% for key, label in types %}
<option value="{{ key }}" {{ 'selected' if current_type == key }}>{{ label }}</option>
{% endfor %}
</select>
</form>
{% if jobs %}
<div class="admin-table-wrap">
<table class="admin-table admin-table-compact">
<thead>
<tr><th>ID</th><th>类型</th><th>状态</th><th>触发者</th><th>创建时间</th><th>耗时</th><th>操作</th></tr>
</thead>
<tbody>
{% for job in jobs %}
<tr>
<td>{{ job.id }}</td>
<td><span class="task-badge {{ type_badge.get(job.type, 'task-crawl') }}">{{ type_label.get(job.type, job.type) }}</span></td>
<td><span class="status-badge {{ status_badge.get(job.status, 'status-running') }}">{{ status_label.get(job.status, job.status) }}</span></td>
<td>{{ job.owner or '-' }}</td>
<td class="time-cell">{{ job.created_at.strftime('%Y-%m-%d %H:%M:%S') if job.created_at else '-' }}</td>
<td class="time-cell">{{ fmt_duration(job.duration_seconds) }}</td>
<td class="action-cell"><button class="action-btn-sm" title="详情" onclick="showJobDetail({{ job.id }})">📋</button></td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
{% set total_pages = ((total + per_page - 1) // per_page) if total else 1 %}
{% if total_pages > 1 %}
<div class="pagination">
{% if page > 1 %}<a class="page-btn" href="{{ pagination_url(page - 1) }}">← 上一页</a>{% endif %}
<span class="page-info">第 {{ page }} / {{ total_pages }} 页(共 {{ total }} 个)</span>
{% if page < total_pages %}<a class="page-btn" href="{{ pagination_url(page + 1) }}">下一页 →</a>{% endif %}
</div>
{% endif %}
{% else %}
<div class="empty-state">
<p>暂无任务记录</p>
<p class="hint">触发抓取、总结等操作后,任务会出现在这里。可在「详情」中查看阶段事件。</p>
</div>
{% endif %}
</div>
<!-- 任务详情 modal -->
<div class="confirm-overlay" id="job-detail-overlay" style="display:none;">
<div class="confirm-dialog" style="max-width:660px;max-height:85vh;overflow:auto;">
<h3 class="admin-info-title" id="job-detail-title">任务详情</h3>
<div id="job-detail-body"><p class="hint">加载中...</p></div>
<div class="confirm-actions">
<button class="confirm-btn confirm-btn-cancel" onclick="closeJobDetail()">关闭</button>
</div>
</div>
</div>
{% endblock %}
{% block scripts %}
<script>
const TYPE_LABEL = {"crawl_daily":"抓取","pipeline_daily":"流水线","summarize_batch":"批量总结","summarize_one":"单篇总结","refresh_upvotes":"刷新投票","delete_range":"删除","cleanup_tmp":"清理","reindex_fts":"重建全文","reindex_chroma":"重建语义","recrawl_one":"重抓","recrawl_batch":"批量重抓"};
const STATUS_LABEL = {"queued":"排队","running":"运行中","success":"成功","failed":"失败","stale":"已过期","cancelled":"已取消"};
function fmtTime(s){ return s ? s.replace('T',' ').slice(0,19) : '-'; }
function esc(s){ return String(s==null?'':s).replace(/[&<>"]/g, c=>({'&':'&amp;','<':'&lt;','>':'&gt;','"':'&quot;'}[c])); }
function eventBadge(s){ return {'success':'status-success','failed':'status-failed','started':'status-running','info':'status-queued'}[s] || 'status-queued'; }
function jobStatusBadge(s){ return {'success':'success','failed':'failed','running':'running','stale':'stale','cancelled':'stale','queued':'queued'}[s] || 'running'; }
function infoRow(label, val){ return '<div class="info-row"><span class="info-label">'+label+'</span><span class="info-value">'+val+'</span></div>'; }
function showJobDetail(id){
document.getElementById('job-detail-overlay').style.display='flex';
document.getElementById('job-detail-body').innerHTML='<p class="hint">加载中...</p>';
fetch('/admin/jobs/'+id)
.then(r=>{if(r.status===303||r.status===401){window.location.href='/admin/login';return;}return r.json();})
.then(d=>{ if(d) renderJobDetail(d); })
.catch(()=>{document.getElementById('job-detail-body').innerHTML='<p class="hint">加载失败</p>';});
}
function renderJobDetail(d){
let h='<div class="admin-info-body">';
h+=infoRow('ID', d.id);
h+=infoRow('类型', esc(TYPE_LABEL[d.type]||d.type));
h+=infoRow('状态', '<span class="status-badge status-'+jobStatusBadge(d.status)+'">'+esc(STATUS_LABEL[d.status]||d.status)+'</span>');
h+=infoRow('触发者', esc(d.owner||'-'));
h+=infoRow('创建', fmtTime(d.created_at));
h+=infoRow('开始', fmtTime(d.started_at));
h+=infoRow('完成', fmtTime(d.completed_at));
if(d.payload && Object.keys(d.payload).length) h+=infoRow('参数', '<code style="word-break:break-all;">'+esc(JSON.stringify(d.payload))+'</code>');
if(d.result) h+=infoRow('结果', '<code style="word-break:break-all;">'+esc(JSON.stringify(d.result))+'</code>');
if(d.error) h+=infoRow('错误', '<span class="error-cell" style="max-width:480px;">'+esc(d.error)+'</span>');
h+='</div>';
if(d.events && d.events.length){
h+='<h3 class="section-subtitle" style="margin-top:18px;">事件时间线</h3>';
h+='<div class="admin-table-wrap" style="max-height:220px;overflow:auto;"><table class="admin-table admin-table-compact"><thead><tr><th>阶段</th><th>状态</th><th>时间</th><th>消息</th></tr></thead><tbody>';
d.events.forEach(e=>{
h+='<tr><td>'+esc(e.stage)+'</td><td><span class="status-badge '+eventBadge(e.status)+'">'+esc(e.status)+'</span></td><td class="time-cell">'+fmtTime(e.created_at)+'</td><td class="error-cell" style="max-width:240px;">'+esc(e.message||'')+'</td></tr>';
});
h+='</tbody></table></div>';
}
document.getElementById('job-detail-body').innerHTML=h;
}
function closeJobDetail(){ document.getElementById('job-detail-overlay').style.display='none'; }
document.addEventListener('keydown',e=>{if(e.key==='Escape')closeJobDetail();});
</script>
{% endblock %}
+16
View File
@@ -109,6 +109,22 @@
<span class="summary-stat summary-stat-failed">失败 <strong>{{ summary_failed or 0 }}</strong></span> <span class="summary-stat summary-stat-failed">失败 <strong>{{ summary_failed or 0 }}</strong></span>
<span class="summary-stat summary-stat-done">已完成 <strong>{{ summary_done or 0 }}</strong></span> <span class="summary-stat summary-stat-done">已完成 <strong>{{ summary_done or 0 }}</strong></span>
</div> </div>
{% if failure_breakdown %}
<div class="summary-dist" style="margin-top:12px;">
<h3 class="section-subtitle">失败原因分布({{ summary_failed or 0 }} 篇)</h3>
<div class="summary-dist-bars">
{% set fb_total = (failure_breakdown | map(attribute='count') | sum) or 1 %}
{% set error_labels = {"pdf_download_failed":"PDF下载失败","timeout":"超时","process_error":"进程错误","json_not_found":"JSON缺失","json_invalid":"JSON无效","field_missing":"字段缺失","schema_error":"结构错误","unknown":"未分类"} %}
{% for item in failure_breakdown %}
<div class="dist-row">
<span class="dist-label">{{ error_labels.get(item.error_type, item.error_type) }}</span>
<div class="dist-bar-wrap"><div class="dist-bar dist-bar-failed" style="width:{{ (item.count / fb_total * 100)|round(1) }}%"></div></div>
<span class="dist-count">{{ item.count }}</span>
</div>
{% endfor %}
</div>
</div>
{% endif %}
<div id="summary-list" <div id="summary-list"
hx-get="/admin/summary-status" hx-get="/admin/summary-status"
hx-trigger="load" hx-trigger="load"
+17
View File
@@ -29,6 +29,7 @@
<option value="title_asc" {% if current_sort == 'title_asc' %}selected{% endif %}>标题 A→Z</option> <option value="title_asc" {% if current_sort == 'title_asc' %}selected{% endif %}>标题 A→Z</option>
</select> </select>
<button type="submit" class="paper-search-btn">搜索</button> <button type="submit" class="paper-search-btn">搜索</button>
<a class="admin-action-btn admin-action-btn-sm" href="/admin/papers/export.csv{% if request.query_params %}?{{ request.query_params }}{% endif %}">⬇ 导出 CSV</a>
</div> </div>
</form> </form>
@@ -37,6 +38,7 @@
<span class="paper-batch-label">批量操作</span> <span class="paper-batch-label">批量操作</span>
<span class="paper-selected-count" id="selected-count">已选 0 篇</span> <span class="paper-selected-count" id="selected-count">已选 0 篇</span>
<button class="admin-action-btn admin-action-btn-sm" onclick="batchAction('summarize')" id="batch-summarize-btn" disabled>📝 批量总结</button> <button class="admin-action-btn admin-action-btn-sm" onclick="batchAction('summarize')" id="batch-summarize-btn" disabled>📝 批量总结</button>
<button class="admin-action-btn admin-action-btn-sm" onclick="batchAction('recrawl')" id="batch-recrawl-btn" disabled>🔄 批量重抓</button>
<button class="admin-action-btn admin-action-btn-sm admin-action-btn-danger" onclick="batchAction('delete')" id="batch-delete-btn" disabled>🗑 批量删除</button> <button class="admin-action-btn admin-action-btn-sm admin-action-btn-danger" onclick="batchAction('delete')" id="batch-delete-btn" disabled>🗑 批量删除</button>
</div> </div>
@@ -72,6 +74,7 @@
</td> </td>
<td class="action-cell"> <td class="action-cell">
<button class="action-btn-sm" title="重新总结" onclick="retryOne('{{ paper.arxiv_id }}', this)"></button> <button class="action-btn-sm" title="重新总结" onclick="retryOne('{{ paper.arxiv_id }}', this)"></button>
<button class="action-btn-sm" title="重新抓取元数据" onclick="recrawlOne('{{ paper.arxiv_id }}', this)">🔄</button>
<button class="action-btn-sm action-btn-danger" title="删除" onclick="confirmDeleteSingle('{{ paper.arxiv_id }}', '{{ (paper.title_zh or paper.title_en)[:40] | replace("'", "\\'") }}')">🗑</button> <button class="action-btn-sm action-btn-danger" title="删除" onclick="confirmDeleteSingle('{{ paper.arxiv_id }}', '{{ (paper.title_zh or paper.title_en)[:40] | replace("'", "\\'") }}')">🗑</button>
</td> </td>
</tr> </tr>
@@ -124,6 +127,7 @@
const n=document.querySelectorAll('.paper-check:checked').length; const n=document.querySelectorAll('.paper-check:checked').length;
document.getElementById('selected-count').textContent='已选 '+n+' 篇'; document.getElementById('selected-count').textContent='已选 '+n+' 篇';
document.getElementById('batch-summarize-btn').disabled=n===0; document.getElementById('batch-summarize-btn').disabled=n===0;
document.getElementById('batch-recrawl-btn').disabled=n===0;
document.getElementById('batch-delete-btn').disabled=n===0; document.getElementById('batch-delete-btn').disabled=n===0;
} }
function retryOne(arxivId,btn) { function retryOne(arxivId,btn) {
@@ -134,6 +138,14 @@
.catch(()=>showToast('❌ 请求失败')) .catch(()=>showToast('❌ 请求失败'))
.finally(()=>{btn.disabled=false;btn.textContent='↻';}); .finally(()=>{btn.disabled=false;btn.textContent='↻';});
} }
function recrawlOne(arxivId,btn) {
btn.disabled=true;btn.textContent='...';
fetch('/admin/paper-recrawl/'+arxivId,{method:'POST',headers:{'Content-Type':'application/json'}})
.then(r=>r.json())
.then(data=>showToast(data.error?'❌ '+data.error.substring(0,100):'✅ 重抓任务已创建,可在任务页查看'))
.catch(()=>showToast('❌ 请求失败'))
.finally(()=>{btn.disabled=false;btn.textContent='🔄';});
}
function confirmDeleteSingle(arxivId,title) { function confirmDeleteSingle(arxivId,title) {
document.getElementById('confirm-msg').textContent='确定删除论文「'+title+'」?此操作不可恢复。'; document.getElementById('confirm-msg').textContent='确定删除论文「'+title+'」?此操作不可恢复。';
_confirmAction='delete-single'; _confirmTarget=arxivId; _confirmAction='delete-single'; _confirmTarget=arxivId;
@@ -151,6 +163,11 @@
.then(r=>r.json()) .then(r=>r.json())
.then(data=>showToast(data.error?'❌ '+data.error.substring(0,100):'✅ 已提交批量总结')) .then(data=>showToast(data.error?'❌ '+data.error.substring(0,100):'✅ 已提交批量总结'))
.catch(()=>showToast('❌ 请求失败')); .catch(()=>showToast('❌ 请求失败'));
} else if(action==='recrawl'){
fetch('/admin/papers-batch-action',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({action:'recrawl',arxiv_ids:ids})})
.then(r=>r.json())
.then(data=>showToast(data.error?'❌ '+data.error.substring(0,100):'✅ 已提交批量重抓,可在任务页查看'))
.catch(()=>showToast('❌ 请求失败'));
} }
} }
function doConfirmAction() { function doConfirmAction() {
-2
View File
@@ -29,8 +29,6 @@
<a href="/reading-list">阅读列表</a> <a href="/reading-list">阅读列表</a>
{% if is_admin %} {% if is_admin %}
<a href="/admin/">管理</a> <a href="/admin/">管理</a>
<a href="/admin/logout" onclick="event.preventDefault();this.closest('form').submit()">退出</a>
<form action="/admin/logout" method="post" style="display:none"></form>
{% else %} {% else %}
<a href="/admin/login">管理</a> <a href="/admin/login">管理</a>
{% endif %} {% endif %}
+2 -1
View File
@@ -1,7 +1,8 @@
{# Admin subnav — 管理后台三个页面共享。active 参数: "dashboard" / "papers" / "logs" #} {# Admin subnav — 管理后台共享。active 参数: "dashboard" / "papers" / "jobs" / "logs" #}
<nav class="admin-subnav"> <nav class="admin-subnav">
<a href="/admin/" class="admin-subnav-link {{ 'active' if active == 'dashboard' else '' }}">仪表盘</a> <a href="/admin/" class="admin-subnav-link {{ 'active' if active == 'dashboard' else '' }}">仪表盘</a>
<a href="/admin/papers" class="admin-subnav-link {{ 'active' if active == 'papers' else '' }}">论文管理</a> <a href="/admin/papers" class="admin-subnav-link {{ 'active' if active == 'papers' else '' }}">论文管理</a>
<a href="/admin/jobs" class="admin-subnav-link {{ 'active' if active == 'jobs' else '' }}">任务</a>
<a href="/admin/logs" class="admin-subnav-link {{ 'active' if active == 'logs' else '' }}">日志</a> <a href="/admin/logs" class="admin-subnav-link {{ 'active' if active == 'logs' else '' }}">日志</a>
<span class="admin-subnav-spacer"></span> <span class="admin-subnav-spacer"></span>
<form action="/admin/logout" method="post" class="admin-subnav-form"> <form action="/admin/logout" method="post" class="admin-subnav-form">
+20
View File
@@ -24,6 +24,26 @@ from app.models import (
from app.utils import utc_now from app.utils import utc_now
# ── ChromaDB 隔离(autouse,所有测试)──────────────────────────────────
@pytest.fixture(autouse=True)
def _isolate_chroma(monkeypatch, tmp_path):
"""所有测试把 ChromaDB 隔离到临时目录 + 重置单例,绝不污染 data/chroma。
与内存 DB 隔离同理:summarize 后处理经真实 _maybe_index_chroma → index_paper
写入,不隔离会把测试夹具(2401.*)泄漏到生产 data/chroma,污染语义搜索。
每个测试前重置 _chroma 单例,确保 CHROMA_DIR 指向本次 tmp。
"""
import app.services.embedder as emb
from app.config import settings
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
emb._chroma.reset()
yield
emb._chroma.reset()
# ── 内存数据库 ────────────────────────────────────────────────────────── # ── 内存数据库 ──────────────────────────────────────────────────────────
+380
View File
@@ -0,0 +1,380 @@
"""管理后台新功能测试 — 任务监控、锁释放、重抓、失败分布、配置、导出、重建索引。"""
from __future__ import annotations
import pytest
from sqlalchemy import select
from unittest.mock import AsyncMock, patch
from app.models import Job, JobStatus, SummaryState, SummaryStatus, TaskLock
from app.services import admin as admin_svc
from app.services.crawler import recrawl_single
from app.services.jobs import create_job, run_job
from app.utils import utc_now
@pytest.fixture
def no_enqueue(monkeypatch):
"""禁用路由层的 enqueue_job,阻止 background task 在测试中真实执行。"""
from app.routes import admin as admin_route
monkeypatch.setattr(admin_route, "enqueue_job", lambda *a, **k: None)
# ═══════════════════════════════════════════════════════════════════════
# 任务监控
# ═══════════════════════════════════════════════════════════════════════
def _make_job(db_session, *, type="crawl_daily", status=JobStatus.QUEUED, owner="t"):
job = Job(
type=type,
status=status,
owner=owner,
payload_json="{}",
created_at=utc_now(),
)
db_session.add(job)
db_session.commit()
return job
def test_query_jobs_filter_and_pagination(db_session):
for i in range(25):
_make_job(db_session, status=JobStatus.SUCCESS)
for i in range(5):
_make_job(db_session, status=JobStatus.FAILED)
# 无过滤:分页
page1, total = admin_svc.query_jobs(db_session, page=1, per_page=20)
assert total == 30
assert len(page1) == 20
page2, _ = admin_svc.query_jobs(db_session, page=2, per_page=20)
assert len(page2) == 10
# status 过滤
failed, ftotal = admin_svc.query_jobs(db_session, status="failed", per_page=50)
assert ftotal == 5
assert len(failed) == 5
assert all(j["status"] == "failed" for j in failed)
# type 过滤
_make_job(db_session, type="reindex_fts", status=JobStatus.QUEUED)
typed, ttotal = admin_svc.query_jobs(db_session, job_type="reindex_fts")
assert ttotal == 1
assert typed[0]["type"] == "reindex_fts"
def test_serialize_job_includes_duration(db_session):
job = _make_job(db_session, status=JobStatus.SUCCESS)
job.started_at = utc_now()
job.completed_at = utc_now()
db_session.commit()
serialized = admin_svc.serialize_job(job)
assert serialized["duration_seconds"] is not None
assert serialized["duration_seconds"] >= 0
def test_serialize_job_running_without_completed(db_session):
# 运行中的 jobcompleted_at=Nonestarted_at 经 db 读回为 naive UTC
# 不能与 aware 的 utc_now() 直接相减(回归测试)。
job = _make_job(db_session, status=JobStatus.RUNNING)
job.started_at = utc_now()
db_session.commit()
serialized = admin_svc.serialize_job(job)
assert serialized["duration_seconds"] is not None
assert serialized["duration_seconds"] >= 0
def test_get_job_status_counts(db_session):
_make_job(db_session, status=JobStatus.QUEUED)
_make_job(db_session, status=JobStatus.QUEUED)
_make_job(db_session, status=JobStatus.RUNNING)
counts = admin_svc.get_job_status_counts(db_session)
assert counts.get("queued") == 2
assert counts.get("running") == 1
# ═══════════════════════════════════════════════════════════════════════
# 锁释放
# ═══════════════════════════════════════════════════════════════════════
def test_force_release_lock(db_session):
running = TaskLock(
task="crawl", lock_key="k1", status="running", acquired_at=utc_now()
)
stale = TaskLock(task="crawl", lock_key="k2", status="stale", acquired_at=utc_now())
finished = TaskLock(
task="crawl", lock_key="k3", status="finished", acquired_at=utc_now()
)
db_session.add_all([running, stale, finished])
db_session.commit()
assert admin_svc.force_release_lock(db_session, running.id) is True
assert admin_svc.force_release_lock(db_session, stale.id) is True
# finished 的不应被再次释放
assert admin_svc.force_release_lock(db_session, finished.id) is False
# 不存在的 id
assert admin_svc.force_release_lock(db_session, 999999) is False
db_session.refresh(running)
db_session.refresh(stale)
db_session.refresh(finished)
assert running.status == "finished"
assert running.released_at is not None
assert stale.status == "finished"
assert finished.status == "finished"
# ═══════════════════════════════════════════════════════════════════════
# 失败原因分布
# ═══════════════════════════════════════════════════════════════════════
def test_get_failure_breakdown(db_session, sample_papers_range):
statuses = (
db_session.execute(select(SummaryStatus).order_by(SummaryStatus.id))
.scalars()
.all()
)
statuses[0].status = SummaryState.FAILED
statuses[0].error_type = "pdf_download_failed"
statuses[1].status = SummaryState.FAILED
statuses[1].error_type = "timeout"
statuses[2].status = SummaryState.PERMANENT_FAILURE
statuses[2].error_type = None # 归 unknown
db_session.commit()
breakdown = admin_svc.get_failure_breakdown(db_session)
by_type = {b["error_type"]: b["count"] for b in breakdown}
assert by_type.get("pdf_download_failed") == 1
assert by_type.get("timeout") == 1
assert by_type.get("unknown") == 1
# 降序
counts = [b["count"] for b in breakdown]
assert counts == sorted(counts, reverse=True)
# ═══════════════════════════════════════════════════════════════════════
# 配置概览
# ═══════════════════════════════════════════════════════════════════════
def test_get_config_overview_no_secrets():
cfg = admin_svc.get_config_overview()
assert "summary_backend" in cfg
assert "schedule_time" in cfg
assert "api_key_configured" in cfg # 只标是否配置,不显值
text = str(cfg)
# 不应泄露默认密钥值
assert "change-me" not in text
# ═══════════════════════════════════════════════════════════════════════
# 单篇/批量重抓
# ═══════════════════════════════════════════════════════════════════════
class TestRecrawl:
@pytest.mark.asyncio
async def test_not_found(self, db_session):
res = await recrawl_single(db_session, "9999.99999")
assert res["updated"] is False
assert res["reason"] == "not_found"
@pytest.mark.asyncio
async def test_updates_full_metadata(self, db_session, sample_paper):
new_item = {
"paper": {
"id": sample_paper.arxiv_id,
"title": "Updated Title",
"abstract": "New abstract",
"publishedAt": "2024-01-15T00:00:00",
"authors": [{"name": "New Author"}],
"tags": [{"name": "CV"}, {"name": "Diffusion"}],
"upvotes": 100,
}
}
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
return_value=[new_item],
):
res = await recrawl_single(db_session, sample_paper.arxiv_id)
assert res["updated"] is True
db_session.refresh(sample_paper)
assert sample_paper.title_en == "Updated Title"
assert sample_paper.abstract == "New abstract"
assert sample_paper.upvotes == 100
# authors 重建(原 Alice/Bob → New Author
assert sorted(a.name for a in sample_paper.authors) == ["New Author"]
# tags 重建(原 NLP/LLM → CV/Diffusion
assert sorted(t.tag for t in sample_paper.tags) == ["CV", "Diffusion"]
@pytest.mark.asyncio
async def test_not_in_daily(self, db_session, sample_paper):
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
return_value=[],
):
res = await recrawl_single(db_session, sample_paper.arxiv_id)
assert res["updated"] is False
assert res["reason"] == "not_in_daily"
assert "date" in res
class TestDispatchRecrawl:
@pytest.mark.asyncio
async def test_recrawl_one_via_run_job(self, db_session, sample_paper):
new_item = {
"paper": {
"id": sample_paper.arxiv_id,
"title": "Via Job",
"authors": [],
"tags": [],
"upvotes": 5,
}
}
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
return_value=[new_item],
):
job = create_job(
db_session,
"recrawl_one",
owner="test",
payload={"arxiv_id": sample_paper.arxiv_id},
)
result = await run_job(db_session, job.id)
assert result["updated"] is True
db_session.refresh(sample_paper)
assert sample_paper.title_en == "Via Job"
@pytest.mark.asyncio
async def test_recrawl_batch_via_run_job(self, db_session, sample_papers_range):
arxiv_ids = [p.arxiv_id for p in sample_papers_range[:2]]
items = [
{
"paper": {
"id": aid,
"title": "Batch " + aid,
"authors": [],
"tags": [],
"upvotes": 1,
}
}
for aid in arxiv_ids
]
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
side_effect=lambda d: items,
):
job = create_job(
db_session,
"recrawl_batch",
owner="test",
payload={"arxiv_ids": arxiv_ids},
)
result = await run_job(db_session, job.id)
assert result["updated"] == 2
assert result["skipped"] == 0
# ═══════════════════════════════════════════════════════════════════════
# 路由
# ═══════════════════════════════════════════════════════════════════════
class TestRoutes:
def test_jobs_page_renders(self, auth_client):
resp = auth_client.get("/admin/jobs")
assert resp.status_code == 200
assert "任务监控" in resp.text
def test_jobs_page_filters_by_status(self, auth_client, db_session):
_make_job(db_session, status=JobStatus.FAILED)
resp = auth_client.get("/admin/jobs?status=failed")
assert resp.status_code == 200
def test_export_csv(self, auth_client, sample_papers_range):
resp = auth_client.get("/admin/papers/export.csv")
assert resp.status_code == 200
assert "text/csv" in resp.headers["content-type"]
# UTF-8 BOM for Excel
assert resp.content.startswith(b"\xef\xbb\xbf")
# 表头 + 数据
assert "arxiv_id" in resp.text
assert "2401.10001" in resp.text
def test_export_csv_respects_filter(self, auth_client, sample_papers_range):
resp = auth_client.get("/admin/papers/export.csv?q=Paper%203")
assert resp.status_code == 200
assert "2401.10003" in resp.text
assert "2401.10001" not in resp.text
def test_rebuild_indexes_fts(self, auth_client, db_session, no_enqueue):
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "fts"})
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "queued"
assert len(data["job_ids"]) == 1
jobs = (
db_session.execute(select(Job).where(Job.type == "reindex_fts"))
.scalars()
.all()
)
assert len(jobs) == 1
def test_rebuild_indexes_both(self, auth_client, db_session, no_enqueue):
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "both"})
assert resp.status_code == 200
data = resp.json()
assert len(data["job_ids"]) == 2
def test_release_lock_route(self, auth_client, db_session):
lock = TaskLock(
task="crawl", lock_key="rt", status="running", acquired_at=utc_now()
)
db_session.add(lock)
db_session.commit()
resp = auth_client.post(f"/admin/locks/{lock.id}/release")
assert resp.status_code == 200
db_session.refresh(lock)
assert lock.status == "finished"
def test_paper_recrawl_route(
self, auth_client, sample_paper, db_session, no_enqueue
):
resp = auth_client.post(f"/admin/paper-recrawl/{sample_paper.arxiv_id}")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "queued"
jobs = (
db_session.execute(select(Job).where(Job.type == "recrawl_one"))
.scalars()
.all()
)
assert len(jobs) == 1
def test_batch_recrawl_route(
self, auth_client, sample_papers_range, db_session, no_enqueue
):
ids = [p.arxiv_id for p in sample_papers_range[:3]]
resp = auth_client.post(
"/admin/papers-batch-action", json={"action": "recrawl", "arxiv_ids": ids}
)
assert resp.status_code == 200
assert resp.json()["status"] == "queued"
jobs = (
db_session.execute(select(Job).where(Job.type == "recrawl_batch"))
.scalars()
.all()
)
assert len(jobs) == 1
+83
View File
@@ -2,6 +2,8 @@
from __future__ import annotations from __future__ import annotations
import threading
import time
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@@ -154,3 +156,84 @@ class TestEmbeddingApi:
) )
result = emb._get_embedding("test") result = emb._get_embedding("test")
assert result is None assert result is None
# ═══════════════════════════════════════════════════════════════════════
# 并发安全:init() 双重检查锁 + 集合访问串行化
# ═══════════════════════════════════════════════════════════════════════
class TestEmbedderConcurrency:
"""后处理经 asyncio.to_thread 多 worker 并发调 index_paper 的安全性。"""
def test_init_serialized_under_concurrency(self, monkeypatch, tmp_path):
"""并发 init() 只调一次 PersistentClientchromadb SharedSystemClient 缓存竞争修复)。
复现崩坏条件:10 线程同时 init()fake PersistentClient 故意 sleep 拉长建连窗口。
修复前会有多线程同时进入 _create_system_if_not_exists → 并发 mutate 类级缓存;
修复后(双重检查锁)只有抢到锁的那个线程建连。
"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
import app.services.embedder as emb
emb._chroma.reset()
counter = {"n": 0}
counter_lock = threading.Lock()
def fake_persistent_client(path):
with counter_lock:
counter["n"] += 1
time.sleep(0.05) # 拉长建连窗口,放大并发竞争
client = MagicMock()
client.get_collection.side_effect = Exception(
"not exist"
) # 触发 create 路径
client.create_collection.return_value = MagicMock()
return client
with patch("chromadb.PersistentClient", side_effect=fake_persistent_client):
threads = [threading.Thread(target=emb._chroma.init) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert counter["n"] == 1, f"PersistentClient 应只调一次,实际 {counter['n']}"
assert emb._chroma._client is not None
emb._chroma.reset()
def test_index_paper_concurrent_no_error(self, monkeypatch, tmp_path):
"""并发 index_paperembedding 锁外并行,集合写入串行化,全部成功。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
import app.services.embedder as emb
emb._chroma.reset()
# 跳过 init,直接注入 mock collection
emb._chroma._client = MagicMock()
col = MagicMock()
col.count.return_value = 0
emb._chroma._collection = col
with patch.object(emb, "_get_embedding", return_value=[0.1, 0.2, 0.3]):
errors: list[BaseException] = []
def worker(i: int) -> None:
try:
emb.index_paper(
f"id-{i}", {"arxiv_id": f"id-{i}", "title_zh": f"标题{i}"}
)
except BaseException as exc: # noqa: BLE001 — 收集所有错误
errors.append(exc)
threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert errors == []
assert col.upsert.call_count == 10
emb._chroma.reset()
+160 -4
View File
@@ -7,6 +7,8 @@
from __future__ import annotations from __future__ import annotations
import json import json
import threading
import time
from unittest.mock import MagicMock from unittest.mock import MagicMock
import numpy as np import numpy as np
@@ -166,10 +168,10 @@ class TestClassMapping:
def test_table(self): def test_table(self):
assert _map_class_to_boxclass(5, {5: "table"}) == "table" assert _map_class_to_boxclass(5, {5: "table"}) == "table"
def test_caption_ignored(self): def test_caption_classes(self):
names = {4: "figure_caption", 6: "table_caption"} names = {4: "figure_caption", 6: "table_caption"}
assert _map_class_to_boxclass(4, names) is None assert _map_class_to_boxclass(4, names) == "figure_caption"
assert _map_class_to_boxclass(6, names) is None assert _map_class_to_boxclass(6, names) == "table_caption"
def test_other_classes_ignored(self): def test_other_classes_ignored(self):
names = {0: "title", 1: "plain text", 2: "abandon", 8: "isolate_formula"} names = {0: "title", 1: "plain text", 2: "abandon", 8: "isolate_formula"}
@@ -240,9 +242,11 @@ class TestPostprocessOutput:
class TestDetectPage: class TestDetectPage:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _reset_detector(self): def _reset_detector(self):
"""每个测试前重置模块级单例,避免复用上个测试的 mock session。""" """每个测试前重建单例(带新锁 + 空 session),避免复用上个测试的 mock session。"""
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector() mod._detector = mod._LayoutDetector()
yield yield
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector() mod._detector = mod._LayoutDetector()
@staticmethod @staticmethod
@@ -338,6 +342,20 @@ class TestDetectPage:
assert len(boxes) == 1 assert len(boxes) == 1
assert boxes[0].boxclass == "table" assert boxes[0].boxclass == "table"
def test_returns_caption_box_with_small_height(self, monkeypatch, tmp_path):
names = {4: "figure_caption"}
sess, (pw, ph) = self._build_mock_session(
595, 842, [(4, 100, 405, 300, 417, 0.9)], names
)
self._setup(monkeypatch, tmp_path, sess)
page = self._make_mock_page(595, 842, pw, ph)
boxes = detect_page_layout(page)
assert len(boxes) == 1
assert boxes[0].boxclass == "figure_caption"
assert boxes[0].y1 - boxes[0].y0 == pytest.approx(12, abs=1.0)
def test_filters_low_confidence(self, monkeypatch, tmp_path): def test_filters_low_confidence(self, monkeypatch, tmp_path):
names = {3: "figure"} names = {3: "figure"}
# conf=0.1 < LAYOUT_THRESHOLD(0.2) → 过滤 # conf=0.1 < LAYOUT_THRESHOLD(0.2) → 过滤
@@ -401,3 +419,141 @@ class TestDetectPage:
boxes = detect_page_layout(page) boxes = detect_page_layout(page)
assert len(boxes) == 1 assert len(boxes) == 1
assert boxes[0].boxclass == "picture" assert boxes[0].boxclass == "picture"
# ═══════════════════════════════════════════════════════════════════════
# 并发安全:锁串行化推理 + 单例 session 只初始化一次
# ═══════════════════════════════════════════════════════════════════════
class TestDetectPageConcurrency:
"""锁包裹整段 detect_page 后,并发调用的安全性。"""
@pytest.fixture(autouse=True)
def _reset_detector(self):
"""重建单例(带新锁),避免跨测试锁状态污染。"""
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector()
yield
mod._LayoutDetector.reset_instance()
mod._detector = mod._LayoutDetector()
@staticmethod
def _build_mock_session(page_w, page_h, boxes, names):
"""同 TestDetectPage._build_mock_session,额外返回 fake_output 供 side_effect。"""
ratio = _compute_render_geometry(page_w, page_h, IMGSZ)
pix_w, pix_h = round(page_w * ratio), round(page_h * ratio)
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
rows = []
for cls_id, x0, y0, x1, y1, conf in boxes:
rows.append(
[
x0 * ratio + dw,
y0 * ratio + dh,
x1 * ratio + dw,
y1 * ratio + dh,
conf,
cls_id,
]
)
fake_output = (
np.array([rows], dtype=np.float32)
if rows
else np.zeros((1, 0, 6), dtype=np.float32)
)
sess = MagicMock()
inp = MagicMock()
inp.name = "images"
sess.get_inputs.return_value = [inp]
sess.run.return_value = [fake_output]
sess.get_providers.return_value = ["CPUExecutionProvider"]
meta = MagicMock()
meta.custom_metadata_map = {
"names": json.dumps({str(k): v for k, v in names.items()})
}
sess.get_modelmeta.return_value = meta
return sess, (pix_w, pix_h), fake_output
@staticmethod
def _make_mock_page(page_w, page_h, pix_w, pix_h):
pix = MagicMock()
pix.width = pix_w
pix.height = pix_h
pix.n = 3
pix.samples = bytes([128] * (pix_w * pix_h * 3))
page = MagicMock()
page.rect.width = page_w
page.rect.height = page_h
page.get_pixmap.return_value = pix
return page
def _setup(self, monkeypatch, tmp_path, sess):
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
(tmp_path / "m.onnx").write_bytes(b"x")
monkeypatch.setattr(ort, "InferenceSession", lambda *a, **kw: sess)
def test_detect_page_serializes_concurrent_calls(self, monkeypatch, tmp_path):
"""多线程并发调 detect_page_layoutsession.run 临界区同时只有一个。"""
sess, (pw, ph), fake_output = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
)
in_critical = 0
max_concurrent = 0
counter_lock = threading.Lock()
def counting_run(*args, **kwargs):
nonlocal in_critical, max_concurrent
with counter_lock:
in_critical += 1
max_concurrent = max(max_concurrent, in_critical)
time.sleep(0.02) # 放大竞争窗口,让并发线程有机会重叠
try:
return [fake_output]
finally:
with counter_lock:
in_critical -= 1
sess.run.side_effect = counting_run
self._setup(monkeypatch, tmp_path, sess)
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(8)]
threads = [
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
]
for t in threads:
t.start()
for t in threads:
t.join()
# 锁生效 → 临界区同时只有一个;不加锁时此值会 > 1(回归保护)
assert max_concurrent == 1
def test_session_created_once_under_concurrency(self, monkeypatch, tmp_path):
"""多线程并发首次调用,InferenceSession 只创建一次(锁间接保护 _init_session)。"""
sess, (pw, ph), _fake_output = self._build_mock_session(
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
)
create_count = 0
create_lock = threading.Lock()
def counting_init(*args, **kwargs):
nonlocal create_count
with create_lock:
create_count += 1
time.sleep(0.02) # 放大窗口,让并发首调都来抢
return sess
monkeypatch.setattr(ort, "InferenceSession", counting_init)
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
(tmp_path / "m.onnx").write_bytes(b"x")
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(6)]
threads = [
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
]
for t in threads:
t.start()
for t in threads:
t.join()
assert create_count == 1
+134
View File
@@ -0,0 +1,134 @@
from __future__ import annotations
import json
from unittest.mock import MagicMock
import pymupdf
from app.services import pdf_image_extractor as mod
from app.services.layout_detector import LayoutBox
def test_process_page_extracts_doclayout_caption(tmp_path):
images_dest = tmp_path / "images"
images_dest.mkdir()
manifest: dict[str, dict] = {}
pix = MagicMock()
pix.tobytes.return_value = b"jpeg"
page = MagicMock()
page.rect.width = 600
page.get_pixmap.return_value = pix
page.get_text.return_value = "Figure 1: Overall architecture.\n"
doc = MagicMock()
doc.__getitem__.return_value = page
boxes = [
LayoutBox(100, 100, 300, 300, "picture"),
LayoutBox(95, 310, 320, 325, "figure_caption"),
]
extracted = mod._process_page(
doc,
0,
boxes,
images_dest=images_dest,
manifest=manifest,
seen_labels=set(),
arxiv_id="2401.00001",
)
assert extracted == 1
info = manifest["figure_(p1-1).jpg"]
assert info["caption_text"] == "Figure 1: Overall architecture."
assert info["caption_source"] == "doclayout"
assert info["caption_box"] == [95.0, 310.0, 320.0, 325.0]
def test_process_page_includes_caption_in_render(tmp_path):
"""渲染时把 caption 区域合并进同一张截图。"""
images_dest = tmp_path / "images"
images_dest.mkdir()
manifest: dict[str, dict] = {}
pix = MagicMock()
pix.tobytes.return_value = b"jpeg"
page = MagicMock()
page.rect.width = 600
page.get_pixmap.return_value = pix
page.get_text.return_value = "Figure 1: Caption text.\n"
doc = MagicMock()
doc.__getitem__.return_value = page
boxes = [
LayoutBox(100, 100, 300, 300, "picture"),
LayoutBox(95, 310, 320, 325, "figure_caption"),
]
mod._process_page(
doc,
0,
boxes,
images_dest=images_dest,
manifest=manifest,
seen_labels=set(),
arxiv_id="2401.00001",
)
# 内容 [100,100,300,300] caption [95,310,320,325],各方向加 _REGION_PADDING=5
# → Rect(90, 95, 325, 330)
clip = page.get_pixmap.call_args.kwargs["clip"]
assert clip == pymupdf.Rect(90, 95, 325, 330)
def test_label_images_preserves_doclayout_caption(tmp_path, monkeypatch):
arxiv_id = "2401.00001"
paper_root = tmp_path / arxiv_id
images_dest = paper_root / "images"
images_dest.mkdir(parents=True)
(images_dest / "figure_(p1-1).jpg").write_bytes(b"jpeg")
(images_dest / "manifest.json").write_text(
json.dumps(
{
"figure_(p1-1).jpg": {
"page": 1,
"type": "figure",
"label": "Figure (p1-1)",
"box": [100, 100, 300, 300],
"caption_text": "Figure 1: PDF original caption.",
"caption_source": "doclayout",
}
}
)
)
pdf_path = tmp_path / "paper.pdf"
pdf_path.write_bytes(b"%PDF")
monkeypatch.setattr(mod, "paper_dir", lambda _arxiv_id: paper_root)
page = MagicMock()
page.search_for.return_value = [pymupdf.Rect(120, 305, 180, 320)]
fake_doc = MagicMock()
fake_doc.page_count = 1
fake_doc.__getitem__.return_value = page
fake_doc.__enter__.return_value = fake_doc
fake_doc.__exit__.return_value = False
monkeypatch.setattr(mod.pymupdf, "open", lambda _path: fake_doc)
labeled = mod.label_images_by_summary(
arxiv_id,
[{"id": "Figure 1", "caption": "Summary caption."}],
pdf_path=pdf_path,
)
assert labeled == 1
manifest = json.loads((images_dest / "manifest.json").read_text())
info = manifest["figure_1.jpg"]
assert info["caption_text"] == "Figure 1: PDF original caption."
assert info["caption_source"] == "doclayout"
assert info["summary_caption_text"] == "Summary caption."
+32
View File
@@ -366,6 +366,38 @@ class TestSummarizeOneFlow:
result = await summarize_one(db_session, sample_paper) result = await summarize_one(db_session, sample_paper)
assert result["status"] == "skipped" assert result["status"] == "skipped"
@pytest.mark.asyncio
async def test_post_processing_runs_in_thread(
self, db_session, sample_paper, mock_pi_output, _summarize_tmp_paths
):
"""后处理(图片提取/ChromaDB)在工作线程而非事件循环线程执行。"""
import threading
seen_threads: list[int] = []
main_thread = threading.current_thread().ident
def spy_extract(arxiv_id, schema):
seen_threads.append(threading.current_thread().ident)
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=(mock_pi_output, "test-session-id"),
),
patch(
"app.services.summary_persister._maybe_extract_images",
side_effect=spy_extract,
),
patch("app.services.summary_persister._maybe_index_chroma"),
):
result = await summarize_one(db_session, sample_paper)
assert result["status"] == "done"
assert seen_threads, "后处理未被调用"
assert seen_threads[0] != main_thread, "后处理应在工作线程执行,不阻塞事件循环"
# ═══════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════
# 批量操作测试 # 批量操作测试