feat: add concurrency safety, caption detection, admin enhancements, and performance improvements
This commit is contained in:
@@ -26,6 +26,7 @@ class Settings(BaseSettings):
|
|||||||
HTTP_TIMEOUT_SECONDS: int = 30
|
HTTP_TIMEOUT_SECONDS: int = 30
|
||||||
HTTP_MAX_RETRIES: int = 3
|
HTTP_MAX_RETRIES: int = 3
|
||||||
HTTP_USER_AGENT: str = "hf-daily-papers-local/0.1"
|
HTTP_USER_AGENT: str = "hf-daily-papers-local/0.1"
|
||||||
|
HF_PROXY: str = ""
|
||||||
PDF_DOWNLOAD_TIMEOUT: int = 120
|
PDF_DOWNLOAD_TIMEOUT: int = 120
|
||||||
|
|
||||||
# AI 总结
|
# AI 总结
|
||||||
|
|||||||
+200
-5
@@ -2,8 +2,10 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import csv
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
|
import io
|
||||||
from datetime import date
|
from datetime import date
|
||||||
|
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
@@ -15,7 +17,7 @@ from fastapi import (
|
|||||||
Query,
|
Query,
|
||||||
Request,
|
Request,
|
||||||
)
|
)
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse, Response
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -298,6 +300,183 @@ async def admin_job_detail(
|
|||||||
return detail
|
return detail
|
||||||
|
|
||||||
|
|
||||||
|
# ── 任务监控 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/jobs")
|
||||||
|
async def admin_jobs(
|
||||||
|
request: Request,
|
||||||
|
_admin: None = Depends(verify_admin),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
status: str = Query("all"),
|
||||||
|
job_type: str = Query("all"),
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
per_page: int = Query(20, ge=1, le=100),
|
||||||
|
):
|
||||||
|
"""后台任务监控页。"""
|
||||||
|
jobs, total = admin_svc.query_jobs(
|
||||||
|
db, status=status, job_type=job_type, page=page, per_page=per_page
|
||||||
|
)
|
||||||
|
counts = admin_svc.get_job_status_counts(db)
|
||||||
|
|
||||||
|
def pagination_url(p: int) -> str:
|
||||||
|
params = dict(request.query_params)
|
||||||
|
params["page"] = str(p)
|
||||||
|
return "/admin/jobs?" + "&".join(f"{k}={v}" for k, v in params.items())
|
||||||
|
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
|
"admin_jobs.html",
|
||||||
|
{
|
||||||
|
"jobs": jobs,
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"per_page": per_page,
|
||||||
|
"current_status": status,
|
||||||
|
"current_type": job_type,
|
||||||
|
"status_counts": counts,
|
||||||
|
"pagination_url": pagination_url,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 锁管理 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/locks/{lock_id}/release")
|
||||||
|
async def admin_release_lock(
|
||||||
|
lock_id: int,
|
||||||
|
_admin: None = Depends(verify_admin),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""强制释放一个卡死的任务锁。"""
|
||||||
|
if not admin_svc.force_release_lock(db, lock_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Lock not found or already released: {lock_id}"
|
||||||
|
)
|
||||||
|
return {"status": "success", "lock_id": lock_id}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 重抓 ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/paper-recrawl/{arxiv_id}")
|
||||||
|
async def admin_paper_recrawl(
|
||||||
|
arxiv_id: str,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
_admin: None = Depends(verify_admin),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""重新抓取单篇已存在论文的完整元数据。"""
|
||||||
|
job = create_job(
|
||||||
|
db, "recrawl_one", owner="admin_recrawl", payload={"arxiv_id": arxiv_id}
|
||||||
|
)
|
||||||
|
enqueue_job(background_tasks, job.id)
|
||||||
|
return {"status": "queued", "job_id": job.id, "arxiv_id": arxiv_id}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 索引重建 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class RebuildIndexRequest(BaseModel):
|
||||||
|
target: str # "fts" / "chroma" / "both"
|
||||||
|
|
||||||
|
@field_validator("target")
|
||||||
|
@classmethod
|
||||||
|
def target_must_be_valid(cls, v: str) -> str:
|
||||||
|
if v not in ("fts", "chroma", "both"):
|
||||||
|
raise ValueError("target must be 'fts', 'chroma' or 'both'")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/rebuild-indexes")
|
||||||
|
async def admin_rebuild_indexes(
|
||||||
|
body: RebuildIndexRequest,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
_admin: None = Depends(verify_admin),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""重建搜索索引(FTS5 / ChromaDB)。"""
|
||||||
|
job_ids: list[int] = []
|
||||||
|
if body.target in ("fts", "both"):
|
||||||
|
job = create_job(db, "reindex_fts", owner="admin_reindex", payload={})
|
||||||
|
enqueue_job(background_tasks, job.id)
|
||||||
|
job_ids.append(job.id)
|
||||||
|
if body.target in ("chroma", "both"):
|
||||||
|
job = create_job(db, "reindex_chroma", owner="admin_reindex", payload={})
|
||||||
|
enqueue_job(background_tasks, job.id)
|
||||||
|
job_ids.append(job.id)
|
||||||
|
return {"status": "queued", "job_ids": job_ids, "target": body.target}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 导出 CSV ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/papers/export.csv")
|
||||||
|
async def admin_papers_export(
|
||||||
|
_admin: None = Depends(verify_admin),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
q: str = Query(""),
|
||||||
|
date_from: str | None = Query(None),
|
||||||
|
date_to: str | None = Query(None),
|
||||||
|
tag: str = Query(""),
|
||||||
|
summary_status: str = Query("all"),
|
||||||
|
sort: str = Query("date_desc"),
|
||||||
|
):
|
||||||
|
"""导出当前过滤条件下的论文为 CSV(含 UTF-8 BOM,Excel 友好)。"""
|
||||||
|
papers, _total, statuses = admin_svc.query_papers(
|
||||||
|
db,
|
||||||
|
q=q,
|
||||||
|
date_from=date_from,
|
||||||
|
date_to=date_to,
|
||||||
|
tag=tag,
|
||||||
|
summary_status=summary_status,
|
||||||
|
sort=sort,
|
||||||
|
page=1,
|
||||||
|
per_page=10**6,
|
||||||
|
)
|
||||||
|
|
||||||
|
buf = io.StringIO()
|
||||||
|
buf.write("") # UTF-8 BOM for Excel
|
||||||
|
writer = csv.writer(buf)
|
||||||
|
writer.writerow(
|
||||||
|
[
|
||||||
|
"arxiv_id",
|
||||||
|
"title_en",
|
||||||
|
"title_zh",
|
||||||
|
"paper_date",
|
||||||
|
"upvotes",
|
||||||
|
"summary_status",
|
||||||
|
"authors",
|
||||||
|
"tags",
|
||||||
|
"pdf_url",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for paper in papers:
|
||||||
|
authors = ";".join(a.name for a in paper.authors)
|
||||||
|
tags = ";".join(t.tag for t in paper.tags)
|
||||||
|
writer.writerow(
|
||||||
|
[
|
||||||
|
paper.arxiv_id,
|
||||||
|
paper.title_en or "",
|
||||||
|
paper.title_zh or "",
|
||||||
|
str(paper.paper_date) if paper.paper_date else "",
|
||||||
|
paper.upvotes or 0,
|
||||||
|
statuses.get(paper.arxiv_id, "none"),
|
||||||
|
authors,
|
||||||
|
tags,
|
||||||
|
paper.pdf_url or "",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
filename = f"papers_{today_str().replace('-', '')}.csv"
|
||||||
|
return Response(
|
||||||
|
content=buf.getvalue(),
|
||||||
|
media_type="text/csv; charset=utf-8",
|
||||||
|
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── 日志 ──────────────────────────────────────────────────────────────
|
# ── 日志 ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -438,24 +617,25 @@ async def admin_paper_delete(
|
|||||||
|
|
||||||
|
|
||||||
class BatchActionRequest(BaseModel):
|
class BatchActionRequest(BaseModel):
|
||||||
action: str # "delete" or "summarize"
|
action: str # "delete" / "summarize" / "recrawl"
|
||||||
arxiv_ids: list[str]
|
arxiv_ids: list[str]
|
||||||
|
|
||||||
@field_validator("action")
|
@field_validator("action")
|
||||||
@classmethod
|
@classmethod
|
||||||
def action_must_be_valid(cls, v: str) -> str:
|
def action_must_be_valid(cls, v: str) -> str:
|
||||||
if v not in ("delete", "summarize"):
|
if v not in ("delete", "summarize", "recrawl"):
|
||||||
raise ValueError("action must be 'delete' or 'summarize'")
|
raise ValueError("action must be 'delete', 'summarize' or 'recrawl'")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
@router.post("/papers-batch-action")
|
@router.post("/papers-batch-action")
|
||||||
async def admin_papers_batch_action(
|
async def admin_papers_batch_action(
|
||||||
body: BatchActionRequest,
|
body: BatchActionRequest,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
_admin: None = Depends(verify_admin),
|
_admin: None = Depends(verify_admin),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""批量操作论文(删除或总结)。"""
|
"""批量操作论文(删除 / 总结 / 重抓)。"""
|
||||||
if not body.arxiv_ids:
|
if not body.arxiv_ids:
|
||||||
raise HTTPException(status_code=400, detail="arxiv_ids 不能为空")
|
raise HTTPException(status_code=400, detail="arxiv_ids 不能为空")
|
||||||
|
|
||||||
@@ -475,3 +655,18 @@ async def admin_papers_batch_action(
|
|||||||
"message": f"已将 {count} 篇论文重置为待总结",
|
"message": f"已将 {count} 篇论文重置为待总结",
|
||||||
"count": count,
|
"count": count,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
elif body.action == "recrawl":
|
||||||
|
job = create_job(
|
||||||
|
db,
|
||||||
|
"recrawl_batch",
|
||||||
|
owner="admin_recrawl",
|
||||||
|
payload={"arxiv_ids": body.arxiv_ids},
|
||||||
|
)
|
||||||
|
enqueue_job(background_tasks, job.id)
|
||||||
|
return {
|
||||||
|
"status": "queued",
|
||||||
|
"job_id": job.id,
|
||||||
|
"count": len(body.arxiv_ids),
|
||||||
|
"message": f"已将 {len(body.arxiv_ids)} 篇论文加入重抓队列",
|
||||||
|
}
|
||||||
|
|||||||
+129
-2
@@ -7,7 +7,7 @@ from datetime import date
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from sqlalchemy import func, select, text
|
from sqlalchemy import func, select, text, update
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
@@ -100,7 +100,9 @@ def get_admin_stats(db: Session) -> dict:
|
|||||||
|
|
||||||
# ── 活跃锁 ────────────────────────────────────────────────────────
|
# ── 活跃锁 ────────────────────────────────────────────────────────
|
||||||
active_locks = (
|
active_locks = (
|
||||||
db.execute(select(TaskLock).where(TaskLock.status == "running")).scalars().all()
|
db.execute(select(TaskLock).where(TaskLock.status.in_(["running", "stale"])))
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -124,6 +126,7 @@ def get_admin_stats(db: Session) -> dict:
|
|||||||
"recent_logs": recent_logs,
|
"recent_logs": recent_logs,
|
||||||
"active_locks": active_locks,
|
"active_locks": active_locks,
|
||||||
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
|
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
|
||||||
|
"config_overview": get_config_overview(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -370,6 +373,7 @@ def get_logs_context(db: Session, *, page: int, per_page: int) -> dict:
|
|||||||
"summary_done": summary_done,
|
"summary_done": summary_done,
|
||||||
"summary_pending": summary_pending,
|
"summary_pending": summary_pending,
|
||||||
"summary_failed": summary_failed,
|
"summary_failed": summary_failed,
|
||||||
|
"failure_breakdown": get_failure_breakdown(db),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -511,3 +515,126 @@ def reset_summaries_pending(db: Session, arxiv_ids: list[str]) -> int:
|
|||||||
db.add(SummaryStatus(paper_id=paper_id, status=SummaryState.PENDING))
|
db.add(SummaryStatus(paper_id=paper_id, status=SummaryState.PENDING))
|
||||||
db.commit()
|
db.commit()
|
||||||
return len(paper_ids)
|
return len(paper_ids)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 任务监控 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def query_jobs(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
status: str | None = None,
|
||||||
|
job_type: str | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
per_page: int = 20,
|
||||||
|
) -> tuple[list[dict], int]:
|
||||||
|
"""后台任务列表查询 — 支持 status/type 过滤 + 分页,返回已 enrich 的 dict 列表。"""
|
||||||
|
query = select(Job)
|
||||||
|
if status and status != "all":
|
||||||
|
query = query.where(Job.status == status)
|
||||||
|
if job_type and job_type != "all":
|
||||||
|
query = query.where(Job.type == job_type)
|
||||||
|
|
||||||
|
total = db.scalar(select(func.count()).select_from(query.subquery())) or 0
|
||||||
|
jobs = (
|
||||||
|
db.execute(
|
||||||
|
query.order_by(Job.created_at.desc())
|
||||||
|
.offset((page - 1) * per_page)
|
||||||
|
.limit(per_page)
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [serialize_job(j) for j in jobs], total
|
||||||
|
|
||||||
|
|
||||||
|
def _as_naive(dt):
|
||||||
|
"""去掉 tzinfo — SQLite 读回的 datetime 是 naive UTC,与 utc_now() 运算前需统一。"""
|
||||||
|
if dt is not None and getattr(dt, "tzinfo", None) is not None:
|
||||||
|
return dt.replace(tzinfo=None)
|
||||||
|
return dt
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_job(job: Job) -> dict:
|
||||||
|
"""单条 job 序列化为展示用 dict(含耗时)。"""
|
||||||
|
duration = None
|
||||||
|
started = _as_naive(job.started_at)
|
||||||
|
if started:
|
||||||
|
end = _as_naive(job.completed_at) or _as_naive(utc_now())
|
||||||
|
duration = round((end - started).total_seconds(), 1)
|
||||||
|
return {
|
||||||
|
"id": job.id,
|
||||||
|
"type": job.type,
|
||||||
|
"status": job.status,
|
||||||
|
"owner": job.owner,
|
||||||
|
"created_at": job.created_at,
|
||||||
|
"started_at": job.started_at,
|
||||||
|
"completed_at": job.completed_at,
|
||||||
|
"duration_seconds": duration,
|
||||||
|
"error": job.error,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_status_counts(db: Session) -> dict:
|
||||||
|
"""按 status 聚合 job 计数,供任务页顶部小统计行用。"""
|
||||||
|
rows = db.execute(
|
||||||
|
select(Job.status, func.count(Job.id)).group_by(Job.status)
|
||||||
|
).fetchall()
|
||||||
|
return {row[0]: row[1] for row in rows}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 锁管理 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def force_release_lock(db: Session, lock_id: int) -> bool:
|
||||||
|
"""强制释放一个卡死的 TaskLock(仅对 running/stale 生效)。"""
|
||||||
|
result = db.execute(
|
||||||
|
update(TaskLock)
|
||||||
|
.where(TaskLock.id == lock_id, TaskLock.status.in_(["running", "stale"]))
|
||||||
|
.values(status="finished", released_at=utc_now())
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
return (result.rowcount or 0) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── 失败原因分布 ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def get_failure_breakdown(db: Session) -> list[dict]:
|
||||||
|
"""按 error_type 聚合失败/永久失败的总结,按数量降序。NULL 归 unknown。"""
|
||||||
|
error_expr = func.coalesce(SummaryStatus.error_type, "unknown")
|
||||||
|
rows = db.execute(
|
||||||
|
select(error_expr, func.count(SummaryStatus.id))
|
||||||
|
.where(
|
||||||
|
SummaryStatus.status.in_(
|
||||||
|
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.group_by(error_expr)
|
||||||
|
.order_by(func.count(SummaryStatus.id).desc())
|
||||||
|
).fetchall()
|
||||||
|
return [{"error_type": row[0], "count": row[1]} for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
# ── 运行配置概览 ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_overview() -> dict:
|
||||||
|
"""聚合非敏感配置,供仪表盘展示。敏感字段只标是否已配置,不显示值。"""
|
||||||
|
return {
|
||||||
|
"summary_backend": settings.SUMMARY_BACKEND,
|
||||||
|
"summary_pdf_mode": settings.SUMMARY_PDF_MODE,
|
||||||
|
"summary_concurrency": settings.SUMMARY_CONCURRENCY,
|
||||||
|
"summary_timeout_seconds": settings.SUMMARY_TIMEOUT_SECONDS,
|
||||||
|
"summary_max_retries": settings.SUMMARY_MAX_RETRIES,
|
||||||
|
"scheduler_enabled": settings.SCHEDULER_ENABLED,
|
||||||
|
"schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}",
|
||||||
|
"chroma_enabled": settings.CHROMA_ENABLED,
|
||||||
|
"embed_model": settings.EMBED_MODEL or "(未配置)",
|
||||||
|
"top_n": settings.TOP_N,
|
||||||
|
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
|
||||||
|
"app_workers": settings.APP_WORKERS,
|
||||||
|
"layout_model": Path(settings.LAYOUT_MODEL_PATH).name,
|
||||||
|
"database_url": settings.DATABASE_URL,
|
||||||
|
"api_key_configured": bool(settings.EMBED_API_KEY),
|
||||||
|
}
|
||||||
|
|||||||
@@ -270,3 +270,67 @@ def _update_upvotes_only(db: Session, papers_raw: list[dict]) -> int:
|
|||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return updated
|
return updated
|
||||||
|
|
||||||
|
|
||||||
|
async def recrawl_single(db: Session, arxiv_id: str) -> dict:
|
||||||
|
"""重新抓取一篇已存在论文的完整元数据。
|
||||||
|
|
||||||
|
基于 paper.paper_date 重新拉取 HF Daily 列表,命中后全字段刷新
|
||||||
|
(标题/摘要/作者/标签/链接/upvotes)并重建 FTS。若该论文不在其收录日的
|
||||||
|
列表中则无法重抓。
|
||||||
|
"""
|
||||||
|
paper = db.execute(
|
||||||
|
select(Paper).where(Paper.arxiv_id == arxiv_id)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
if not paper:
|
||||||
|
return {"updated": False, "reason": "not_found", "arxiv_id": arxiv_id}
|
||||||
|
|
||||||
|
target_date = paper.paper_date.isoformat()
|
||||||
|
raw_papers = await fetch_daily(target_date)
|
||||||
|
|
||||||
|
target = None
|
||||||
|
for item in raw_papers:
|
||||||
|
if _parse_paper(item)["arxiv_id"] == arxiv_id:
|
||||||
|
target = item
|
||||||
|
break
|
||||||
|
|
||||||
|
if target is None:
|
||||||
|
return {
|
||||||
|
"updated": False,
|
||||||
|
"reason": "not_in_daily",
|
||||||
|
"arxiv_id": arxiv_id,
|
||||||
|
"date": target_date,
|
||||||
|
}
|
||||||
|
|
||||||
|
meta = _parse_paper(target)
|
||||||
|
now = utc_now()
|
||||||
|
|
||||||
|
# 全字段刷新
|
||||||
|
paper.title_en = meta["title_en"]
|
||||||
|
paper.abstract = meta["abstract"]
|
||||||
|
paper.published_at = meta["published_at"]
|
||||||
|
paper.hf_url = meta["hf_url"]
|
||||||
|
paper.arxiv_url = meta["arxiv_url"]
|
||||||
|
paper.pdf_url = meta["pdf_url"]
|
||||||
|
paper.upvotes = meta["upvotes"]
|
||||||
|
paper.crawled_at = now
|
||||||
|
|
||||||
|
# 重建 authors(删旧再加新)
|
||||||
|
paper.authors.clear()
|
||||||
|
seen_authors: set[str] = set()
|
||||||
|
for idx, name in enumerate(meta["authors"]):
|
||||||
|
if name and name not in seen_authors:
|
||||||
|
seen_authors.add(name)
|
||||||
|
db.add(PaperAuthor(paper_id=paper.id, name=name, position=idx))
|
||||||
|
|
||||||
|
# 重建 tags
|
||||||
|
paper.tags.clear()
|
||||||
|
for tag_name in meta["tags"]:
|
||||||
|
if tag_name:
|
||||||
|
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="hf"))
|
||||||
|
|
||||||
|
db.flush()
|
||||||
|
reindex_paper_fts(db, paper)
|
||||||
|
db.commit()
|
||||||
|
logger.info("Re-crawled paper %s (full metadata refresh)", arxiv_id)
|
||||||
|
return {"updated": True, "arxiv_id": arxiv_id, "date": target_date}
|
||||||
|
|||||||
+78
-40
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@@ -18,14 +19,27 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ChromaManager:
|
class ChromaManager:
|
||||||
"""封装 ChromaDB 客户端和 collection 的生命周期。"""
|
"""封装 ChromaDB 客户端和 collection 的生命周期。
|
||||||
|
|
||||||
|
所有客户端/集合访问经 ``self._lock`` 串行化:后处理经 ``asyncio.to_thread``
|
||||||
|
多 worker 并发调 ``index_paper``,若不串行化会并发建连,触发 chromadb 1.5.x
|
||||||
|
``SharedSystemClient`` 类级缓存的并发竞争(``_create_system_if_not_exists``
|
||||||
|
无锁 + refcount release 弹 key)→ ``KeyError: '<persist_dir>'``。
|
||||||
|
锁用 RLock:``index_paper`` 持锁后经 ``get_collection()`` 间接再调 ``init()``,
|
||||||
|
同线程可重入,不死锁。
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
self._lock = threading.RLock()
|
||||||
self._client = None
|
self._client = None
|
||||||
self._collection = None
|
self._collection = None
|
||||||
|
|
||||||
def init(self) -> None:
|
def init(self) -> None:
|
||||||
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
|
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。
|
||||||
|
|
||||||
|
双重检查锁串行化首次建连 —— 外层快判已建好就直接返回(快路径),抢到锁后再
|
||||||
|
判一次(防并发下另一线程已建好),确保 ``PersistentClient`` 全进程只调一次。
|
||||||
|
"""
|
||||||
if not settings.CHROMA_ENABLED:
|
if not settings.CHROMA_ENABLED:
|
||||||
logger.debug("ChromaDB disabled, skip init")
|
logger.debug("ChromaDB disabled, skip init")
|
||||||
return
|
return
|
||||||
@@ -33,19 +47,23 @@ class ChromaManager:
|
|||||||
if self._client is not None:
|
if self._client is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
with self._lock:
|
||||||
import chromadb
|
if self._client is not None: # 双重检查:抢到锁后可能已被别的线程建好
|
||||||
|
return
|
||||||
|
|
||||||
chroma_path = Path(settings.CHROMA_DIR)
|
try:
|
||||||
chroma_path.mkdir(parents=True, exist_ok=True)
|
import chromadb
|
||||||
|
|
||||||
self._client = chromadb.PersistentClient(path=str(chroma_path))
|
chroma_path = Path(settings.CHROMA_DIR)
|
||||||
self._collection = self._get_or_create_collection()
|
chroma_path.mkdir(parents=True, exist_ok=True)
|
||||||
logger.info("ChromaDB initialized at %s", chroma_path)
|
|
||||||
except Exception:
|
self._client = chromadb.PersistentClient(path=str(chroma_path))
|
||||||
logger.exception("Failed to initialize ChromaDB")
|
self._collection = self._get_or_create_collection()
|
||||||
self._client = None
|
logger.info("ChromaDB initialized at %s", chroma_path)
|
||||||
self._collection = None
|
except Exception:
|
||||||
|
logger.exception("Failed to initialize ChromaDB")
|
||||||
|
self._client = None
|
||||||
|
self._collection = None
|
||||||
|
|
||||||
def _get_or_create_collection(self):
|
def _get_or_create_collection(self):
|
||||||
"""获取或创建 papers_embeddings collection。"""
|
"""获取或创建 papers_embeddings collection。"""
|
||||||
@@ -102,6 +120,8 @@ def _get_embedding(text: str) -> list[float] | None:
|
|||||||
POST /v1/embeddings, model=EMBED_MODEL
|
POST /v1/embeddings, model=EMBED_MODEL
|
||||||
校验返回向量长度 == EMBED_DIMENSIONS
|
校验返回向量长度 == EMBED_DIMENSIONS
|
||||||
失败时返回 None 并记录日志。
|
失败时返回 None 并记录日志。
|
||||||
|
|
||||||
|
纯远程 HTTP 调用、线程安全 —— 留在锁外,让多 worker 并行调。
|
||||||
"""
|
"""
|
||||||
if not settings.EMBED_API_BASE or not settings.EMBED_MODEL:
|
if not settings.EMBED_API_BASE or not settings.EMBED_MODEL:
|
||||||
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
|
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
|
||||||
@@ -177,9 +197,11 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True 表示成功,False 表示失败或跳过。
|
True 表示成功,False 表示失败或跳过。
|
||||||
|
|
||||||
|
并发设计:远程 embedding 调用在锁外(多 worker 并行),chroma 集合访问
|
||||||
|
(含首次 init)在 ``_chroma._lock`` 内串行化。
|
||||||
"""
|
"""
|
||||||
col = get_collection()
|
if not settings.CHROMA_ENABLED:
|
||||||
if col is None:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -227,17 +249,25 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
|
|||||||
logger.warning("Empty index text for %s, skip", paper_id)
|
logger.warning("Empty index text for %s, skip", paper_id)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
vec = _get_embedding(index_text)
|
vec = _get_embedding(index_text) # 远程 HTTP,锁外并行
|
||||||
if vec is None:
|
if vec is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
col.upsert(
|
with _chroma._lock: # 串行化集合访问(首次含 init)
|
||||||
ids=[arxiv_id],
|
col = _chroma.get_collection()
|
||||||
embeddings=[vec],
|
if col is None:
|
||||||
metadatas=[
|
return False
|
||||||
{"arxiv_id": arxiv_id, "title_zh": title_zh, "paper_date": paper_date}
|
col.upsert(
|
||||||
],
|
ids=[arxiv_id],
|
||||||
)
|
embeddings=[vec],
|
||||||
|
metadatas=[
|
||||||
|
{
|
||||||
|
"arxiv_id": arxiv_id,
|
||||||
|
"title_zh": title_zh,
|
||||||
|
"paper_date": paper_date,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
logger.info("Indexed paper %s in ChromaDB", arxiv_id)
|
logger.info("Indexed paper %s in ChromaDB", arxiv_id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -255,17 +285,20 @@ def delete_paper(paper_id: str) -> bool:
|
|||||||
Args:
|
Args:
|
||||||
paper_id: arxiv_id
|
paper_id: arxiv_id
|
||||||
"""
|
"""
|
||||||
col = get_collection()
|
if not settings.CHROMA_ENABLED:
|
||||||
if col is None:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
with _chroma._lock:
|
||||||
col.delete(ids=[paper_id])
|
col = _chroma.get_collection()
|
||||||
logger.info("Deleted paper %s from ChromaDB", paper_id)
|
if col is None:
|
||||||
return True
|
return False
|
||||||
except Exception:
|
try:
|
||||||
logger.exception("Failed to delete paper %s from ChromaDB", paper_id)
|
col.delete(ids=[paper_id])
|
||||||
return False
|
logger.info("Deleted paper %s from ChromaDB", paper_id)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to delete paper %s from ChromaDB", paper_id)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# ── 相似查询 ────────────────────────────────────────────────────────────
|
# ── 相似查询 ────────────────────────────────────────────────────────────
|
||||||
@@ -280,21 +313,26 @@ def search_similar(query_text: str, top_k: int = 20) -> list[dict]:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[{"arxiv_id": str, "distance": float}, ...]
|
[{"arxiv_id": str, "distance": float}, ...]
|
||||||
|
|
||||||
|
并发设计:远程 embedding 在锁外,集合查询在 ``_chroma._lock`` 内。
|
||||||
"""
|
"""
|
||||||
col = get_collection()
|
if not settings.CHROMA_ENABLED:
|
||||||
if col is None:
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
vec = _get_embedding(query_text)
|
vec = _get_embedding(query_text) # 远程 HTTP,锁外
|
||||||
if vec is None:
|
if vec is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
results = col.query(
|
with _chroma._lock:
|
||||||
query_embeddings=[vec],
|
col = _chroma.get_collection()
|
||||||
n_results=min(top_k, col.count()) if col.count() > 0 else top_k,
|
if col is None:
|
||||||
include=["metadatas", "distances"],
|
return []
|
||||||
)
|
results = col.query(
|
||||||
|
query_embeddings=[vec],
|
||||||
|
n_results=min(top_k, col.count()) if col.count() > 0 else top_k,
|
||||||
|
include=["metadatas", "distances"],
|
||||||
|
)
|
||||||
|
|
||||||
if not results["ids"] or not results["ids"][0]:
|
if not results["ids"] or not results["ids"][0]:
|
||||||
return []
|
return []
|
||||||
|
|||||||
+15
-1
@@ -148,7 +148,7 @@ async def run_job(db: Session, job_id: int) -> dict:
|
|||||||
|
|
||||||
async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
|
async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
|
||||||
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
|
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
|
||||||
from app.services.crawler import refresh_upvotes
|
from app.services.crawler import recrawl_single, refresh_upvotes
|
||||||
from app.services.derived import reindex_chroma, reindex_fts
|
from app.services.derived import reindex_chroma, reindex_fts
|
||||||
from app.services.pipeline import run_crawl, run_pipeline
|
from app.services.pipeline import run_crawl, run_pipeline
|
||||||
from app.services.summarizer import summarize_batch, summarize_single
|
from app.services.summarizer import summarize_batch, summarize_single
|
||||||
@@ -193,6 +193,20 @@ async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
|
|||||||
return reindex_fts(db)
|
return reindex_fts(db)
|
||||||
if job.type == "reindex_chroma":
|
if job.type == "reindex_chroma":
|
||||||
return reindex_chroma(db)
|
return reindex_chroma(db)
|
||||||
|
if job.type == "recrawl_one":
|
||||||
|
return await recrawl_single(db, payload["arxiv_id"])
|
||||||
|
if job.type == "recrawl_batch":
|
||||||
|
updated = 0
|
||||||
|
skipped = 0
|
||||||
|
results = []
|
||||||
|
for arxiv_id in payload.get("arxiv_ids", []):
|
||||||
|
res = await recrawl_single(db, arxiv_id)
|
||||||
|
results.append(res)
|
||||||
|
if res.get("updated"):
|
||||||
|
updated += 1
|
||||||
|
else:
|
||||||
|
skipped += 1
|
||||||
|
return {"updated": updated, "skipped": skipped, "results": results}
|
||||||
|
|
||||||
raise ValueError(f"Unsupported job type: {job.type}")
|
raise ValueError(f"Unsupported job type: {job.type}")
|
||||||
|
|
||||||
|
|||||||
@@ -23,8 +23,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
@@ -47,14 +49,18 @@ _FALLBACK_NAMES: dict[int, str] = {
|
|||||||
8: "isolate_formula",
|
8: "isolate_formula",
|
||||||
9: "formula_caption",
|
9: "formula_caption",
|
||||||
}
|
}
|
||||||
# 下游只需 picture/table —— 按 class name 字符串动态匹配(不依赖 class index,
|
# 下游需要 picture/table 及其 caption —— 按 class name 字符串动态匹配(不依赖 class index,
|
||||||
# 规避 DocStructBench 不同发布的类别顺序差异)
|
# 规避 DocStructBench 不同发布的类别顺序差异)
|
||||||
_PICTURE_NAMES = {"figure", "figure_group"}
|
_PICTURE_NAMES = {"figure", "figure_group"}
|
||||||
_TABLE_NAMES = {"table", "table_group"}
|
_TABLE_NAMES = {"table", "table_group"}
|
||||||
|
_FIGURE_CAPTION_NAMES = {"figure_caption"}
|
||||||
|
_TABLE_CAPTION_NAMES = {"table_caption"}
|
||||||
# letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降)
|
# letterbox 灰边值(ultralytics 训练标准,不可改为 0/128,否则精度下降)
|
||||||
_PAD_VALUE = 114
|
_PAD_VALUE = 114
|
||||||
# 最小 bbox 尺寸(PDF 点)
|
# 最小 bbox 尺寸(PDF 点)
|
||||||
_MIN_BOX_SIZE = 20
|
_MIN_BOX_SIZE = 20
|
||||||
|
_MIN_CAPTION_BOX_WIDTH = 30
|
||||||
|
_MIN_CAPTION_BOX_HEIGHT = 6
|
||||||
|
|
||||||
# device → ExecutionProvider 映射
|
# device → ExecutionProvider 映射
|
||||||
_PROVIDER_MAP: dict[str, str] = {
|
_PROVIDER_MAP: dict[str, str] = {
|
||||||
@@ -72,7 +78,7 @@ _AUTO_PRIORITY = ["cuda", "directml", "openvino", "cann", "tensorrt", "qnn"]
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LayoutBox:
|
class LayoutBox:
|
||||||
"""检测到的布局区域,坐标为 PDF 点,boxclass ∈ {"picture", "table"}。"""
|
"""检测到的布局区域,坐标为 PDF 点。"""
|
||||||
|
|
||||||
x0: float
|
x0: float
|
||||||
y0: float
|
y0: float
|
||||||
@@ -191,13 +197,17 @@ def _postprocess_output(
|
|||||||
|
|
||||||
|
|
||||||
def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None:
|
def _map_class_to_boxclass(cls_id: int, names: dict[int, str]) -> str | None:
|
||||||
"""按 class name 匹配 figure→picture / table→table,其余返回 None。"""
|
"""按 class name 匹配下游关心的布局类别,其余返回 None。"""
|
||||||
name = names.get(cls_id, "")
|
name = names.get(cls_id, "")
|
||||||
n = name.strip().lower()
|
n = name.strip().lower()
|
||||||
if n in _PICTURE_NAMES:
|
if n in _PICTURE_NAMES:
|
||||||
return "picture"
|
return "picture"
|
||||||
if n in _TABLE_NAMES:
|
if n in _TABLE_NAMES:
|
||||||
return "table"
|
return "table"
|
||||||
|
if n in _FIGURE_CAPTION_NAMES:
|
||||||
|
return "figure_caption"
|
||||||
|
if n in _TABLE_CAPTION_NAMES:
|
||||||
|
return "table_caption"
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -220,15 +230,50 @@ def _parse_names_from_meta(session: ort.InferenceSession) -> dict[int, str]:
|
|||||||
# ── 检测器单例 ──────────────────────────────────────────────────────────
|
# ── 检测器单例 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class _LayoutDetector:
|
class _Singleton(type):
|
||||||
"""单例:管理 ONNX InferenceSession 生命周期。"""
|
"""元类单例:``cls()`` 永远返回同一实例;``reset_instance()`` 清缓存以便重建。
|
||||||
|
|
||||||
|
生产代码只应在模块级 ``_detector = _LayoutDetector()`` 创建一次。任何第二处
|
||||||
|
``_LayoutDetector()`` 都会拿到同一实例(含同一 ONNX session + 同一锁),杜绝
|
||||||
|
并发推理时各建一份 session 导致内存峰值翻倍(8GB 机器崩溃根因)。双检锁保证
|
||||||
|
首次实例化线程安全。
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instances: dict[type, Any] = {}
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
def __call__(cls, *args, **kwargs):
|
||||||
|
if cls in _Singleton._instances:
|
||||||
|
return _Singleton._instances[cls]
|
||||||
|
with _Singleton._lock:
|
||||||
|
if cls not in _Singleton._instances:
|
||||||
|
_Singleton._instances[cls] = super().__call__(*args, **kwargs)
|
||||||
|
return _Singleton._instances[cls]
|
||||||
|
|
||||||
|
|
||||||
|
class _LayoutDetector(metaclass=_Singleton):
|
||||||
|
"""强约束单例:管理 ONNX InferenceSession 生命周期。
|
||||||
|
|
||||||
|
由 ``_Singleton`` 元类保证全进程唯一实例 —— 重复 ``_LayoutDetector()`` 只会返回
|
||||||
|
已有实例(含已加载的 session 和锁),不会新建。``reset_instance()`` 清缓存,仅供
|
||||||
|
测试隔离用。
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
self._lock = threading.Lock()
|
||||||
self._session: ort.InferenceSession | None = None
|
self._session: ort.InferenceSession | None = None
|
||||||
self._names: dict[int, str] = {}
|
self._names: dict[int, str] = {}
|
||||||
self._input_name: str = ""
|
self._input_name: str = ""
|
||||||
self._imgsz: int = settings.LAYOUT_IMGSZ
|
self._imgsz: int = settings.LAYOUT_IMGSZ
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset_instance(cls) -> None:
|
||||||
|
"""清空单例缓存,下次 ``_LayoutDetector()`` 重建新实例(含新锁 + 空 session)。
|
||||||
|
|
||||||
|
仅用于测试隔离 —— 生产代码永远不该调用(否则会丢掉已加载的模型 session)。
|
||||||
|
"""
|
||||||
|
_Singleton._instances.pop(cls, None)
|
||||||
|
|
||||||
def _init_session(self) -> ort.InferenceSession:
|
def _init_session(self) -> ort.InferenceSession:
|
||||||
if self._session is not None:
|
if self._session is not None:
|
||||||
return self._session
|
return self._session
|
||||||
@@ -275,7 +320,7 @@ class _LayoutDetector:
|
|||||||
self._imgsz = settings.LAYOUT_IMGSZ
|
self._imgsz = settings.LAYOUT_IMGSZ
|
||||||
return self._session
|
return self._session
|
||||||
|
|
||||||
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
|
def _detect_page_impl(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||||||
"""检测单页 PDF 的 figure / table 区域。
|
"""检测单页 PDF 的 figure / table 区域。
|
||||||
|
|
||||||
流程:
|
流程:
|
||||||
@@ -323,21 +368,38 @@ class _LayoutDetector:
|
|||||||
y0 = max(0.0, min(y0, page_h))
|
y0 = max(0.0, min(y0, page_h))
|
||||||
x1 = max(0.0, min(x1, page_w))
|
x1 = max(0.0, min(x1, page_w))
|
||||||
y1 = max(0.0, min(y1, page_h))
|
y1 = max(0.0, min(y1, page_h))
|
||||||
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
|
if boxclass in ("figure_caption", "table_caption"):
|
||||||
continue
|
if (x1 - x0) < _MIN_CAPTION_BOX_WIDTH or (
|
||||||
|
y1 - y0
|
||||||
|
) < _MIN_CAPTION_BOX_HEIGHT:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
|
||||||
|
continue
|
||||||
result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass))
|
result.append(LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=boxclass))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||||||
|
"""公共入口:加锁串行化推理。
|
||||||
|
|
||||||
# 模块级单例
|
包裹整段 _detect_page_impl(含 pixmap 渲染 + tensor 构造 + session.run),
|
||||||
|
保证同一时刻只有一个推理在跑——避免 SUMMARY_CONCURRENCY>1 时多个 to_thread
|
||||||
|
线程并发推理导致内存峰值翻倍(8GB 机器崩溃根因)。锁由 _detect_page_impl
|
||||||
|
间接保护 _init_session,首次加载也串行,杜绝并发各建一份 session。
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
return self._detect_page_impl(page)
|
||||||
|
|
||||||
|
|
||||||
|
# 模块级单例 —— 生产代码唯一的实例化点(_Singleton 元类保证不会再有第二个)
|
||||||
_detector = _LayoutDetector()
|
_detector = _LayoutDetector()
|
||||||
|
|
||||||
|
|
||||||
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
|
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
|
||||||
"""检测 PDF 页面中的 figure / table 区域。
|
"""检测 PDF 页面中的 figure / table / caption 区域。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table。
|
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table 及其 caption。
|
||||||
"""
|
"""
|
||||||
return _detector.detect_page(page)
|
return _detector.detect_page(page)
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ _CLUSTER_GAP = 15
|
|||||||
_MIN_BOX_AREA = 2000
|
_MIN_BOX_AREA = 2000
|
||||||
# Phase 2: 搜索文本到 box 的最大匹配距离(单位: pt)
|
# Phase 2: 搜索文本到 box 的最大匹配距离(单位: pt)
|
||||||
_LABEL_MATCH_DISTANCE = 100
|
_LABEL_MATCH_DISTANCE = 100
|
||||||
|
# DocLayout caption 与 figure/table 匹配的最大距离(单位: pt)
|
||||||
|
_CAPTION_MATCH_DISTANCE = 120
|
||||||
|
|
||||||
|
|
||||||
# ── Box 聚类 ─────────────────────────────────────────────────────────
|
# ── Box 聚类 ─────────────────────────────────────────────────────────
|
||||||
@@ -53,6 +55,15 @@ class _BoxCluster:
|
|||||||
self.boxclass = "table" if raw == "table-fallback" else raw
|
self.boxclass = "table" if raw == "table-fallback" else raw
|
||||||
|
|
||||||
|
|
||||||
|
def _cluster_to_box(cluster: _BoxCluster) -> list[float]:
|
||||||
|
return [
|
||||||
|
round(float(cluster.x0), 1),
|
||||||
|
round(float(cluster.y0), 1),
|
||||||
|
round(float(cluster.x1), 1),
|
||||||
|
round(float(cluster.y1), 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
|
def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
|
||||||
"""将相邻的同类型 box 合并为聚类。"""
|
"""将相邻的同类型 box 合并为聚类。"""
|
||||||
if not boxes:
|
if not boxes:
|
||||||
@@ -92,6 +103,67 @@ def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
|
|||||||
return [_BoxCluster(members) for members in groups.values()]
|
return [_BoxCluster(members) for members in groups.values()]
|
||||||
|
|
||||||
|
|
||||||
|
def _caption_class_for_content(boxclass: str) -> str:
|
||||||
|
return "figure_caption" if boxclass == "picture" else "table_caption"
|
||||||
|
|
||||||
|
|
||||||
|
def _caption_distance(content: _BoxCluster, caption: _BoxCluster) -> float | None:
|
||||||
|
"""Return a spatial score for pairing a caption with a content box."""
|
||||||
|
h_overlap = min(content.x1, caption.x1) - max(content.x0, caption.x0)
|
||||||
|
min_width = min(content.x1 - content.x0, caption.x1 - caption.x0)
|
||||||
|
if min_width <= 0 or h_overlap < min_width * 0.25:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if caption.y1 < content.y0:
|
||||||
|
v_gap = content.y0 - caption.y1
|
||||||
|
elif caption.y0 > content.y1:
|
||||||
|
v_gap = caption.y0 - content.y1
|
||||||
|
else:
|
||||||
|
v_gap = 0.0
|
||||||
|
|
||||||
|
return v_gap if v_gap <= _CAPTION_MATCH_DISTANCE else None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_caption_text(page, caption: _BoxCluster) -> str:
|
||||||
|
rect = pymupdf.Rect(caption.x0, caption.y0, caption.x1, caption.y1)
|
||||||
|
try:
|
||||||
|
text = page.get_text("text", clip=rect)
|
||||||
|
except Exception:
|
||||||
|
return ""
|
||||||
|
return " ".join(text.split())
|
||||||
|
|
||||||
|
|
||||||
|
def _match_captions(
|
||||||
|
page,
|
||||||
|
content_clusters: list[_BoxCluster],
|
||||||
|
caption_clusters: list[_BoxCluster],
|
||||||
|
) -> dict[int, tuple[_BoxCluster, str]]:
|
||||||
|
"""Match each content cluster to its nearest same-type DocLayout caption."""
|
||||||
|
matches: dict[int, tuple[_BoxCluster, str]] = {}
|
||||||
|
used_captions: set[int] = set()
|
||||||
|
candidates: list[tuple[float, int, int]] = []
|
||||||
|
|
||||||
|
for content_idx, content in enumerate(content_clusters):
|
||||||
|
wanted_caption_class = _caption_class_for_content(content.boxclass)
|
||||||
|
for caption_idx, caption in enumerate(caption_clusters):
|
||||||
|
if caption.boxclass != wanted_caption_class:
|
||||||
|
continue
|
||||||
|
dist = _caption_distance(content, caption)
|
||||||
|
if dist is not None:
|
||||||
|
candidates.append((dist, content_idx, caption_idx))
|
||||||
|
|
||||||
|
for _dist, content_idx, caption_idx in sorted(candidates):
|
||||||
|
if content_idx in matches or caption_idx in used_captions:
|
||||||
|
continue
|
||||||
|
text = _extract_caption_text(page, caption_clusters[caption_idx])
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
matches[content_idx] = (caption_clusters[caption_idx], text)
|
||||||
|
used_captions.add(caption_idx)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
# ── Phase 1: 检测 + 渲染 ──────────────────────────────────────────────
|
# ── Phase 1: 检测 + 渲染 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -102,14 +174,25 @@ def _render_box(
|
|||||||
filename: str,
|
filename: str,
|
||||||
cap_type: str,
|
cap_type: str,
|
||||||
page_num: int,
|
page_num: int,
|
||||||
|
caption: _BoxCluster | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""渲染单个 box 区域并保存 JPEG,成功返回 True。"""
|
"""渲染单个 box 区域并保存 JPEG,成功返回 True。
|
||||||
|
|
||||||
|
若提供 caption,则将内容与 caption 区域合并后一起截取,
|
||||||
|
使同一张截图同时包含图/表及其标题文字。
|
||||||
|
"""
|
||||||
page_width = page.rect.width
|
page_width = page.rect.width
|
||||||
|
x0, y0, x1, y1 = box.x0, box.y0, box.x1, box.y1
|
||||||
|
if caption is not None:
|
||||||
|
x0 = min(x0, caption.x0)
|
||||||
|
y0 = min(y0, caption.y0)
|
||||||
|
x1 = max(x1, caption.x1)
|
||||||
|
y1 = max(y1, caption.y1)
|
||||||
clip = pymupdf.Rect(
|
clip = pymupdf.Rect(
|
||||||
max(0, box.x0 - _REGION_PADDING),
|
max(0, x0 - _REGION_PADDING),
|
||||||
max(0, box.y0 - _REGION_PADDING),
|
max(0, y0 - _REGION_PADDING),
|
||||||
min(page_width, box.x1 + _REGION_PADDING),
|
min(page_width, x1 + _REGION_PADDING),
|
||||||
box.y1 + _REGION_PADDING,
|
y1 + _REGION_PADDING,
|
||||||
)
|
)
|
||||||
mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM)
|
mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM)
|
||||||
try:
|
try:
|
||||||
@@ -136,25 +219,31 @@ def _process_page(
|
|||||||
fig_counter = 0
|
fig_counter = 0
|
||||||
tbl_counter = 0
|
tbl_counter = 0
|
||||||
|
|
||||||
# 收集本页的 table/picture box(跳过极小区域)
|
# 收集本页的 table/picture box 与 caption box(跳过极小区域)
|
||||||
raw_boxes = []
|
raw_boxes = []
|
||||||
|
raw_caption_boxes = []
|
||||||
for box in page_boxes:
|
for box in page_boxes:
|
||||||
if box.boxclass not in ("table", "table-fallback", "picture"):
|
|
||||||
continue
|
|
||||||
w = box.x1 - box.x0
|
w = box.x1 - box.x0
|
||||||
h = box.y1 - box.y0
|
h = box.y1 - box.y0
|
||||||
if w < 20 or h < 20 or w * h < _MIN_BOX_AREA:
|
if box.boxclass in ("table", "table-fallback", "picture"):
|
||||||
continue
|
if w < 20 or h < 20 or w * h < _MIN_BOX_AREA:
|
||||||
raw_boxes.append(box)
|
continue
|
||||||
|
raw_boxes.append(box)
|
||||||
|
elif box.boxclass in ("figure_caption", "table_caption"):
|
||||||
|
if w < 30 or h < 6:
|
||||||
|
continue
|
||||||
|
raw_caption_boxes.append(box)
|
||||||
|
|
||||||
if not raw_boxes:
|
if not raw_boxes:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# 聚类:将同一 figure/table 的碎片 box 合并
|
# 聚类:将同一 figure/table 的碎片 box 合并
|
||||||
clusters = _cluster_boxes(raw_boxes)
|
clusters = _cluster_boxes(raw_boxes)
|
||||||
|
caption_clusters = _cluster_boxes(raw_caption_boxes)
|
||||||
|
caption_matches = _match_captions(page, clusters, caption_clusters)
|
||||||
|
|
||||||
extracted = 0
|
extracted = 0
|
||||||
for cluster in clusters:
|
for cluster_idx, cluster in enumerate(clusters):
|
||||||
cap_type = "figure" if cluster.boxclass == "picture" else "table"
|
cap_type = "figure" if cluster.boxclass == "picture" else "table"
|
||||||
|
|
||||||
if cap_type == "figure":
|
if cap_type == "figure":
|
||||||
@@ -168,21 +257,33 @@ def _process_page(
|
|||||||
continue
|
continue
|
||||||
seen_labels.add(label)
|
seen_labels.add(label)
|
||||||
|
|
||||||
|
caption_match = caption_matches.get(cluster_idx)
|
||||||
|
caption_cluster = caption_match[0] if caption_match else None
|
||||||
|
|
||||||
filename = f"{label.replace(' ', '_').lower()}.jpg"
|
filename = f"{label.replace(' ', '_').lower()}.jpg"
|
||||||
if not _render_box(page, cluster, images_dest, filename, cap_type, page_num):
|
if not _render_box(
|
||||||
|
page,
|
||||||
|
cluster,
|
||||||
|
images_dest,
|
||||||
|
filename,
|
||||||
|
cap_type,
|
||||||
|
page_num,
|
||||||
|
caption=caption_cluster,
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
manifest[filename] = {
|
info = {
|
||||||
"page": page_num,
|
"page": page_num,
|
||||||
"type": cap_type,
|
"type": cap_type,
|
||||||
"label": label,
|
"label": label,
|
||||||
"box": [
|
"box": _cluster_to_box(cluster),
|
||||||
round(float(cluster.x0), 1),
|
|
||||||
round(float(cluster.y0), 1),
|
|
||||||
round(float(cluster.x1), 1),
|
|
||||||
round(float(cluster.y1), 1),
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
if caption_match:
|
||||||
|
info["caption_text"] = caption_match[1][:500]
|
||||||
|
info["caption_box"] = _cluster_to_box(caption_cluster)
|
||||||
|
info["caption_source"] = "doclayout"
|
||||||
|
|
||||||
|
manifest[filename] = info
|
||||||
extracted += 1
|
extracted += 1
|
||||||
|
|
||||||
return extracted
|
return extracted
|
||||||
@@ -446,14 +547,20 @@ def label_images_by_summary(
|
|||||||
cap_type = info.get("type", "figure")
|
cap_type = info.get("type", "figure")
|
||||||
|
|
||||||
# 读取 caption 文本(从 figures 列表)
|
# 读取 caption 文本(从 figures 列表)
|
||||||
caption_text = ""
|
summary_caption_text = ""
|
||||||
for fig in figures:
|
for fig in figures:
|
||||||
if fig.get("id") == fig_id:
|
if fig.get("id") == fig_id:
|
||||||
caption_text = fig.get("caption", "")
|
summary_caption_text = fig.get("caption", "")
|
||||||
break
|
break
|
||||||
|
|
||||||
info["label"] = fig_id
|
info["label"] = fig_id
|
||||||
info["caption_text"] = caption_text[:200] if caption_text else ""
|
existing_caption_text = info.get("caption_text", "")
|
||||||
|
if existing_caption_text and summary_caption_text:
|
||||||
|
info["summary_caption_text"] = summary_caption_text[:500]
|
||||||
|
else:
|
||||||
|
info["caption_text"] = (
|
||||||
|
summary_caption_text[:500] if summary_caption_text else ""
|
||||||
|
)
|
||||||
info.setdefault("figures" if cap_type == "figure" else "tables", []).append(
|
info.setdefault("figures" if cap_type == "figure" else "tables", []).append(
|
||||||
fig_id
|
fig_id
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from app.services.summary_persister import (
|
|||||||
_cleanup_old_images,
|
_cleanup_old_images,
|
||||||
_handle_summary_failure,
|
_handle_summary_failure,
|
||||||
_persist_summary,
|
_persist_summary,
|
||||||
|
_run_post_processing,
|
||||||
)
|
)
|
||||||
from app.utils import TMP_DIR, release_lock, truncate_error, utc_now
|
from app.utils import TMP_DIR, release_lock, truncate_error, utc_now
|
||||||
|
|
||||||
@@ -115,12 +116,31 @@ async def _do_summarize_one(db: Session, paper: Paper, pdf_mode: str = "auto") -
|
|||||||
_t3 = _time.monotonic()
|
_t3 = _time.monotonic()
|
||||||
logger.info(" [%s] pi生成: %.2fs", arxiv_id, _t3 - _t2)
|
logger.info(" [%s] pi生成: %.2fs", arxiv_id, _t3 - _t2)
|
||||||
|
|
||||||
quality = _persist_summary(db, paper, json_data, raw_output)
|
quality, schema = _persist_summary(db, paper, json_data, raw_output)
|
||||||
_t4 = _time.monotonic()
|
_t4 = _time.monotonic()
|
||||||
logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3)
|
logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3)
|
||||||
|
|
||||||
|
# 后处理(图片提取 + ChromaDB 索引)搬到线程池跑,避免 CPU 密集推理冻结
|
||||||
|
# 事件循环。paper 字段在此(事件循环线程)提取成纯值再传入,规避 worker
|
||||||
|
# 线程跨线程访问 ORM 的 DetachedInstanceError。DocLayout 推理由单例的
|
||||||
|
# threading.Lock 串行化,并发 worker 不会同时压模型。
|
||||||
|
paper_meta = {
|
||||||
|
"title_en": paper.title_en or "",
|
||||||
|
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
|
||||||
|
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
|
||||||
|
}
|
||||||
|
_t5 = _time.monotonic()
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(_run_post_processing, arxiv_id, schema, paper_meta)
|
||||||
|
except Exception:
|
||||||
|
# 双保险:_run_post_processing 内部已 try/except,此处兜底,
|
||||||
|
# 确保后处理失败绝不影响已 DONE 的总结。
|
||||||
|
logger.warning("Post-processing error for %s", arxiv_id, exc_info=True)
|
||||||
|
_t6 = _time.monotonic()
|
||||||
|
logger.info(" [%s] 后处理(线程池): %.2fs", arxiv_id, _t6 - _t5)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0
|
"✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t6 - _t0
|
||||||
)
|
)
|
||||||
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
|
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
|
||||||
|
|
||||||
|
|||||||
@@ -131,8 +131,12 @@ def _handle_summary_failure(
|
|||||||
|
|
||||||
def _persist_summary(
|
def _persist_summary(
|
||||||
db: Session, paper: Paper, json_data: dict, raw_output: str
|
db: Session, paper: Paper, json_data: dict, raw_output: str
|
||||||
) -> str:
|
) -> tuple[str, SummarySchema]:
|
||||||
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
|
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 (quality, schema)。
|
||||||
|
|
||||||
|
后处理(图片提取/ChromaDB)不再在此函数内执行,由调用方搬到线程池,
|
||||||
|
以免阻塞事件循环。返回 schema 供调用方在线程池里跑后处理。
|
||||||
|
"""
|
||||||
import time as _time
|
import time as _time
|
||||||
|
|
||||||
arxiv_id = paper.arxiv_id
|
arxiv_id = paper.arxiv_id
|
||||||
@@ -165,21 +169,10 @@ def _persist_summary(
|
|||||||
_t4 - _t3,
|
_t4 - _t3,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 触发性增强(失败不影响总结)
|
# 后处理(图片提取 + ChromaDB 索引)已上移到调用方 _do_summarize_one,
|
||||||
_t5 = _time.monotonic()
|
# 经 asyncio.to_thread 在线程池跑——DB session 必须留在事件循环线程,
|
||||||
_maybe_extract_images(arxiv_id, schema)
|
# 而 CPU/IO 密集的后处理搬走才不冻结事件循环。
|
||||||
_t6 = _time.monotonic()
|
return quality, schema
|
||||||
_maybe_index_chroma(arxiv_id, paper, schema)
|
|
||||||
_t7 = _time.monotonic()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
" [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs",
|
|
||||||
arxiv_id,
|
|
||||||
_t6 - _t5,
|
|
||||||
_t7 - _t6,
|
|
||||||
)
|
|
||||||
|
|
||||||
return quality
|
|
||||||
|
|
||||||
|
|
||||||
# ── 清理 ────────────────────────────────────────────────────────────────
|
# ── 清理 ────────────────────────────────────────────────────────────────
|
||||||
@@ -226,21 +219,44 @@ def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
|
|||||||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None:
|
def _maybe_index_chroma(arxiv_id: str, schema: SummarySchema, paper_meta: dict) -> None:
|
||||||
"""写入 ChromaDB 语义索引(失败不影响总结)。"""
|
"""写入 ChromaDB 语义索引(失败不影响总结)。
|
||||||
|
|
||||||
|
paper_meta 是调用方在事件循环线程从 ORM 提取的纯值(title_en/tags/paper_date),
|
||||||
|
规避此函数在线程池跑时跨线程访问 ORM 的 DetachedInstanceError 风险。
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
from app.services.embedder import index_paper
|
from app.services.embedder import index_paper
|
||||||
|
|
||||||
texts_dict = {
|
texts_dict = {
|
||||||
"arxiv_id": arxiv_id,
|
"arxiv_id": arxiv_id,
|
||||||
"title_zh": schema.title_zh or "",
|
"title_zh": schema.title_zh or "",
|
||||||
"title_en": paper.title_en or "",
|
"title_en": paper_meta.get("title_en", ""),
|
||||||
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
|
"tags": paper_meta.get("tags", ""),
|
||||||
"one_line": schema.one_line or "",
|
"one_line": schema.one_line or "",
|
||||||
"motivation_problem": schema.motivation.problem or "",
|
"motivation_problem": schema.motivation.problem or "",
|
||||||
"method_key_idea": schema.method.key_idea or "",
|
"method_key_idea": schema.method.key_idea or "",
|
||||||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
|
"paper_date": paper_meta.get("paper_date", ""),
|
||||||
}
|
}
|
||||||
index_paper(arxiv_id, texts_dict)
|
index_paper(arxiv_id, texts_dict)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
|
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_post_processing(
|
||||||
|
arxiv_id: str, schema: SummarySchema, paper_meta: dict
|
||||||
|
) -> None:
|
||||||
|
"""线程池里跑的 CPU/IO 密集后处理(由 _do_summarize_one 经 asyncio.to_thread 调用)。
|
||||||
|
|
||||||
|
顺序与原 _persist_summary 内部一致:图片提取 → ChromaDB 索引。两者各自
|
||||||
|
try/except(失败不影响已成功的总结),此处再包一层做双保险。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
_maybe_extract_images(arxiv_id, schema)
|
||||||
|
_maybe_index_chroma(arxiv_id, schema, paper_meta)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Post-processing failed for %s (summary already persisted)",
|
||||||
|
arxiv_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|||||||
@@ -42,6 +42,9 @@
|
|||||||
.status-success { background:var(--success-bg); color:#388e3c; }
|
.status-success { background:var(--success-bg); color:#388e3c; }
|
||||||
.status-running { background:var(--info-bg); color:#1976d2; }
|
.status-running { background:var(--info-bg); color:#1976d2; }
|
||||||
.status-failed { background:var(--danger-bg); color:var(--danger-bright); }
|
.status-failed { background:var(--danger-bg); color:var(--danger-bright); }
|
||||||
|
.status-queued { background:#fff8e1; color:#8a6d3b; }
|
||||||
|
.status-stale { background:var(--border); color:var(--ink-muted); }
|
||||||
|
.task-reindex { background:#fff3e0; color:#e65100; }
|
||||||
.time-cell { white-space:nowrap; color:var(--ink-light); }
|
.time-cell { white-space:nowrap; color:var(--ink-light); }
|
||||||
.error-cell { max-width:200px; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; color:var(--danger-bright); font-size:.8rem; }
|
.error-cell { max-width:200px; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; color:var(--danger-bright); font-size:.8rem; }
|
||||||
|
|
||||||
|
|||||||
@@ -69,7 +69,8 @@
|
|||||||
<span class="info-label">活跃任务</span>
|
<span class="info-label">活跃任务</span>
|
||||||
<span class="info-value">
|
<span class="info-value">
|
||||||
{% for lock in stats.active_locks %}
|
{% for lock in stats.active_locks %}
|
||||||
<span class="task-badge task-{{ lock.task }}">{{ lock.task }}</span>
|
<span class="task-badge task-{{ lock.task }}" title="{{ lock.status }} · #{{ lock.id }}">{{ lock.task }}</span>
|
||||||
|
<button class="admin-action-btn admin-action-btn-sm admin-action-btn-danger" title="强制释放锁 #{{ lock.id }}" onclick="releaseLock({{ lock.id }})">🔓</button>
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
@@ -118,6 +119,15 @@
|
|||||||
<div class="info-row"><span class="info-label">数据库</span><span class="info-value">{{ stats.db_size }}</span></div>
|
<div class="info-row"><span class="info-label">数据库</span><span class="info-value">{{ stats.db_size }}</span></div>
|
||||||
<div class="info-row"><span class="info-label">论文文件</span><span class="info-value">{{ stats.papers_size }}</span></div>
|
<div class="info-row"><span class="info-label">论文文件</span><span class="info-value">{{ stats.papers_size }}</span></div>
|
||||||
<div class="info-row"><span class="info-label">临时文件</span><span class="info-value">{{ stats.tmp_size }}</span></div>
|
<div class="info-row"><span class="info-label">临时文件</span><span class="info-value">{{ stats.tmp_size }}</span></div>
|
||||||
|
<div class="info-row">
|
||||||
|
<span class="info-label">搜索索引</span>
|
||||||
|
<span class="info-value">
|
||||||
|
<button class="admin-action-btn admin-action-btn-sm" onclick="rebuildIndexes('fts')">🔤 重建全文</button>
|
||||||
|
{% if stats.config_overview.chroma_enabled %}
|
||||||
|
<button class="admin-action-btn admin-action-btn-sm" onclick="rebuildIndexes('chroma')">🧠 重建语义</button>
|
||||||
|
{% endif %}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="summary-dist">
|
<div class="summary-dist">
|
||||||
<h3 class="section-subtitle">总结状态分布</h3>
|
<h3 class="section-subtitle">总结状态分布</h3>
|
||||||
@@ -136,6 +146,19 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="admin-info-card">
|
||||||
|
<h2 class="admin-info-title">⚙️ 运行配置</h2>
|
||||||
|
<div class="admin-info-body">
|
||||||
|
<div class="info-row"><span class="info-label">总结后端</span><span class="info-value">{{ stats.config_overview.summary_backend }} · {{ stats.config_overview.summary_pdf_mode }} 模式</span></div>
|
||||||
|
<div class="info-row"><span class="info-label">并发/超时</span><span class="info-value">{{ stats.config_overview.summary_concurrency }} 并发 · {{ stats.config_overview.summary_timeout_seconds }}s · 重试 {{ stats.config_overview.summary_max_retries }}</span></div>
|
||||||
|
<div class="info-row"><span class="info-label">调度</span><span class="info-value">{{ '启用' if stats.config_overview.scheduler_enabled else '未启用' }} · {{ stats.config_overview.schedule_time }} · {{ stats.config_overview.app_workers }} worker</span></div>
|
||||||
|
<div class="info-row"><span class="info-label">语义搜索</span><span class="info-value">{{ '启用' if stats.config_overview.chroma_enabled else '未启用' }} · {{ stats.config_overview.embed_model }}</span></div>
|
||||||
|
<div class="info-row"><span class="info-label">抓取</span><span class="info-value">TOP {{ stats.config_overview.top_n }} · 投票刷新 {{ stats.config_overview.upvote_refresh_days }} 天</span></div>
|
||||||
|
<div class="info-row"><span class="info-label">布局模型</span><span class="info-value">{{ stats.config_overview.layout_model }}</span></div>
|
||||||
|
<div class="info-row"><span class="info-label">数据库</span><span class="info-value">{{ stats.config_overview.database_url }}</span></div>
|
||||||
|
<div class="info-row"><span class="info-label">嵌入密钥</span><span class="info-value">{{ '已配置' if stats.config_overview.api_key_configured else '未配置' }}</span></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="admin-section">
|
<div class="admin-section">
|
||||||
@@ -193,5 +216,17 @@
|
|||||||
.then(data => { if (data) showToast(data.error ? "❌ " + data.error.substring(0,200) : `✅ 已刷新 ${data.updated || 0} 篇论文投票`); })
|
.then(data => { if (data) showToast(data.error ? "❌ " + data.error.substring(0,200) : `✅ 已刷新 ${data.updated || 0} 篇论文投票`); })
|
||||||
.catch(err => showToast("❌ 请求失败"));
|
.catch(err => showToast("❌ 请求失败"));
|
||||||
}
|
}
|
||||||
|
function releaseLock(lockId) {
|
||||||
|
fetch("/admin/locks/"+lockId+"/release", { method: "POST", headers: { "Content-Type": "application/json" } })
|
||||||
|
.then(r => { if (r.status===303||r.status===401) { window.location.href="/admin/login"; return; } return r.json(); })
|
||||||
|
.then(data => { if (data) showToast(data.error ? "❌ " + data.error.substring(0,200) : "✅ 已释放锁,刷新中…", {callback:function(){location.reload();}}); })
|
||||||
|
.catch(err => showToast("❌ 请求失败"));
|
||||||
|
}
|
||||||
|
function rebuildIndexes(target) {
|
||||||
|
fetch("/admin/rebuild-indexes", { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({target: target}) })
|
||||||
|
.then(r => { if (r.status===303||r.status===401) { window.location.href="/admin/login"; return; } return r.json(); })
|
||||||
|
.then(data => { if (data) showToast(data.error ? "❌ " + data.error.substring(0,200) : "✅ 重建任务已创建,可在任务页查看"); })
|
||||||
|
.catch(err => showToast("❌ 请求失败"));
|
||||||
|
}
|
||||||
</script>
|
</script>
|
||||||
{% endblock %}
|
{% endblock %}
|
||||||
|
|||||||
@@ -0,0 +1,149 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
{% block title %}任务监控 — HF Daily Papers{% endblock %}
|
||||||
|
|
||||||
|
{% set type_label = {"crawl_daily":"抓取","pipeline_daily":"流水线","summarize_batch":"批量总结","summarize_one":"单篇总结","refresh_upvotes":"刷新投票","delete_range":"删除","cleanup_tmp":"清理","reindex_fts":"重建全文","reindex_chroma":"重建语义","recrawl_one":"重抓","recrawl_batch":"批量重抓"} %}
|
||||||
|
{% set type_badge = {"crawl_daily":"task-crawl","pipeline_daily":"task-crawl","recrawl_one":"task-crawl","recrawl_batch":"task-crawl","refresh_upvotes":"task-crawl","summarize_batch":"task-summarize","summarize_one":"task-summarize","cleanup_tmp":"task-cleanup","delete_range":"task-delete","reindex_fts":"task-reindex","reindex_chroma":"task-reindex"} %}
|
||||||
|
{% set status_label = {"queued":"排队","running":"运行中","success":"成功","failed":"失败","stale":"已过期","cancelled":"已取消"} %}
|
||||||
|
{% set status_badge = {"queued":"status-queued","running":"status-running","success":"status-success","failed":"status-failed","stale":"status-stale","cancelled":"status-stale"} %}
|
||||||
|
|
||||||
|
{% macro fmt_duration(s) -%}
|
||||||
|
{%- if s is none %}-
|
||||||
|
{%- elif s < 60 %}{{ "%.0f"|format(s) }}s
|
||||||
|
{%- elif s < 3600 %}{{ (s // 60)|int }}m {{ (s % 60)|round|int }}s
|
||||||
|
{%- else %}{{ (s // 3600)|int }}h {{ ((s % 3600) // 60)|int }}m
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endmacro %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
<div class="admin-page">
|
||||||
|
{% set active = "jobs" %}{% include "partials/admin_subnav.html" %}
|
||||||
|
|
||||||
|
<h1 class="page-heading">🧰 任务监控</h1>
|
||||||
|
|
||||||
|
{% set _total = (status_counts.values() | sum) if status_counts else 0 %}
|
||||||
|
<div class="summary-stats-row">
|
||||||
|
<span class="summary-stat">总计 <strong>{{ _total }}</strong></span>
|
||||||
|
<span class="summary-stat summary-stat-pending">排队 <strong>{{ status_counts.get('queued', 0) }}</strong></span>
|
||||||
|
<span class="summary-stat">运行中 <strong>{{ status_counts.get('running', 0) }}</strong></span>
|
||||||
|
<span class="summary-stat summary-stat-done">成功 <strong>{{ status_counts.get('success', 0) }}</strong></span>
|
||||||
|
<span class="summary-stat summary-stat-failed">失败 <strong>{{ status_counts.get('failed', 0) + status_counts.get('stale', 0) }}</strong></span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{% set statuses = [("all","全部"),("queued","排队"),("running","运行中"),("success","成功"),("failed","失败"),("stale","已过期")] %}
|
||||||
|
<div class="summary-filters">
|
||||||
|
<span class="summary-filter-label">状态:</span>
|
||||||
|
{% for key, label in statuses %}
|
||||||
|
<a class="filter-chip {{ 'active' if current_status == key else '' }}"
|
||||||
|
href="?status={{ key }}{% if current_type != 'all' %}&type={{ current_type }}{% endif %}">{{ label }}
|
||||||
|
({% if key == 'all' %}{{ _total }}{% else %}{{ status_counts.get(key, 0) }}{% endif %})</a>
|
||||||
|
{% endfor %}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{% set types = [("crawl_daily","抓取"),("pipeline_daily","流水线"),("summarize_batch","批量总结"),("summarize_one","单篇总结"),("refresh_upvotes","刷新投票"),("recrawl_one","重抓"),("recrawl_batch","批量重抓"),("delete_range","删除"),("cleanup_tmp","清理"),("reindex_fts","重建全文"),("reindex_chroma","重建语义")] %}
|
||||||
|
<form method="get" class="summary-filters">
|
||||||
|
<span class="summary-filter-label">类型:</span>
|
||||||
|
<input type="hidden" name="status" value="{{ current_status }}" />
|
||||||
|
<select name="type" class="paper-filter-input" onchange="this.form.submit()">
|
||||||
|
<option value="all" {{ 'selected' if current_type == 'all' }}>全部类型</option>
|
||||||
|
{% for key, label in types %}
|
||||||
|
<option value="{{ key }}" {{ 'selected' if current_type == key }}>{{ label }}</option>
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
|
</form>
|
||||||
|
|
||||||
|
{% if jobs %}
|
||||||
|
<div class="admin-table-wrap">
|
||||||
|
<table class="admin-table admin-table-compact">
|
||||||
|
<thead>
|
||||||
|
<tr><th>ID</th><th>类型</th><th>状态</th><th>触发者</th><th>创建时间</th><th>耗时</th><th>操作</th></tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{% for job in jobs %}
|
||||||
|
<tr>
|
||||||
|
<td>{{ job.id }}</td>
|
||||||
|
<td><span class="task-badge {{ type_badge.get(job.type, 'task-crawl') }}">{{ type_label.get(job.type, job.type) }}</span></td>
|
||||||
|
<td><span class="status-badge {{ status_badge.get(job.status, 'status-running') }}">{{ status_label.get(job.status, job.status) }}</span></td>
|
||||||
|
<td>{{ job.owner or '-' }}</td>
|
||||||
|
<td class="time-cell">{{ job.created_at.strftime('%Y-%m-%d %H:%M:%S') if job.created_at else '-' }}</td>
|
||||||
|
<td class="time-cell">{{ fmt_duration(job.duration_seconds) }}</td>
|
||||||
|
<td class="action-cell"><button class="action-btn-sm" title="详情" onclick="showJobDetail({{ job.id }})">📋</button></td>
|
||||||
|
</tr>
|
||||||
|
{% endfor %}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{% set total_pages = ((total + per_page - 1) // per_page) if total else 1 %}
|
||||||
|
{% if total_pages > 1 %}
|
||||||
|
<div class="pagination">
|
||||||
|
{% if page > 1 %}<a class="page-btn" href="{{ pagination_url(page - 1) }}">← 上一页</a>{% endif %}
|
||||||
|
<span class="page-info">第 {{ page }} / {{ total_pages }} 页(共 {{ total }} 个)</span>
|
||||||
|
{% if page < total_pages %}<a class="page-btn" href="{{ pagination_url(page + 1) }}">下一页 →</a>{% endif %}
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
{% else %}
|
||||||
|
<div class="empty-state">
|
||||||
|
<p>暂无任务记录</p>
|
||||||
|
<p class="hint">触发抓取、总结等操作后,任务会出现在这里。可在「详情」中查看阶段事件。</p>
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 任务详情 modal -->
|
||||||
|
<div class="confirm-overlay" id="job-detail-overlay" style="display:none;">
|
||||||
|
<div class="confirm-dialog" style="max-width:660px;max-height:85vh;overflow:auto;">
|
||||||
|
<h3 class="admin-info-title" id="job-detail-title">任务详情</h3>
|
||||||
|
<div id="job-detail-body"><p class="hint">加载中...</p></div>
|
||||||
|
<div class="confirm-actions">
|
||||||
|
<button class="confirm-btn confirm-btn-cancel" onclick="closeJobDetail()">关闭</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{% endblock %}
|
||||||
|
|
||||||
|
{% block scripts %}
|
||||||
|
<script>
|
||||||
|
const TYPE_LABEL = {"crawl_daily":"抓取","pipeline_daily":"流水线","summarize_batch":"批量总结","summarize_one":"单篇总结","refresh_upvotes":"刷新投票","delete_range":"删除","cleanup_tmp":"清理","reindex_fts":"重建全文","reindex_chroma":"重建语义","recrawl_one":"重抓","recrawl_batch":"批量重抓"};
|
||||||
|
const STATUS_LABEL = {"queued":"排队","running":"运行中","success":"成功","failed":"失败","stale":"已过期","cancelled":"已取消"};
|
||||||
|
|
||||||
|
function fmtTime(s){ return s ? s.replace('T',' ').slice(0,19) : '-'; }
|
||||||
|
function esc(s){ return String(s==null?'':s).replace(/[&<>"]/g, c=>({'&':'&','<':'<','>':'>','"':'"'}[c])); }
|
||||||
|
function eventBadge(s){ return {'success':'status-success','failed':'status-failed','started':'status-running','info':'status-queued'}[s] || 'status-queued'; }
|
||||||
|
function jobStatusBadge(s){ return {'success':'success','failed':'failed','running':'running','stale':'stale','cancelled':'stale','queued':'queued'}[s] || 'running'; }
|
||||||
|
function infoRow(label, val){ return '<div class="info-row"><span class="info-label">'+label+'</span><span class="info-value">'+val+'</span></div>'; }
|
||||||
|
|
||||||
|
function showJobDetail(id){
|
||||||
|
document.getElementById('job-detail-overlay').style.display='flex';
|
||||||
|
document.getElementById('job-detail-body').innerHTML='<p class="hint">加载中...</p>';
|
||||||
|
fetch('/admin/jobs/'+id)
|
||||||
|
.then(r=>{if(r.status===303||r.status===401){window.location.href='/admin/login';return;}return r.json();})
|
||||||
|
.then(d=>{ if(d) renderJobDetail(d); })
|
||||||
|
.catch(()=>{document.getElementById('job-detail-body').innerHTML='<p class="hint">加载失败</p>';});
|
||||||
|
}
|
||||||
|
function renderJobDetail(d){
|
||||||
|
let h='<div class="admin-info-body">';
|
||||||
|
h+=infoRow('ID', d.id);
|
||||||
|
h+=infoRow('类型', esc(TYPE_LABEL[d.type]||d.type));
|
||||||
|
h+=infoRow('状态', '<span class="status-badge status-'+jobStatusBadge(d.status)+'">'+esc(STATUS_LABEL[d.status]||d.status)+'</span>');
|
||||||
|
h+=infoRow('触发者', esc(d.owner||'-'));
|
||||||
|
h+=infoRow('创建', fmtTime(d.created_at));
|
||||||
|
h+=infoRow('开始', fmtTime(d.started_at));
|
||||||
|
h+=infoRow('完成', fmtTime(d.completed_at));
|
||||||
|
if(d.payload && Object.keys(d.payload).length) h+=infoRow('参数', '<code style="word-break:break-all;">'+esc(JSON.stringify(d.payload))+'</code>');
|
||||||
|
if(d.result) h+=infoRow('结果', '<code style="word-break:break-all;">'+esc(JSON.stringify(d.result))+'</code>');
|
||||||
|
if(d.error) h+=infoRow('错误', '<span class="error-cell" style="max-width:480px;">'+esc(d.error)+'</span>');
|
||||||
|
h+='</div>';
|
||||||
|
if(d.events && d.events.length){
|
||||||
|
h+='<h3 class="section-subtitle" style="margin-top:18px;">事件时间线</h3>';
|
||||||
|
h+='<div class="admin-table-wrap" style="max-height:220px;overflow:auto;"><table class="admin-table admin-table-compact"><thead><tr><th>阶段</th><th>状态</th><th>时间</th><th>消息</th></tr></thead><tbody>';
|
||||||
|
d.events.forEach(e=>{
|
||||||
|
h+='<tr><td>'+esc(e.stage)+'</td><td><span class="status-badge '+eventBadge(e.status)+'">'+esc(e.status)+'</span></td><td class="time-cell">'+fmtTime(e.created_at)+'</td><td class="error-cell" style="max-width:240px;">'+esc(e.message||'')+'</td></tr>';
|
||||||
|
});
|
||||||
|
h+='</tbody></table></div>';
|
||||||
|
}
|
||||||
|
document.getElementById('job-detail-body').innerHTML=h;
|
||||||
|
}
|
||||||
|
function closeJobDetail(){ document.getElementById('job-detail-overlay').style.display='none'; }
|
||||||
|
document.addEventListener('keydown',e=>{if(e.key==='Escape')closeJobDetail();});
|
||||||
|
</script>
|
||||||
|
{% endblock %}
|
||||||
@@ -109,6 +109,22 @@
|
|||||||
<span class="summary-stat summary-stat-failed">失败 <strong>{{ summary_failed or 0 }}</strong></span>
|
<span class="summary-stat summary-stat-failed">失败 <strong>{{ summary_failed or 0 }}</strong></span>
|
||||||
<span class="summary-stat summary-stat-done">已完成 <strong>{{ summary_done or 0 }}</strong></span>
|
<span class="summary-stat summary-stat-done">已完成 <strong>{{ summary_done or 0 }}</strong></span>
|
||||||
</div>
|
</div>
|
||||||
|
{% if failure_breakdown %}
|
||||||
|
<div class="summary-dist" style="margin-top:12px;">
|
||||||
|
<h3 class="section-subtitle">失败原因分布({{ summary_failed or 0 }} 篇)</h3>
|
||||||
|
<div class="summary-dist-bars">
|
||||||
|
{% set fb_total = (failure_breakdown | map(attribute='count') | sum) or 1 %}
|
||||||
|
{% set error_labels = {"pdf_download_failed":"PDF下载失败","timeout":"超时","process_error":"进程错误","json_not_found":"JSON缺失","json_invalid":"JSON无效","field_missing":"字段缺失","schema_error":"结构错误","unknown":"未分类"} %}
|
||||||
|
{% for item in failure_breakdown %}
|
||||||
|
<div class="dist-row">
|
||||||
|
<span class="dist-label">{{ error_labels.get(item.error_type, item.error_type) }}</span>
|
||||||
|
<div class="dist-bar-wrap"><div class="dist-bar dist-bar-failed" style="width:{{ (item.count / fb_total * 100)|round(1) }}%"></div></div>
|
||||||
|
<span class="dist-count">{{ item.count }}</span>
|
||||||
|
</div>
|
||||||
|
{% endfor %}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
<div id="summary-list"
|
<div id="summary-list"
|
||||||
hx-get="/admin/summary-status"
|
hx-get="/admin/summary-status"
|
||||||
hx-trigger="load"
|
hx-trigger="load"
|
||||||
|
|||||||
@@ -29,6 +29,7 @@
|
|||||||
<option value="title_asc" {% if current_sort == 'title_asc' %}selected{% endif %}>标题 A→Z</option>
|
<option value="title_asc" {% if current_sort == 'title_asc' %}selected{% endif %}>标题 A→Z</option>
|
||||||
</select>
|
</select>
|
||||||
<button type="submit" class="paper-search-btn">搜索</button>
|
<button type="submit" class="paper-search-btn">搜索</button>
|
||||||
|
<a class="admin-action-btn admin-action-btn-sm" href="/admin/papers/export.csv{% if request.query_params %}?{{ request.query_params }}{% endif %}">⬇ 导出 CSV</a>
|
||||||
</div>
|
</div>
|
||||||
</form>
|
</form>
|
||||||
|
|
||||||
@@ -37,6 +38,7 @@
|
|||||||
<span class="paper-batch-label">批量操作</span>
|
<span class="paper-batch-label">批量操作</span>
|
||||||
<span class="paper-selected-count" id="selected-count">已选 0 篇</span>
|
<span class="paper-selected-count" id="selected-count">已选 0 篇</span>
|
||||||
<button class="admin-action-btn admin-action-btn-sm" onclick="batchAction('summarize')" id="batch-summarize-btn" disabled>📝 批量总结</button>
|
<button class="admin-action-btn admin-action-btn-sm" onclick="batchAction('summarize')" id="batch-summarize-btn" disabled>📝 批量总结</button>
|
||||||
|
<button class="admin-action-btn admin-action-btn-sm" onclick="batchAction('recrawl')" id="batch-recrawl-btn" disabled>🔄 批量重抓</button>
|
||||||
<button class="admin-action-btn admin-action-btn-sm admin-action-btn-danger" onclick="batchAction('delete')" id="batch-delete-btn" disabled>🗑 批量删除</button>
|
<button class="admin-action-btn admin-action-btn-sm admin-action-btn-danger" onclick="batchAction('delete')" id="batch-delete-btn" disabled>🗑 批量删除</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -72,6 +74,7 @@
|
|||||||
</td>
|
</td>
|
||||||
<td class="action-cell">
|
<td class="action-cell">
|
||||||
<button class="action-btn-sm" title="重新总结" onclick="retryOne('{{ paper.arxiv_id }}', this)">↻</button>
|
<button class="action-btn-sm" title="重新总结" onclick="retryOne('{{ paper.arxiv_id }}', this)">↻</button>
|
||||||
|
<button class="action-btn-sm" title="重新抓取元数据" onclick="recrawlOne('{{ paper.arxiv_id }}', this)">🔄</button>
|
||||||
<button class="action-btn-sm action-btn-danger" title="删除" onclick="confirmDeleteSingle('{{ paper.arxiv_id }}', '{{ (paper.title_zh or paper.title_en)[:40] | replace("'", "\\'") }}')">🗑</button>
|
<button class="action-btn-sm action-btn-danger" title="删除" onclick="confirmDeleteSingle('{{ paper.arxiv_id }}', '{{ (paper.title_zh or paper.title_en)[:40] | replace("'", "\\'") }}')">🗑</button>
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
@@ -124,6 +127,7 @@
|
|||||||
const n=document.querySelectorAll('.paper-check:checked').length;
|
const n=document.querySelectorAll('.paper-check:checked').length;
|
||||||
document.getElementById('selected-count').textContent='已选 '+n+' 篇';
|
document.getElementById('selected-count').textContent='已选 '+n+' 篇';
|
||||||
document.getElementById('batch-summarize-btn').disabled=n===0;
|
document.getElementById('batch-summarize-btn').disabled=n===0;
|
||||||
|
document.getElementById('batch-recrawl-btn').disabled=n===0;
|
||||||
document.getElementById('batch-delete-btn').disabled=n===0;
|
document.getElementById('batch-delete-btn').disabled=n===0;
|
||||||
}
|
}
|
||||||
function retryOne(arxivId,btn) {
|
function retryOne(arxivId,btn) {
|
||||||
@@ -134,6 +138,14 @@
|
|||||||
.catch(()=>showToast('❌ 请求失败'))
|
.catch(()=>showToast('❌ 请求失败'))
|
||||||
.finally(()=>{btn.disabled=false;btn.textContent='↻';});
|
.finally(()=>{btn.disabled=false;btn.textContent='↻';});
|
||||||
}
|
}
|
||||||
|
function recrawlOne(arxivId,btn) {
|
||||||
|
btn.disabled=true;btn.textContent='...';
|
||||||
|
fetch('/admin/paper-recrawl/'+arxivId,{method:'POST',headers:{'Content-Type':'application/json'}})
|
||||||
|
.then(r=>r.json())
|
||||||
|
.then(data=>showToast(data.error?'❌ '+data.error.substring(0,100):'✅ 重抓任务已创建,可在任务页查看'))
|
||||||
|
.catch(()=>showToast('❌ 请求失败'))
|
||||||
|
.finally(()=>{btn.disabled=false;btn.textContent='🔄';});
|
||||||
|
}
|
||||||
function confirmDeleteSingle(arxivId,title) {
|
function confirmDeleteSingle(arxivId,title) {
|
||||||
document.getElementById('confirm-msg').textContent='确定删除论文「'+title+'」?此操作不可恢复。';
|
document.getElementById('confirm-msg').textContent='确定删除论文「'+title+'」?此操作不可恢复。';
|
||||||
_confirmAction='delete-single'; _confirmTarget=arxivId;
|
_confirmAction='delete-single'; _confirmTarget=arxivId;
|
||||||
@@ -151,6 +163,11 @@
|
|||||||
.then(r=>r.json())
|
.then(r=>r.json())
|
||||||
.then(data=>showToast(data.error?'❌ '+data.error.substring(0,100):'✅ 已提交批量总结'))
|
.then(data=>showToast(data.error?'❌ '+data.error.substring(0,100):'✅ 已提交批量总结'))
|
||||||
.catch(()=>showToast('❌ 请求失败'));
|
.catch(()=>showToast('❌ 请求失败'));
|
||||||
|
} else if(action==='recrawl'){
|
||||||
|
fetch('/admin/papers-batch-action',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({action:'recrawl',arxiv_ids:ids})})
|
||||||
|
.then(r=>r.json())
|
||||||
|
.then(data=>showToast(data.error?'❌ '+data.error.substring(0,100):'✅ 已提交批量重抓,可在任务页查看'))
|
||||||
|
.catch(()=>showToast('❌ 请求失败'));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
function doConfirmAction() {
|
function doConfirmAction() {
|
||||||
|
|||||||
@@ -29,8 +29,6 @@
|
|||||||
<a href="/reading-list">阅读列表</a>
|
<a href="/reading-list">阅读列表</a>
|
||||||
{% if is_admin %}
|
{% if is_admin %}
|
||||||
<a href="/admin/">管理</a>
|
<a href="/admin/">管理</a>
|
||||||
<a href="/admin/logout" onclick="event.preventDefault();this.closest('form').submit()">退出</a>
|
|
||||||
<form action="/admin/logout" method="post" style="display:none"></form>
|
|
||||||
{% else %}
|
{% else %}
|
||||||
<a href="/admin/login">管理</a>
|
<a href="/admin/login">管理</a>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
{# Admin subnav — 管理后台三个页面共享。active 参数: "dashboard" / "papers" / "logs" #}
|
{# Admin subnav — 管理后台共享。active 参数: "dashboard" / "papers" / "jobs" / "logs" #}
|
||||||
<nav class="admin-subnav">
|
<nav class="admin-subnav">
|
||||||
<a href="/admin/" class="admin-subnav-link {{ 'active' if active == 'dashboard' else '' }}">仪表盘</a>
|
<a href="/admin/" class="admin-subnav-link {{ 'active' if active == 'dashboard' else '' }}">仪表盘</a>
|
||||||
<a href="/admin/papers" class="admin-subnav-link {{ 'active' if active == 'papers' else '' }}">论文管理</a>
|
<a href="/admin/papers" class="admin-subnav-link {{ 'active' if active == 'papers' else '' }}">论文管理</a>
|
||||||
|
<a href="/admin/jobs" class="admin-subnav-link {{ 'active' if active == 'jobs' else '' }}">任务</a>
|
||||||
<a href="/admin/logs" class="admin-subnav-link {{ 'active' if active == 'logs' else '' }}">日志</a>
|
<a href="/admin/logs" class="admin-subnav-link {{ 'active' if active == 'logs' else '' }}">日志</a>
|
||||||
<span class="admin-subnav-spacer"></span>
|
<span class="admin-subnav-spacer"></span>
|
||||||
<form action="/admin/logout" method="post" class="admin-subnav-form">
|
<form action="/admin/logout" method="post" class="admin-subnav-form">
|
||||||
|
|||||||
@@ -24,6 +24,26 @@ from app.models import (
|
|||||||
from app.utils import utc_now
|
from app.utils import utc_now
|
||||||
|
|
||||||
|
|
||||||
|
# ── ChromaDB 隔离(autouse,所有测试)──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _isolate_chroma(monkeypatch, tmp_path):
|
||||||
|
"""所有测试把 ChromaDB 隔离到临时目录 + 重置单例,绝不污染 data/chroma。
|
||||||
|
|
||||||
|
与内存 DB 隔离同理:summarize 后处理经真实 _maybe_index_chroma → index_paper
|
||||||
|
写入,不隔离会把测试夹具(2401.*)泄漏到生产 data/chroma,污染语义搜索。
|
||||||
|
每个测试前重置 _chroma 单例,确保 CHROMA_DIR 指向本次 tmp。
|
||||||
|
"""
|
||||||
|
import app.services.embedder as emb
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||||||
|
emb._chroma.reset()
|
||||||
|
yield
|
||||||
|
emb._chroma.reset()
|
||||||
|
|
||||||
|
|
||||||
# ── 内存数据库 ──────────────────────────────────────────────────────────
|
# ── 内存数据库 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,380 @@
|
|||||||
|
"""管理后台新功能测试 — 任务监控、锁释放、重抓、失败分布、配置、导出、重建索引。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from app.models import Job, JobStatus, SummaryState, SummaryStatus, TaskLock
|
||||||
|
from app.services import admin as admin_svc
|
||||||
|
from app.services.crawler import recrawl_single
|
||||||
|
from app.services.jobs import create_job, run_job
|
||||||
|
from app.utils import utc_now
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def no_enqueue(monkeypatch):
|
||||||
|
"""禁用路由层的 enqueue_job,阻止 background task 在测试中真实执行。"""
|
||||||
|
from app.routes import admin as admin_route
|
||||||
|
|
||||||
|
monkeypatch.setattr(admin_route, "enqueue_job", lambda *a, **k: None)
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 任务监控
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
def _make_job(db_session, *, type="crawl_daily", status=JobStatus.QUEUED, owner="t"):
|
||||||
|
job = Job(
|
||||||
|
type=type,
|
||||||
|
status=status,
|
||||||
|
owner=owner,
|
||||||
|
payload_json="{}",
|
||||||
|
created_at=utc_now(),
|
||||||
|
)
|
||||||
|
db_session.add(job)
|
||||||
|
db_session.commit()
|
||||||
|
return job
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_jobs_filter_and_pagination(db_session):
|
||||||
|
for i in range(25):
|
||||||
|
_make_job(db_session, status=JobStatus.SUCCESS)
|
||||||
|
for i in range(5):
|
||||||
|
_make_job(db_session, status=JobStatus.FAILED)
|
||||||
|
|
||||||
|
# 无过滤:分页
|
||||||
|
page1, total = admin_svc.query_jobs(db_session, page=1, per_page=20)
|
||||||
|
assert total == 30
|
||||||
|
assert len(page1) == 20
|
||||||
|
page2, _ = admin_svc.query_jobs(db_session, page=2, per_page=20)
|
||||||
|
assert len(page2) == 10
|
||||||
|
|
||||||
|
# status 过滤
|
||||||
|
failed, ftotal = admin_svc.query_jobs(db_session, status="failed", per_page=50)
|
||||||
|
assert ftotal == 5
|
||||||
|
assert len(failed) == 5
|
||||||
|
assert all(j["status"] == "failed" for j in failed)
|
||||||
|
|
||||||
|
# type 过滤
|
||||||
|
_make_job(db_session, type="reindex_fts", status=JobStatus.QUEUED)
|
||||||
|
typed, ttotal = admin_svc.query_jobs(db_session, job_type="reindex_fts")
|
||||||
|
assert ttotal == 1
|
||||||
|
assert typed[0]["type"] == "reindex_fts"
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_job_includes_duration(db_session):
|
||||||
|
job = _make_job(db_session, status=JobStatus.SUCCESS)
|
||||||
|
job.started_at = utc_now()
|
||||||
|
job.completed_at = utc_now()
|
||||||
|
db_session.commit()
|
||||||
|
serialized = admin_svc.serialize_job(job)
|
||||||
|
assert serialized["duration_seconds"] is not None
|
||||||
|
assert serialized["duration_seconds"] >= 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_job_running_without_completed(db_session):
|
||||||
|
# 运行中的 job:completed_at=None,started_at 经 db 读回为 naive UTC,
|
||||||
|
# 不能与 aware 的 utc_now() 直接相减(回归测试)。
|
||||||
|
job = _make_job(db_session, status=JobStatus.RUNNING)
|
||||||
|
job.started_at = utc_now()
|
||||||
|
db_session.commit()
|
||||||
|
serialized = admin_svc.serialize_job(job)
|
||||||
|
assert serialized["duration_seconds"] is not None
|
||||||
|
assert serialized["duration_seconds"] >= 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_job_status_counts(db_session):
|
||||||
|
_make_job(db_session, status=JobStatus.QUEUED)
|
||||||
|
_make_job(db_session, status=JobStatus.QUEUED)
|
||||||
|
_make_job(db_session, status=JobStatus.RUNNING)
|
||||||
|
counts = admin_svc.get_job_status_counts(db_session)
|
||||||
|
assert counts.get("queued") == 2
|
||||||
|
assert counts.get("running") == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 锁释放
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
def test_force_release_lock(db_session):
|
||||||
|
running = TaskLock(
|
||||||
|
task="crawl", lock_key="k1", status="running", acquired_at=utc_now()
|
||||||
|
)
|
||||||
|
stale = TaskLock(task="crawl", lock_key="k2", status="stale", acquired_at=utc_now())
|
||||||
|
finished = TaskLock(
|
||||||
|
task="crawl", lock_key="k3", status="finished", acquired_at=utc_now()
|
||||||
|
)
|
||||||
|
db_session.add_all([running, stale, finished])
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
assert admin_svc.force_release_lock(db_session, running.id) is True
|
||||||
|
assert admin_svc.force_release_lock(db_session, stale.id) is True
|
||||||
|
# finished 的不应被再次释放
|
||||||
|
assert admin_svc.force_release_lock(db_session, finished.id) is False
|
||||||
|
# 不存在的 id
|
||||||
|
assert admin_svc.force_release_lock(db_session, 999999) is False
|
||||||
|
|
||||||
|
db_session.refresh(running)
|
||||||
|
db_session.refresh(stale)
|
||||||
|
db_session.refresh(finished)
|
||||||
|
assert running.status == "finished"
|
||||||
|
assert running.released_at is not None
|
||||||
|
assert stale.status == "finished"
|
||||||
|
assert finished.status == "finished"
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 失败原因分布
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_failure_breakdown(db_session, sample_papers_range):
|
||||||
|
statuses = (
|
||||||
|
db_session.execute(select(SummaryStatus).order_by(SummaryStatus.id))
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
statuses[0].status = SummaryState.FAILED
|
||||||
|
statuses[0].error_type = "pdf_download_failed"
|
||||||
|
statuses[1].status = SummaryState.FAILED
|
||||||
|
statuses[1].error_type = "timeout"
|
||||||
|
statuses[2].status = SummaryState.PERMANENT_FAILURE
|
||||||
|
statuses[2].error_type = None # 归 unknown
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
breakdown = admin_svc.get_failure_breakdown(db_session)
|
||||||
|
by_type = {b["error_type"]: b["count"] for b in breakdown}
|
||||||
|
assert by_type.get("pdf_download_failed") == 1
|
||||||
|
assert by_type.get("timeout") == 1
|
||||||
|
assert by_type.get("unknown") == 1
|
||||||
|
# 降序
|
||||||
|
counts = [b["count"] for b in breakdown]
|
||||||
|
assert counts == sorted(counts, reverse=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 配置概览
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_config_overview_no_secrets():
|
||||||
|
cfg = admin_svc.get_config_overview()
|
||||||
|
assert "summary_backend" in cfg
|
||||||
|
assert "schedule_time" in cfg
|
||||||
|
assert "api_key_configured" in cfg # 只标是否配置,不显值
|
||||||
|
text = str(cfg)
|
||||||
|
# 不应泄露默认密钥值
|
||||||
|
assert "change-me" not in text
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 单篇/批量重抓
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestRecrawl:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_found(self, db_session):
|
||||||
|
res = await recrawl_single(db_session, "9999.99999")
|
||||||
|
assert res["updated"] is False
|
||||||
|
assert res["reason"] == "not_found"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_updates_full_metadata(self, db_session, sample_paper):
|
||||||
|
new_item = {
|
||||||
|
"paper": {
|
||||||
|
"id": sample_paper.arxiv_id,
|
||||||
|
"title": "Updated Title",
|
||||||
|
"abstract": "New abstract",
|
||||||
|
"publishedAt": "2024-01-15T00:00:00",
|
||||||
|
"authors": [{"name": "New Author"}],
|
||||||
|
"tags": [{"name": "CV"}, {"name": "Diffusion"}],
|
||||||
|
"upvotes": 100,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with patch(
|
||||||
|
"app.services.crawler.fetch_daily",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=[new_item],
|
||||||
|
):
|
||||||
|
res = await recrawl_single(db_session, sample_paper.arxiv_id)
|
||||||
|
|
||||||
|
assert res["updated"] is True
|
||||||
|
db_session.refresh(sample_paper)
|
||||||
|
assert sample_paper.title_en == "Updated Title"
|
||||||
|
assert sample_paper.abstract == "New abstract"
|
||||||
|
assert sample_paper.upvotes == 100
|
||||||
|
# authors 重建(原 Alice/Bob → New Author)
|
||||||
|
assert sorted(a.name for a in sample_paper.authors) == ["New Author"]
|
||||||
|
# tags 重建(原 NLP/LLM → CV/Diffusion)
|
||||||
|
assert sorted(t.tag for t in sample_paper.tags) == ["CV", "Diffusion"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_in_daily(self, db_session, sample_paper):
|
||||||
|
with patch(
|
||||||
|
"app.services.crawler.fetch_daily",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=[],
|
||||||
|
):
|
||||||
|
res = await recrawl_single(db_session, sample_paper.arxiv_id)
|
||||||
|
assert res["updated"] is False
|
||||||
|
assert res["reason"] == "not_in_daily"
|
||||||
|
assert "date" in res
|
||||||
|
|
||||||
|
|
||||||
|
class TestDispatchRecrawl:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_recrawl_one_via_run_job(self, db_session, sample_paper):
|
||||||
|
new_item = {
|
||||||
|
"paper": {
|
||||||
|
"id": sample_paper.arxiv_id,
|
||||||
|
"title": "Via Job",
|
||||||
|
"authors": [],
|
||||||
|
"tags": [],
|
||||||
|
"upvotes": 5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with patch(
|
||||||
|
"app.services.crawler.fetch_daily",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=[new_item],
|
||||||
|
):
|
||||||
|
job = create_job(
|
||||||
|
db_session,
|
||||||
|
"recrawl_one",
|
||||||
|
owner="test",
|
||||||
|
payload={"arxiv_id": sample_paper.arxiv_id},
|
||||||
|
)
|
||||||
|
result = await run_job(db_session, job.id)
|
||||||
|
|
||||||
|
assert result["updated"] is True
|
||||||
|
db_session.refresh(sample_paper)
|
||||||
|
assert sample_paper.title_en == "Via Job"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_recrawl_batch_via_run_job(self, db_session, sample_papers_range):
|
||||||
|
arxiv_ids = [p.arxiv_id for p in sample_papers_range[:2]]
|
||||||
|
items = [
|
||||||
|
{
|
||||||
|
"paper": {
|
||||||
|
"id": aid,
|
||||||
|
"title": "Batch " + aid,
|
||||||
|
"authors": [],
|
||||||
|
"tags": [],
|
||||||
|
"upvotes": 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for aid in arxiv_ids
|
||||||
|
]
|
||||||
|
with patch(
|
||||||
|
"app.services.crawler.fetch_daily",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=lambda d: items,
|
||||||
|
):
|
||||||
|
job = create_job(
|
||||||
|
db_session,
|
||||||
|
"recrawl_batch",
|
||||||
|
owner="test",
|
||||||
|
payload={"arxiv_ids": arxiv_ids},
|
||||||
|
)
|
||||||
|
result = await run_job(db_session, job.id)
|
||||||
|
|
||||||
|
assert result["updated"] == 2
|
||||||
|
assert result["skipped"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 路由
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoutes:
|
||||||
|
def test_jobs_page_renders(self, auth_client):
|
||||||
|
resp = auth_client.get("/admin/jobs")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "任务监控" in resp.text
|
||||||
|
|
||||||
|
def test_jobs_page_filters_by_status(self, auth_client, db_session):
|
||||||
|
_make_job(db_session, status=JobStatus.FAILED)
|
||||||
|
resp = auth_client.get("/admin/jobs?status=failed")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_export_csv(self, auth_client, sample_papers_range):
|
||||||
|
resp = auth_client.get("/admin/papers/export.csv")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "text/csv" in resp.headers["content-type"]
|
||||||
|
# UTF-8 BOM for Excel
|
||||||
|
assert resp.content.startswith(b"\xef\xbb\xbf")
|
||||||
|
# 表头 + 数据
|
||||||
|
assert "arxiv_id" in resp.text
|
||||||
|
assert "2401.10001" in resp.text
|
||||||
|
|
||||||
|
def test_export_csv_respects_filter(self, auth_client, sample_papers_range):
|
||||||
|
resp = auth_client.get("/admin/papers/export.csv?q=Paper%203")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "2401.10003" in resp.text
|
||||||
|
assert "2401.10001" not in resp.text
|
||||||
|
|
||||||
|
def test_rebuild_indexes_fts(self, auth_client, db_session, no_enqueue):
|
||||||
|
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "fts"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["status"] == "queued"
|
||||||
|
assert len(data["job_ids"]) == 1
|
||||||
|
jobs = (
|
||||||
|
db_session.execute(select(Job).where(Job.type == "reindex_fts"))
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
assert len(jobs) == 1
|
||||||
|
|
||||||
|
def test_rebuild_indexes_both(self, auth_client, db_session, no_enqueue):
|
||||||
|
resp = auth_client.post("/admin/rebuild-indexes", json={"target": "both"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data["job_ids"]) == 2
|
||||||
|
|
||||||
|
def test_release_lock_route(self, auth_client, db_session):
|
||||||
|
lock = TaskLock(
|
||||||
|
task="crawl", lock_key="rt", status="running", acquired_at=utc_now()
|
||||||
|
)
|
||||||
|
db_session.add(lock)
|
||||||
|
db_session.commit()
|
||||||
|
resp = auth_client.post(f"/admin/locks/{lock.id}/release")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
db_session.refresh(lock)
|
||||||
|
assert lock.status == "finished"
|
||||||
|
|
||||||
|
def test_paper_recrawl_route(
|
||||||
|
self, auth_client, sample_paper, db_session, no_enqueue
|
||||||
|
):
|
||||||
|
resp = auth_client.post(f"/admin/paper-recrawl/{sample_paper.arxiv_id}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["status"] == "queued"
|
||||||
|
jobs = (
|
||||||
|
db_session.execute(select(Job).where(Job.type == "recrawl_one"))
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
assert len(jobs) == 1
|
||||||
|
|
||||||
|
def test_batch_recrawl_route(
|
||||||
|
self, auth_client, sample_papers_range, db_session, no_enqueue
|
||||||
|
):
|
||||||
|
ids = [p.arxiv_id for p in sample_papers_range[:3]]
|
||||||
|
resp = auth_client.post(
|
||||||
|
"/admin/papers-batch-action", json={"action": "recrawl", "arxiv_ids": ids}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "queued"
|
||||||
|
jobs = (
|
||||||
|
db_session.execute(select(Job).where(Job.type == "recrawl_batch"))
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
assert len(jobs) == 1
|
||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
@@ -154,3 +156,84 @@ class TestEmbeddingApi:
|
|||||||
)
|
)
|
||||||
result = emb._get_embedding("test")
|
result = emb._get_embedding("test")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 并发安全:init() 双重检查锁 + 集合访问串行化
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbedderConcurrency:
|
||||||
|
"""后处理经 asyncio.to_thread 多 worker 并发调 index_paper 的安全性。"""
|
||||||
|
|
||||||
|
def test_init_serialized_under_concurrency(self, monkeypatch, tmp_path):
|
||||||
|
"""并发 init() 只调一次 PersistentClient(chromadb SharedSystemClient 缓存竞争修复)。
|
||||||
|
|
||||||
|
复现崩坏条件:10 线程同时 init(),fake PersistentClient 故意 sleep 拉长建连窗口。
|
||||||
|
修复前会有多线程同时进入 _create_system_if_not_exists → 并发 mutate 类级缓存;
|
||||||
|
修复后(双重检查锁)只有抢到锁的那个线程建连。
|
||||||
|
"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||||||
|
import app.services.embedder as emb
|
||||||
|
|
||||||
|
emb._chroma.reset()
|
||||||
|
|
||||||
|
counter = {"n": 0}
|
||||||
|
counter_lock = threading.Lock()
|
||||||
|
|
||||||
|
def fake_persistent_client(path):
|
||||||
|
with counter_lock:
|
||||||
|
counter["n"] += 1
|
||||||
|
time.sleep(0.05) # 拉长建连窗口,放大并发竞争
|
||||||
|
client = MagicMock()
|
||||||
|
client.get_collection.side_effect = Exception(
|
||||||
|
"not exist"
|
||||||
|
) # 触发 create 路径
|
||||||
|
client.create_collection.return_value = MagicMock()
|
||||||
|
return client
|
||||||
|
|
||||||
|
with patch("chromadb.PersistentClient", side_effect=fake_persistent_client):
|
||||||
|
threads = [threading.Thread(target=emb._chroma.init) for _ in range(10)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert counter["n"] == 1, f"PersistentClient 应只调一次,实际 {counter['n']}"
|
||||||
|
assert emb._chroma._client is not None
|
||||||
|
emb._chroma.reset()
|
||||||
|
|
||||||
|
def test_index_paper_concurrent_no_error(self, monkeypatch, tmp_path):
|
||||||
|
"""并发 index_paper:embedding 锁外并行,集合写入串行化,全部成功。"""
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
|
||||||
|
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||||||
|
import app.services.embedder as emb
|
||||||
|
|
||||||
|
emb._chroma.reset()
|
||||||
|
# 跳过 init,直接注入 mock collection
|
||||||
|
emb._chroma._client = MagicMock()
|
||||||
|
col = MagicMock()
|
||||||
|
col.count.return_value = 0
|
||||||
|
emb._chroma._collection = col
|
||||||
|
|
||||||
|
with patch.object(emb, "_get_embedding", return_value=[0.1, 0.2, 0.3]):
|
||||||
|
errors: list[BaseException] = []
|
||||||
|
|
||||||
|
def worker(i: int) -> None:
|
||||||
|
try:
|
||||||
|
emb.index_paper(
|
||||||
|
f"id-{i}", {"arxiv_id": f"id-{i}", "title_zh": f"标题{i}"}
|
||||||
|
)
|
||||||
|
except BaseException as exc: # noqa: BLE001 — 收集所有错误
|
||||||
|
errors.append(exc)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert errors == []
|
||||||
|
assert col.upsert.call_count == 10
|
||||||
|
emb._chroma.reset()
|
||||||
|
|||||||
@@ -7,6 +7,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -166,10 +168,10 @@ class TestClassMapping:
|
|||||||
def test_table(self):
|
def test_table(self):
|
||||||
assert _map_class_to_boxclass(5, {5: "table"}) == "table"
|
assert _map_class_to_boxclass(5, {5: "table"}) == "table"
|
||||||
|
|
||||||
def test_caption_ignored(self):
|
def test_caption_classes(self):
|
||||||
names = {4: "figure_caption", 6: "table_caption"}
|
names = {4: "figure_caption", 6: "table_caption"}
|
||||||
assert _map_class_to_boxclass(4, names) is None
|
assert _map_class_to_boxclass(4, names) == "figure_caption"
|
||||||
assert _map_class_to_boxclass(6, names) is None
|
assert _map_class_to_boxclass(6, names) == "table_caption"
|
||||||
|
|
||||||
def test_other_classes_ignored(self):
|
def test_other_classes_ignored(self):
|
||||||
names = {0: "title", 1: "plain text", 2: "abandon", 8: "isolate_formula"}
|
names = {0: "title", 1: "plain text", 2: "abandon", 8: "isolate_formula"}
|
||||||
@@ -240,9 +242,11 @@ class TestPostprocessOutput:
|
|||||||
class TestDetectPage:
|
class TestDetectPage:
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def _reset_detector(self):
|
def _reset_detector(self):
|
||||||
"""每个测试前重置模块级单例,避免复用上个测试的 mock session。"""
|
"""每个测试前重建单例(带新锁 + 空 session),避免复用上个测试的 mock session。"""
|
||||||
|
mod._LayoutDetector.reset_instance()
|
||||||
mod._detector = mod._LayoutDetector()
|
mod._detector = mod._LayoutDetector()
|
||||||
yield
|
yield
|
||||||
|
mod._LayoutDetector.reset_instance()
|
||||||
mod._detector = mod._LayoutDetector()
|
mod._detector = mod._LayoutDetector()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -338,6 +342,20 @@ class TestDetectPage:
|
|||||||
assert len(boxes) == 1
|
assert len(boxes) == 1
|
||||||
assert boxes[0].boxclass == "table"
|
assert boxes[0].boxclass == "table"
|
||||||
|
|
||||||
|
def test_returns_caption_box_with_small_height(self, monkeypatch, tmp_path):
|
||||||
|
names = {4: "figure_caption"}
|
||||||
|
sess, (pw, ph) = self._build_mock_session(
|
||||||
|
595, 842, [(4, 100, 405, 300, 417, 0.9)], names
|
||||||
|
)
|
||||||
|
self._setup(monkeypatch, tmp_path, sess)
|
||||||
|
page = self._make_mock_page(595, 842, pw, ph)
|
||||||
|
|
||||||
|
boxes = detect_page_layout(page)
|
||||||
|
|
||||||
|
assert len(boxes) == 1
|
||||||
|
assert boxes[0].boxclass == "figure_caption"
|
||||||
|
assert boxes[0].y1 - boxes[0].y0 == pytest.approx(12, abs=1.0)
|
||||||
|
|
||||||
def test_filters_low_confidence(self, monkeypatch, tmp_path):
|
def test_filters_low_confidence(self, monkeypatch, tmp_path):
|
||||||
names = {3: "figure"}
|
names = {3: "figure"}
|
||||||
# conf=0.1 < LAYOUT_THRESHOLD(0.2) → 过滤
|
# conf=0.1 < LAYOUT_THRESHOLD(0.2) → 过滤
|
||||||
@@ -401,3 +419,141 @@ class TestDetectPage:
|
|||||||
boxes = detect_page_layout(page)
|
boxes = detect_page_layout(page)
|
||||||
assert len(boxes) == 1
|
assert len(boxes) == 1
|
||||||
assert boxes[0].boxclass == "picture"
|
assert boxes[0].boxclass == "picture"
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 并发安全:锁串行化推理 + 单例 session 只初始化一次
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestDetectPageConcurrency:
|
||||||
|
"""锁包裹整段 detect_page 后,并发调用的安全性。"""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_detector(self):
|
||||||
|
"""重建单例(带新锁),避免跨测试锁状态污染。"""
|
||||||
|
mod._LayoutDetector.reset_instance()
|
||||||
|
mod._detector = mod._LayoutDetector()
|
||||||
|
yield
|
||||||
|
mod._LayoutDetector.reset_instance()
|
||||||
|
mod._detector = mod._LayoutDetector()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_mock_session(page_w, page_h, boxes, names):
|
||||||
|
"""同 TestDetectPage._build_mock_session,额外返回 fake_output 供 side_effect。"""
|
||||||
|
ratio = _compute_render_geometry(page_w, page_h, IMGSZ)
|
||||||
|
pix_w, pix_h = round(page_w * ratio), round(page_h * ratio)
|
||||||
|
dw, dh = _letterbox_padding(pix_w, pix_h, IMGSZ)
|
||||||
|
rows = []
|
||||||
|
for cls_id, x0, y0, x1, y1, conf in boxes:
|
||||||
|
rows.append(
|
||||||
|
[
|
||||||
|
x0 * ratio + dw,
|
||||||
|
y0 * ratio + dh,
|
||||||
|
x1 * ratio + dw,
|
||||||
|
y1 * ratio + dh,
|
||||||
|
conf,
|
||||||
|
cls_id,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
fake_output = (
|
||||||
|
np.array([rows], dtype=np.float32)
|
||||||
|
if rows
|
||||||
|
else np.zeros((1, 0, 6), dtype=np.float32)
|
||||||
|
)
|
||||||
|
sess = MagicMock()
|
||||||
|
inp = MagicMock()
|
||||||
|
inp.name = "images"
|
||||||
|
sess.get_inputs.return_value = [inp]
|
||||||
|
sess.run.return_value = [fake_output]
|
||||||
|
sess.get_providers.return_value = ["CPUExecutionProvider"]
|
||||||
|
meta = MagicMock()
|
||||||
|
meta.custom_metadata_map = {
|
||||||
|
"names": json.dumps({str(k): v for k, v in names.items()})
|
||||||
|
}
|
||||||
|
sess.get_modelmeta.return_value = meta
|
||||||
|
return sess, (pix_w, pix_h), fake_output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_mock_page(page_w, page_h, pix_w, pix_h):
|
||||||
|
pix = MagicMock()
|
||||||
|
pix.width = pix_w
|
||||||
|
pix.height = pix_h
|
||||||
|
pix.n = 3
|
||||||
|
pix.samples = bytes([128] * (pix_w * pix_h * 3))
|
||||||
|
page = MagicMock()
|
||||||
|
page.rect.width = page_w
|
||||||
|
page.rect.height = page_h
|
||||||
|
page.get_pixmap.return_value = pix
|
||||||
|
return page
|
||||||
|
|
||||||
|
def _setup(self, monkeypatch, tmp_path, sess):
|
||||||
|
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
|
||||||
|
(tmp_path / "m.onnx").write_bytes(b"x")
|
||||||
|
monkeypatch.setattr(ort, "InferenceSession", lambda *a, **kw: sess)
|
||||||
|
|
||||||
|
def test_detect_page_serializes_concurrent_calls(self, monkeypatch, tmp_path):
|
||||||
|
"""多线程并发调 detect_page_layout,session.run 临界区同时只有一个。"""
|
||||||
|
sess, (pw, ph), fake_output = self._build_mock_session(
|
||||||
|
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
|
||||||
|
)
|
||||||
|
in_critical = 0
|
||||||
|
max_concurrent = 0
|
||||||
|
counter_lock = threading.Lock()
|
||||||
|
|
||||||
|
def counting_run(*args, **kwargs):
|
||||||
|
nonlocal in_critical, max_concurrent
|
||||||
|
with counter_lock:
|
||||||
|
in_critical += 1
|
||||||
|
max_concurrent = max(max_concurrent, in_critical)
|
||||||
|
time.sleep(0.02) # 放大竞争窗口,让并发线程有机会重叠
|
||||||
|
try:
|
||||||
|
return [fake_output]
|
||||||
|
finally:
|
||||||
|
with counter_lock:
|
||||||
|
in_critical -= 1
|
||||||
|
|
||||||
|
sess.run.side_effect = counting_run
|
||||||
|
self._setup(monkeypatch, tmp_path, sess)
|
||||||
|
|
||||||
|
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(8)]
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
|
||||||
|
]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# 锁生效 → 临界区同时只有一个;不加锁时此值会 > 1(回归保护)
|
||||||
|
assert max_concurrent == 1
|
||||||
|
|
||||||
|
def test_session_created_once_under_concurrency(self, monkeypatch, tmp_path):
|
||||||
|
"""多线程并发首次调用,InferenceSession 只创建一次(锁间接保护 _init_session)。"""
|
||||||
|
sess, (pw, ph), _fake_output = self._build_mock_session(
|
||||||
|
595, 842, [(3, 100, 100, 300, 400, 0.9)], {3: "figure"}
|
||||||
|
)
|
||||||
|
create_count = 0
|
||||||
|
create_lock = threading.Lock()
|
||||||
|
|
||||||
|
def counting_init(*args, **kwargs):
|
||||||
|
nonlocal create_count
|
||||||
|
with create_lock:
|
||||||
|
create_count += 1
|
||||||
|
time.sleep(0.02) # 放大窗口,让并发首调都来抢
|
||||||
|
return sess
|
||||||
|
|
||||||
|
monkeypatch.setattr(ort, "InferenceSession", counting_init)
|
||||||
|
monkeypatch.setattr(settings, "LAYOUT_MODEL_PATH", str(tmp_path / "m.onnx"))
|
||||||
|
(tmp_path / "m.onnx").write_bytes(b"x")
|
||||||
|
|
||||||
|
pages = [self._make_mock_page(595, 842, pw, ph) for _ in range(6)]
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=detect_page_layout, args=(p,)) for p in pages
|
||||||
|
]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert create_count == 1
|
||||||
|
|||||||
@@ -0,0 +1,134 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pymupdf
|
||||||
|
|
||||||
|
from app.services import pdf_image_extractor as mod
|
||||||
|
from app.services.layout_detector import LayoutBox
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_page_extracts_doclayout_caption(tmp_path):
|
||||||
|
images_dest = tmp_path / "images"
|
||||||
|
images_dest.mkdir()
|
||||||
|
manifest: dict[str, dict] = {}
|
||||||
|
|
||||||
|
pix = MagicMock()
|
||||||
|
pix.tobytes.return_value = b"jpeg"
|
||||||
|
|
||||||
|
page = MagicMock()
|
||||||
|
page.rect.width = 600
|
||||||
|
page.get_pixmap.return_value = pix
|
||||||
|
page.get_text.return_value = "Figure 1: Overall architecture.\n"
|
||||||
|
|
||||||
|
doc = MagicMock()
|
||||||
|
doc.__getitem__.return_value = page
|
||||||
|
|
||||||
|
boxes = [
|
||||||
|
LayoutBox(100, 100, 300, 300, "picture"),
|
||||||
|
LayoutBox(95, 310, 320, 325, "figure_caption"),
|
||||||
|
]
|
||||||
|
|
||||||
|
extracted = mod._process_page(
|
||||||
|
doc,
|
||||||
|
0,
|
||||||
|
boxes,
|
||||||
|
images_dest=images_dest,
|
||||||
|
manifest=manifest,
|
||||||
|
seen_labels=set(),
|
||||||
|
arxiv_id="2401.00001",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert extracted == 1
|
||||||
|
info = manifest["figure_(p1-1).jpg"]
|
||||||
|
assert info["caption_text"] == "Figure 1: Overall architecture."
|
||||||
|
assert info["caption_source"] == "doclayout"
|
||||||
|
assert info["caption_box"] == [95.0, 310.0, 320.0, 325.0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_page_includes_caption_in_render(tmp_path):
|
||||||
|
"""渲染时把 caption 区域合并进同一张截图。"""
|
||||||
|
images_dest = tmp_path / "images"
|
||||||
|
images_dest.mkdir()
|
||||||
|
manifest: dict[str, dict] = {}
|
||||||
|
|
||||||
|
pix = MagicMock()
|
||||||
|
pix.tobytes.return_value = b"jpeg"
|
||||||
|
|
||||||
|
page = MagicMock()
|
||||||
|
page.rect.width = 600
|
||||||
|
page.get_pixmap.return_value = pix
|
||||||
|
page.get_text.return_value = "Figure 1: Caption text.\n"
|
||||||
|
|
||||||
|
doc = MagicMock()
|
||||||
|
doc.__getitem__.return_value = page
|
||||||
|
|
||||||
|
boxes = [
|
||||||
|
LayoutBox(100, 100, 300, 300, "picture"),
|
||||||
|
LayoutBox(95, 310, 320, 325, "figure_caption"),
|
||||||
|
]
|
||||||
|
|
||||||
|
mod._process_page(
|
||||||
|
doc,
|
||||||
|
0,
|
||||||
|
boxes,
|
||||||
|
images_dest=images_dest,
|
||||||
|
manifest=manifest,
|
||||||
|
seen_labels=set(),
|
||||||
|
arxiv_id="2401.00001",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 内容 [100,100,300,300] ∪ caption [95,310,320,325],各方向加 _REGION_PADDING=5
|
||||||
|
# → Rect(90, 95, 325, 330)
|
||||||
|
clip = page.get_pixmap.call_args.kwargs["clip"]
|
||||||
|
assert clip == pymupdf.Rect(90, 95, 325, 330)
|
||||||
|
|
||||||
|
|
||||||
|
def test_label_images_preserves_doclayout_caption(tmp_path, monkeypatch):
|
||||||
|
arxiv_id = "2401.00001"
|
||||||
|
paper_root = tmp_path / arxiv_id
|
||||||
|
images_dest = paper_root / "images"
|
||||||
|
images_dest.mkdir(parents=True)
|
||||||
|
(images_dest / "figure_(p1-1).jpg").write_bytes(b"jpeg")
|
||||||
|
(images_dest / "manifest.json").write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"figure_(p1-1).jpg": {
|
||||||
|
"page": 1,
|
||||||
|
"type": "figure",
|
||||||
|
"label": "Figure (p1-1)",
|
||||||
|
"box": [100, 100, 300, 300],
|
||||||
|
"caption_text": "Figure 1: PDF original caption.",
|
||||||
|
"caption_source": "doclayout",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
pdf_path = tmp_path / "paper.pdf"
|
||||||
|
pdf_path.write_bytes(b"%PDF")
|
||||||
|
monkeypatch.setattr(mod, "paper_dir", lambda _arxiv_id: paper_root)
|
||||||
|
|
||||||
|
page = MagicMock()
|
||||||
|
page.search_for.return_value = [pymupdf.Rect(120, 305, 180, 320)]
|
||||||
|
|
||||||
|
fake_doc = MagicMock()
|
||||||
|
fake_doc.page_count = 1
|
||||||
|
fake_doc.__getitem__.return_value = page
|
||||||
|
fake_doc.__enter__.return_value = fake_doc
|
||||||
|
fake_doc.__exit__.return_value = False
|
||||||
|
monkeypatch.setattr(mod.pymupdf, "open", lambda _path: fake_doc)
|
||||||
|
|
||||||
|
labeled = mod.label_images_by_summary(
|
||||||
|
arxiv_id,
|
||||||
|
[{"id": "Figure 1", "caption": "Summary caption."}],
|
||||||
|
pdf_path=pdf_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert labeled == 1
|
||||||
|
manifest = json.loads((images_dest / "manifest.json").read_text())
|
||||||
|
info = manifest["figure_1.jpg"]
|
||||||
|
assert info["caption_text"] == "Figure 1: PDF original caption."
|
||||||
|
assert info["caption_source"] == "doclayout"
|
||||||
|
assert info["summary_caption_text"] == "Summary caption."
|
||||||
@@ -366,6 +366,38 @@ class TestSummarizeOneFlow:
|
|||||||
result = await summarize_one(db_session, sample_paper)
|
result = await summarize_one(db_session, sample_paper)
|
||||||
assert result["status"] == "skipped"
|
assert result["status"] == "skipped"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_post_processing_runs_in_thread(
|
||||||
|
self, db_session, sample_paper, mock_pi_output, _summarize_tmp_paths
|
||||||
|
):
|
||||||
|
"""后处理(图片提取/ChromaDB)在工作线程而非事件循环线程执行。"""
|
||||||
|
import threading
|
||||||
|
|
||||||
|
seen_threads: list[int] = []
|
||||||
|
main_thread = threading.current_thread().ident
|
||||||
|
|
||||||
|
def spy_extract(arxiv_id, schema):
|
||||||
|
seen_threads.append(threading.current_thread().ident)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
|
patch(
|
||||||
|
"app.services.summary_generator.call_pi",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=(mock_pi_output, "test-session-id"),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"app.services.summary_persister._maybe_extract_images",
|
||||||
|
side_effect=spy_extract,
|
||||||
|
),
|
||||||
|
patch("app.services.summary_persister._maybe_index_chroma"),
|
||||||
|
):
|
||||||
|
result = await summarize_one(db_session, sample_paper)
|
||||||
|
|
||||||
|
assert result["status"] == "done"
|
||||||
|
assert seen_threads, "后处理未被调用"
|
||||||
|
assert seen_threads[0] != main_thread, "后处理应在工作线程执行,不阻塞事件循环"
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
# 批量操作测试
|
# 批量操作测试
|
||||||
|
|||||||
Reference in New Issue
Block a user