Compare commits
3 Commits
743d69efd0
...
29fb20828e
| Author | SHA1 | Date | |
|---|---|---|---|
| 29fb20828e | |||
| 8f13c31991 | |||
| 90fe705e8f |
+15
-4
@@ -12,7 +12,6 @@ SECRET_KEY=your_random_secret_key
|
||||
|
||||
# ─── HuggingFace / arXiv ────────────────
|
||||
HF_API_BASE=https://huggingface.co/api
|
||||
HF_PROXY=
|
||||
TOP_N=20
|
||||
HTTP_TIMEOUT_SECONDS=30
|
||||
HTTP_MAX_RETRIES=3
|
||||
@@ -48,6 +47,18 @@ EMBED_MODEL=Qwen/Qwen3-Embedding-4B
|
||||
EMBED_DIMENSIONS=2560
|
||||
|
||||
# ─── 布局检测 ─────────────────────────────
|
||||
# ONNX 模型路径(首次运行前执行 scripts/export_picodet_onnx.py 导出)
|
||||
# LAYOUT_MODEL_PATH=data/models/picodet_layout_3cls.onnx
|
||||
# LAYOUT_THRESHOLD=0.5
|
||||
# DocLayout-YOLO ONNX 模型(首次运行前执行 scripts/export_doclayout_yolo_onnx.py 导出)
|
||||
# LAYOUT_MODEL_PATH=data/models/doclayout_yolo_docstructbench_imgsz1024.onnx
|
||||
# 模型输入尺寸(DocLayout-YOLO 推荐 1024)
|
||||
# LAYOUT_IMGSZ=1024
|
||||
# 检测置信度阈值(DocLayout-YOLO 推荐 0.2)
|
||||
# LAYOUT_THRESHOLD=0.2
|
||||
# 推理设备:auto|cpu|cuda|directml|openvino|cann|tensorrt|qnn
|
||||
# auto = 按优先级 [CUDA, DirectML, OpenVINO, CPU] 自动探测,失败降级 CPU
|
||||
# LAYOUT_DEVICE=auto
|
||||
# 设备 ID(GPU 序号)
|
||||
# LAYOUT_DEVICE_ID=0
|
||||
#
|
||||
# GPU 用户:onnxruntime 与 onnxruntime-gpu/-directml 同环境冲突,需手动二选一:
|
||||
# pip uninstall onnxruntime && pip install onnxruntime-gpu # NVIDIA CUDA
|
||||
# pip uninstall onnxruntime && pip install onnxruntime-directml # Windows 任意 GPU
|
||||
|
||||
@@ -125,7 +125,7 @@ paper/
|
||||
├── scripts/
|
||||
│ ├── init_db.py # 数据库初始化
|
||||
│ ├── manual_crawl.py # 手动抓取脚本
|
||||
│ ├── export_picodet_onnx.py # 导出布局检测 ONNX 模型
|
||||
│ ├── export_doclayout_yolo_onnx.py # 导出布局检测 ONNX 模型
|
||||
│ ├── reextract_images.py # 批量重新提取图片
|
||||
│ └── validate_summary.py # 校验总结 JSON 结构
|
||||
│
|
||||
@@ -198,8 +198,10 @@ SECRET_KEY=your_random_secret_key
|
||||
| `EMBED_API_KEY` | — | Embedding API Key |
|
||||
| `EMBED_MODEL` | — | Embedding 模型名 |
|
||||
| `EMBED_DIMENSIONS` | `0` | 向量维度 |
|
||||
| `LAYOUT_MODEL_PATH` | `data/models/picodet_layout_3cls.onnx` | ONNX 布局检测模型路径(可选) |
|
||||
| `LAYOUT_THRESHOLD` | `0.5` | 布局检测置信度阈值(可选) |
|
||||
| `LAYOUT_MODEL_PATH` | `data/models/doclayout_yolo_docstructbench_imgsz1024.onnx` | DocLayout-YOLO ONNX 模型路径(可选) |
|
||||
| `LAYOUT_IMGSZ` | `1024` | 模型输入尺寸 |
|
||||
| `LAYOUT_THRESHOLD` | `0.2` | 布局检测置信度阈值(可选) |
|
||||
| `LAYOUT_DEVICE` | `auto` | 推理设备:auto/cpu/cuda/directml/openvino/...(可选) |
|
||||
|
||||
### 4. 初始化数据库
|
||||
|
||||
|
||||
+8
-3
@@ -22,11 +22,11 @@ class Settings(BaseSettings):
|
||||
|
||||
# HuggingFace / arXiv
|
||||
HF_API_BASE: str = "https://huggingface.co/api"
|
||||
HF_PROXY: str = ""
|
||||
TOP_N: int = 20
|
||||
HTTP_TIMEOUT_SECONDS: int = 30
|
||||
HTTP_MAX_RETRIES: int = 3
|
||||
HTTP_USER_AGENT: str = "hf-daily-papers-local/0.1"
|
||||
HF_PROXY: str = ""
|
||||
PDF_DOWNLOAD_TIMEOUT: int = 120
|
||||
|
||||
# AI 总结
|
||||
@@ -60,8 +60,13 @@ class Settings(BaseSettings):
|
||||
EMBED_DIMENSIONS: int = 0
|
||||
|
||||
# 布局检测
|
||||
LAYOUT_MODEL_PATH: str = "data/models/picodet_layout_3cls.onnx"
|
||||
LAYOUT_THRESHOLD: float = 0.5
|
||||
LAYOUT_MODEL_PATH: str = "data/models/doclayout_yolo_docstructbench_imgsz1024.onnx"
|
||||
LAYOUT_IMGSZ: int = 1024
|
||||
LAYOUT_THRESHOLD: float = 0.2
|
||||
# 推理设备:auto|cpu|cuda|directml|openvino|cann|tensorrt|qnn
|
||||
# auto = 按优先级 [CUDA, DirectML, OpenVINO, CPU] 自动探测
|
||||
LAYOUT_DEVICE: str = "auto"
|
||||
LAYOUT_DEVICE_ID: int = 0
|
||||
|
||||
model_config = {
|
||||
"env_file": str(BASE_DIR / ".env"),
|
||||
|
||||
+8
-1
@@ -10,7 +10,14 @@ from fastapi.staticfiles import StaticFiles
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
|
||||
from app.config import settings
|
||||
from app.exceptions import AppError, ConflictError, ExternalAPIError, NotFoundError, PdfProcessError, ValidationError
|
||||
from app.exceptions import (
|
||||
AppError,
|
||||
ConflictError,
|
||||
ExternalAPIError,
|
||||
NotFoundError,
|
||||
PdfProcessError,
|
||||
ValidationError,
|
||||
)
|
||||
from app.database import engine, init_db
|
||||
from app.routes.admin import router as admin_router
|
||||
from app.routes.compare import router as compare_router
|
||||
|
||||
+209
-6
@@ -2,12 +2,22 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import hashlib
|
||||
import hmac
|
||||
import io
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Form, HTTPException, Query, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
BackgroundTasks,
|
||||
Depends,
|
||||
Form,
|
||||
HTTPException,
|
||||
Query,
|
||||
Request,
|
||||
)
|
||||
from fastapi.responses import RedirectResponse, Response
|
||||
from pydantic import BaseModel, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -290,6 +300,183 @@ async def admin_job_detail(
|
||||
return detail
|
||||
|
||||
|
||||
# ── 任务监控 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/jobs")
|
||||
async def admin_jobs(
|
||||
request: Request,
|
||||
_admin: None = Depends(verify_admin),
|
||||
db: Session = Depends(get_db),
|
||||
status: str = Query("all"),
|
||||
job_type: str = Query("all"),
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(20, ge=1, le=100),
|
||||
):
|
||||
"""后台任务监控页。"""
|
||||
jobs, total = admin_svc.query_jobs(
|
||||
db, status=status, job_type=job_type, page=page, per_page=per_page
|
||||
)
|
||||
counts = admin_svc.get_job_status_counts(db)
|
||||
|
||||
def pagination_url(p: int) -> str:
|
||||
params = dict(request.query_params)
|
||||
params["page"] = str(p)
|
||||
return "/admin/jobs?" + "&".join(f"{k}={v}" for k, v in params.items())
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"admin_jobs.html",
|
||||
{
|
||||
"jobs": jobs,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"current_status": status,
|
||||
"current_type": job_type,
|
||||
"status_counts": counts,
|
||||
"pagination_url": pagination_url,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── 锁管理 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/locks/{lock_id}/release")
|
||||
async def admin_release_lock(
|
||||
lock_id: int,
|
||||
_admin: None = Depends(verify_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""强制释放一个卡死的任务锁。"""
|
||||
if not admin_svc.force_release_lock(db, lock_id):
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Lock not found or already released: {lock_id}"
|
||||
)
|
||||
return {"status": "success", "lock_id": lock_id}
|
||||
|
||||
|
||||
# ── 重抓 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/paper-recrawl/{arxiv_id}")
|
||||
async def admin_paper_recrawl(
|
||||
arxiv_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
_admin: None = Depends(verify_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""重新抓取单篇已存在论文的完整元数据。"""
|
||||
job = create_job(
|
||||
db, "recrawl_one", owner="admin_recrawl", payload={"arxiv_id": arxiv_id}
|
||||
)
|
||||
enqueue_job(background_tasks, job.id)
|
||||
return {"status": "queued", "job_id": job.id, "arxiv_id": arxiv_id}
|
||||
|
||||
|
||||
# ── 索引重建 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class RebuildIndexRequest(BaseModel):
|
||||
target: str # "fts" / "chroma" / "both"
|
||||
|
||||
@field_validator("target")
|
||||
@classmethod
|
||||
def target_must_be_valid(cls, v: str) -> str:
|
||||
if v not in ("fts", "chroma", "both"):
|
||||
raise ValueError("target must be 'fts', 'chroma' or 'both'")
|
||||
return v
|
||||
|
||||
|
||||
@router.post("/rebuild-indexes")
|
||||
async def admin_rebuild_indexes(
|
||||
body: RebuildIndexRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
_admin: None = Depends(verify_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""重建搜索索引(FTS5 / ChromaDB)。"""
|
||||
job_ids: list[int] = []
|
||||
if body.target in ("fts", "both"):
|
||||
job = create_job(db, "reindex_fts", owner="admin_reindex", payload={})
|
||||
enqueue_job(background_tasks, job.id)
|
||||
job_ids.append(job.id)
|
||||
if body.target in ("chroma", "both"):
|
||||
job = create_job(db, "reindex_chroma", owner="admin_reindex", payload={})
|
||||
enqueue_job(background_tasks, job.id)
|
||||
job_ids.append(job.id)
|
||||
return {"status": "queued", "job_ids": job_ids, "target": body.target}
|
||||
|
||||
|
||||
# ── 导出 CSV ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/papers/export.csv")
|
||||
async def admin_papers_export(
|
||||
_admin: None = Depends(verify_admin),
|
||||
db: Session = Depends(get_db),
|
||||
q: str = Query(""),
|
||||
date_from: str | None = Query(None),
|
||||
date_to: str | None = Query(None),
|
||||
tag: str = Query(""),
|
||||
summary_status: str = Query("all"),
|
||||
sort: str = Query("date_desc"),
|
||||
):
|
||||
"""导出当前过滤条件下的论文为 CSV(含 UTF-8 BOM,Excel 友好)。"""
|
||||
papers, _total, statuses = admin_svc.query_papers(
|
||||
db,
|
||||
q=q,
|
||||
date_from=date_from,
|
||||
date_to=date_to,
|
||||
tag=tag,
|
||||
summary_status=summary_status,
|
||||
sort=sort,
|
||||
page=1,
|
||||
per_page=10**6,
|
||||
)
|
||||
|
||||
buf = io.StringIO()
|
||||
buf.write("") # UTF-8 BOM for Excel
|
||||
writer = csv.writer(buf)
|
||||
writer.writerow(
|
||||
[
|
||||
"arxiv_id",
|
||||
"title_en",
|
||||
"title_zh",
|
||||
"paper_date",
|
||||
"upvotes",
|
||||
"summary_status",
|
||||
"authors",
|
||||
"tags",
|
||||
"pdf_url",
|
||||
]
|
||||
)
|
||||
for paper in papers:
|
||||
authors = ";".join(a.name for a in paper.authors)
|
||||
tags = ";".join(t.tag for t in paper.tags)
|
||||
writer.writerow(
|
||||
[
|
||||
paper.arxiv_id,
|
||||
paper.title_en or "",
|
||||
paper.title_zh or "",
|
||||
str(paper.paper_date) if paper.paper_date else "",
|
||||
paper.upvotes or 0,
|
||||
statuses.get(paper.arxiv_id, "none"),
|
||||
authors,
|
||||
tags,
|
||||
paper.pdf_url or "",
|
||||
]
|
||||
)
|
||||
|
||||
filename = f"papers_{today_str().replace('-', '')}.csv"
|
||||
return Response(
|
||||
content=buf.getvalue(),
|
||||
media_type="text/csv; charset=utf-8",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
# ── 日志 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -430,24 +617,25 @@ async def admin_paper_delete(
|
||||
|
||||
|
||||
class BatchActionRequest(BaseModel):
|
||||
action: str # "delete" or "summarize"
|
||||
action: str # "delete" / "summarize" / "recrawl"
|
||||
arxiv_ids: list[str]
|
||||
|
||||
@field_validator("action")
|
||||
@classmethod
|
||||
def action_must_be_valid(cls, v: str) -> str:
|
||||
if v not in ("delete", "summarize"):
|
||||
raise ValueError("action must be 'delete' or 'summarize'")
|
||||
if v not in ("delete", "summarize", "recrawl"):
|
||||
raise ValueError("action must be 'delete', 'summarize' or 'recrawl'")
|
||||
return v
|
||||
|
||||
|
||||
@router.post("/papers-batch-action")
|
||||
async def admin_papers_batch_action(
|
||||
body: BatchActionRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
_admin: None = Depends(verify_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""批量操作论文(删除或总结)。"""
|
||||
"""批量操作论文(删除 / 总结 / 重抓)。"""
|
||||
if not body.arxiv_ids:
|
||||
raise HTTPException(status_code=400, detail="arxiv_ids 不能为空")
|
||||
|
||||
@@ -467,3 +655,18 @@ async def admin_papers_batch_action(
|
||||
"message": f"已将 {count} 篇论文重置为待总结",
|
||||
"count": count,
|
||||
}
|
||||
|
||||
elif body.action == "recrawl":
|
||||
job = create_job(
|
||||
db,
|
||||
"recrawl_batch",
|
||||
owner="admin_recrawl",
|
||||
payload={"arxiv_ids": body.arxiv_ids},
|
||||
)
|
||||
enqueue_job(background_tasks, job.id)
|
||||
return {
|
||||
"status": "queued",
|
||||
"job_id": job.id,
|
||||
"count": len(body.arxiv_ids),
|
||||
"message": f"已将 {len(body.arxiv_ids)} 篇论文加入重抓队列",
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import date, timedelta
|
||||
|
||||
|
||||
+129
-2
@@ -7,7 +7,7 @@ from datetime import date
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy import func, select, text, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
@@ -100,7 +100,9 @@ def get_admin_stats(db: Session) -> dict:
|
||||
|
||||
# ── 活跃锁 ────────────────────────────────────────────────────────
|
||||
active_locks = (
|
||||
db.execute(select(TaskLock).where(TaskLock.status == "running")).scalars().all()
|
||||
db.execute(select(TaskLock).where(TaskLock.status.in_(["running", "stale"])))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -124,6 +126,7 @@ def get_admin_stats(db: Session) -> dict:
|
||||
"recent_logs": recent_logs,
|
||||
"active_locks": active_locks,
|
||||
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
|
||||
"config_overview": get_config_overview(),
|
||||
}
|
||||
|
||||
|
||||
@@ -370,6 +373,7 @@ def get_logs_context(db: Session, *, page: int, per_page: int) -> dict:
|
||||
"summary_done": summary_done,
|
||||
"summary_pending": summary_pending,
|
||||
"summary_failed": summary_failed,
|
||||
"failure_breakdown": get_failure_breakdown(db),
|
||||
}
|
||||
|
||||
|
||||
@@ -511,3 +515,126 @@ def reset_summaries_pending(db: Session, arxiv_ids: list[str]) -> int:
|
||||
db.add(SummaryStatus(paper_id=paper_id, status=SummaryState.PENDING))
|
||||
db.commit()
|
||||
return len(paper_ids)
|
||||
|
||||
|
||||
# ── 任务监控 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def query_jobs(
|
||||
db: Session,
|
||||
*,
|
||||
status: str | None = None,
|
||||
job_type: str | None = None,
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""后台任务列表查询 — 支持 status/type 过滤 + 分页,返回已 enrich 的 dict 列表。"""
|
||||
query = select(Job)
|
||||
if status and status != "all":
|
||||
query = query.where(Job.status == status)
|
||||
if job_type and job_type != "all":
|
||||
query = query.where(Job.type == job_type)
|
||||
|
||||
total = db.scalar(select(func.count()).select_from(query.subquery())) or 0
|
||||
jobs = (
|
||||
db.execute(
|
||||
query.order_by(Job.created_at.desc())
|
||||
.offset((page - 1) * per_page)
|
||||
.limit(per_page)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
return [serialize_job(j) for j in jobs], total
|
||||
|
||||
|
||||
def _as_naive(dt):
|
||||
"""去掉 tzinfo — SQLite 读回的 datetime 是 naive UTC,与 utc_now() 运算前需统一。"""
|
||||
if dt is not None and getattr(dt, "tzinfo", None) is not None:
|
||||
return dt.replace(tzinfo=None)
|
||||
return dt
|
||||
|
||||
|
||||
def serialize_job(job: Job) -> dict:
|
||||
"""单条 job 序列化为展示用 dict(含耗时)。"""
|
||||
duration = None
|
||||
started = _as_naive(job.started_at)
|
||||
if started:
|
||||
end = _as_naive(job.completed_at) or _as_naive(utc_now())
|
||||
duration = round((end - started).total_seconds(), 1)
|
||||
return {
|
||||
"id": job.id,
|
||||
"type": job.type,
|
||||
"status": job.status,
|
||||
"owner": job.owner,
|
||||
"created_at": job.created_at,
|
||||
"started_at": job.started_at,
|
||||
"completed_at": job.completed_at,
|
||||
"duration_seconds": duration,
|
||||
"error": job.error,
|
||||
}
|
||||
|
||||
|
||||
def get_job_status_counts(db: Session) -> dict:
|
||||
"""按 status 聚合 job 计数,供任务页顶部小统计行用。"""
|
||||
rows = db.execute(
|
||||
select(Job.status, func.count(Job.id)).group_by(Job.status)
|
||||
).fetchall()
|
||||
return {row[0]: row[1] for row in rows}
|
||||
|
||||
|
||||
# ── 锁管理 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def force_release_lock(db: Session, lock_id: int) -> bool:
|
||||
"""强制释放一个卡死的 TaskLock(仅对 running/stale 生效)。"""
|
||||
result = db.execute(
|
||||
update(TaskLock)
|
||||
.where(TaskLock.id == lock_id, TaskLock.status.in_(["running", "stale"]))
|
||||
.values(status="finished", released_at=utc_now())
|
||||
)
|
||||
db.commit()
|
||||
return (result.rowcount or 0) > 0
|
||||
|
||||
|
||||
# ── 失败原因分布 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_failure_breakdown(db: Session) -> list[dict]:
|
||||
"""按 error_type 聚合失败/永久失败的总结,按数量降序。NULL 归 unknown。"""
|
||||
error_expr = func.coalesce(SummaryStatus.error_type, "unknown")
|
||||
rows = db.execute(
|
||||
select(error_expr, func.count(SummaryStatus.id))
|
||||
.where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||
)
|
||||
)
|
||||
.group_by(error_expr)
|
||||
.order_by(func.count(SummaryStatus.id).desc())
|
||||
).fetchall()
|
||||
return [{"error_type": row[0], "count": row[1]} for row in rows]
|
||||
|
||||
|
||||
# ── 运行配置概览 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_config_overview() -> dict:
|
||||
"""聚合非敏感配置,供仪表盘展示。敏感字段只标是否已配置,不显示值。"""
|
||||
return {
|
||||
"summary_backend": settings.SUMMARY_BACKEND,
|
||||
"summary_pdf_mode": settings.SUMMARY_PDF_MODE,
|
||||
"summary_concurrency": settings.SUMMARY_CONCURRENCY,
|
||||
"summary_timeout_seconds": settings.SUMMARY_TIMEOUT_SECONDS,
|
||||
"summary_max_retries": settings.SUMMARY_MAX_RETRIES,
|
||||
"scheduler_enabled": settings.SCHEDULER_ENABLED,
|
||||
"schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}",
|
||||
"chroma_enabled": settings.CHROMA_ENABLED,
|
||||
"embed_model": settings.EMBED_MODEL or "(未配置)",
|
||||
"top_n": settings.TOP_N,
|
||||
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
|
||||
"app_workers": settings.APP_WORKERS,
|
||||
"layout_model": Path(settings.LAYOUT_MODEL_PATH).name,
|
||||
"database_url": settings.DATABASE_URL,
|
||||
"api_key_configured": bool(settings.EMBED_API_KEY),
|
||||
}
|
||||
|
||||
@@ -270,3 +270,67 @@ def _update_upvotes_only(db: Session, papers_raw: list[dict]) -> int:
|
||||
|
||||
db.commit()
|
||||
return updated
|
||||
|
||||
|
||||
async def recrawl_single(db: Session, arxiv_id: str) -> dict:
|
||||
"""重新抓取一篇已存在论文的完整元数据。
|
||||
|
||||
基于 paper.paper_date 重新拉取 HF Daily 列表,命中后全字段刷新
|
||||
(标题/摘要/作者/标签/链接/upvotes)并重建 FTS。若该论文不在其收录日的
|
||||
列表中则无法重抓。
|
||||
"""
|
||||
paper = db.execute(
|
||||
select(Paper).where(Paper.arxiv_id == arxiv_id)
|
||||
).scalar_one_or_none()
|
||||
if not paper:
|
||||
return {"updated": False, "reason": "not_found", "arxiv_id": arxiv_id}
|
||||
|
||||
target_date = paper.paper_date.isoformat()
|
||||
raw_papers = await fetch_daily(target_date)
|
||||
|
||||
target = None
|
||||
for item in raw_papers:
|
||||
if _parse_paper(item)["arxiv_id"] == arxiv_id:
|
||||
target = item
|
||||
break
|
||||
|
||||
if target is None:
|
||||
return {
|
||||
"updated": False,
|
||||
"reason": "not_in_daily",
|
||||
"arxiv_id": arxiv_id,
|
||||
"date": target_date,
|
||||
}
|
||||
|
||||
meta = _parse_paper(target)
|
||||
now = utc_now()
|
||||
|
||||
# 全字段刷新
|
||||
paper.title_en = meta["title_en"]
|
||||
paper.abstract = meta["abstract"]
|
||||
paper.published_at = meta["published_at"]
|
||||
paper.hf_url = meta["hf_url"]
|
||||
paper.arxiv_url = meta["arxiv_url"]
|
||||
paper.pdf_url = meta["pdf_url"]
|
||||
paper.upvotes = meta["upvotes"]
|
||||
paper.crawled_at = now
|
||||
|
||||
# 重建 authors(删旧再加新)
|
||||
paper.authors.clear()
|
||||
seen_authors: set[str] = set()
|
||||
for idx, name in enumerate(meta["authors"]):
|
||||
if name and name not in seen_authors:
|
||||
seen_authors.add(name)
|
||||
db.add(PaperAuthor(paper_id=paper.id, name=name, position=idx))
|
||||
|
||||
# 重建 tags
|
||||
paper.tags.clear()
|
||||
for tag_name in meta["tags"]:
|
||||
if tag_name:
|
||||
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="hf"))
|
||||
|
||||
db.flush()
|
||||
reindex_paper_fts(db, paper)
|
||||
db.commit()
|
||||
logger.info("Re-crawled paper %s (full metadata refresh)", arxiv_id)
|
||||
return {"updated": True, "arxiv_id": arxiv_id, "date": target_date}
|
||||
|
||||
+78
-40
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select
|
||||
@@ -18,14 +19,27 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChromaManager:
|
||||
"""封装 ChromaDB 客户端和 collection 的生命周期。"""
|
||||
"""封装 ChromaDB 客户端和 collection 的生命周期。
|
||||
|
||||
所有客户端/集合访问经 ``self._lock`` 串行化:后处理经 ``asyncio.to_thread``
|
||||
多 worker 并发调 ``index_paper``,若不串行化会并发建连,触发 chromadb 1.5.x
|
||||
``SharedSystemClient`` 类级缓存的并发竞争(``_create_system_if_not_exists``
|
||||
无锁 + refcount release 弹 key)→ ``KeyError: '<persist_dir>'``。
|
||||
锁用 RLock:``index_paper`` 持锁后经 ``get_collection()`` 间接再调 ``init()``,
|
||||
同线程可重入,不死锁。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.RLock()
|
||||
self._client = None
|
||||
self._collection = None
|
||||
|
||||
def init(self) -> None:
|
||||
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
|
||||
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。
|
||||
|
||||
双重检查锁串行化首次建连 —— 外层快判已建好就直接返回(快路径),抢到锁后再
|
||||
判一次(防并发下另一线程已建好),确保 ``PersistentClient`` 全进程只调一次。
|
||||
"""
|
||||
if not settings.CHROMA_ENABLED:
|
||||
logger.debug("ChromaDB disabled, skip init")
|
||||
return
|
||||
@@ -33,19 +47,23 @@ class ChromaManager:
|
||||
if self._client is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
with self._lock:
|
||||
if self._client is not None: # 双重检查:抢到锁后可能已被别的线程建好
|
||||
return
|
||||
|
||||
chroma_path = Path(settings.CHROMA_DIR)
|
||||
chroma_path.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
import chromadb
|
||||
|
||||
self._client = chromadb.PersistentClient(path=str(chroma_path))
|
||||
self._collection = self._get_or_create_collection()
|
||||
logger.info("ChromaDB initialized at %s", chroma_path)
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize ChromaDB")
|
||||
self._client = None
|
||||
self._collection = None
|
||||
chroma_path = Path(settings.CHROMA_DIR)
|
||||
chroma_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._client = chromadb.PersistentClient(path=str(chroma_path))
|
||||
self._collection = self._get_or_create_collection()
|
||||
logger.info("ChromaDB initialized at %s", chroma_path)
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize ChromaDB")
|
||||
self._client = None
|
||||
self._collection = None
|
||||
|
||||
def _get_or_create_collection(self):
|
||||
"""获取或创建 papers_embeddings collection。"""
|
||||
@@ -102,6 +120,8 @@ def _get_embedding(text: str) -> list[float] | None:
|
||||
POST /v1/embeddings, model=EMBED_MODEL
|
||||
校验返回向量长度 == EMBED_DIMENSIONS
|
||||
失败时返回 None 并记录日志。
|
||||
|
||||
纯远程 HTTP 调用、线程安全 —— 留在锁外,让多 worker 并行调。
|
||||
"""
|
||||
if not settings.EMBED_API_BASE or not settings.EMBED_MODEL:
|
||||
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
|
||||
@@ -177,9 +197,11 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
|
||||
|
||||
Returns:
|
||||
True 表示成功,False 表示失败或跳过。
|
||||
|
||||
并发设计:远程 embedding 调用在锁外(多 worker 并行),chroma 集合访问
|
||||
(含首次 init)在 ``_chroma._lock`` 内串行化。
|
||||
"""
|
||||
col = get_collection()
|
||||
if col is None:
|
||||
if not settings.CHROMA_ENABLED:
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -227,17 +249,25 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
|
||||
logger.warning("Empty index text for %s, skip", paper_id)
|
||||
return False
|
||||
|
||||
vec = _get_embedding(index_text)
|
||||
vec = _get_embedding(index_text) # 远程 HTTP,锁外并行
|
||||
if vec is None:
|
||||
return False
|
||||
|
||||
col.upsert(
|
||||
ids=[arxiv_id],
|
||||
embeddings=[vec],
|
||||
metadatas=[
|
||||
{"arxiv_id": arxiv_id, "title_zh": title_zh, "paper_date": paper_date}
|
||||
],
|
||||
)
|
||||
with _chroma._lock: # 串行化集合访问(首次含 init)
|
||||
col = _chroma.get_collection()
|
||||
if col is None:
|
||||
return False
|
||||
col.upsert(
|
||||
ids=[arxiv_id],
|
||||
embeddings=[vec],
|
||||
metadatas=[
|
||||
{
|
||||
"arxiv_id": arxiv_id,
|
||||
"title_zh": title_zh,
|
||||
"paper_date": paper_date,
|
||||
}
|
||||
],
|
||||
)
|
||||
logger.info("Indexed paper %s in ChromaDB", arxiv_id)
|
||||
return True
|
||||
|
||||
@@ -255,17 +285,20 @@ def delete_paper(paper_id: str) -> bool:
|
||||
Args:
|
||||
paper_id: arxiv_id
|
||||
"""
|
||||
col = get_collection()
|
||||
if col is None:
|
||||
if not settings.CHROMA_ENABLED:
|
||||
return False
|
||||
|
||||
try:
|
||||
col.delete(ids=[paper_id])
|
||||
logger.info("Deleted paper %s from ChromaDB", paper_id)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to delete paper %s from ChromaDB", paper_id)
|
||||
return False
|
||||
with _chroma._lock:
|
||||
col = _chroma.get_collection()
|
||||
if col is None:
|
||||
return False
|
||||
try:
|
||||
col.delete(ids=[paper_id])
|
||||
logger.info("Deleted paper %s from ChromaDB", paper_id)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to delete paper %s from ChromaDB", paper_id)
|
||||
return False
|
||||
|
||||
|
||||
# ── 相似查询 ────────────────────────────────────────────────────────────
|
||||
@@ -280,21 +313,26 @@ def search_similar(query_text: str, top_k: int = 20) -> list[dict]:
|
||||
|
||||
Returns:
|
||||
[{"arxiv_id": str, "distance": float}, ...]
|
||||
|
||||
并发设计:远程 embedding 在锁外,集合查询在 ``_chroma._lock`` 内。
|
||||
"""
|
||||
col = get_collection()
|
||||
if col is None:
|
||||
if not settings.CHROMA_ENABLED:
|
||||
return []
|
||||
|
||||
try:
|
||||
vec = _get_embedding(query_text)
|
||||
vec = _get_embedding(query_text) # 远程 HTTP,锁外
|
||||
if vec is None:
|
||||
return []
|
||||
|
||||
results = col.query(
|
||||
query_embeddings=[vec],
|
||||
n_results=min(top_k, col.count()) if col.count() > 0 else top_k,
|
||||
include=["metadatas", "distances"],
|
||||
)
|
||||
with _chroma._lock:
|
||||
col = _chroma.get_collection()
|
||||
if col is None:
|
||||
return []
|
||||
results = col.query(
|
||||
query_embeddings=[vec],
|
||||
n_results=min(top_k, col.count()) if col.count() > 0 else top_k,
|
||||
include=["metadatas", "distances"],
|
||||
)
|
||||
|
||||
if not results["ids"] or not results["ids"][0]:
|
||||
return []
|
||||
|
||||
+18
-2
@@ -141,12 +141,14 @@ async def run_job(db: Session, job_id: int) -> dict:
|
||||
status=JobEventStatus.SUCCESS,
|
||||
payload=result if isinstance(result, dict) else {"result": result},
|
||||
)
|
||||
return result if isinstance(result, dict) else {"status": "success", "result": result}
|
||||
return (
|
||||
result if isinstance(result, dict) else {"status": "success", "result": result}
|
||||
)
|
||||
|
||||
|
||||
async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
|
||||
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
|
||||
from app.services.crawler import refresh_upvotes
|
||||
from app.services.crawler import recrawl_single, refresh_upvotes
|
||||
from app.services.derived import reindex_chroma, reindex_fts
|
||||
from app.services.pipeline import run_crawl, run_pipeline
|
||||
from app.services.summarizer import summarize_batch, summarize_single
|
||||
@@ -191,6 +193,20 @@ async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
|
||||
return reindex_fts(db)
|
||||
if job.type == "reindex_chroma":
|
||||
return reindex_chroma(db)
|
||||
if job.type == "recrawl_one":
|
||||
return await recrawl_single(db, payload["arxiv_id"])
|
||||
if job.type == "recrawl_batch":
|
||||
updated = 0
|
||||
skipped = 0
|
||||
results = []
|
||||
for arxiv_id in payload.get("arxiv_ids", []):
|
||||
res = await recrawl_single(db, arxiv_id)
|
||||
results.append(res)
|
||||
if res.get("updated"):
|
||||
updated += 1
|
||||
else:
|
||||
skipped += 1
|
||||
return {"updated": updated, "skipped": skipped, "results": results}
|
||||
|
||||
raise ValueError(f"Unsupported job type: {job.type}")
|
||||
|
||||
|
||||
+325
-94
@@ -1,22 +1,32 @@
|
||||
"""PicoDet-S_layout_3cls 布局检测 — 纯 ONNX Runtime 推理.
|
||||
"""DocLayout-YOLO 布局检测 — ONNX Runtime 推理,支持 CPU/GPU/NPU 多设备.
|
||||
|
||||
用 onnxruntime 加载导出好的 ONNX 模型,检测 PDF 页面中的 figure / table 区域。
|
||||
模型自带 NMS + GFL decode,输出即为后处理完毕的检测框。
|
||||
用 onnxruntime 加载 DocLayout-YOLO(DocStructBench, imgsz=1024)ONNX 模型,
|
||||
检测 PDF 页面中的 figure / table 区域。
|
||||
|
||||
预处理:letterbox(保比例缩放 + 灰边 padding 到 imgsz×imgsz),RGB,仅 /255 归一化
|
||||
(不做 ImageNet mean/std)。缩放由 pymupdf Matrix 完成,不依赖 OpenCV。
|
||||
后处理:YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls](已内置 NMS)。
|
||||
坐标还原:(model_coord - padding) / ratio —— 渲染缩放与 letterbox 缩放在 pymupdf
|
||||
渲染阶段合二为一,故只需一次除法。
|
||||
|
||||
设备:resolve_providers() 按 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表;
|
||||
_init_session() 逐个 try,首个不可用则降级,CPU 永远兜底。
|
||||
|
||||
输入:
|
||||
image: (1, 3, 480, 480) float32 — ImageNet 标准化后的图片
|
||||
scale_factor: (1, 2) float32 — [y_scale, x_scale],用于坐标还原
|
||||
images: (1, 3, imgsz, imgsz) float32 —— letterbox + /255 后的图
|
||||
|
||||
输出:
|
||||
fetch_name_0: (N, 6) float32 — [xmin, ymin, xmax, ymax, score, class_id]
|
||||
fetch_name_1: (1,) int32 — 有效框数量 N
|
||||
output0: (1, N, 6) float32 —— [x1, y1, x2, y2, conf, cls],已 NMS
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
@@ -26,37 +36,243 @@ from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 模型输入尺寸
|
||||
_MODEL_SIZE = 480
|
||||
# ImageNet normalize
|
||||
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||||
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||
# PicoDet label → 内部 boxclass
|
||||
_LABEL_MAP: dict[int, str] = {
|
||||
0: "picture", # PicoDet "image" → "picture"
|
||||
1: "table",
|
||||
# 2: seal — 忽略
|
||||
# DocLayout-YOLO DocStructBench 标准 10 类(ONNX metadata 读不到时的兜底,以实际为准)
|
||||
_FALLBACK_NAMES: dict[int, str] = {
|
||||
0: "title",
|
||||
1: "plain text",
|
||||
2: "abandon",
|
||||
3: "figure",
|
||||
4: "figure_caption",
|
||||
5: "table",
|
||||
6: "table_caption",
|
||||
7: "table_footnote",
|
||||
8: "isolate_formula",
|
||||
9: "formula_caption",
|
||||
}
|
||||
# 下游需要 picture/table 及其 caption —— 按 class name 字符串动态匹配(不依赖 class index,
|
||||
# 规避 DocStructBench 不同发布的类别顺序差异)
|
||||
_PICTURE_NAMES = {"figure", "figure_group"}
|
||||
_TABLE_NAMES = {"table", "table_group"}
|
||||
_FIGURE_CAPTION_NAMES = {"figure_caption"}
|
||||
_TABLE_CAPTION_NAMES = {"table_caption"}
|
||||
# letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降)
|
||||
_PAD_VALUE = 114
|
||||
# 最小 bbox 尺寸(PDF 点)
|
||||
_MIN_BOX_SIZE = 20
|
||||
_MIN_CAPTION_BOX_WIDTH = 30
|
||||
_MIN_CAPTION_BOX_HEIGHT = 6
|
||||
|
||||
# device → ExecutionProvider 映射
|
||||
_PROVIDER_MAP: dict[str, str] = {
|
||||
"cpu": "CPUExecutionProvider",
|
||||
"cuda": "CUDAExecutionProvider",
|
||||
"directml": "DmlExecutionProvider",
|
||||
"openvino": "OpenVINOExecutionProvider",
|
||||
"cann": "CannExecutionProvider",
|
||||
"tensorrt": "TensorrtExecutionProvider",
|
||||
"qnn": "QNNExecutionProvider",
|
||||
}
|
||||
# auto 探测优先级(不含 cpu,cpu 永远兜底)
|
||||
_AUTO_PRIORITY = ["cuda", "directml", "openvino", "cann", "tensorrt", "qnn"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayoutBox:
|
||||
"""检测到的布局区域,兼容现有 _process_page 代码。"""
|
||||
"""检测到的布局区域,坐标为 PDF 点。"""
|
||||
|
||||
x0: float
|
||||
y0: float
|
||||
x1: float
|
||||
y1: float
|
||||
boxclass: str # "picture" | "table"
|
||||
boxclass: str
|
||||
|
||||
|
||||
class _LayoutDetector:
|
||||
"""单例:管理 ONNX InferenceSession 生命周期。"""
|
||||
# ── 设备选择 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def resolve_providers(device: str, device_id: int) -> list[tuple[str, dict]]:
|
||||
"""根据 LAYOUT_DEVICE 产出候选 ExecutionProvider 列表(首选在前,均带 CPU 兜底)。
|
||||
|
||||
返回 list[tuple[ep_name, provider_options]],供 _init_session() 逐个 try。
|
||||
onnxruntime 创建 session 时若指定 EP 在本机变体里未注册会直接抛错,
|
||||
故降级逻辑由 _init_session() 完成,这里只产出候选。
|
||||
"""
|
||||
if device == "cpu":
|
||||
return [("CPUExecutionProvider", {})]
|
||||
|
||||
opts = {"device_id": str(device_id)}
|
||||
|
||||
if device == "auto":
|
||||
available = set(ort.get_available_providers())
|
||||
for dev in _AUTO_PRIORITY:
|
||||
ep = _PROVIDER_MAP[dev]
|
||||
if ep in available:
|
||||
logger.info("auto: selected provider %s", ep)
|
||||
return [(ep, opts), ("CPUExecutionProvider", {})]
|
||||
logger.info("auto: no GPU/NPU provider available, using CPU")
|
||||
return [("CPUExecutionProvider", {})]
|
||||
|
||||
ep = _PROVIDER_MAP.get(device)
|
||||
if ep is None:
|
||||
logger.warning("Unknown LAYOUT_DEVICE=%r, falling back to CPU", device)
|
||||
return [("CPUExecutionProvider", {})]
|
||||
return [(ep, opts), ("CPUExecutionProvider", {})]
|
||||
|
||||
|
||||
# ── 预处理:渲染几何与 letterbox ────────────────────────────────────────
|
||||
|
||||
|
||||
def _compute_render_geometry(page_w: float, page_h: float, imgsz: int) -> float:
|
||||
"""letterbox 渲染缩放 ratio = min(imgsz/page_w, imgsz/page_h)。
|
||||
|
||||
pymupdf 以 Matrix(ratio, ratio) 渲染,长边贴到 imgsz,短边留灰边。
|
||||
"""
|
||||
return min(imgsz / page_w, imgsz / page_h)
|
||||
|
||||
|
||||
def _letterbox_padding(
|
||||
content_w: float, content_h: float, imgsz: int
|
||||
) -> tuple[float, float]:
|
||||
"""居中 padding:(imgsz - content) / 2。content 为实际 pixmap 尺寸(已取整)。"""
|
||||
return (imgsz - content_w) / 2.0, (imgsz - content_h) / 2.0
|
||||
|
||||
|
||||
def _padded_nchw_from_pixmap(
|
||||
pix: pymupdf.Pixmap, imgsz: int, dw: float, dh: float
|
||||
) -> np.ndarray:
|
||||
"""pixmap → letterbox padded (1, 3, imgsz, imgsz) float32,灰边=114,/255 归一化。"""
|
||||
arr = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
|
||||
pix.height, pix.width, pix.n
|
||||
)
|
||||
if arr.shape[2] == 4: # 去 alpha(csRGB alpha=False 一般不会,防御性)
|
||||
arr = arr[:, :, :3]
|
||||
|
||||
canvas = np.full((imgsz, imgsz, 3), _PAD_VALUE, dtype=np.uint8)
|
||||
top = int(round(dh))
|
||||
left = int(round(dw))
|
||||
canvas[top : top + pix.height, left : left + pix.width] = arr
|
||||
|
||||
out = canvas.astype(np.float32) / 255.0
|
||||
return out.transpose(2, 0, 1)[np.newaxis] # (1, 3, imgsz, imgsz)
|
||||
|
||||
|
||||
def _model_to_pdf(
|
||||
model_x: float, model_y: float, dw: float, dh: float, ratio: float
|
||||
) -> tuple[float, float]:
|
||||
"""模型 imgsz 空间坐标 → PDF 点:(model - padding) / ratio。"""
|
||||
return (model_x - dw) / ratio, (model_y - dh) / ratio
|
||||
|
||||
|
||||
# ── 后处理 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _postprocess_output(
|
||||
output: np.ndarray, threshold: float, names: dict[int, str]
|
||||
) -> list[tuple[int, float, float, float, float]]:
|
||||
"""解析 YOLOv10 end-to-end 输出,过滤 conf < threshold。
|
||||
|
||||
Args:
|
||||
output: session.run 返回的第一个输出,shape [1, N, 6]
|
||||
threshold: 置信度阈值
|
||||
names: class id → name(仅用于日志,过滤不依赖)
|
||||
|
||||
Returns:
|
||||
[(cls_id, x1, y1, x2, y2), ...],坐标为模型 imgsz padded 空间。
|
||||
"""
|
||||
out = output[0] # 去 batch 维
|
||||
if out.ndim != 2 or out.shape[1] != 6:
|
||||
logger.warning(
|
||||
"Unexpected DocLayout-YOLO output shape %s (expected [N,6]); skip page",
|
||||
tuple(out.shape),
|
||||
)
|
||||
return []
|
||||
|
||||
results: list[tuple[int, float, float, float, float]] = []
|
||||
for row in out:
|
||||
x1, y1, x2, y2, conf, cls = row.tolist()
|
||||
if conf < threshold:
|
||||
continue
|
||||
results.append((int(cls), x1, y1, x2, y2))
|
||||
return results
|
||||
|
||||
|
||||
def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None:
|
||||
"""按 class name 匹配下游关心的布局类别,其余返回 None。"""
|
||||
name = names.get(cls_id, "")
|
||||
n = name.strip().lower()
|
||||
if n in _PICTURE_NAMES:
|
||||
return "picture"
|
||||
if n in _TABLE_NAMES:
|
||||
return "table"
|
||||
if n in _FIGURE_CAPTION_NAMES:
|
||||
return "figure_caption"
|
||||
if n in _TABLE_CAPTION_NAMES:
|
||||
return "table_caption"
|
||||
return None
|
||||
|
||||
|
||||
def _parse_names_from_meta(session: ort.InferenceSession) -> dict[int, str]:
|
||||
"""从 ONNX metadata 读 names(ultralytics 导出写入的 JSON),读不到用兜底。"""
|
||||
raw = None
|
||||
try:
|
||||
raw = session.get_modelmeta().custom_metadata_map.get("names")
|
||||
except Exception:
|
||||
raw = None
|
||||
if raw:
|
||||
try:
|
||||
d = json.loads(raw)
|
||||
return {int(k): str(v) for k, v in d.items()}
|
||||
except Exception:
|
||||
logger.warning("Failed to parse ONNX names metadata; using fallback")
|
||||
return dict(_FALLBACK_NAMES)
|
||||
|
||||
|
||||
# ── 检测器单例 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _Singleton(type):
|
||||
"""元类单例:``cls()`` 永远返回同一实例;``reset_instance()`` 清缓存以便重建。
|
||||
|
||||
生产代码只应在模块级 ``_detector = _LayoutDetector()`` 创建一次。任何第二处
|
||||
``_LayoutDetector()`` 都会拿到同一实例(含同一 ONNX session + 同一锁),杜绝
|
||||
并发推理时各建一份 session 导致内存峰值翻倍(8GB 机器崩溃根因)。双检锁保证
|
||||
首次实例化线程安全。
|
||||
"""
|
||||
|
||||
_instances: dict[type, Any] = {}
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls in _Singleton._instances:
|
||||
return _Singleton._instances[cls]
|
||||
with _Singleton._lock:
|
||||
if cls not in _Singleton._instances:
|
||||
_Singleton._instances[cls] = super().__call__(*args, **kwargs)
|
||||
return _Singleton._instances[cls]
|
||||
|
||||
|
||||
class _LayoutDetector(metaclass=_Singleton):
|
||||
"""强约束单例:管理 ONNX InferenceSession 生命周期。
|
||||
|
||||
由 ``_Singleton`` 元类保证全进程唯一实例 —— 重复 ``_LayoutDetector()`` 只会返回
|
||||
已有实例(含已加载的 session 和锁),不会新建。``reset_instance()`` 清缓存,仅供
|
||||
测试隔离用。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._session: ort.InferenceSession | None = None
|
||||
self._names: dict[int, str] = {}
|
||||
self._input_name: str = ""
|
||||
self._imgsz: int = settings.LAYOUT_IMGSZ
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""清空单例缓存,下次 ``_LayoutDetector()`` 重建新实例(含新锁 + 空 session)。
|
||||
|
||||
仅用于测试隔离 —— 生产代码永远不该调用(否则会丢掉已加载的模型 session)。
|
||||
"""
|
||||
_Singleton._instances.pop(cls, None)
|
||||
|
||||
def _init_session(self) -> ort.InferenceSession:
|
||||
if self._session is not None:
|
||||
@@ -66,109 +282,124 @@ class _LayoutDetector:
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Layout model not found: {model_path}. "
|
||||
"Run scripts/export_picodet_onnx.py first."
|
||||
"Run scripts/export_doclayout_yolo_onnx.py first."
|
||||
)
|
||||
|
||||
logger.info("Loading ONNX layout model: %s", model_path)
|
||||
self._session = ort.InferenceSession(
|
||||
str(model_path), providers=["CPUExecutionProvider"]
|
||||
eps = resolve_providers(settings.LAYOUT_DEVICE, settings.LAYOUT_DEVICE_ID)
|
||||
logger.info(
|
||||
"Loading layout model %s, candidate providers: %s",
|
||||
model_path,
|
||||
[ep[0] for ep in eps],
|
||||
)
|
||||
logger.info("ONNX layout model loaded")
|
||||
|
||||
# 逐个 EP 尝试,首个不可用则降级
|
||||
last_err: Exception | None = None
|
||||
for idx, (ep_name, ep_opts) in enumerate(eps):
|
||||
try:
|
||||
self._session = ort.InferenceSession(
|
||||
str(model_path), providers=[(ep_name, ep_opts)]
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
if idx < len(eps) - 1:
|
||||
logger.warning(
|
||||
"Provider %s unavailable (%s); falling back to %s",
|
||||
ep_name,
|
||||
e,
|
||||
eps[idx + 1][0],
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Failed to create layout session: {last_err}")
|
||||
|
||||
logger.info(
|
||||
"Layout session active providers: %s", self._session.get_providers()
|
||||
)
|
||||
self._input_name = self._session.get_inputs()[0].name
|
||||
self._names = _parse_names_from_meta(self._session)
|
||||
self._imgsz = settings.LAYOUT_IMGSZ
|
||||
return self._session
|
||||
|
||||
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||||
def _detect_page_impl(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||||
"""检测单页 PDF 的 figure / table 区域。
|
||||
|
||||
流程:
|
||||
1. pymupdf 以 480×480 渲染页面
|
||||
2. ImageNet normalize → NCHW
|
||||
3. ONNX 推理 → 得到已解码+NMS 的检测框
|
||||
4. 像素坐标 → PDF 点坐标
|
||||
5. 过滤 seal 类和低置信度框
|
||||
|
||||
Args:
|
||||
page: pymupdf Page 对象
|
||||
1. letterbox 渲染:保比例缩放到长边=imgsz,短边留灰边
|
||||
2. /255 + NCHW → ONNX 推理
|
||||
3. YOLOv10 end-to-end 后处理(已 NMS)
|
||||
4. 模型坐标 → PDF 点
|
||||
5. 过滤非 figure/table 类、极小框、越界 clip
|
||||
|
||||
Returns:
|
||||
LayoutBox 列表,坐标为 PDF 点
|
||||
LayoutBox 列表,坐标为 PDF 点。
|
||||
"""
|
||||
session = self._init_session()
|
||||
|
||||
page_w = page.rect.width
|
||||
page_h = page.rect.height
|
||||
ratio = _compute_render_geometry(page_w, page_h, self._imgsz)
|
||||
|
||||
# 1. 渲染页面到 _MODEL_SIZE × _MODEL_SIZE
|
||||
zoom_x = _MODEL_SIZE / page_w
|
||||
zoom_y = _MODEL_SIZE / page_h
|
||||
mat = pymupdf.Matrix(zoom_x, zoom_y)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
# 2. 预处理
|
||||
img = (
|
||||
np.frombuffer(pix.samples, dtype=np.uint8)
|
||||
.reshape(pix.height, pix.width, pix.n)
|
||||
.astype(np.float32)
|
||||
/ 255.0
|
||||
# 1. 保比例渲染(长边贴 imgsz)
|
||||
pix = page.get_pixmap(
|
||||
matrix=pymupdf.Matrix(ratio, ratio),
|
||||
colorspace=pymupdf.csRGB,
|
||||
alpha=False,
|
||||
)
|
||||
# 去掉 alpha 通道(如有)
|
||||
if img.shape[2] == 4:
|
||||
img = img[:, :, :3]
|
||||
img = (img - _MEAN) / _STD
|
||||
img = img.transpose(2, 0, 1)[np.newaxis] # (1, 3, H, W)
|
||||
# 用 pixmap 实际尺寸(已取整)算 padding,消除取整导致的坐标偏移
|
||||
dw, dh = _letterbox_padding(pix.width, pix.height, self._imgsz)
|
||||
tensor = _padded_nchw_from_pixmap(pix, self._imgsz, dw, dh)
|
||||
|
||||
# scale_factor 用于坐标还原(模型内部可能用)
|
||||
scale_factor = np.array([[1.0, 1.0]], dtype=np.float32)
|
||||
|
||||
# 3. 推理
|
||||
input_names = [i.name for i in session.get_inputs()]
|
||||
feed = {input_names[0]: img}
|
||||
if len(input_names) > 1:
|
||||
feed[input_names[1]] = scale_factor
|
||||
|
||||
outputs = session.run(None, feed)
|
||||
boxes_raw = outputs[0] # (N, 6): [class_id, score, xmin, ymin, xmax, ymax]
|
||||
num_boxes = int(outputs[1][0]) # 有效框数
|
||||
|
||||
if num_boxes == 0:
|
||||
return []
|
||||
|
||||
# 4. 像素 → PDF 点坐标
|
||||
sx = page_w / _MODEL_SIZE
|
||||
sy = page_h / _MODEL_SIZE
|
||||
# 2. 推理
|
||||
outputs = session.run(None, {self._input_name: tensor})
|
||||
detections = _postprocess_output(
|
||||
outputs[0], settings.LAYOUT_THRESHOLD, self._names
|
||||
)
|
||||
|
||||
# 3. 坐标还原 + 过滤
|
||||
result: list[LayoutBox] = []
|
||||
for i in range(min(num_boxes, len(boxes_raw))):
|
||||
cls_id, score, xmin, ymin, xmax, ymax = boxes_raw[i]
|
||||
cls_id = int(cls_id)
|
||||
|
||||
# 跳过 seal 类和低置信度
|
||||
if cls_id not in _LABEL_MAP:
|
||||
for cls_id, x1m, y1m, x2m, y2m in detections:
|
||||
boxclass = _map_class_to_boxclass(cls_id, self._names)
|
||||
if boxclass is None:
|
||||
continue
|
||||
if score < settings.LAYOUT_THRESHOLD:
|
||||
continue
|
||||
|
||||
x0, y0 = xmin * sx, ymin * sy
|
||||
x1, y1 = xmax * sx, ymax * sy
|
||||
|
||||
# 跳过极小区域
|
||||
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
|
||||
continue
|
||||
|
||||
result.append(
|
||||
LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=_LABEL_MAP[cls_id])
|
||||
)
|
||||
x0, y0 = _model_to_pdf(x1m, y1m, dw, dh, ratio)
|
||||
x1, y1 = _model_to_pdf(x2m, y2m, dw, dh, ratio)
|
||||
# clip 到页面范围
|
||||
x0 = max(0.0, min(x0, page_w))
|
||||
y0 = max(0.0, min(y0, page_h))
|
||||
x1 = max(0.0, min(x1, page_w))
|
||||
y1 = max(0.0, min(y1, page_h))
|
||||
if boxclass in ("figure_caption", "table_caption"):
|
||||
if (x1 - x0) < _MIN_CAPTION_BOX_WIDTH or (
|
||||
y1 - y0
|
||||
) < _MIN_CAPTION_BOX_HEIGHT:
|
||||
continue
|
||||
else:
|
||||
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
|
||||
continue
|
||||
result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass))
|
||||
|
||||
return result
|
||||
|
||||
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||||
"""公共入口:加锁串行化推理。
|
||||
|
||||
# 模块级单例
|
||||
包裹整段 _detect_page_impl(含 pixmap 渲染 + tensor 构造 + session.run),
|
||||
保证同一时刻只有一个推理在跑——避免 SUMMARY_CONCURRENCY>1 时多个 to_thread
|
||||
线程并发推理导致内存峰值翻倍(8GB 机器崩溃根因)。锁由 _detect_page_impl
|
||||
间接保护 _init_session,首次加载也串行,杜绝并发各建一份 session。
|
||||
"""
|
||||
with self._lock:
|
||||
return self._detect_page_impl(page)
|
||||
|
||||
|
||||
# 模块级单例 —— 生产代码唯一的实例化点(_Singleton 元类保证不会再有第二个)
|
||||
_detector = _LayoutDetector()
|
||||
|
||||
|
||||
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
|
||||
"""检测 PDF 页面中的 figure / table 区域。
|
||||
"""检测 PDF 页面中的 figure / table / caption 区域。
|
||||
|
||||
Returns:
|
||||
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table。
|
||||
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table 及其 caption。
|
||||
"""
|
||||
return _detector.detect_page(page)
|
||||
|
||||
@@ -71,7 +71,9 @@ async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
|
||||
|
||||
try:
|
||||
session = _get_session()
|
||||
resp = session.get(pdf_url, timeout=settings.PDF_DOWNLOAD_TIMEOUT, allow_redirects=True)
|
||||
resp = session.get(
|
||||
pdf_url, timeout=settings.PDF_DOWNLOAD_TIMEOUT, allow_redirects=True
|
||||
)
|
||||
resp.raise_for_status()
|
||||
dest.write_bytes(resp.content)
|
||||
except Exception as exc:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""PDF 图片与表格提取 — 两阶段流水线。
|
||||
|
||||
Phase 1: PicoDet-S_layout_3cls 检测 figure/table 区域 → 渲染为 JPEG(通用标签)
|
||||
Phase 1: DocLayout-YOLO 检测 figure/table 区域 → 渲染为 JPEG(通用标签)
|
||||
Phase 2: 用 LLM summary 的 figures[].id 在 PDF 中搜索定位 → 匹配到 box → 重命名
|
||||
|
||||
相比旧方案(正则匹配 caption):
|
||||
@@ -34,6 +34,8 @@ _CLUSTER_GAP = 15
|
||||
_MIN_BOX_AREA = 2000
|
||||
# Phase 2: 搜索文本到 box 的最大匹配距离(单位: pt)
|
||||
_LABEL_MATCH_DISTANCE = 100
|
||||
# DocLayout caption 与 figure/table 匹配的最大距离(单位: pt)
|
||||
_CAPTION_MATCH_DISTANCE = 120
|
||||
|
||||
|
||||
# ── Box 聚类 ─────────────────────────────────────────────────────────
|
||||
@@ -53,6 +55,15 @@ class _BoxCluster:
|
||||
self.boxclass = "table" if raw == "table-fallback" else raw
|
||||
|
||||
|
||||
def _cluster_to_box(cluster: _BoxCluster) -> list[float]:
|
||||
return [
|
||||
round(float(cluster.x0), 1),
|
||||
round(float(cluster.y0), 1),
|
||||
round(float(cluster.x1), 1),
|
||||
round(float(cluster.y1), 1),
|
||||
]
|
||||
|
||||
|
||||
def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
|
||||
"""将相邻的同类型 box 合并为聚类。"""
|
||||
if not boxes:
|
||||
@@ -92,6 +103,67 @@ def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
|
||||
return [_BoxCluster(members) for members in groups.values()]
|
||||
|
||||
|
||||
def _caption_class_for_content(boxclass: str) -> str:
|
||||
return "figure_caption" if boxclass == "picture" else "table_caption"
|
||||
|
||||
|
||||
def _caption_distance(content: _BoxCluster, caption: _BoxCluster) -> float | None:
|
||||
"""Return a spatial score for pairing a caption with a content box."""
|
||||
h_overlap = min(content.x1, caption.x1) - max(content.x0, caption.x0)
|
||||
min_width = min(content.x1 - content.x0, caption.x1 - caption.x0)
|
||||
if min_width <= 0 or h_overlap < min_width * 0.25:
|
||||
return None
|
||||
|
||||
if caption.y1 < content.y0:
|
||||
v_gap = content.y0 - caption.y1
|
||||
elif caption.y0 > content.y1:
|
||||
v_gap = caption.y0 - content.y1
|
||||
else:
|
||||
v_gap = 0.0
|
||||
|
||||
return v_gap if v_gap <= _CAPTION_MATCH_DISTANCE else None
|
||||
|
||||
|
||||
def _extract_caption_text(page, caption: _BoxCluster) -> str:
|
||||
rect = pymupdf.Rect(caption.x0, caption.y0, caption.x1, caption.y1)
|
||||
try:
|
||||
text = page.get_text("text", clip=rect)
|
||||
except Exception:
|
||||
return ""
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
def _match_captions(
|
||||
page,
|
||||
content_clusters: list[_BoxCluster],
|
||||
caption_clusters: list[_BoxCluster],
|
||||
) -> dict[int, tuple[_BoxCluster, str]]:
|
||||
"""Match each content cluster to its nearest same-type DocLayout caption."""
|
||||
matches: dict[int, tuple[_BoxCluster, str]] = {}
|
||||
used_captions: set[int] = set()
|
||||
candidates: list[tuple[float, int, int]] = []
|
||||
|
||||
for content_idx, content in enumerate(content_clusters):
|
||||
wanted_caption_class = _caption_class_for_content(content.boxclass)
|
||||
for caption_idx, caption in enumerate(caption_clusters):
|
||||
if caption.boxclass != wanted_caption_class:
|
||||
continue
|
||||
dist = _caption_distance(content, caption)
|
||||
if dist is not None:
|
||||
candidates.append((dist, content_idx, caption_idx))
|
||||
|
||||
for _dist, content_idx, caption_idx in sorted(candidates):
|
||||
if content_idx in matches or caption_idx in used_captions:
|
||||
continue
|
||||
text = _extract_caption_text(page, caption_clusters[caption_idx])
|
||||
if not text:
|
||||
continue
|
||||
matches[content_idx] = (caption_clusters[caption_idx], text)
|
||||
used_captions.add(caption_idx)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
# ── Phase 1: 检测 + 渲染 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -102,14 +174,25 @@ def _render_box(
|
||||
filename: str,
|
||||
cap_type: str,
|
||||
page_num: int,
|
||||
caption: _BoxCluster | None = None,
|
||||
) -> bool:
|
||||
"""渲染单个 box 区域并保存 JPEG,成功返回 True。"""
|
||||
"""渲染单个 box 区域并保存 JPEG,成功返回 True。
|
||||
|
||||
若提供 caption,则将内容与 caption 区域合并后一起截取,
|
||||
使同一张截图同时包含图/表及其标题文字。
|
||||
"""
|
||||
page_width = page.rect.width
|
||||
x0, y0, x1, y1 = box.x0, box.y0, box.x1, box.y1
|
||||
if caption is not None:
|
||||
x0 = min(x0, caption.x0)
|
||||
y0 = min(y0, caption.y0)
|
||||
x1 = max(x1, caption.x1)
|
||||
y1 = max(y1, caption.y1)
|
||||
clip = pymupdf.Rect(
|
||||
max(0, box.x0 - _REGION_PADDING),
|
||||
max(0, box.y0 - _REGION_PADDING),
|
||||
min(page_width, box.x1 + _REGION_PADDING),
|
||||
box.y1 + _REGION_PADDING,
|
||||
max(0, x0 - _REGION_PADDING),
|
||||
max(0, y0 - _REGION_PADDING),
|
||||
min(page_width, x1 + _REGION_PADDING),
|
||||
y1 + _REGION_PADDING,
|
||||
)
|
||||
mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM)
|
||||
try:
|
||||
@@ -136,25 +219,31 @@ def _process_page(
|
||||
fig_counter = 0
|
||||
tbl_counter = 0
|
||||
|
||||
# 收集本页的 table/picture box(跳过极小区域)
|
||||
# 收集本页的 table/picture box 与 caption box(跳过极小区域)
|
||||
raw_boxes = []
|
||||
raw_caption_boxes = []
|
||||
for box in page_boxes:
|
||||
if box.boxclass not in ("table", "table-fallback", "picture"):
|
||||
continue
|
||||
w = box.x1 - box.x0
|
||||
h = box.y1 - box.y0
|
||||
if w < 20 or h < 20 or w * h < _MIN_BOX_AREA:
|
||||
continue
|
||||
raw_boxes.append(box)
|
||||
if box.boxclass in ("table", "table-fallback", "picture"):
|
||||
if w < 20 or h < 20 or w * h < _MIN_BOX_AREA:
|
||||
continue
|
||||
raw_boxes.append(box)
|
||||
elif box.boxclass in ("figure_caption", "table_caption"):
|
||||
if w < 30 or h < 6:
|
||||
continue
|
||||
raw_caption_boxes.append(box)
|
||||
|
||||
if not raw_boxes:
|
||||
return 0
|
||||
|
||||
# 聚类:将同一 figure/table 的碎片 box 合并
|
||||
clusters = _cluster_boxes(raw_boxes)
|
||||
caption_clusters = _cluster_boxes(raw_caption_boxes)
|
||||
caption_matches = _match_captions(page, clusters, caption_clusters)
|
||||
|
||||
extracted = 0
|
||||
for cluster in clusters:
|
||||
for cluster_idx, cluster in enumerate(clusters):
|
||||
cap_type = "figure" if cluster.boxclass == "picture" else "table"
|
||||
|
||||
if cap_type == "figure":
|
||||
@@ -168,21 +257,33 @@ def _process_page(
|
||||
continue
|
||||
seen_labels.add(label)
|
||||
|
||||
caption_match = caption_matches.get(cluster_idx)
|
||||
caption_cluster = caption_match[0] if caption_match else None
|
||||
|
||||
filename = f"{label.replace(' ', '_').lower()}.jpg"
|
||||
if not _render_box(page, cluster, images_dest, filename, cap_type, page_num):
|
||||
if not _render_box(
|
||||
page,
|
||||
cluster,
|
||||
images_dest,
|
||||
filename,
|
||||
cap_type,
|
||||
page_num,
|
||||
caption=caption_cluster,
|
||||
):
|
||||
continue
|
||||
|
||||
manifest[filename] = {
|
||||
info = {
|
||||
"page": page_num,
|
||||
"type": cap_type,
|
||||
"label": label,
|
||||
"box": [
|
||||
round(float(cluster.x0), 1),
|
||||
round(float(cluster.y0), 1),
|
||||
round(float(cluster.x1), 1),
|
||||
round(float(cluster.y1), 1),
|
||||
],
|
||||
"box": _cluster_to_box(cluster),
|
||||
}
|
||||
if caption_match:
|
||||
info["caption_text"] = caption_match[1][:500]
|
||||
info["caption_box"] = _cluster_to_box(caption_cluster)
|
||||
info["caption_source"] = "doclayout"
|
||||
|
||||
manifest[filename] = info
|
||||
extracted += 1
|
||||
|
||||
return extracted
|
||||
@@ -446,14 +547,20 @@ def label_images_by_summary(
|
||||
cap_type = info.get("type", "figure")
|
||||
|
||||
# 读取 caption 文本(从 figures 列表)
|
||||
caption_text = ""
|
||||
summary_caption_text = ""
|
||||
for fig in figures:
|
||||
if fig.get("id") == fig_id:
|
||||
caption_text = fig.get("caption", "")
|
||||
summary_caption_text = fig.get("caption", "")
|
||||
break
|
||||
|
||||
info["label"] = fig_id
|
||||
info["caption_text"] = caption_text[:200] if caption_text else ""
|
||||
existing_caption_text = info.get("caption_text", "")
|
||||
if existing_caption_text and summary_caption_text:
|
||||
info["summary_caption_text"] = summary_caption_text[:500]
|
||||
else:
|
||||
info["caption_text"] = (
|
||||
summary_caption_text[:500] if summary_caption_text else ""
|
||||
)
|
||||
info.setdefault("figures" if cap_type == "figure" else "tables", []).append(
|
||||
fig_id
|
||||
)
|
||||
@@ -501,10 +608,6 @@ def _image_sort_key(name: str) -> tuple[int, int]:
|
||||
m = re.search(r"(?:figure|table)_(\d+)", name)
|
||||
if m:
|
||||
return (0, int(m.group(1)))
|
||||
# 旧格式:page2_img1.png, page5_table1.png, figure_1.png
|
||||
m2 = re.search(r"page(\d+)_(?:img|table)(\d+)", name)
|
||||
if m2:
|
||||
return (int(m2.group(1)), int(m2.group(2)))
|
||||
return (0, 0)
|
||||
|
||||
|
||||
|
||||
@@ -13,11 +13,8 @@ from pathlib import Path
|
||||
from app.config import settings
|
||||
from app.utils import truncate_error
|
||||
from app.services.summary_utils import (
|
||||
JsonNotFoundError,
|
||||
build_prompt,
|
||||
extract_json,
|
||||
extract_pdf_text,
|
||||
write_meta_json,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,18 +22,6 @@ logger = logging.getLogger(__name__)
|
||||
# PDF 全文注入模式的字符上限 — 超过此阈值自动切换到 search 模式
|
||||
_PDF_MAX_CHARS = 80_000
|
||||
|
||||
# 重新导出,保持向后兼容
|
||||
__all__ = [
|
||||
"PiTimeoutError",
|
||||
"PiProcessError",
|
||||
"JsonNotFoundError",
|
||||
"call_pi",
|
||||
"write_meta_json",
|
||||
"extract_pdf_text",
|
||||
"build_prompt",
|
||||
"extract_json",
|
||||
]
|
||||
|
||||
|
||||
# ── 自定义异常 ──────────────────────────────────────────────────────────
|
||||
|
||||
@@ -80,13 +65,17 @@ async def call_pi(
|
||||
actual_mode = "search"
|
||||
logger.info(
|
||||
"Auto mode: %s text=%d chars > %dk → search",
|
||||
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
|
||||
arxiv_id,
|
||||
txt_size,
|
||||
_PDF_MAX_CHARS // 1000,
|
||||
)
|
||||
else:
|
||||
actual_mode = "inject"
|
||||
logger.info(
|
||||
"Auto mode: %s text=%d chars ≤ %dk → inject",
|
||||
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
|
||||
arxiv_id,
|
||||
txt_size,
|
||||
_PDF_MAX_CHARS // 1000,
|
||||
)
|
||||
|
||||
# inject 模式需要截断过长的文本(避免撑爆 context)
|
||||
|
||||
@@ -31,6 +31,7 @@ from app.services.summary_persister import (
|
||||
_cleanup_old_images,
|
||||
_handle_summary_failure,
|
||||
_persist_summary,
|
||||
_run_post_processing,
|
||||
)
|
||||
from app.utils import TMP_DIR, release_lock, truncate_error, utc_now
|
||||
|
||||
@@ -115,12 +116,31 @@ async def _do_summarize_one(db: Session, paper: Paper, pdf_mode: str = "auto") -
|
||||
_t3 = _time.monotonic()
|
||||
logger.info(" [%s] pi生成: %.2fs", arxiv_id, _t3 - _t2)
|
||||
|
||||
quality = _persist_summary(db, paper, json_data, raw_output)
|
||||
quality, schema = _persist_summary(db, paper, json_data, raw_output)
|
||||
_t4 = _time.monotonic()
|
||||
logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3)
|
||||
|
||||
# 后处理(图片提取 + ChromaDB 索引)搬到线程池跑,避免 CPU 密集推理冻结
|
||||
# 事件循环。paper 字段在此(事件循环线程)提取成纯值再传入,规避 worker
|
||||
# 线程跨线程访问 ORM 的 DetachedInstanceError。DocLayout 推理由单例的
|
||||
# threading.Lock 串行化,并发 worker 不会同时压模型。
|
||||
paper_meta = {
|
||||
"title_en": paper.title_en or "",
|
||||
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
|
||||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
|
||||
}
|
||||
_t5 = _time.monotonic()
|
||||
try:
|
||||
await asyncio.to_thread(_run_post_processing, arxiv_id, schema, paper_meta)
|
||||
except Exception:
|
||||
# 双保险:_run_post_processing 内部已 try/except,此处兜底,
|
||||
# 确保后处理失败绝不影响已 DONE 的总结。
|
||||
logger.warning("Post-processing error for %s", arxiv_id, exc_info=True)
|
||||
_t6 = _time.monotonic()
|
||||
logger.info(" [%s] 后处理(线程池): %.2fs", arxiv_id, _t6 - _t5)
|
||||
|
||||
logger.info(
|
||||
"✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0
|
||||
"✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t6 - _t0
|
||||
)
|
||||
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
|
||||
|
||||
|
||||
@@ -24,22 +24,6 @@ from app.utils import TMP_DIR, truncate_error, utc_now
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── FTS5 文本构建 ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _build_fts_summary_text(schema: SummarySchema) -> str:
|
||||
"""拼接用于 FTS5 索引的总结文本。"""
|
||||
parts = [
|
||||
schema.one_line or "",
|
||||
schema.motivation.problem or "",
|
||||
schema.motivation.goal or "",
|
||||
schema.method.overview or "",
|
||||
schema.method.key_idea or "",
|
||||
schema.results.main_findings or "",
|
||||
]
|
||||
return " ".join(p for p in parts if p)
|
||||
|
||||
|
||||
# ── DB 更新 ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -147,8 +131,12 @@ def _handle_summary_failure(
|
||||
|
||||
def _persist_summary(
|
||||
db: Session, paper: Paper, json_data: dict, raw_output: str
|
||||
) -> str:
|
||||
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
|
||||
) -> tuple[str, SummarySchema]:
|
||||
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 (quality, schema)。
|
||||
|
||||
后处理(图片提取/ChromaDB)不再在此函数内执行,由调用方搬到线程池,
|
||||
以免阻塞事件循环。返回 schema 供调用方在线程池里跑后处理。
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
arxiv_id = paper.arxiv_id
|
||||
@@ -181,21 +169,10 @@ def _persist_summary(
|
||||
_t4 - _t3,
|
||||
)
|
||||
|
||||
# 触发性增强(失败不影响总结)
|
||||
_t5 = _time.monotonic()
|
||||
_maybe_extract_images(arxiv_id, schema)
|
||||
_t6 = _time.monotonic()
|
||||
_maybe_index_chroma(arxiv_id, paper, schema)
|
||||
_t7 = _time.monotonic()
|
||||
|
||||
logger.info(
|
||||
" [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs",
|
||||
arxiv_id,
|
||||
_t6 - _t5,
|
||||
_t7 - _t6,
|
||||
)
|
||||
|
||||
return quality
|
||||
# 后处理(图片提取 + ChromaDB 索引)已上移到调用方 _do_summarize_one,
|
||||
# 经 asyncio.to_thread 在线程池跑——DB session 必须留在事件循环线程,
|
||||
# 而 CPU/IO 密集的后处理搬走才不冻结事件循环。
|
||||
return quality, schema
|
||||
|
||||
|
||||
# ── 清理 ────────────────────────────────────────────────────────────────
|
||||
@@ -225,7 +202,7 @@ def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
|
||||
"""从 PDF 提取图片和表格(失败不影响总结)。
|
||||
|
||||
两阶段流水线:
|
||||
1. PicoDet 检测 + 渲染截图(通用标签)
|
||||
1. DocLayout-YOLO 检测 + 渲染截图(通用标签)
|
||||
2. 用 summary 的 figures ID 在 PDF 中搜索定位 → 重命名
|
||||
"""
|
||||
try:
|
||||
@@ -242,21 +219,44 @@ def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
|
||||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||||
|
||||
|
||||
def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None:
|
||||
"""写入 ChromaDB 语义索引(失败不影响总结)。"""
|
||||
def _maybe_index_chroma(arxiv_id: str, schema: SummarySchema, paper_meta: dict) -> None:
|
||||
"""写入 ChromaDB 语义索引(失败不影响总结)。
|
||||
|
||||
paper_meta 是调用方在事件循环线程从 ORM 提取的纯值(title_en/tags/paper_date),
|
||||
规避此函数在线程池跑时跨线程访问 ORM 的 DetachedInstanceError 风险。
|
||||
"""
|
||||
try:
|
||||
from app.services.embedder import index_paper
|
||||
|
||||
texts_dict = {
|
||||
"arxiv_id": arxiv_id,
|
||||
"title_zh": schema.title_zh or "",
|
||||
"title_en": paper.title_en or "",
|
||||
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
|
||||
"title_en": paper_meta.get("title_en", ""),
|
||||
"tags": paper_meta.get("tags", ""),
|
||||
"one_line": schema.one_line or "",
|
||||
"motivation_problem": schema.motivation.problem or "",
|
||||
"method_key_idea": schema.method.key_idea or "",
|
||||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
|
||||
"paper_date": paper_meta.get("paper_date", ""),
|
||||
}
|
||||
index_paper(arxiv_id, texts_dict)
|
||||
except Exception:
|
||||
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
|
||||
|
||||
|
||||
def _run_post_processing(
|
||||
arxiv_id: str, schema: SummarySchema, paper_meta: dict
|
||||
) -> None:
|
||||
"""线程池里跑的 CPU/IO 密集后处理(由 _do_summarize_one 经 asyncio.to_thread 调用)。
|
||||
|
||||
顺序与原 _persist_summary 内部一致:图片提取 → ChromaDB 索引。两者各自
|
||||
try/except(失败不影响已成功的总结),此处再包一层做双保险。
|
||||
"""
|
||||
try:
|
||||
_maybe_extract_images(arxiv_id, schema)
|
||||
_maybe_index_chroma(arxiv_id, schema, paper_meta)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Post-processing failed for %s (summary already persisted)",
|
||||
arxiv_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@@ -42,6 +42,9 @@
|
||||
.status-success { background:var(--success-bg); color:#388e3c; }
|
||||
.status-running { background:var(--info-bg); color:#1976d2; }
|
||||
.status-failed { background:var(--danger-bg); color:var(--danger-bright); }
|
||||
.status-queued { background:#fff8e1; color:#8a6d3b; }
|
||||
.status-stale { background:var(--border); color:var(--ink-muted); }
|
||||
.task-reindex { background:#fff3e0; color:#e65100; }
|
||||
.time-cell { white-space:nowrap; color:var(--ink-light); }
|
||||
.error-cell { max-width:200px; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; color:var(--danger-bright); font-size:.8rem; }
|
||||
|
||||
|
||||
@@ -69,7 +69,8 @@
|
||||
<span class="info-label">活跃任务</span>
|
||||
<span class="info-value">
|
||||
{% for lock in stats.active_locks %}
|
||||
<span class="task-badge task-{{ lock.task }}">{{ lock.task }}</span>
|
||||
<span class="task-badge task-{{ lock.task }}" title="{{ lock.status }} · #{{ lock.id }}">{{ lock.task }}</span>
|
||||
<button class="admin-action-btn admin-action-btn-sm admin-action-btn-danger" title="强制释放锁 #{{ lock.id }}" onclick="releaseLock({{ lock.id }})">🔓</button>
|
||||
{% endfor %}
|
||||
</span>
|
||||
</div>
|
||||
@@ -118,6 +119,15 @@
|
||||
<div class="info-row"><span class="info-label">数据库</span><span class="info-value">{{ stats.db_size }}</span></div>
|
||||
<div class="info-row"><span class="info-label">论文文件</span><span class="info-value">{{ stats.papers_size }}</span></div>
|
||||
<div class="info-row"><span class="info-label">临时文件</span><span class="info-value">{{ stats.tmp_size }}</span></div>
|
||||
<div class="info-row">
|
||||
<span class="info-label">搜索索引</span>
|
||||
<span class="info-value">
|
||||
<button class="admin-action-btn admin-action-btn-sm" onclick="rebuildIndexes('fts')">🔤 重建全文</button>
|
||||
{% if stats.config_overview.chroma_enabled %}
|
||||
<button class="admin-action-btn admin-action-btn-sm" onclick="rebuildIndexes('chroma')">🧠 重建语义</button>
|
||||
{% endif %}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="summary-dist">
|
||||
<h3 class="section-subtitle">总结状态分布</h3>
|
||||
@@ -136,6 +146,19 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="admin-info-card">
|
||||
<h2 class="admin-info-title">⚙️ 运行配置</h2>
|
||||
<div class="admin-info-body">
|
||||
<div class="info-row"><span class="info-label">总结后端</span><span class="info-value">{{ stats.config_overview.summary_backend }} · {{ stats.config_overview.summary_pdf_mode }} 模式</span></div>
|
||||
<div class="info-row"><span class="info-label">并发/超时</span><span class="info-value">{{ stats.config_overview.summary_concurrency }} 并发 · {{ stats.config_overview.summary_timeout_seconds }}s · 重试 {{ stats.config_overview.summary_max_retries }}</span></div>
|
||||
<div class="info-row"><span class="info-label">调度</span><span class="info-value">{{ '启用' if stats.config_overview.scheduler_enabled else '未启用' }} · {{ stats.config_overview.schedule_time }} · {{ stats.config_overview.app_workers }} worker</span></div>
|
||||
<div class="info-row"><span class="info-label">语义搜索</span><span class="info-value">{{ '启用' if stats.config_overview.chroma_enabled else '未启用' }} · {{ stats.config_overview.embed_model }}</span></div>
|
||||
<div class="info-row"><span class="info-label">抓取</span><span class="info-value">TOP {{ stats.config_overview.top_n }} · 投票刷新 {{ stats.config_overview.upvote_refresh_days }} 天</span></div>
|
||||
<div class="info-row"><span class="info-label">布局模型</span><span class="info-value">{{ stats.config_overview.layout_model }}</span></div>
|
||||
<div class="info-row"><span class="info-label">数据库</span><span class="info-value">{{ stats.config_overview.database_url }}</span></div>
|
||||
<div class="info-row"><span class="info-label">嵌入密钥</span><span class="info-value">{{ '已配置' if stats.config_overview.api_key_configured else '未配置' }}</span></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="admin-section">
|
||||
@@ -193,5 +216,17 @@
|
||||
.then(data => { if (data) showToast(data.error ? "❌ " + data.error.substring(0,200) : `✅ 已刷新 ${data.updated || 0} 篇论文投票`); })
|
||||
.catch(err => showToast("❌ 请求失败"));
|
||||
}
|
||||
function releaseLock(lockId) {
|
||||
fetch("/admin/locks/"+lockId+"/release", { method: "POST", headers: { "Content-Type": "application/json" } })
|
||||
.then(r => { if (r.status===303||r.status===401) { window.location.href="/admin/login"; return; } return r.json(); })
|
||||
.then(data => { if (data) showToast(data.error ? "❌ " + data.error.substring(0,200) : "✅ 已释放锁,刷新中…", {callback:function(){location.reload();}}); })
|
||||
.catch(err => showToast("❌ 请求失败"));
|
||||
}
|
||||
function rebuildIndexes(target) {
|
||||
fetch("/admin/rebuild-indexes", { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({target: target}) })
|
||||
.then(r => { if (r.status===303||r.status===401) { window.location.href="/admin/login"; return; } return r.json(); })
|
||||
.then(data => { if (data) showToast(data.error ? "❌ " + data.error.substring(0,200) : "✅ 重建任务已创建,可在任务页查看"); })
|
||||
.catch(err => showToast("❌ 请求失败"));
|
||||
}
|
||||
</script>
|
||||
{% endblock %}
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
{% extends "base.html" %}
|
||||
{% block title %}任务监控 — HF Daily Papers{% endblock %}
|
||||
|
||||
{% set type_label = {"crawl_daily":"抓取","pipeline_daily":"流水线","summarize_batch":"批量总结","summarize_one":"单篇总结","refresh_upvotes":"刷新投票","delete_range":"删除","cleanup_tmp":"清理","reindex_fts":"重建全文","reindex_chroma":"重建语义","recrawl_one":"重抓","recrawl_batch":"批量重抓"} %}
|
||||
{% set type_badge = {"crawl_daily":"task-crawl","pipeline_daily":"task-crawl","recrawl_one":"task-crawl","recrawl_batch":"task-crawl","refresh_upvotes":"task-crawl","summarize_batch":"task-summarize","summarize_one":"task-summarize","cleanup_tmp":"task-cleanup","delete_range":"task-delete","reindex_fts":"task-reindex","reindex_chroma":"task-reindex"} %}
|
||||
{% set status_label = {"queued":"排队","running":"运行中","success":"成功","failed":"失败","stale":"已过期","cancelled":"已取消"} %}
|
||||
{% set status_badge = {"queued":"status-queued","running":"status-running","success":"status-success","failed":"status-failed","stale":"status-stale","cancelled":"status-stale"} %}
|
||||
|
||||
{% macro fmt_duration(s) -%}
|
||||
{%- if s is none %}-
|
||||
{%- elif s < 60 %}{{ "%.0f"|format(s) }}s
|
||||
{%- elif s < 3600 %}{{ (s // 60)|int }}m {{ (s % 60)|round|int }}s
|
||||
{%- else %}{{ (s // 3600)|int }}h {{ ((s % 3600) // 60)|int }}m
|
||||
{%- endif -%}
|
||||
{%- endmacro %}
|
||||
|
||||
{% block content %}
|
||||
<div class="admin-page">
|
||||
{% set active = "jobs" %}{% include "partials/admin_subnav.html" %}
|
||||
|
||||
<h1 class="page-heading">🧰 任务监控</h1>
|
||||
|
||||
{% set _total = (status_counts.values() | sum) if status_counts else 0 %}
|
||||
<div class="summary-stats-row">
|
||||
<span class="summary-stat">总计 <strong>{{ _total }}</strong></span>
|
||||
<span class="summary-stat summary-stat-pending">排队 <strong>{{ status_counts.get('queued', 0) }}</strong></span>
|
||||
<span class="summary-stat">运行中 <strong>{{ status_counts.get('running', 0) }}</strong></span>
|
||||
<span class="summary-stat summary-stat-done">成功 <strong>{{ status_counts.get('success', 0) }}</strong></span>
|
||||
<span class="summary-stat summary-stat-failed">失败 <strong>{{ status_counts.get('failed', 0) + status_counts.get('stale', 0) }}</strong></span>
|
||||
</div>
|
||||
|
||||
{% set statuses = [("all","全部"),("queued","排队"),("running","运行中"),("success","成功"),("failed","失败"),("stale","已过期")] %}
|
||||
<div class="summary-filters">
|
||||
<span class="summary-filter-label">状态:</span>
|
||||
{% for key, label in statuses %}
|
||||
<a class="filter-chip {{ 'active' if current_status == key else '' }}"
|
||||
href="?status={{ key }}{% if current_type != 'all' %}&type={{ current_type }}{% endif %}">{{ label }}
|
||||
({% if key == 'all' %}{{ _total }}{% else %}{{ status_counts.get(key, 0) }}{% endif %})</a>
|
||||
{% endfor %}
|
||||
</div>
|
||||
|
||||
{% set types = [("crawl_daily","抓取"),("pipeline_daily","流水线"),("summarize_batch","批量总结"),("summarize_one","单篇总结"),("refresh_upvotes","刷新投票"),("recrawl_one","重抓"),("recrawl_batch","批量重抓"),("delete_range","删除"),("cleanup_tmp","清理"),("reindex_fts","重建全文"),("reindex_chroma","重建语义")] %}
|
||||
<form method="get" class="summary-filters">
|
||||
<span class="summary-filter-label">类型:</span>
|
||||
<input type="hidden" name="status" value="{{ current_status }}" />
|
||||
<select name="type" class="paper-filter-input" onchange="this.form.submit()">
|
||||
<option value="all" {{ 'selected' if current_type == 'all' }}>全部类型</option>
|
||||
{% for key, label in types %}
|
||||
<option value="{{ key }}" {{ 'selected' if current_type == key }}>{{ label }}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
</form>
|
||||
|
||||
{% if jobs %}
|
||||
<div class="admin-table-wrap">
|
||||
<table class="admin-table admin-table-compact">
|
||||
<thead>
|
||||
<tr><th>ID</th><th>类型</th><th>状态</th><th>触发者</th><th>创建时间</th><th>耗时</th><th>操作</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for job in jobs %}
|
||||
<tr>
|
||||
<td>{{ job.id }}</td>
|
||||
<td><span class="task-badge {{ type_badge.get(job.type, 'task-crawl') }}">{{ type_label.get(job.type, job.type) }}</span></td>
|
||||
<td><span class="status-badge {{ status_badge.get(job.status, 'status-running') }}">{{ status_label.get(job.status, job.status) }}</span></td>
|
||||
<td>{{ job.owner or '-' }}</td>
|
||||
<td class="time-cell">{{ job.created_at.strftime('%Y-%m-%d %H:%M:%S') if job.created_at else '-' }}</td>
|
||||
<td class="time-cell">{{ fmt_duration(job.duration_seconds) }}</td>
|
||||
<td class="action-cell"><button class="action-btn-sm" title="详情" onclick="showJobDetail({{ job.id }})">📋</button></td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
{% set total_pages = ((total + per_page - 1) // per_page) if total else 1 %}
|
||||
{% if total_pages > 1 %}
|
||||
<div class="pagination">
|
||||
{% if page > 1 %}<a class="page-btn" href="{{ pagination_url(page - 1) }}">← 上一页</a>{% endif %}
|
||||
<span class="page-info">第 {{ page }} / {{ total_pages }} 页(共 {{ total }} 个)</span>
|
||||
{% if page < total_pages %}<a class="page-btn" href="{{ pagination_url(page + 1) }}">下一页 →</a>{% endif %}
|
||||
</div>
|
||||
{% endif %}
|
||||
{% else %}
|
||||
<div class="empty-state">
|
||||
<p>暂无任务记录</p>
|
||||
<p class="hint">触发抓取、总结等操作后,任务会出现在这里。可在「详情」中查看阶段事件。</p>
|
||||
</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
|
||||
<!-- 任务详情 modal -->
|
||||
<div class="confirm-overlay" id="job-detail-overlay" style="display:none;">
|
||||
<div class="confirm-dialog" style="max-width:660px;max-height:85vh;overflow:auto;">
|
||||
<h3 class="admin-info-title" id="job-detail-title">任务详情</h3>
|
||||
<div id="job-detail-body"><p class="hint">加载中...</p></div>
|
||||
<div class="confirm-actions">
|
||||
<button class="confirm-btn confirm-btn-cancel" onclick="closeJobDetail()">关闭</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{% endblock %}
|
||||
|
||||
{% block scripts %}
|
||||
<script>
|
||||
const TYPE_LABEL = {"crawl_daily":"抓取","pipeline_daily":"流水线","summarize_batch":"批量总结","summarize_one":"单篇总结","refresh_upvotes":"刷新投票","delete_range":"删除","cleanup_tmp":"清理","reindex_fts":"重建全文","reindex_chroma":"重建语义","recrawl_one":"重抓","recrawl_batch":"批量重抓"};
|
||||
const STATUS_LABEL = {"queued":"排队","running":"运行中","success":"成功","failed":"失败","stale":"已过期","cancelled":"已取消"};
|
||||
|
||||
function fmtTime(s){ return s ? s.replace('T',' ').slice(0,19) : '-'; }
|
||||
function esc(s){ return String(s==null?'':s).replace(/[&<>"]/g, c=>({'&':'&','<':'<','>':'>','"':'"'}[c])); }
|
||||
function eventBadge(s){ return {'success':'status-success','failed':'status-failed','started':'status-running','info':'status-queued'}[s] || 'status-queued'; }
|
||||
function jobStatusBadge(s){ return {'success':'success','failed':'failed','running':'running','stale':'stale','cancelled':'stale','queued':'queued'}[s] || 'running'; }
|
||||
function infoRow(label, val){ return '<div class="info-row"><span class="info-label">'+label+'</span><span class="info-value">'+val+'</span></div>'; }
|
||||
|
||||
function showJobDetail(id){
|
||||
document.getElementById('job-detail-overlay').style.display='flex';
|
||||
document.getElementById('job-detail-body').innerHTML='<p class="hint">加载中...</p>';
|
||||
fetch('/admin/jobs/'+id)
|
||||
.then(r=>{if(r.status===303||r.status===401){window.location.href='/admin/login';return;}return r.json();})
|
||||
.then(d=>{ if(d) renderJobDetail(d); })
|
||||
.catch(()=>{document.getElementById('job-detail-body').innerHTML='<p class="hint">加载失败</p>';});
|
||||
}
|
||||
function renderJobDetail(d){
|
||||
let h='<div class="admin-info-body">';
|
||||
h+=infoRow('ID', d.id);
|
||||
h+=infoRow('类型', esc(TYPE_LABEL[d.type]||d.type));
|
||||
h+=infoRow('状态', '<span class="status-badge status-'+jobStatusBadge(d.status)+'">'+esc(STATUS_LABEL[d.status]||d.status)+'</span>');
|
||||
h+=infoRow('触发者', esc(d.owner||'-'));
|
||||
h+=infoRow('创建', fmtTime(d.created_at));
|
||||
h+=infoRow('开始', fmtTime(d.started_at));
|
||||
h+=infoRow('完成', fmtTime(d.completed_at));
|
||||
if(d.payload && Object.keys(d.payload).length) h+=infoRow('参数', '<code style="word-break:break-all;">'+esc(JSON.stringify(d.payload))+'</code>');
|
||||
if(d.result) h+=infoRow('结果', '<code style="word-break:break-all;">'+esc(JSON.stringify(d.result))+'</code>');
|
||||
if(d.error) h+=infoRow('错误', '<span class="error-cell" style="max-width:480px;">'+esc(d.error)+'</span>');
|
||||
h+='</div>';
|
||||
if(d.events && d.events.length){
|
||||
h+='<h3 class="section-subtitle" style="margin-top:18px;">事件时间线</h3>';
|
||||
h+='<div class="admin-table-wrap" style="max-height:220px;overflow:auto;"><table class="admin-table admin-table-compact"><thead><tr><th>阶段</th><th>状态</th><th>时间</th><th>消息</th></tr></thead><tbody>';
|
||||
d.events.forEach(e=>{
|
||||
h+='<tr><td>'+esc(e.stage)+'</td><td><span class="status-badge '+eventBadge(e.status)+'">'+esc(e.status)+'</span></td><td class="time-cell">'+fmtTime(e.created_at)+'</td><td class="error-cell" style="max-width:240px;">'+esc(e.message||'')+'</td></tr>';
|
||||
});
|
||||
h+='</tbody></table></div>';
|
||||
}
|
||||
document.getElementById('job-detail-body').innerHTML=h;
|
||||
}
|
||||
function closeJobDetail(){ document.getElementById('job-detail-overlay').style.display='none'; }
|
||||
document.addEventListener('keydown',e=>{if(e.key==='Escape')closeJobDetail();});
|
||||
</script>
|
||||
{% endblock %}
|
||||
@@ -109,6 +109,22 @@
|
||||
<span class="summary-stat summary-stat-failed">失败 <strong>{{ summary_failed or 0 }}</strong></span>
|
||||
<span class="summary-stat summary-stat-done">已完成 <strong>{{ summary_done or 0 }}</strong></span>
|
||||
</div>
|
||||
{% if failure_breakdown %}
|
||||
<div class="summary-dist" style="margin-top:12px;">
|
||||
<h3 class="section-subtitle">失败原因分布({{ summary_failed or 0 }} 篇)</h3>
|
||||
<div class="summary-dist-bars">
|
||||
{% set fb_total = (failure_breakdown | map(attribute='count') | sum) or 1 %}
|
||||
{% set error_labels = {"pdf_download_failed":"PDF下载失败","timeout":"超时","process_error":"进程错误","json_not_found":"JSON缺失","json_invalid":"JSON无效","field_missing":"字段缺失","schema_error":"结构错误","unknown":"未分类"} %}
|
||||
{% for item in failure_breakdown %}
|
||||
<div class="dist-row">
|
||||
<span class="dist-label">{{ error_labels.get(item.error_type, item.error_type) }}</span>
|
||||
<div class="dist-bar-wrap"><div class="dist-bar dist-bar-failed" style="width:{{ (item.count / fb_total * 100)|round(1) }}%"></div></div>
|
||||
<span class="dist-count">{{ item.count }}</span>
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
<div id="summary-list"
|
||||
hx-get="/admin/summary-status"
|
||||
hx-trigger="load"
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
<option value="title_asc" {% if current_sort == 'title_asc' %}selected{% endif %}>标题 A→Z</option>
|
||||
</select>
|
||||
<button type="submit" class="paper-search-btn">搜索</button>
|
||||
<a class="admin-action-btn admin-action-btn-sm" href="/admin/papers/export.csv{% if request.query_params %}?{{ request.query_params }}{% endif %}">⬇ 导出 CSV</a>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
@@ -37,6 +38,7 @@
|
||||
<span class="paper-batch-label">批量操作</span>
|
||||
<span class="paper-selected-count" id="selected-count">已选 0 篇</span>
|
||||
<button class="admin-action-btn admin-action-btn-sm" onclick="batchAction('summarize')" id="batch-summarize-btn" disabled>📝 批量总结</button>
|
||||
<button class="admin-action-btn admin-action-btn-sm" onclick="batchAction('recrawl')" id="batch-recrawl-btn" disabled>🔄 批量重抓</button>
|
||||
<button class="admin-action-btn admin-action-btn-sm admin-action-btn-danger" onclick="batchAction('delete')" id="batch-delete-btn" disabled>🗑 批量删除</button>
|
||||
</div>
|
||||
|
||||
@@ -72,6 +74,7 @@
|
||||
</td>
|
||||
<td class="action-cell">
|
||||
<button class="action-btn-sm" title="重新总结" onclick="retryOne('{{ paper.arxiv_id }}', this)">↻</button>
|
||||
<button class="action-btn-sm" title="重新抓取元数据" onclick="recrawlOne('{{ paper.arxiv_id }}', this)">🔄</button>
|
||||
<button class="action-btn-sm action-btn-danger" title="删除" onclick="confirmDeleteSingle('{{ paper.arxiv_id }}', '{{ (paper.title_zh or paper.title_en)[:40] | replace("'", "\\'") }}')">🗑</button>
|
||||
</td>
|
||||
</tr>
|
||||
@@ -124,6 +127,7 @@
|
||||
const n=document.querySelectorAll('.paper-check:checked').length;
|
||||
document.getElementById('selected-count').textContent='已选 '+n+' 篇';
|
||||
document.getElementById('batch-summarize-btn').disabled=n===0;
|
||||
document.getElementById('batch-recrawl-btn').disabled=n===0;
|
||||
document.getElementById('batch-delete-btn').disabled=n===0;
|
||||
}
|
||||
function retryOne(arxivId,btn) {
|
||||
@@ -134,6 +138,14 @@
|
||||
.catch(()=>showToast('❌ 请求失败'))
|
||||
.finally(()=>{btn.disabled=false;btn.textContent='↻';});
|
||||
}
|
||||
function recrawlOne(arxivId,btn) {
|
||||
btn.disabled=true;btn.textContent='...';
|
||||
fetch('/admin/paper-recrawl/'+arxivId,{method:'POST',headers:{'Content-Type':'application/json'}})
|
||||
.then(r=>r.json())
|
||||
.then(data=>showToast(data.error?'❌ '+data.error.substring(0,100):'✅ 重抓任务已创建,可在任务页查看'))
|
||||
.catch(()=>showToast('❌ 请求失败'))
|
||||
.finally(()=>{btn.disabled=false;btn.textContent='🔄';});
|
||||
}
|
||||
function confirmDeleteSingle(arxivId,title) {
|
||||
document.getElementById('confirm-msg').textContent='确定删除论文「'+title+'」?此操作不可恢复。';
|
||||
_confirmAction='delete-single'; _confirmTarget=arxivId;
|
||||
@@ -151,6 +163,11 @@
|
||||
.then(r=>r.json())
|
||||
.then(data=>showToast(data.error?'❌ '+data.error.substring(0,100):'✅ 已提交批量总结'))
|
||||
.catch(()=>showToast('❌ 请求失败'));
|
||||
} else if(action==='recrawl'){
|
||||
fetch('/admin/papers-batch-action',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({action:'recrawl',arxiv_ids:ids})})
|
||||
.then(r=>r.json())
|
||||
.then(data=>showToast(data.error?'❌ '+data.error.substring(0,100):'✅ 已提交批量重抓,可在任务页查看'))
|
||||
.catch(()=>showToast('❌ 请求失败'));
|
||||
}
|
||||
}
|
||||
function doConfirmAction() {
|
||||
|
||||
@@ -29,8 +29,6 @@
|
||||
<a href="/reading-list">阅读列表</a>
|
||||
{% if is_admin %}
|
||||
<a href="/admin/">管理</a>
|
||||
<a href="/admin/logout" onclick="event.preventDefault();this.closest('form').submit()">退出</a>
|
||||
<form action="/admin/logout" method="post" style="display:none"></form>
|
||||
{% else %}
|
||||
<a href="/admin/login">管理</a>
|
||||
{% endif %}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
{# Admin subnav — 管理后台三个页面共享。active 参数: "dashboard" / "papers" / "logs" #}
|
||||
{# Admin subnav — 管理后台共享。active 参数: "dashboard" / "papers" / "jobs" / "logs" #}
|
||||
<nav class="admin-subnav">
|
||||
<a href="/admin/" class="admin-subnav-link {{ 'active' if active == 'dashboard' else '' }}">仪表盘</a>
|
||||
<a href="/admin/papers" class="admin-subnav-link {{ 'active' if active == 'papers' else '' }}">论文管理</a>
|
||||
<a href="/admin/jobs" class="admin-subnav-link {{ 'active' if active == 'jobs' else '' }}">任务</a>
|
||||
<a href="/admin/logs" class="admin-subnav-link {{ 'active' if active == 'logs' else '' }}">日志</a>
|
||||
<span class="admin-subnav-spacer"></span>
|
||||
<form action="/admin/logout" method="post" class="admin-subnav-form">
|
||||
|
||||
+24
-19
@@ -4,37 +4,42 @@ data = {
|
||||
"arxiv_id": "2602.21760",
|
||||
"title_zh": "基于条件引导调度的混合数据-流水线并行加速扩散模型",
|
||||
"one_line": "提出混合并行框架,通过条件划分与自适应流水线切换加速扩散推理,实现2.31倍提速。",
|
||||
"tags": ["Diffusion Models", "Distributed Inference", "Parallel Computing", "Image Generation"],
|
||||
"tags": [
|
||||
"Diffusion Models",
|
||||
"Distributed Inference",
|
||||
"Parallel Computing",
|
||||
"Image Generation",
|
||||
],
|
||||
"difficulty": "进阶",
|
||||
"prerequisites": {
|
||||
"concepts": [
|
||||
{
|
||||
"term": "Diffusion Models",
|
||||
"explanation": "扩散模型是一类基于去噪过程的生成模型。在正向过程中,它逐渐向数据添加高斯噪声直到变成纯噪声;在反向过程中,模型学习逐步去噪以恢复原始数据。这种迭代特性虽然能生成高质量的样本,但也导致了高昂的推理计算成本。",
|
||||
"why_matters": "理解扩散模型的迭代去噪机制是理解本文如何通过并行化减少推理延迟的基础。"
|
||||
"why_matters": "理解扩散模型的迭代去噪机制是理解本文如何通过并行化减少推理延迟的基础。",
|
||||
},
|
||||
{
|
||||
"term": "Classifier-Free Guidance (CFG)",
|
||||
"explanation": "无分类器引导是一种在推理时提升生成样本与文本条件一致性的技术。模型同时预测有条件噪声(给定文本提示)和无条件噪声(不给定提示),最终通过加权组合两者来获得最终预测。公式为 $\epsilon_{cfg} = \epsilon_\theta(x_t, c, t) + w (\epsilon_\theta(x_t, c, t) - \epsilon_\theta(x_t, t))$,其中 $w$ 是引导强度。",
|
||||
"why_matters": "本文的核心创新点在于利用CFG中存在的有条件和无条件双路径作为数据划分的基础。"
|
||||
"why_matters": "本文的核心创新点在于利用CFG中存在的有条件和无条件双路径作为数据划分的基础。",
|
||||
},
|
||||
{
|
||||
"term": "Distributed Inference",
|
||||
"explanation": "分布式推理利用多个GPU并行处理计算任务以减少延迟。主要分为数据并行(如将图像切片处理)和流水线并行(如将模型层切分)。然而,现有的分布式方法在扩散模型中往往面临通信开销大或生成图像出现拼接伪影的问题。",
|
||||
"why_matters": "本文提出的混合并行框架正是为了解决现有分布式推理方法中的这些痛点。"
|
||||
}
|
||||
"why_matters": "本文提出的混合并行框架正是为了解决现有分布式推理方法中的这些痛点。",
|
||||
},
|
||||
]
|
||||
},
|
||||
"motivation": {
|
||||
"problem": "现有的扩散模型加速方法,无论是单卡优化(如减少采样步数、模型剪枝)还是多卡分布式并行(如DistriFusion和AsyncDiff),都存在明显的局限性。单卡优化受限于硬件算力上限,而现有多卡并行方法通常只能实现次线性的加速比。例如,DistriFusion将图像切片并行处理,容易在拼接处产生明显的伪影;AsyncDiff采用异步流水线,虽然加速了但会引入估计误差,且通信开销巨大(在SDXL上高达9.83GB)。",
|
||||
"goal": "本文旨在提出一种新颖的混合并行框架,在仅使用两张GPU的情况下,不仅能实现超过线性的加速比(即 $>2\times$),还要严格保持甚至提升生成图像的质量,同时将通信开销降到最低。",
|
||||
"gap": "与以往将图像空间切片(Patch-based)的思路不同,本文独辟蹊径,利用无分类器引导(CFG)中天然存在的“有条件”和“无条件”两条路径作为新的数据划分维度(Condition-based Partitioning)。同时,作者发现这两条路径的预测误差差异在整个去噪过程中呈现出先大后小再变大的U型曲线,因此引入了自适应的并行切换策略,只在误差差异最小时才进行并行流水线处理。"
|
||||
"gap": "与以往将图像空间切片(Patch-based)的思路不同,本文独辟蹊径,利用无分类器引导(CFG)中天然存在的“有条件”和“无条件”两条路径作为新的数据划分维度(Condition-based Partitioning)。同时,作者发现这两条路径的预测误差差异在整个去噪过程中呈现出先大后小再变大的U型曲线,因此引入了自适应的并行切换策略,只在误差差异最小时才进行并行流水线处理。",
|
||||
},
|
||||
"method": {
|
||||
"overview": "该框架的核心思想是将扩散推理过程划分为三个阶段:预热阶段(Warm-Up)、并行阶段(Parallelism)和完全连接阶段(Fully-Connecting)。在预热和完全连接阶段,使用“基于条件的划分”策略,即一张GPU处理有条件的预测,另一张处理无条件的预测。而在中间的并行阶段,由于两个预测结果非常接近,框架切换到“自适应流水线并行”,利用两张GPU交替执行推理步骤,从而大幅压缩时间。",
|
||||
"key_idea": "核心创新在于不再将图片在空间上切片,而是沿“条件”维度切分数据。这保证了每个GPU都能看到整张图片的全局信息,从而避免了拼接伪影。此外,引入了“去噪差异度”(Denoising Discrepancy,即 rel-MAE)这一指标来动态评估两条路径的相似性,并以此自动决定何时开启和关闭流水线并行,实现了最优的加速-质量平衡。",
|
||||
"steps": "1. 数据划分:输入潜变量同时送入GPU 1(有条件预测 $\epsilon_\theta(x_t, c, t)$)和GPU 2(无条件预测 $\epsilon_\theta(x_t, t)$)。2. 阶段判断:根据实时计算的“去噪差异度” $G_t$ 与阈值 $g_{slope}$ 的关系,确定切换点 $\tau_1$ 和 $\tau_2$。3. 混合执行:在 $[T, \tau_1]$ 阶段同步运行;在 $[\tau_1, \tau_2]$ 阶段启用流水线并行(如GPU 1处理 $t-1$ 步时GPU 2处理 $t$ 步);在 $[\tau_2, 0]$ 阶段重新恢复同步以精细调整细节。",
|
||||
"novelty": "该方法的另一大新颖之处在于其“安全性”设计:通过设置 $\tau_{cap}$ 作为安全上限,确保即使自动算法失效,也不会在错误的时间点引入并行,从而保证了算法的鲁棒性。此外,该框架对U-Net(如SDXL)和DiT(如SD3)架构均具有良好的泛化性。"
|
||||
"novelty": "该方法的另一大新颖之处在于其“安全性”设计:通过设置 $\tau_{cap}$ 作为安全上限,确保即使自动算法失效,也不会在错误的时间点引入并行,从而保证了算法的鲁棒性。此外,该框架对U-Net(如SDXL)和DiT(如SD3)架构均具有良好的泛化性。",
|
||||
},
|
||||
"results": {
|
||||
"main_findings": "实验在SDXL和SD3模型上进行,使用MS-COCO 2014验证集。结果显示,在SDXL上,该方法实现了2.31倍加速,延迟从16.49秒降至7.12秒,且FID指标与原始单卡模型持平(甚至略优)。相比此前最强的DistriFusion(1.22倍)和AsyncDiff(1.31倍),提速效果显著。在通信开销方面,本方法仅为0.516GB,比AsyncDiff的9.83GB降低了19.6倍。在SD3模型上,同样实现了2.07倍的加速。",
|
||||
@@ -44,29 +49,29 @@ data = {
|
||||
"metric": "Speed-Up",
|
||||
"this_work": "2.31x",
|
||||
"baseline": "1.31x (AsyncDiff)",
|
||||
"improvement": "1.0x (Extra speed)"
|
||||
"improvement": "1.0x (Extra speed)",
|
||||
},
|
||||
{
|
||||
"task": "Text-to-Image (SDXL)",
|
||||
"metric": "Comm. (GB)",
|
||||
"this_work": "0.516",
|
||||
"baseline": "9.830 (AsyncDiff)",
|
||||
"improvement": "Reduced by 19.6x"
|
||||
"improvement": "Reduced by 19.6x",
|
||||
},
|
||||
{
|
||||
"task": "Text-to-Image (SD3)",
|
||||
"metric": "Speed-Up",
|
||||
"this_work": "2.07x",
|
||||
"baseline": "1.97x (AsyncDiff)",
|
||||
"improvement": "0.1x (Extra speed)"
|
||||
}
|
||||
"improvement": "0.1x (Extra speed)",
|
||||
},
|
||||
],
|
||||
"limitations": "尽管该方法在通用性上表现出色,但在处理极高分辨率(如4K以上)时,加速比会随分辨率提升而有所下降(从2.72x降至1.62x)。此外,目前的实现仅针对两张GPU进行了深度优化,虽然文中提出了多卡扩展策略,但在单个样本推理场景下,如何高效地扩展到四卡或更多卡仍是一个挑战。最后,参数 $k$ 的选取目前仍需人工根据经验设定。"
|
||||
"limitations": "尽管该方法在通用性上表现出色,但在处理极高分辨率(如4K以上)时,加速比会随分辨率提升而有所下降(从2.72x降至1.62x)。此外,目前的实现仅针对两张GPU进行了深度优化,虽然文中提出了多卡扩展策略,但在单个样本推理场景下,如何高效地扩展到四卡或更多卡仍是一个挑战。最后,参数 $k$ 的选取目前仍需人工根据经验设定。",
|
||||
},
|
||||
"improvements": {
|
||||
"weaknesses": "主要弱点在于自适应切换参数(如 $k$ 和 $\tau_{cap}$)的确定目前仍偏向经验性,缺乏完全自动化的端到端学习机制。此外,虽然避免了图像切片,但条件分支的“信息量”并不总是完全对等的,特别是在极早期的噪声阶段,可能导致其中一张GPU负载不均衡。改进方向可以是结合动态负载均衡算法,根据当前步骤的预测难度动态分配计算资源。",
|
||||
"future_work": "未来的研究方向包括:1. 将该混合并行策略扩展到视频生成模型(Video Diffusion)中,利用时间轴上的相关性进行更细粒度的流水线调度。2. 结合模型量化(Quantization)和蒸馏技术,在多卡并行的基础上进一步压缩单步推理时间。3. 探索在“去噪差异度”指标指导下自动学习最优的 $k$ 值和切换点。",
|
||||
"reproducibility": "代码已在GitHub开源(https://github.com/kaist-dmlab/Hybridiff)。实验环境基于PyTorch,使用的GPU为NVIDIA GeForce 3090,硬件门槛相对较低。文中详细列出了关键超参数(如SDXL上的 $L=12, k=5, \tau_{cap}=15$),使得复现结果的难度较低。"
|
||||
"reproducibility": "代码已在GitHub开源(https://github.com/kaist-dmlab/Hybridiff)。实验环境基于PyTorch,使用的GPU为NVIDIA GeForce 3090,硬件门槛相对较低。文中详细列出了关键超参数(如SDXL上的 $L=12, k=5, \tau_{cap}=15$),使得复现结果的难度较低。",
|
||||
},
|
||||
"figures": [
|
||||
{
|
||||
@@ -74,30 +79,30 @@ data = {
|
||||
"caption": "Summary of the proposed hybrid data-pipeline parallelism",
|
||||
"description": "五维雷达图展示了该方法在速度、图像质量、通用性、高分辨率能力和通信开销五个方面均优于现有分布式框架。",
|
||||
"reason": "直观概括了本文的核心优势,即全方位的性能提升。",
|
||||
"section": "results"
|
||||
"section": "results",
|
||||
},
|
||||
{
|
||||
"id": "Figure 2",
|
||||
"caption": "Comparison of parallel strategies",
|
||||
"description": "对比了三种并行策略:(a)基于切片的数据并行容易产生伪影,(b)流水线并行通信开销大,(c)本文提出的混合并行既保留全局一致性又实现了高效并行。",
|
||||
"reason": "通过对比展示了本文方法设计的合理性和必要性。",
|
||||
"section": "method"
|
||||
"section": "method",
|
||||
},
|
||||
{
|
||||
"id": "Figure 3",
|
||||
"caption": "Overview of the hybrid parallel framework",
|
||||
"description": "详细展示了三个阶段(Warm-Up, Parallelism, Fully-Connecting)的数据流和通信模式,清晰地说明了自适应切换的动态过程。",
|
||||
"reason": "这是理解整个算法执行流程的关键示意图。",
|
||||
"section": "method"
|
||||
"section": "method",
|
||||
},
|
||||
{
|
||||
"id": "Table 1",
|
||||
"caption": "Quantitative comparison on SDXL and SD3",
|
||||
"description": "表格列出了该方法与基线方法在延迟、加速比、通信开销及生成质量指标(FID, LPIPS, PSNR)上的详细对比数据。",
|
||||
"reason": "提供了最核心的定量证据,证明了该方法的有效性。",
|
||||
"section": "results"
|
||||
}
|
||||
]
|
||||
"section": "results",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
with open("data/papers/2602.21760/summary.json", "w", encoding="utf-8") as f:
|
||||
|
||||
@@ -27,6 +27,18 @@ dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.24",
|
||||
]
|
||||
# 导出 DocLayout-YOLO ONNX 用(一次性脚本 scripts/export_doclayout_yolo_onnx.py,独立 venv 运行)
|
||||
# GPU 推理:onnxruntime 与 onnxruntime-gpu/-directml 同环境冲突,不在此声明,
|
||||
# 需手动二选一(见 .env.example 布局检测段说明)
|
||||
export = [
|
||||
"torch>=2.0",
|
||||
"torchvision>=0.15",
|
||||
"doclayout-yolo",
|
||||
"onnx>=1.14",
|
||||
"onnxscript", # torch 2.12+ 的 onnx exporter 需要
|
||||
"onnxsim",
|
||||
"huggingface-hub>=0.20",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
"""导出 DocLayout-YOLO (DocStructBench, imgsz=1024) 为 ONNX 格式.
|
||||
|
||||
一次性脚本,在独立 venv 中运行(不进运行时依赖):
|
||||
python -m venv .venv-export && source .venv-export/bin/activate
|
||||
pip install torch torchvision onnx onnxscript onnxsim onnxruntime huggingface-hub doclayout-yolo
|
||||
|
||||
两种权重来源:
|
||||
# 1) 用本地已下载的 .pt(推荐,省下载)
|
||||
.venv/bin/python scripts/export_doclayout_yolo_onnx.py \
|
||||
--weights /path/to/doclayout_yolo_docstructbench_imgsz1024.pt
|
||||
|
||||
# 2) 从 HuggingFace 下载(不传 --weights)
|
||||
HF_ENDPOINT=https://hf-mirror.com .venv/bin/python scripts/export_doclayout_yolo_onnx.py
|
||||
|
||||
输出:
|
||||
data/models/doclayout_yolo_docstructbench_imgsz1024.onnx
|
||||
|
||||
注意:model.export(simplify=True) 会清空 ONNX metadata,本脚本在导出后
|
||||
用 onnx 包把 names 重新写回 metadata,供运行时 _parse_names_from_meta 读取。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# hf-mirror(国内加速,仅 --weights 未传、走 HF 下载时生效)
|
||||
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
MODEL_DIR = PROJECT_ROOT / "data" / "models"
|
||||
DEFAULT_OUTPUT = MODEL_DIR / "doclayout_yolo_docstructbench_imgsz1024.onnx"
|
||||
REPO_ID = "juliozhao/DocLayout-YOLO-DocStructBench"
|
||||
PT_FILENAME = "doclayout_yolo_docstructbench.pt"
|
||||
IMGSZ = 1024
|
||||
|
||||
|
||||
def resolve_weights(arg: str | None) -> Path:
|
||||
"""返回 .pt 路径:传 --weights 用本地,否则从 HuggingFace 下载。"""
|
||||
if arg:
|
||||
p = Path(arg)
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"--weights not found: {p}")
|
||||
print(f"[1/5] Using local weights: {p}")
|
||||
return p
|
||||
|
||||
print(f"[1/5] Downloading .pt from HuggingFace ({REPO_ID}) ...")
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
pt_path = Path(hf_hub_download(repo_id=REPO_ID, filename=PT_FILENAME))
|
||||
print(f" ✓ {pt_path}")
|
||||
return pt_path
|
||||
|
||||
|
||||
def export_onnx(pt_path: Path, output: Path) -> None:
|
||||
print("\n[2/5] Loading model with doclayout_yolo ...")
|
||||
from doclayout_yolo import YOLOv10
|
||||
|
||||
model = YOLOv10(str(pt_path))
|
||||
names = model.names # dict[int, str],与 model.model.names 等价
|
||||
print(f" ✓ Loaded. names = {names}")
|
||||
|
||||
print(f"\n[3/5] Exporting ONNX (imgsz={IMGSZ}, opset=12, simplify=True) ...")
|
||||
try:
|
||||
exported = model.export(
|
||||
format="onnx",
|
||||
imgsz=IMGSZ,
|
||||
opset=12,
|
||||
simplify=True, # 需要 onnxsim;失败则下面回退
|
||||
dynamic=False, # 固定 batch=1 + 固定 1024,部署最稳
|
||||
half=False, # FP32,保证 CPU 推理精度
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" ⚠ export with simplify failed ({e}); retrying without simplify")
|
||||
exported = model.export(
|
||||
format="onnx", imgsz=IMGSZ, opset=12, dynamic=False, half=False
|
||||
)
|
||||
exported_path = Path(exported)
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(exported_path), str(output))
|
||||
print(f" ✓ Exported → {output} ({output.stat().st_size / 1024 / 1024:.1f} MB)")
|
||||
|
||||
print("\n[4/5] Re-writing names metadata (simplify may have dropped it) ...")
|
||||
write_names_metadata(output, names)
|
||||
|
||||
|
||||
def write_names_metadata(onnx_path: Path, names: dict) -> None:
|
||||
"""把 names dict 写入 ONNX model.metadata_props(simplify 后通常丢失)。"""
|
||||
import onnx
|
||||
|
||||
m = onnx.load(str(onnx_path))
|
||||
keep = [p for p in m.metadata_props if p.key != "names"]
|
||||
del m.metadata_props[:]
|
||||
m.metadata_props.extend(keep)
|
||||
names_json = json.dumps({str(k): v for k, v in names.items()}, ensure_ascii=False)
|
||||
m.metadata_props.append(onnx.StringStringEntryProto(key="names", value=names_json))
|
||||
onnx.save(m, str(onnx_path))
|
||||
print(f" ✓ names metadata written: {names_json}")
|
||||
|
||||
|
||||
def inspect_onnx(onnx_path: Path) -> None:
|
||||
"""用 onnxruntime 加载模型,打印输入输出 + names metadata + 试推理。"""
|
||||
print("\n[5/5] Verifying with onnxruntime ...")
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
|
||||
print(" Inputs:")
|
||||
for inp in session.get_inputs():
|
||||
print(f" {inp.name}: shape={inp.shape}, dtype={inp.type}")
|
||||
print(" Outputs:")
|
||||
for out in session.get_outputs():
|
||||
print(f" {out.name}: shape={out.shape}, dtype={out.type}")
|
||||
|
||||
meta = session.get_modelmeta()
|
||||
print(f" metadata keys: {list(meta.custom_metadata_map.keys())}")
|
||||
print(f" names: {meta.custom_metadata_map.get('names')}")
|
||||
|
||||
# dummy 推理
|
||||
input_info = session.get_inputs()[0]
|
||||
h = input_info.shape[2] if isinstance(input_info.shape[2], int) else IMGSZ
|
||||
w = input_info.shape[3] if isinstance(input_info.shape[3], int) else IMGSZ
|
||||
dummy = np.random.rand(1, 3, h, w).astype(np.float32)
|
||||
outputs = session.run(None, {input_info.name: dummy})
|
||||
print(f" Inference test: output[0] shape = {outputs[0].shape}")
|
||||
|
||||
out_shape = outputs[0].shape
|
||||
if len(out_shape) == 3 and out_shape[2] == 6:
|
||||
print(" ✓ output is [1, N, 6] (YOLOv10 end-to-end, NMS applied)")
|
||||
else:
|
||||
print(
|
||||
f" ⚠️ output shape {out_shape} ≠ [1, N, 6]; "
|
||||
"layout_detector._postprocess_output will warn and skip pages — "
|
||||
"adjust export (e.g. end2end/nms) or postprocess.",
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
ap.add_argument("--weights", help="本地 .pt 路径(不传则从 HuggingFace 下载)")
|
||||
ap.add_argument(
|
||||
"--output",
|
||||
default=str(DEFAULT_OUTPUT),
|
||||
help=f"输出 ONNX 路径(默认 {DEFAULT_OUTPUT})",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
output = Path(args.output)
|
||||
pt_path = resolve_weights(args.weights)
|
||||
export_onnx(pt_path, output)
|
||||
inspect_onnx(output)
|
||||
print(f"\n✓ Done! ONNX model saved to {output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,172 +0,0 @@
|
||||
"""导出 PicoDet-S_layout_3cls 为 ONNX 格式.
|
||||
|
||||
一次性脚本,在独立 venv 中运行:
|
||||
python -m venv .venv-export && source .venv-export/bin/activate
|
||||
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple paddlepaddle paddleocr paddle2onnx onnxruntime opencv-python-headless
|
||||
HF_ENDPOINT=https://hf-mirror.com python scripts/export_picodet_onnx.py
|
||||
|
||||
输出:
|
||||
data/models/picodet_layout_3cls.onnx
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# hf-mirror
|
||||
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
MODEL_DIR = PROJECT_ROOT / "data" / "models"
|
||||
OUTPUT_PATH = MODEL_DIR / "picodet_layout_3cls.onnx"
|
||||
MODEL_NAME = "PicoDet-S_layout_3cls"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ── Step 1: 用 PaddleOCR paddle_static 引擎加载模型,触发下载 ──
|
||||
print(f"[1/4] Loading model '{MODEL_NAME}' (paddle_static engine, triggers download) ...")
|
||||
from paddleocr import LayoutDetection
|
||||
|
||||
model = LayoutDetection(
|
||||
model_name=MODEL_NAME,
|
||||
engine="paddle_static",
|
||||
device="cpu",
|
||||
)
|
||||
print(" ✓ Model loaded and cached")
|
||||
|
||||
# ── Step 2: 找到 PaddleX 缓存的 Paddle 模型文件 ────────────────
|
||||
paddlex_cache = Path.home() / ".paddlex"
|
||||
print(f"\n[2/4] Searching Paddle model cache in {paddlex_cache} ...")
|
||||
|
||||
# 搜索 layout 相关的缓存目录
|
||||
candidates = []
|
||||
for d in paddlex_cache.rglob("*"):
|
||||
if d.is_dir() and (d / "inference.pdiparams").exists():
|
||||
# 检查是否是 layout 模型
|
||||
marker = d.name
|
||||
parent_name = d.parent.name
|
||||
if "layout" in marker.lower() or "layout" in parent_name.lower() or "picodet" in marker.lower():
|
||||
candidates.append(d)
|
||||
elif "PicoDet" in str(d):
|
||||
candidates.append(d)
|
||||
|
||||
if not candidates:
|
||||
# 如果没找到明确的 layout 目录,列出所有含 inference.pdiparams 的目录
|
||||
all_model_dirs = [d for d in paddlex_cache.rglob("*") if d.is_dir() and (d / "inference.pdiparams").exists()]
|
||||
print(" No layout-specific dir found. All model dirs with inference.pdiparams:")
|
||||
for d in all_model_dirs:
|
||||
files = [f.name for f in d.iterdir()]
|
||||
print(f" {d} ({', '.join(files)})")
|
||||
if all_model_dirs:
|
||||
# 取最新的(刚下载的)
|
||||
candidates = sorted(all_model_dirs, key=lambda d: (d / "inference.pdiparams").stat().st_mtime, reverse=True)[:1]
|
||||
|
||||
if not candidates:
|
||||
print(" ✗ No cached model found")
|
||||
sys.exit(1)
|
||||
|
||||
model_cache_dir = candidates[0]
|
||||
files_in_dir = list(model_cache_dir.iterdir())
|
||||
print(f" Using: {model_cache_dir}")
|
||||
for f in files_in_dir:
|
||||
print(f" {f.name} ({f.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
# ── Step 3: 用 paddle2onnx 转换 ─────────────────────────────────
|
||||
print("\n[3/4] Converting to ONNX with paddle2onnx ...")
|
||||
tmp_onnx = OUTPUT_PATH.with_suffix(".tmp.onnx")
|
||||
|
||||
# 确定 model_filename
|
||||
pdmodel = model_cache_dir / "inference.pdmodel"
|
||||
has_pdmodel = pdmodel.exists()
|
||||
|
||||
cmd = [
|
||||
sys.executable, "-m", "paddle2onnx",
|
||||
"--model_dir", str(model_cache_dir),
|
||||
"--save_file", str(tmp_onnx),
|
||||
"--opset_version", "11",
|
||||
"--enable_onnx_checker", "True",
|
||||
]
|
||||
if has_pdmodel:
|
||||
cmd.extend(["--model_filename", "inference.pdmodel"])
|
||||
cmd.extend(["--params_filename", "inference.pdiparams"])
|
||||
|
||||
print(f" Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.stdout:
|
||||
print(f" stdout: {result.stdout[:500]}")
|
||||
if result.returncode != 0:
|
||||
print(f" ✗ paddle2onnx failed (exit {result.returncode})")
|
||||
print(f" stderr: {result.stderr[:500]}")
|
||||
|
||||
# 尝试不带 model_filename(combined format)
|
||||
if has_pdmodel:
|
||||
print(" Retrying without explicit model_filename ...")
|
||||
cmd2 = [
|
||||
sys.executable, "-m", "paddle2onnx",
|
||||
"--model_dir", str(model_cache_dir),
|
||||
"--params_filename", "inference.pdiparams",
|
||||
"--save_file", str(tmp_onnx),
|
||||
"--opset_version", "11",
|
||||
]
|
||||
result2 = subprocess.run(cmd2, capture_output=True, text=True)
|
||||
if result2.returncode != 0:
|
||||
print(f" ✗ Retry also failed: {result2.stderr[:500]}")
|
||||
sys.exit(1)
|
||||
|
||||
if not tmp_onnx.exists() or tmp_onnx.stat().st_size < 1000:
|
||||
print(" ✗ ONNX file not created or too small")
|
||||
sys.exit(1)
|
||||
|
||||
shutil.move(str(tmp_onnx), str(OUTPUT_PATH))
|
||||
print(f" ✓ ONNX saved ({OUTPUT_PATH.stat().st_size / 1024 / 1024:.2f} MB)")
|
||||
|
||||
# ── Step 4: 用 onnxruntime 验证 ─────────────────────────────────
|
||||
print("\n[4/4] Verifying with onnxruntime ...")
|
||||
_inspect_onnx(OUTPUT_PATH)
|
||||
|
||||
print(f"\n✓ Done! ONNX model saved to {OUTPUT_PATH}")
|
||||
|
||||
|
||||
def _inspect_onnx(onnx_path: Path) -> None:
|
||||
"""用 onnxruntime 加载模型,打印输入输出信息."""
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
|
||||
|
||||
print(" Inputs:")
|
||||
for inp in session.get_inputs():
|
||||
print(f" {inp.name}: shape={inp.shape}, dtype={inp.type}")
|
||||
|
||||
print(" Outputs:")
|
||||
for out in session.get_outputs():
|
||||
print(f" {out.name}: shape={out.shape}, dtype={out.type}")
|
||||
|
||||
# 试推理
|
||||
input_info = session.get_inputs()[0]
|
||||
input_name = input_info.name
|
||||
batch_size = input_info.shape[0] if isinstance(input_info.shape[0], int) else 1
|
||||
channels = input_info.shape[1] if isinstance(input_info.shape[1], int) else 3
|
||||
height = input_info.shape[2] if isinstance(input_info.shape[2], int) else 480
|
||||
width = input_info.shape[3] if isinstance(input_info.shape[3], int) else 480
|
||||
|
||||
dummy_input = np.random.rand(batch_size, channels, height, width).astype(np.float32)
|
||||
outputs = session.run(None, {input_name: dummy_input})
|
||||
|
||||
print(" Inference test outputs:")
|
||||
for i, (out_info, out_val) in enumerate(zip(session.get_outputs(), outputs)):
|
||||
print(f" output[{i}] '{out_info.name}': shape={out_val.shape}, dtype={out_val.dtype}")
|
||||
if out_val.size <= 20:
|
||||
print(f" values: {out_val}")
|
||||
|
||||
print(" ✓ Inference OK")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,7 +0,0 @@
|
||||
"""快捷脚本:手动抓取指定日期。用法: python scripts/manual_crawl.py [YYYY-MM-DD] [--top N]"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
from app.cli import cli_app
|
||||
|
||||
cli_app(["crawl"] + sys.argv[1:])
|
||||
@@ -1,4 +1,4 @@
|
||||
"""批量重新提取所有论文的图片 — 下载 PDF + PicoDet 检测 + caption 匹配.
|
||||
"""批量重新提取所有论文的图片 — 下载 PDF + DocLayout-YOLO 检测 + caption 匹配.
|
||||
|
||||
用法:
|
||||
PROXY_SERVER=http://... uv run python scripts/reextract_images.py
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
import json
|
||||
import sys
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"required": ["arxiv_id", "title_zh", "one_line", "tags", "difficulty",
|
||||
"prerequisites", "motivation", "method", "results", "improvements", "figures"],
|
||||
"properties": {
|
||||
"arxiv_id": {"type": "string"},
|
||||
"title_zh": {"type": "string"},
|
||||
"one_line": {"type": "string"},
|
||||
"tags": {"type": "array", "items": {"type": "string"}},
|
||||
"difficulty": {"type": "string", "enum": ["入门", "进阶", "前沿"]},
|
||||
"prerequisites": {
|
||||
"type": "object",
|
||||
"required": ["concepts"],
|
||||
"properties": {
|
||||
"concepts": {"type": "array", "items": {
|
||||
"type": "object",
|
||||
"required": ["term", "explanation", "why_matters"],
|
||||
"properties": {
|
||||
"term": {"type": "string"},
|
||||
"explanation": {"type": "string"},
|
||||
"why_matters": {"type": "string"}
|
||||
}
|
||||
}}
|
||||
}
|
||||
},
|
||||
"motivation": {
|
||||
"type": "object",
|
||||
"required": ["problem", "goal", "gap"],
|
||||
"properties": {
|
||||
"problem": {"type": "string"},
|
||||
"goal": {"type": "string"},
|
||||
"gap": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"method": {
|
||||
"type": "object",
|
||||
"required": ["overview", "key_idea", "steps", "novelty"],
|
||||
"properties": {
|
||||
"overview": {"type": "string"},
|
||||
"key_idea": {"type": "string"},
|
||||
"steps": {"type": "string"},
|
||||
"novelty": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"results": {
|
||||
"type": "object",
|
||||
"required": ["main_findings", "benchmarks", "limitations"],
|
||||
"properties": {
|
||||
"main_findings": {"type": "string"},
|
||||
"benchmarks": {"type": "array", "items": {
|
||||
"type": "object",
|
||||
"required": ["task", "metric", "this_work", "baseline", "improvement"],
|
||||
"properties": {
|
||||
"task": {"type": "string"},
|
||||
"metric": {"type": "string"},
|
||||
"this_work": {"type": "string"},
|
||||
"baseline": {"type": "string"},
|
||||
"improvement": {"type": "string"}
|
||||
}
|
||||
}},
|
||||
"limitations": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"improvements": {
|
||||
"type": "object",
|
||||
"required": ["weaknesses", "future_work", "reproducibility"],
|
||||
"properties": {
|
||||
"weaknesses": {"type": "string"},
|
||||
"future_work": {"type": "string"},
|
||||
"reproducibility": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"figures": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["id", "caption", "description", "reason", "section"],
|
||||
"properties": {
|
||||
"id": {"type": "string"},
|
||||
"caption": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"reason": {"type": "string"},
|
||||
"section": {"type": "string", "enum": ["motivation", "method", "results", "limitations"]}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def validate_file(filepath):
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Check required fields
|
||||
for field in schema["required"]:
|
||||
if field not in data:
|
||||
print(f"❌ Missing field: {field}")
|
||||
return False
|
||||
|
||||
# Validate nested structure
|
||||
for field, spec in schema["properties"].items():
|
||||
if field in data:
|
||||
if spec["type"] == "string":
|
||||
if not isinstance(data[field], str):
|
||||
print(f"❌ Field '{field}' should be string")
|
||||
return False
|
||||
elif spec["type"] == "array":
|
||||
if not isinstance(data[field], list):
|
||||
print(f"❌ Field '{field}' should be array")
|
||||
return False
|
||||
elif spec["type"] == "object":
|
||||
if not isinstance(data[field], dict):
|
||||
print(f"❌ Field '{field}' should be object")
|
||||
return False
|
||||
if "required" in spec:
|
||||
for subfield in spec["required"]:
|
||||
if subfield not in data[field]:
|
||||
print(f"❌ Missing subfield: {field}.{subfield}")
|
||||
return False
|
||||
|
||||
# Validate section enum in figures
|
||||
valid_sections = ["motivation", "method", "results", "limitations"]
|
||||
for fig in data.get("figures", []):
|
||||
if fig["section"] not in valid_sections:
|
||||
print(f"❌ Invalid section in figure: {fig['section']}")
|
||||
return False
|
||||
|
||||
print("✅ JSON validation passed!")
|
||||
return True
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"❌ JSON decode error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ Validation error: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
filepath = sys.argv[1] if len(sys.argv) > 1 else "data/papers/2601.10592/summary.json"
|
||||
validate_file(filepath)
|
||||
@@ -24,6 +24,26 @@ from app.models import (
|
||||
from app.utils import utc_now
|
||||
|
||||
|
||||
# ── ChromaDB 隔离(autouse,所有测试)──────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_chroma(monkeypatch, tmp_path):
|
||||
"""所有测试把 ChromaDB 隔离到临时目录 + 重置单例,绝不污染 data/chroma。
|
||||
|
||||
与内存 DB 隔离同理:summarize 后处理经真实 _maybe_index_chroma → index_paper
|
||||
写入,不隔离会把测试夹具(2401.*)泄漏到生产 data/chroma,污染语义搜索。
|
||||
每个测试前重置 _chroma 单例,确保 CHROMA_DIR 指向本次 tmp。
|
||||
"""
|
||||
import app.services.embedder as emb
|
||||
from app.config import settings
|
||||
|
||||
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||||
emb._chroma.reset()
|
||||
yield
|
||||
emb._chroma.reset()
|
||||
|
||||
|
||||
# ── 内存数据库 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
+1
-4
@@ -301,10 +301,7 @@ class TestAdminPapers:
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["count"] == 2
|
||||
remaining = db_session.execute(
|
||||
text(
|
||||
"SELECT rowid FROM papers_fts "
|
||||
"WHERE rowid IN (:id1, :id2)"
|
||||
),
|
||||
text("SELECT rowid FROM papers_fts WHERE rowid IN (:id1, :id2)"),
|
||||
{"id1": target_ids[0], "id2": target_ids[1]},
|
||||
).fetchall()
|
||||
assert remaining == []
|
||||
|
||||
@@ -0,0 +1,380 @@
|
||||
"""管理后台新功能测试 — 任务监控、锁释放、重抓、失败分布、配置、导出、重建索引。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.models import Job, JobStatus, SummaryState, SummaryStatus, TaskLock
|
||||
from app.services import admin as admin_svc
|
||||
from app.services.crawler import recrawl_single
|
||||
from app.services.jobs import create_job, run_job
|
||||
from app.utils import utc_now
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_enqueue(monkeypatch):
|
||||
"""禁用路由层的 enqueue_job,阻止 background task 在测试中真实执行。"""
|
||||
from app.routes import admin as admin_route
|
||||
|
||||
monkeypatch.setattr(admin_route, "enqueue_job", lambda *a, **k: None)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 任务监控
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _make_job(db_session, *, type="crawl_daily", status=JobStatus.QUEUED, owner="t"):
|
||||
job = Job(
|
||||
type=type,
|
||||
status=status,
|
||||
owner=owner,
|
||||
payload_json="{}",
|
||||
created_at=utc_now(),
|
||||
)
|
||||
db_session.add(job)
|
||||
db_session.commit()
|
||||
return job
|
||||
|
||||
|
||||
def test_query_jobs_filter_and_pagination(db_session):
|
||||
for i in range(25):
|
||||
_make_job(db_session, status=JobStatus.SUCCESS)
|
||||
for i in range(5):
|
||||
_make_job(db_session, status=JobStatus.FAILED)
|
||||
|
||||
# 无过滤:分页
|
||||
page1, total = admin_svc.query_jobs(db_session, page=1, per_page=20)
|
||||
assert total == 30
|
||||
assert len(page1) == 20
|
||||
page2, _ = admin_svc.query_jobs(db_session, page=2, per_page=20)
|
||||
assert len(page2) == 10
|
||||
|
||||
# status 过滤
|
||||
failed, ftotal = admin_svc.query_jobs(db_session, status="failed", per_page=50)
|
||||
assert ftotal == 5
|
||||
assert len(failed) == 5
|
||||
assert all(j["status"] == "failed" for j in failed)
|
||||
|
||||
# type 过滤
|
||||
_make_job(db_session, type="reindex_fts", status=JobStatus.QUEUED)
|
||||
typed, ttotal = admin_svc.query_jobs(db_session, job_type="reindex_fts")
|
||||
assert ttotal == 1
|
||||
assert typed[0]["type"] == "reindex_fts"
|
||||
|
||||
|
||||
def test_serialize_job_includes_duration(db_session):
|
||||
job = _make_job(db_session, status=JobStatus.SUCCESS)
|
||||
job.started_at = utc_now()
|
||||
job.completed_at = utc_now()
|
||||
db_session.commit()
|
||||
serialized = admin_svc.serialize_job(job)
|
||||
assert serialized["duration_seconds"] is not None
|
||||
assert serialized["duration_seconds"] >= 0
|
||||
|
||||
|
||||
def test_serialize_job_running_without_completed(db_session):
|
||||
# 运行中的 job:completed_at=None,started_at 经 db 读回为 naive UTC,
|
||||
# 不能与 aware 的 utc_now() 直接相减(回归测试)。
|
||||
job = _make_job(db_session, status=JobStatus.RUNNING)
|
||||
job.started_at = utc_now()
|
||||
db_session.commit()
|
||||
serialized = admin_svc.serialize_job(job)
|
||||
assert serialized["duration_seconds"] is not None
|
||||
assert serialized["duration_seconds"] >= 0
|
||||
|
||||
|
||||
def test_get_job_status_counts(db_session):
|
||||
_make_job(db_session, status=JobStatus.QUEUED)
|
||||
_make_job(db_session, status=JobStatus.QUEUED)
|
||||
_make_job(db_session, status=JobStatus.RUNNING)
|
||||
counts = admin_svc.get_job_status_counts(db_session)
|
||||
assert counts.get("queued") == 2
|
||||
assert counts.get("running") == 1
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 锁释放
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def test_force_release_lock(db_session):
|
||||
running = TaskLock(
|
||||
task="crawl", lock_key="k1", status="running", acquired_at=utc_now()
|
||||
)
|
||||
stale = TaskLock(task="crawl", lock_key="k2", status="stale", acquired_at=utc_now())
|
||||
finished = TaskLock(
|
||||
task="crawl", lock_key="k3", status="finished", acquired_at=utc_now()
|
||||
)
|
||||
db_session.add_all([running, stale, finished])
|
||||
db_session.commit()
|
||||
|
||||
assert admin_svc.force_release_lock(db_session, running.id) is True
|
||||
assert admin_svc.force_release_lock(db_session, stale.id) is True
|
||||
# finished 的不应被再次释放
|
||||
assert admin_svc.force_release_lock(db_session, finished.id) is False
|
||||
# 不存在的 id
|
||||
assert admin_svc.force_release_lock(db_session, 999999) is False
|
||||
|
||||
db_session.refresh(running)
|
||||
db_session.refresh(stale)
|
||||
db_session.refresh(finished)
|
||||
assert running.status == "finished"
|
||||
assert running.released_at is not None
|
||||
assert stale.status == "finished"
|
||||
assert finished.status == "finished"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 失败原因分布
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def test_get_failure_breakdown(db_session, sample_papers_range):
|
||||
statuses = (
|
||||
db_session.execute(select(SummaryStatus).order_by(SummaryStatus.id))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
statuses[0].status = SummaryState.FAILED
|
||||
statuses[0].error_type = "pdf_download_failed"
|
||||
statuses[1].status = SummaryState.FAILED
|
||||
statuses[1].error_type = "timeout"
|
||||
statuses[2].status = SummaryState.PERMANENT_FAILURE
|
||||
statuses[2].error_type = None # 归 unknown
|
||||
db_session.commit()
|
||||
|
||||
breakdown = admin_svc.get_failure_breakdown(db_session)
|
||||
by_type = {b["error_type"]: b["count"] for b in breakdown}
|
||||
assert by_type.get("pdf_download_failed") == 1
|
||||
assert by_type.get("timeout") == 1
|
||||
assert by_type.get("unknown") == 1
|
||||
# 降序
|
||||
counts = [b["count"] for b in breakdown]
|
||||
assert counts == sorted(counts, reverse=True)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 配置概览
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def test_get_config_overview_no_secrets():
|
||||
cfg = admin_svc.get_config_overview()
|
||||
assert "summary_backend" in cfg
|
||||
assert "schedule_time" in cfg
|
||||
assert "api_key_configured" in cfg # 只标是否配置,不显值
|
||||
text = str(cfg)
|
||||
# 不应泄露默认密钥值
|
||||
assert "change-me" not in text
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 单篇/批量重抓
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestRecrawl:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self, db_session):
|
||||
res = await recrawl_single(db_session, "9999.99999")
|
||||
assert res["updated"] is False
|
||||
assert res["reason"] == "not_found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_full_metadata(self, db_session, sample_paper):
|
||||
new_item = {
|
||||
"paper": {
|
||||
"id": sample_paper.arxiv_id,
|
||||
"title": "Updated Title",
|
||||
"abstract": "New abstract",
|
||||
"publishedAt": "2024-01-15T00:00:00",
|
||||
"authors": [{"name": "New Author"}],
|
||||
"tags": [{"name": "CV"}, {"name": "Diffusion"}],
|
||||
"upvotes": 100,
|
||||
}
|
||||
}
|
||||
with patch(
|
||||
"app.services.crawler.fetch_daily",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[new_item],
|
||||
):
|
||||
res = await recrawl_single(db_session, sample_paper.arxiv_id)
|
||||
|
||||
assert res["updated"] is True
|
||||
db_session.refresh(sample_paper)
|
||||
assert sample_paper.title_en == "Updated Title"
|
||||
assert sample_paper.abstract == "New abstract"
|
||||
assert sample_paper.upvotes == 100
|
||||
# authors 重建(原 Alice/Bob → New Author)
|
||||
assert sorted(a.name for a in sample_paper.authors) == ["New Author"]
|
||||
# tags 重建(原 NLP/LLM → CV/Diffusion)
|
||||
assert sorted(t.tag for t in sample_paper.tags) == ["CV", "Diffusion"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_in_daily(self, db_session, sample_paper):
|
||||
with patch(
|
||||
"app.services.crawler.fetch_daily",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
res = await recrawl_single(db_session, sample_paper.arxiv_id)
|
||||
assert res["updated"] is False
|
||||
assert res["reason"] == "not_in_daily"
|
||||
assert "date" in res
|
||||
|
||||
|
||||
class TestDispatchRecrawl:
|
||||
@pytest.mark.asyncio
|
||||
async def test_recrawl_one_via_run_job(self, db_session, sample_paper):
|
||||
new_item = {
|
||||
"paper": {
|
||||
"id": sample_paper.arxiv_id,
|
||||
"title": "Via Job",
|
||||
"authors": [],
|
||||
"tags": [],
|
||||
"upvotes": 5,
|
||||
}
|
||||
}
|
||||
with patch(
|
||||
"app.services.crawler.fetch_daily",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[new_item],
|
||||
):
|
||||
job = create_job(
|
||||
db_session,
|
||||
"recrawl_one",
|
||||
owner="test",
|
||||
payload={"arxiv_id": sample_paper.arxiv_id},
|
||||
)
|
||||
result = await run_job(db_session, job.id)
|
||||
|
||||
assert result["updated"] is True
|
||||
db_session.refresh(sample_paper)
|
||||
assert sample_paper.title_en == "Via Job"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recrawl_batch_via_run_job(self, db_session, sample_papers_range):
|
||||
arxiv_ids = [p.arxiv_id for p in sample_papers_range[:2]]
|
||||
items = [
|
||||
{
|
||||
"paper": {
|
||||
"id": aid,
|
||||
"title": "Batch " + aid,
|
||||
"authors": [],
|
||||
"tags": [],
|
||||
"upvotes": 1,
|
||||
}
|
||||
}
|
||||
for aid in arxiv_ids
|
||||
]
|
||||
with patch(
|
||||
"app.services.crawler.fetch_daily",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=lambda d: items,
|
||||
):
|
||||
job = create_job(
|
||||
db_session,
|
||||
"recrawl_batch",
|
||||
owner="test",
|
||||
payload={"arxiv_ids": arxiv_ids},
|
||||
)
|
||||
result = await run_job(db_session, job.id)
|
||||
|
||||
assert result["updated"] == 2
|
||||
assert result["skipped"] == 0
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 路由
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestRoutes:
|
||||
def test_jobs_page_renders(self, auth_client):
|
||||
resp = auth_client.get("/admin/jobs")
|
||||
assert resp.status_code == 200
|
||||
assert "任务监控" in resp.text
|
||||
|
||||
def test_jobs_page_filters_by_status(self, auth_client, db_session):
|
||||
_make_job(db_session, status=JobStatus.FAILED)
|
||||
resp = auth_client.get("/admin/jobs?status=failed")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_export_csv(self, auth_client, sample_papers_range):
|
||||
resp = auth_client.get("/admin/papers/export.csv")
|
||||
assert resp.status_code == 200
|
||||
assert "text/csv" in resp.headers["content-type"]
|
||||
# UTF-8 BOM for Excel
|
||||
assert resp.content.startswith(b"\xef\xbb\xbf")
|
||||
# 表头 + 数据
|
||||
assert "arxiv_id" in resp.text
|
||||
assert "2401.10001" in resp.text
|
||||
|
||||
def test_export_csv_respects_filter(self, auth_client, sample_papers_range):
|
||||
resp = auth_client.get("/admin/papers/export.csv?q=Paper%203")
|
||||
assert resp.status_code == 200
|
||||
assert "2401.10003" in resp.text
|
||||
assert "2401.10001" not in resp.text
|
||||
|
||||
def test_rebuild_indexes_fts(self, auth_client, db_session, no_enqueue):
|
||||
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "fts"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "queued"
|
||||
assert len(data["job_ids"]) == 1
|
||||
jobs = (
|
||||
db_session.execute(select(Job).where(Job.type == "reindex_fts"))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
assert len(jobs) == 1
|
||||
|
||||
def test_rebuild_indexes_both(self, auth_client, db_session, no_enqueue):
|
||||
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "both"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["job_ids"]) == 2
|
||||
|
||||
def test_release_lock_route(self, auth_client, db_session):
|
||||
lock = TaskLock(
|
||||
task="crawl", lock_key="rt", status="running", acquired_at=utc_now()
|
||||
)
|
||||
db_session.add(lock)
|
||||
db_session.commit()
|
||||
resp = auth_client.post(f"/admin/locks/{lock.id}/release")
|
||||
assert resp.status_code == 200
|
||||
db_session.refresh(lock)
|
||||
assert lock.status == "finished"
|
||||
|
||||
def test_paper_recrawl_route(
|
||||
self, auth_client, sample_paper, db_session, no_enqueue
|
||||
):
|
||||
resp = auth_client.post(f"/admin/paper-recrawl/{sample_paper.arxiv_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "queued"
|
||||
jobs = (
|
||||
db_session.execute(select(Job).where(Job.type == "recrawl_one"))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
assert len(jobs) == 1
|
||||
|
||||
def test_batch_recrawl_route(
|
||||
self, auth_client, sample_papers_range, db_session, no_enqueue
|
||||
):
|
||||
ids = [p.arxiv_id for p in sample_papers_range[:3]]
|
||||
resp = auth_client.post(
|
||||
"/admin/papers-batch-action", json={"action": "recrawl", "arxiv_ids": ids}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "queued"
|
||||
jobs = (
|
||||
db_session.execute(select(Job).where(Job.type == "recrawl_batch"))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
assert len(jobs) == 1
|
||||
@@ -54,7 +54,9 @@ class TestReindexChroma:
|
||||
def test_reindex_chroma_indexes_only_summarized_papers(
|
||||
self, db_session, sample_papers_with_summary
|
||||
):
|
||||
with patch("app.services.embedder.index_paper", return_value=True) as mock_index:
|
||||
with patch(
|
||||
"app.services.embedder.index_paper", return_value=True
|
||||
) as mock_index:
|
||||
result = reindex_chroma(db_session)
|
||||
|
||||
assert result["status"] == "success"
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
@@ -154,3 +156,84 @@ class TestEmbeddingApi:
|
||||
)
|
||||
result = emb._get_embedding("test")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 并发安全:init() 双重检查锁 + 集合访问串行化
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestEmbedderConcurrency:
|
||||
"""后处理经 asyncio.to_thread 多 worker 并发调 index_paper 的安全性。"""
|
||||
|
||||
def test_init_serialized_under_concurrency(self, monkeypatch, tmp_path):
|
||||
"""并发 init() 只调一次 PersistentClient(chromadb SharedSystemClient 缓存竞争修复)。
|
||||
|
||||
复现崩坏条件:10 线程同时 init(),fake PersistentClient 故意 sleep 拉长建连窗口。
|
||||
修复前会有多线程同时进入 _create_system_if_not_exists → 并发 mutate 类级缓存;
|
||||
修复后(双重检查锁)只有抢到锁的那个线程建连。
|
||||
"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
|
||||
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||||
import app.services.embedder as emb
|
||||
|
||||
emb._chroma.reset()
|
||||
|
||||
counter = {"n": 0}
|
||||
counter_lock = threading.Lock()
|
||||
|
||||
def fake_persistent_client(path):
|
||||
with counter_lock:
|
||||
counter["n"] += 1
|
||||
time.sleep(0.05) # 拉长建连窗口,放大并发竞争
|
||||
client = MagicMock()
|
||||
client.get_collection.side_effect = Exception(
|
||||
"not exist"
|
||||
) # 触发 create 路径
|
||||
client.create_collection.return_value = MagicMock()
|
||||
return client
|
||||
|
||||
with patch("chromadb.PersistentClient", side_effect=fake_persistent_client):
|
||||
threads = [threading.Thread(target=emb._chroma.init) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert counter["n"] == 1, f"PersistentClient 应只调一次,实际 {counter['n']}"
|
||||
assert emb._chroma._client is not None
|
||||
emb._chroma.reset()
|
||||
|
||||
def test_index_paper_concurrent_no_error(self, monkeypatch, tmp_path):
|
||||
"""并发 index_paper:embedding 锁外并行,集合写入串行化,全部成功。"""
|
||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
|
||||
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||||
import app.services.embedder as emb
|
||||
|
||||
emb._chroma.reset()
|
||||
# 跳过 init,直接注入 mock collection
|
||||
emb._chroma._client = MagicMock()
|
||||
col = MagicMock()
|
||||
col.count.return_value = 0
|
||||
emb._chroma._collection = col
|
||||
|
||||
with patch.object(emb, "_get_embedding", return_value=[0.1, 0.2, 0.3]):
|
||||
errors: list[BaseException] = []
|
||||
|
||||
def worker(i: int) -> None:
|
||||
try:
|
||||
emb.index_paper(
|
||||
f"id-{i}", {"arxiv_id": f"id-{i}", "title_zh": f"标题{i}"}
|
||||
)
|
||||
except BaseException as exc: # noqa: BLE001 — 收集所有错误
|
||||
errors.append(exc)
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert errors == []
|
||||
assert col.upsert.call_count == 10
|
||||
emb._chroma.reset()
|
||||
|
||||
@@ -0,0 +1,559 @@
|
||||
"""layout_detector 测试 — 坐标还原数学、设备探测、类别映射、端到端 detect_page.
|
||||
|
||||
纯函数测试不依赖真实模型;TestDetectPage 用 MagicMock mock ort.InferenceSession
|
||||
与 pymupdf.Page,参考 test_summary_utils.py 的 mock 模式。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import pytest
|
||||
|
||||
from app.config import settings
|
||||
from app.services import layout_detector as mod
|
||||
from app.services.layout_detector import (
|
||||
LayoutBox,
|
||||
_compute_render_geometry,
|
||||
_FALLBACK_NAMES,
|
||||
_letterbox_padding,
|
||||
_map_class_to_boxclass,
|
||||
_model_to_pdf,
|
||||
_parse_names_from_meta,
|
||||
_postprocess_output,
|
||||
detect_page_layout,
|
||||
resolve_providers,
|
||||
)
|
||||
|
||||
IMGSZ = 1024
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 渲染几何与 letterbox padding
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestComputeRenderGeometry:
|
||||
def test_a4_portrait_short_edge_pads(self):
|
||||
# A4 595×842,高度贴边,宽度方向留灰边
|
||||
ratio = _compute_render_geometry(595, 842, IMGSZ)
|
||||
assert ratio == pytest.approx(min(IMGSZ / 595, IMGSZ / 842))
|
||||
assert ratio == pytest.approx(IMGSZ / 842) # 高度方向贴边
|
||||
|
||||
def test_wide_page_width_pads(self):
|
||||
# 1600×900 横向,宽度贴边
|
||||
ratio = _compute_render_geometry(1600, 900, IMGSZ)
|
||||
assert ratio == pytest.approx(IMGSZ / 1600)
|
||||
|
||||
def test_square_no_letterbox(self):
|
||||
ratio = _compute_render_geometry(100, 100, IMGSZ)
|
||||
assert ratio == pytest.approx(10.24)
|
||||
|
||||
|
||||
class TestLetterboxPadding:
|
||||
def test_centered_padding(self):
|
||||
# pixmap 723×1024 贴满高度,宽度两侧各 (1024-723)/2
|
||||
dw, dh = _letterbox_padding(723, 1024, IMGSZ)
|
||||
assert dw == pytest.approx((IMGSZ - 723) / 2)
|
||||
assert dh == pytest.approx(0.0)
|
||||
|
||||
def test_square_no_padding(self):
|
||||
dw, dh = _letterbox_padding(IMGSZ, IMGSZ, IMGSZ)
|
||||
assert dw == 0.0
|
||||
assert dh == 0.0
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 坐标还原(核心)—— pdf = (model - padding) / ratio
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestModelToPdf:
|
||||
def test_padding_corner_maps_to_origin(self):
|
||||
# 模型空间左上角 (dw, dh) → PDF (0, 0)
|
||||
dw, dh, ratio = 150.5, 0.0, 1.2157
|
||||
x, y = _model_to_pdf(dw, dh, dw, dh, ratio)
|
||||
assert x == pytest.approx(0.0, abs=0.01)
|
||||
assert y == pytest.approx(0.0, abs=0.01)
|
||||
|
||||
def test_round_trip(self):
|
||||
dw, dh, ratio = 150.5, 0.0, 1.2157
|
||||
# PDF (100, 200) → 模型空间 → 再还原回 PDF
|
||||
mx, my = 100 * ratio + dw, 200 * ratio + dh
|
||||
px, py = _model_to_pdf(mx, my, dw, dh, ratio)
|
||||
assert px == pytest.approx(100, abs=0.01)
|
||||
assert py == pytest.approx(200, abs=0.01)
|
||||
|
||||
def test_full_a4_page_box(self):
|
||||
# 整页框在模型空间为 (dw,dh)-(dw+pix_w, dh+pix_h),还原回页面尺寸
|
||||
ratio = _compute_render_geometry(595, 842, IMGSZ)
|
||||
pix_w, pix_h = round(595 * ratio), round(842 * ratio)
|
||||
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
|
||||
x0, y0 = _model_to_pdf(dw, dh, dw, dh, ratio)
|
||||
x1, y1 = _model_to_pdf(dw + pix_w, dh + pix_h, dw, dh, ratio)
|
||||
assert x0 == pytest.approx(0.0, abs=1.0)
|
||||
assert y0 == pytest.approx(0.0, abs=1.0)
|
||||
assert x1 == pytest.approx(595, abs=1.0)
|
||||
assert y1 == pytest.approx(842, abs=1.0)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 设备探测
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestResolveProviders:
|
||||
def test_cpu(self):
|
||||
assert resolve_providers("cpu", 0) == [("CPUExecutionProvider", {})]
|
||||
|
||||
def test_cuda_with_cpu_fallback(self):
|
||||
eps = resolve_providers("cuda", 0)
|
||||
assert eps[0] == ("CUDAExecutionProvider", {"device_id": "0"})
|
||||
assert eps[1] == ("CPUExecutionProvider", {})
|
||||
|
||||
def test_directml_device_id(self):
|
||||
eps = resolve_providers("directml", 2)
|
||||
assert eps[0] == ("DmlExecutionProvider", {"device_id": "2"})
|
||||
|
||||
def test_auto_picks_cuda_if_available(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
ort,
|
||||
"get_available_providers",
|
||||
lambda: ["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
)
|
||||
eps = resolve_providers("auto", 0)
|
||||
assert eps[0][0] == "CUDAExecutionProvider"
|
||||
assert eps[-1] == ("CPUExecutionProvider", {})
|
||||
|
||||
def test_auto_falls_back_to_cpu(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
ort, "get_available_providers", lambda: ["CPUExecutionProvider"]
|
||||
)
|
||||
assert resolve_providers("auto", 0) == [("CPUExecutionProvider", {})]
|
||||
|
||||
def test_auto_prefers_cuda_over_directml(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
ort,
|
||||
"get_available_providers",
|
||||
lambda: [
|
||||
"DmlExecutionProvider",
|
||||
"CUDAExecutionProvider",
|
||||
"CPUExecutionProvider",
|
||||
],
|
||||
)
|
||||
eps = resolve_providers("auto", 0)
|
||||
assert eps[0][0] == "CUDAExecutionProvider"
|
||||
|
||||
def test_unknown_device_falls_back(self):
|
||||
assert resolve_providers("tpu", 0) == [("CPUExecutionProvider", {})]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 类别映射与 names 解析
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestClassMapping:
|
||||
def test_figure_to_picture(self):
|
||||
assert _map_class_to_boxclass(3, {3: "figure"}) == "picture"
|
||||
|
||||
def test_figure_group_to_picture(self):
|
||||
assert _map_class_to_boxclass(0, {0: "figure_group"}) == "picture"
|
||||
|
||||
def test_table(self):
|
||||
assert _map_class_to_boxclass(5, {5: "table"}) == "table"
|
||||
|
||||
def test_caption_classes(self):
|
||||
names = {4: "figure_caption", 6: "table_caption"}
|
||||
assert _map_class_to_boxclass(4, names) == "figure_caption"
|
||||
assert _map_class_to_boxclass(6, names) == "table_caption"
|
||||
|
||||
def test_other_classes_ignored(self):
|
||||
names = {0: "title", 1: "plain text", 2: "abandon", 8: "isolate_formula"}
|
||||
for k in names:
|
||||
assert _map_class_to_boxclass(k, names) is None
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _map_class_to_boxclass(0, {0: "Figure"}) == "picture"
|
||||
assert _map_class_to_boxclass(0, {0: "TABLE"}) == "table"
|
||||
|
||||
def test_unknown_class_id(self):
|
||||
assert _map_class_to_boxclass(99, {0: "figure"}) is None
|
||||
|
||||
|
||||
class TestParseNamesFromMeta:
|
||||
def test_reads_json_metadata(self):
|
||||
sess = MagicMock()
|
||||
meta = MagicMock()
|
||||
meta.custom_metadata_map = {
|
||||
"names": '{"0": "title", "3": "figure", "5": "table"}'
|
||||
}
|
||||
sess.get_modelmeta.return_value = meta
|
||||
assert _parse_names_from_meta(sess) == {0: "title", 3: "figure", 5: "table"}
|
||||
|
||||
def test_fallback_when_missing(self):
|
||||
sess = MagicMock()
|
||||
meta = MagicMock()
|
||||
meta.custom_metadata_map = {}
|
||||
sess.get_modelmeta.return_value = meta
|
||||
assert _parse_names_from_meta(sess) == _FALLBACK_NAMES
|
||||
|
||||
def test_fallback_on_garbage(self):
|
||||
sess = MagicMock()
|
||||
meta = MagicMock()
|
||||
meta.custom_metadata_map = {"names": "not json"}
|
||||
sess.get_modelmeta.return_value = meta
|
||||
assert _parse_names_from_meta(sess) == _FALLBACK_NAMES
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 后处理
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestPostprocessOutput:
|
||||
def test_parses_end_to_end_filters_by_conf(self):
|
||||
out = np.array(
|
||||
[[[10, 20, 30, 40, 0.9, 3], [50, 60, 70, 80, 0.1, 5]]],
|
||||
dtype=np.float32,
|
||||
)
|
||||
res = _postprocess_output(out, 0.2, {3: "figure", 5: "table"})
|
||||
assert res == [(3, 10.0, 20.0, 30.0, 40.0)]
|
||||
|
||||
def test_empty_output(self):
|
||||
out = np.zeros((1, 0, 6), dtype=np.float32)
|
||||
assert _postprocess_output(out, 0.2, {}) == []
|
||||
|
||||
def test_unexpected_shape_returns_empty(self):
|
||||
out = np.zeros((1, 84, 8400), dtype=np.float32)
|
||||
assert _postprocess_output(out, 0.2, {}) == []
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# detect_page 端到端(mock ort.InferenceSession + pymupdf.Page)
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestDetectPage:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_detector(self):
|
||||
"""每个测试前重建单例(带新锁 + 空 session),避免复用上个测试的 mock session。"""
|
||||
mod._LayoutDetector.reset_instance()
|
||||
mod._detector = mod._LayoutDetector()
|
||||
yield
|
||||
mod._LayoutDetector.reset_instance()
|
||||
mod._detector = mod._LayoutDetector()
|
||||
|
||||
@staticmethod
|
||||
def _build_mock_session(page_w, page_h, boxes, names):
|
||||
"""构造 mock InferenceSession。
|
||||
|
||||
boxes: list of (cls_id, pdf_x0, pdf_y0, pdf_x1, pdf_y1, conf)
|
||||
坐标为 PDF 点,内部转成模型空间坐标塞进 output。
|
||||
names: dict[int, str] —— 写入 metadata 供 _parse_names_from_meta 读取。
|
||||
"""
|
||||
ratio = _compute_render_geometry(page_w, page_h, IMGSZ)
|
||||
pix_w, pix_h = round(page_w * ratio), round(page_h * ratio)
|
||||
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
|
||||
|
||||
rows = []
|
||||
for cls_id, x0, y0, x1, y1, conf in boxes:
|
||||
rows.append(
|
||||
[
|
||||
x0 * ratio + dw,
|
||||
y0 * ratio + dh,
|
||||
x1 * ratio + dw,
|
||||
y1 * ratio + dh,
|
||||
conf,
|
||||
cls_id,
|
||||
]
|
||||
)
|
||||
fake_output = (
|
||||
np.array([rows], dtype=np.float32)
|
||||
if rows
|
||||
else np.zeros((1, 0, 6), dtype=np.float32)
|
||||
)
|
||||
|
||||
sess = MagicMock()
|
||||
inp = MagicMock()
|
||||
inp.name = "images"
|
||||
sess.get_inputs.return_value = [inp]
|
||||
sess.run.return_value = [fake_output]
|
||||
sess.get_providers.return_value = ["CPUExecutionProvider"]
|
||||
meta = MagicMock()
|
||||
meta.custom_metadata_map = {
|
||||
"names": json.dumps({str(k): v for k, v in names.items()})
|
||||
}
|
||||
sess.get_modelmeta.return_value = meta
|
||||
return sess, (pix_w, pix_h)
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_page(page_w, page_h, pix_w, pix_h):
|
||||
pix = MagicMock()
|
||||
pix.width = pix_w
|
||||
pix.height = pix_h
|
||||
pix.n = 3
|
||||
pix.samples = bytes([128] * (pix_w * pix_h * 3))
|
||||
page = MagicMock()
|
||||
page.rect.width = page_w
|
||||
page.rect.height = page_h
|
||||
page.get_pixmap.return_value = pix
|
||||
return page
|
||||
|
||||
def _setup(self, monkeypatch, tmp_path, sess):
|
||||
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
|
||||
(tmp_path / "m.onnx").write_bytes(b"x")
|
||||
monkeypatch.setattr(ort, "InferenceSession", lambda *a, **kw: sess)
|
||||
|
||||
def test_returns_picture_box(self, monkeypatch, tmp_path):
|
||||
names = {3: "figure", 5: "table"}
|
||||
sess, (pw, ph) = self._build_mock_session(
|
||||
595, 842, [(3, 100, 100, 300, 400, 0.9)], names
|
||||
)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
boxes = detect_page_layout(page)
|
||||
|
||||
assert len(boxes) == 1
|
||||
b = boxes[0]
|
||||
assert isinstance(b, LayoutBox)
|
||||
assert b.boxclass == "picture"
|
||||
assert b.x0 == pytest.approx(100, abs=1.0)
|
||||
assert b.y0 == pytest.approx(100, abs=1.0)
|
||||
assert b.x1 == pytest.approx(300, abs=1.0)
|
||||
assert b.y1 == pytest.approx(400, abs=1.0)
|
||||
|
||||
def test_returns_table_box(self, monkeypatch, tmp_path):
|
||||
names = {3: "figure", 5: "table"}
|
||||
sess, (pw, ph) = self._build_mock_session(
|
||||
595, 842, [(5, 50, 50, 400, 300, 0.85)], names
|
||||
)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
boxes = detect_page_layout(page)
|
||||
|
||||
assert len(boxes) == 1
|
||||
assert boxes[0].boxclass == "table"
|
||||
|
||||
def test_returns_caption_box_with_small_height(self, monkeypatch, tmp_path):
|
||||
names = {4: "figure_caption"}
|
||||
sess, (pw, ph) = self._build_mock_session(
|
||||
595, 842, [(4, 100, 405, 300, 417, 0.9)], names
|
||||
)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
boxes = detect_page_layout(page)
|
||||
|
||||
assert len(boxes) == 1
|
||||
assert boxes[0].boxclass == "figure_caption"
|
||||
assert boxes[0].y1 - boxes[0].y0 == pytest.approx(12, abs=1.0)
|
||||
|
||||
def test_filters_low_confidence(self, monkeypatch, tmp_path):
|
||||
names = {3: "figure"}
|
||||
# conf=0.1 < LAYOUT_THRESHOLD(0.2) → 过滤
|
||||
sess, (pw, ph) = self._build_mock_session(
|
||||
595, 842, [(3, 100, 100, 300, 400, 0.1)], names
|
||||
)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
assert detect_page_layout(page) == []
|
||||
|
||||
def test_filters_small_box(self, monkeypatch, tmp_path):
|
||||
names = {3: "figure"}
|
||||
# 还原后 5×5 pt < _MIN_BOX_SIZE(20) → 过滤
|
||||
sess, (pw, ph) = self._build_mock_session(
|
||||
595, 842, [(3, 100, 100, 105, 105, 0.9)], names
|
||||
)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
assert detect_page_layout(page) == []
|
||||
|
||||
def test_mixed_picture_and_table(self, monkeypatch, tmp_path):
|
||||
names = {3: "figure", 5: "table"}
|
||||
sess, (pw, ph) = self._build_mock_session(
|
||||
595,
|
||||
842,
|
||||
[
|
||||
(3, 100, 100, 300, 400, 0.9),
|
||||
(5, 50, 500, 400, 700, 0.8),
|
||||
],
|
||||
names,
|
||||
)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
boxes = detect_page_layout(page)
|
||||
classes = sorted(b.boxclass for b in boxes)
|
||||
assert classes == ["picture", "table"]
|
||||
|
||||
def test_empty_output(self, monkeypatch, tmp_path):
|
||||
names = {3: "figure"}
|
||||
sess, (pw, ph) = self._build_mock_session(595, 842, [], names)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
assert detect_page_layout(page) == []
|
||||
|
||||
def test_ignored_class_skipped(self, monkeypatch, tmp_path):
|
||||
# title 类(cls_id=0)不应产出 LayoutBox
|
||||
names = {0: "title", 3: "figure"}
|
||||
sess, (pw, ph) = self._build_mock_session(
|
||||
595,
|
||||
842,
|
||||
[(0, 100, 100, 400, 150, 0.9), (3, 100, 200, 300, 400, 0.9)],
|
||||
names,
|
||||
)
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
page = self._make_mock_page(595, 842, pw, ph)
|
||||
|
||||
boxes = detect_page_layout(page)
|
||||
assert len(boxes) == 1
|
||||
assert boxes[0].boxclass == "picture"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 并发安全:锁串行化推理 + 单例 session 只初始化一次
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestDetectPageConcurrency:
|
||||
"""锁包裹整段 detect_page 后,并发调用的安全性。"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_detector(self):
|
||||
"""重建单例(带新锁),避免跨测试锁状态污染。"""
|
||||
mod._LayoutDetector.reset_instance()
|
||||
mod._detector = mod._LayoutDetector()
|
||||
yield
|
||||
mod._LayoutDetector.reset_instance()
|
||||
mod._detector = mod._LayoutDetector()
|
||||
|
||||
@staticmethod
|
||||
def _build_mock_session(page_w, page_h, boxes, names):
|
||||
"""同 TestDetectPage._build_mock_session,额外返回 fake_output 供 side_effect。"""
|
||||
ratio = _compute_render_geometry(page_w, page_h, IMGSZ)
|
||||
pix_w, pix_h = round(page_w * ratio), round(page_h * ratio)
|
||||
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
|
||||
rows = []
|
||||
for cls_id, x0, y0, x1, y1, conf in boxes:
|
||||
rows.append(
|
||||
[
|
||||
x0 * ratio + dw,
|
||||
y0 * ratio + dh,
|
||||
x1 * ratio + dw,
|
||||
y1 * ratio + dh,
|
||||
conf,
|
||||
cls_id,
|
||||
]
|
||||
)
|
||||
fake_output = (
|
||||
np.array([rows], dtype=np.float32)
|
||||
if rows
|
||||
else np.zeros((1, 0, 6), dtype=np.float32)
|
||||
)
|
||||
sess = MagicMock()
|
||||
inp = MagicMock()
|
||||
inp.name = "images"
|
||||
sess.get_inputs.return_value = [inp]
|
||||
sess.run.return_value = [fake_output]
|
||||
sess.get_providers.return_value = ["CPUExecutionProvider"]
|
||||
meta = MagicMock()
|
||||
meta.custom_metadata_map = {
|
||||
"names": json.dumps({str(k): v for k, v in names.items()})
|
||||
}
|
||||
sess.get_modelmeta.return_value = meta
|
||||
return sess, (pix_w, pix_h), fake_output
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_page(page_w, page_h, pix_w, pix_h):
|
||||
pix = MagicMock()
|
||||
pix.width = pix_w
|
||||
pix.height = pix_h
|
||||
pix.n = 3
|
||||
pix.samples = bytes([128] * (pix_w * pix_h * 3))
|
||||
page = MagicMock()
|
||||
page.rect.width = page_w
|
||||
page.rect.height = page_h
|
||||
page.get_pixmap.return_value = pix
|
||||
return page
|
||||
|
||||
def _setup(self, monkeypatch, tmp_path, sess):
|
||||
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
|
||||
(tmp_path / "m.onnx").write_bytes(b"x")
|
||||
monkeypatch.setattr(ort, "InferenceSession", lambda *a, **kw: sess)
|
||||
|
||||
def test_detect_page_serializes_concurrent_calls(self, monkeypatch, tmp_path):
|
||||
"""多线程并发调 detect_page_layout,session.run 临界区同时只有一个。"""
|
||||
sess, (pw, ph), fake_output = self._build_mock_session(
|
||||
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
|
||||
)
|
||||
in_critical = 0
|
||||
max_concurrent = 0
|
||||
counter_lock = threading.Lock()
|
||||
|
||||
def counting_run(*args, **kwargs):
|
||||
nonlocal in_critical, max_concurrent
|
||||
with counter_lock:
|
||||
in_critical += 1
|
||||
max_concurrent = max(max_concurrent, in_critical)
|
||||
time.sleep(0.02) # 放大竞争窗口,让并发线程有机会重叠
|
||||
try:
|
||||
return [fake_output]
|
||||
finally:
|
||||
with counter_lock:
|
||||
in_critical -= 1
|
||||
|
||||
sess.run.side_effect = counting_run
|
||||
self._setup(monkeypatch, tmp_path, sess)
|
||||
|
||||
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(8)]
|
||||
threads = [
|
||||
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# 锁生效 → 临界区同时只有一个;不加锁时此值会 > 1(回归保护)
|
||||
assert max_concurrent == 1
|
||||
|
||||
def test_session_created_once_under_concurrency(self, monkeypatch, tmp_path):
|
||||
"""多线程并发首次调用,InferenceSession 只创建一次(锁间接保护 _init_session)。"""
|
||||
sess, (pw, ph), _fake_output = self._build_mock_session(
|
||||
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
|
||||
)
|
||||
create_count = 0
|
||||
create_lock = threading.Lock()
|
||||
|
||||
def counting_init(*args, **kwargs):
|
||||
nonlocal create_count
|
||||
with create_lock:
|
||||
create_count += 1
|
||||
time.sleep(0.02) # 放大窗口,让并发首调都来抢
|
||||
return sess
|
||||
|
||||
monkeypatch.setattr(ort, "InferenceSession", counting_init)
|
||||
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
|
||||
(tmp_path / "m.onnx").write_bytes(b"x")
|
||||
|
||||
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(6)]
|
||||
threads = [
|
||||
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert create_count == 1
|
||||
@@ -0,0 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pymupdf
|
||||
|
||||
from app.services import pdf_image_extractor as mod
|
||||
from app.services.layout_detector import LayoutBox
|
||||
|
||||
|
||||
def test_process_page_extracts_doclayout_caption(tmp_path):
|
||||
images_dest = tmp_path / "images"
|
||||
images_dest.mkdir()
|
||||
manifest: dict[str, dict] = {}
|
||||
|
||||
pix = MagicMock()
|
||||
pix.tobytes.return_value = b"jpeg"
|
||||
|
||||
page = MagicMock()
|
||||
page.rect.width = 600
|
||||
page.get_pixmap.return_value = pix
|
||||
page.get_text.return_value = "Figure 1: Overall architecture.\n"
|
||||
|
||||
doc = MagicMock()
|
||||
doc.__getitem__.return_value = page
|
||||
|
||||
boxes = [
|
||||
LayoutBox(100, 100, 300, 300, "picture"),
|
||||
LayoutBox(95, 310, 320, 325, "figure_caption"),
|
||||
]
|
||||
|
||||
extracted = mod._process_page(
|
||||
doc,
|
||||
0,
|
||||
boxes,
|
||||
images_dest=images_dest,
|
||||
manifest=manifest,
|
||||
seen_labels=set(),
|
||||
arxiv_id="2401.00001",
|
||||
)
|
||||
|
||||
assert extracted == 1
|
||||
info = manifest["figure_(p1-1).jpg"]
|
||||
assert info["caption_text"] == "Figure 1: Overall architecture."
|
||||
assert info["caption_source"] == "doclayout"
|
||||
assert info["caption_box"] == [95.0, 310.0, 320.0, 325.0]
|
||||
|
||||
|
||||
def test_process_page_includes_caption_in_render(tmp_path):
|
||||
"""渲染时把 caption 区域合并进同一张截图。"""
|
||||
images_dest = tmp_path / "images"
|
||||
images_dest.mkdir()
|
||||
manifest: dict[str, dict] = {}
|
||||
|
||||
pix = MagicMock()
|
||||
pix.tobytes.return_value = b"jpeg"
|
||||
|
||||
page = MagicMock()
|
||||
page.rect.width = 600
|
||||
page.get_pixmap.return_value = pix
|
||||
page.get_text.return_value = "Figure 1: Caption text.\n"
|
||||
|
||||
doc = MagicMock()
|
||||
doc.__getitem__.return_value = page
|
||||
|
||||
boxes = [
|
||||
LayoutBox(100, 100, 300, 300, "picture"),
|
||||
LayoutBox(95, 310, 320, 325, "figure_caption"),
|
||||
]
|
||||
|
||||
mod._process_page(
|
||||
doc,
|
||||
0,
|
||||
boxes,
|
||||
images_dest=images_dest,
|
||||
manifest=manifest,
|
||||
seen_labels=set(),
|
||||
arxiv_id="2401.00001",
|
||||
)
|
||||
|
||||
# 内容 [100,100,300,300] ∪ caption [95,310,320,325],各方向加 _REGION_PADDING=5
|
||||
# → Rect(90, 95, 325, 330)
|
||||
clip = page.get_pixmap.call_args.kwargs["clip"]
|
||||
assert clip == pymupdf.Rect(90, 95, 325, 330)
|
||||
|
||||
|
||||
def test_label_images_preserves_doclayout_caption(tmp_path, monkeypatch):
|
||||
arxiv_id = "2401.00001"
|
||||
paper_root = tmp_path / arxiv_id
|
||||
images_dest = paper_root / "images"
|
||||
images_dest.mkdir(parents=True)
|
||||
(images_dest / "figure_(p1-1).jpg").write_bytes(b"jpeg")
|
||||
(images_dest / "manifest.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"figure_(p1-1).jpg": {
|
||||
"page": 1,
|
||||
"type": "figure",
|
||||
"label": "Figure (p1-1)",
|
||||
"box": [100, 100, 300, 300],
|
||||
"caption_text": "Figure 1: PDF original caption.",
|
||||
"caption_source": "doclayout",
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
pdf_path = tmp_path / "paper.pdf"
|
||||
pdf_path.write_bytes(b"%PDF")
|
||||
monkeypatch.setattr(mod, "paper_dir", lambda _arxiv_id: paper_root)
|
||||
|
||||
page = MagicMock()
|
||||
page.search_for.return_value = [pymupdf.Rect(120, 305, 180, 320)]
|
||||
|
||||
fake_doc = MagicMock()
|
||||
fake_doc.page_count = 1
|
||||
fake_doc.__getitem__.return_value = page
|
||||
fake_doc.__enter__.return_value = fake_doc
|
||||
fake_doc.__exit__.return_value = False
|
||||
monkeypatch.setattr(mod.pymupdf, "open", lambda _path: fake_doc)
|
||||
|
||||
labeled = mod.label_images_by_summary(
|
||||
arxiv_id,
|
||||
[{"id": "Figure 1", "caption": "Summary caption."}],
|
||||
pdf_path=pdf_path,
|
||||
)
|
||||
|
||||
assert labeled == 1
|
||||
manifest = json.loads((images_dest / "manifest.json").read_text())
|
||||
info = manifest["figure_1.jpg"]
|
||||
assert info["caption_text"] == "Figure 1: PDF original caption."
|
||||
assert info["caption_source"] == "doclayout"
|
||||
assert info["summary_caption_text"] == "Summary caption."
|
||||
@@ -7,11 +7,13 @@ import json
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.services.pi_client import (
|
||||
from app.services.summary_utils import (
|
||||
JsonNotFoundError,
|
||||
extract_json as _extract_json,
|
||||
)
|
||||
from app.services.pi_client import (
|
||||
PiProcessError,
|
||||
PiTimeoutError,
|
||||
extract_json as _extract_json,
|
||||
)
|
||||
from app.services.pdf_downloader import PdfDownloadError
|
||||
from app.services.schemas import (
|
||||
|
||||
@@ -366,6 +366,38 @@ class TestSummarizeOneFlow:
|
||||
result = await summarize_one(db_session, sample_paper)
|
||||
assert result["status"] == "skipped"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_processing_runs_in_thread(
|
||||
self, db_session, sample_paper, mock_pi_output, _summarize_tmp_paths
|
||||
):
|
||||
"""后处理(图片提取/ChromaDB)在工作线程而非事件循环线程执行。"""
|
||||
import threading
|
||||
|
||||
seen_threads: list[int] = []
|
||||
main_thread = threading.current_thread().ident
|
||||
|
||||
def spy_extract(arxiv_id, schema):
|
||||
seen_threads.append(threading.current_thread().ident)
|
||||
|
||||
with (
|
||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||
patch(
|
||||
"app.services.summary_generator.call_pi",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_pi_output, "test-session-id"),
|
||||
),
|
||||
patch(
|
||||
"app.services.summary_persister._maybe_extract_images",
|
||||
side_effect=spy_extract,
|
||||
),
|
||||
patch("app.services.summary_persister._maybe_index_chroma"),
|
||||
):
|
||||
result = await summarize_one(db_session, sample_paper)
|
||||
|
||||
assert result["status"] == "done"
|
||||
assert seen_threads, "后处理未被调用"
|
||||
assert seen_threads[0] != main_thread, "后处理应在工作线程执行,不阻塞事件循环"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# 批量操作测试
|
||||
|
||||
Reference in New Issue
Block a user