Files
daily-paper/app/services/jobs.py
T
Rain-Bus 90fe705e8f refactor: 迁移布局检测模型从 PicoDet 到 DocLayout-YOLO
- 核心变更:
  - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024)
  - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备
  - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式
  - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls]
  - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table)

- 新增文件:
  - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行)
  - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例)

- 配置更新:
  - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID)
  - app/config.py: Settings 类对应字段
  - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等)

- 删除旧文件:
  - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本

- 文档更新:
  - README.md: 更新环境变量说明
  - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py)

此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
2026-06-14 10:41:44 +08:00

247 lines
7.0 KiB
Python

"""统一后台任务系统 — 创建、运行、事件记录、失败恢复。"""
from __future__ import annotations
import json
import logging
from datetime import date, timedelta
from typing import Any
from fastapi import BackgroundTasks
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
from app.config import settings
from app.database import SessionLocal
from app.models import Job, JobEvent, JobEventStatus, JobStatus, TaskLock
from app.utils import truncate_error, utc_now
logger = logging.getLogger(__name__)
STALE_JOB_AFTER = timedelta(hours=6)
def _dumps(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, default=str)
def _loads(value: str | None) -> dict:
if not value:
return {}
try:
data = json.loads(value)
return data if isinstance(data, dict) else {}
except json.JSONDecodeError:
return {}
def create_job(
db: Session,
job_type: str,
*,
owner: str,
payload: dict | None = None,
) -> Job:
"""创建后台任务主记录。"""
job = Job(
type=job_type,
status=JobStatus.QUEUED,
owner=owner,
payload_json=_dumps(payload or {}),
created_at=utc_now(),
)
db.add(job)
db.commit()
db.refresh(job)
add_job_event(
db,
job,
stage="created",
status=JobEventStatus.INFO,
message=f"Job queued: {job_type}",
payload=payload or {},
)
return job
def add_job_event(
db: Session,
job: Job,
*,
stage: str,
status: str,
message: str | None = None,
payload: dict | None = None,
) -> None:
"""追加一条任务阶段事件。"""
db.add(
JobEvent(
job_id=job.id,
stage=stage,
status=str(status),
message=message,
payload_json=_dumps(payload) if payload is not None else None,
created_at=utc_now(),
)
)
job.heartbeat_at = utc_now()
db.commit()
def enqueue_job(background_tasks: BackgroundTasks, job_id: int) -> None:
"""把任务提交给 FastAPI BackgroundTasks。"""
background_tasks.add_task(run_job_by_id, job_id)
async def run_job_by_id(job_id: int) -> None:
"""使用独立 DB session 运行一个已创建的 job。"""
db = SessionLocal()
try:
await run_job(db, job_id)
finally:
db.close()
async def run_job(db: Session, job_id: int) -> dict:
"""运行 job,并把状态/result/error 写回 jobs/job_events。"""
job = db.get(Job, job_id)
if not job:
raise ValueError(f"Job not found: {job_id}")
if job.status == JobStatus.RUNNING:
raise RuntimeError(f"Job already running: {job_id}")
payload = _loads(job.payload_json)
job.status = JobStatus.RUNNING
job.started_at = utc_now()
job.heartbeat_at = job.started_at
db.commit()
add_job_event(db, job, stage="run", status=JobEventStatus.STARTED)
try:
result = await _dispatch_job(db, job, payload)
except Exception as exc:
logger.exception("Job failed: id=%s type=%s", job.id, job.type)
error = truncate_error(exc, limit=4000)
job.status = JobStatus.FAILED
job.error = error
job.completed_at = utc_now()
db.commit()
add_job_event(db, job, stage="run", status=JobEventStatus.FAILED, message=error)
return {"status": "failed", "error": error}
job.status = JobStatus.SUCCESS
job.result_json = _dumps(result)
job.completed_at = utc_now()
job.error = None
db.commit()
add_job_event(
db,
job,
stage="run",
status=JobEventStatus.SUCCESS,
payload=result if isinstance(result, dict) else {"result": result},
)
return (
result if isinstance(result, dict) else {"status": "success", "result": result}
)
async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
from app.services.crawler import refresh_upvotes
from app.services.derived import reindex_chroma, reindex_fts
from app.services.pipeline import run_crawl, run_pipeline
from app.services.summarizer import summarize_batch, summarize_single
if job.type == "crawl_daily":
return await run_crawl(
db,
payload["target_date"],
owner=job.owner or f"job:{job.id}",
top_n=payload.get("top_n"),
)
if job.type == "pipeline_daily":
return await run_pipeline(
db,
payload["target_date"],
owner=job.owner or f"job:{job.id}",
)
if job.type == "summarize_batch":
return await summarize_batch(
db,
pdf_mode=payload.get("pdf_mode", settings.SUMMARY_PDF_MODE),
)
if job.type == "summarize_one":
return await summarize_single(
db,
payload["arxiv_id"],
force=payload.get("force", True),
pdf_mode=payload.get("pdf_mode", settings.SUMMARY_PDF_MODE),
)
if job.type == "refresh_upvotes":
return await refresh_upvotes(db, days=payload.get("days"))
if job.type == "delete_range":
return await delete_papers_by_date_range(
db,
date.fromisoformat(payload["date_start"]),
date.fromisoformat(payload["date_end"]),
include_notes=payload.get("include_notes", True),
)
if job.type == "cleanup_tmp":
return cleanup_tmp()
if job.type == "reindex_fts":
return reindex_fts(db)
if job.type == "reindex_chroma":
return reindex_chroma(db)
raise ValueError(f"Unsupported job type: {job.type}")
def recover_stale_jobs(db: Session) -> int:
"""启动时将过期 running job/lock 标记为 stale,避免永久卡住。"""
now = utc_now()
cutoff = now - STALE_JOB_AFTER
stale_jobs = (
db.execute(
select(Job).where(
Job.status == JobStatus.RUNNING,
or_(Job.heartbeat_at == None, Job.heartbeat_at < cutoff), # noqa: E711
)
)
.scalars()
.all()
)
for job in stale_jobs:
job.status = JobStatus.STALE
job.error = "Marked stale after process restart or missed heartbeat"
job.completed_at = now
db.add(
JobEvent(
job_id=job.id,
stage="recovery",
status=JobEventStatus.FAILED,
message=job.error,
created_at=now,
)
)
stale_locks = (
db.execute(
select(TaskLock).where(
TaskLock.status == "running",
TaskLock.acquired_at < cutoff,
)
)
.scalars()
.all()
)
for lock in stale_locks:
lock.status = "stale"
lock.released_at = now
db.commit()
recovered = len(stale_jobs) + len(stale_locks)
if recovered:
logger.warning("Recovered stale runtime records: %d", recovered)
return recovered