90fe705e8f
- 核心变更: - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024) - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备 - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式 - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls] - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table) - 新增文件: - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行) - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例) - 配置更新: - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID) - app/config.py: Settings 类对应字段 - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等) - 删除旧文件: - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本 - 文档更新: - README.md: 更新环境变量说明 - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py) 此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
247 lines
7.0 KiB
Python
247 lines
7.0 KiB
Python
"""统一后台任务系统 — 创建、运行、事件记录、失败恢复。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from datetime import date, timedelta
|
|
from typing import Any
|
|
|
|
from fastapi import BackgroundTasks
|
|
from sqlalchemy import or_, select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.config import settings
|
|
from app.database import SessionLocal
|
|
from app.models import Job, JobEvent, JobEventStatus, JobStatus, TaskLock
|
|
from app.utils import truncate_error, utc_now
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
STALE_JOB_AFTER = timedelta(hours=6)
|
|
|
|
|
|
def _dumps(value: Any) -> str:
|
|
return json.dumps(value, ensure_ascii=False, default=str)
|
|
|
|
|
|
def _loads(value: str | None) -> dict:
|
|
if not value:
|
|
return {}
|
|
try:
|
|
data = json.loads(value)
|
|
return data if isinstance(data, dict) else {}
|
|
except json.JSONDecodeError:
|
|
return {}
|
|
|
|
|
|
def create_job(
|
|
db: Session,
|
|
job_type: str,
|
|
*,
|
|
owner: str,
|
|
payload: dict | None = None,
|
|
) -> Job:
|
|
"""创建后台任务主记录。"""
|
|
job = Job(
|
|
type=job_type,
|
|
status=JobStatus.QUEUED,
|
|
owner=owner,
|
|
payload_json=_dumps(payload or {}),
|
|
created_at=utc_now(),
|
|
)
|
|
db.add(job)
|
|
db.commit()
|
|
db.refresh(job)
|
|
add_job_event(
|
|
db,
|
|
job,
|
|
stage="created",
|
|
status=JobEventStatus.INFO,
|
|
message=f"Job queued: {job_type}",
|
|
payload=payload or {},
|
|
)
|
|
return job
|
|
|
|
|
|
def add_job_event(
|
|
db: Session,
|
|
job: Job,
|
|
*,
|
|
stage: str,
|
|
status: str,
|
|
message: str | None = None,
|
|
payload: dict | None = None,
|
|
) -> None:
|
|
"""追加一条任务阶段事件。"""
|
|
db.add(
|
|
JobEvent(
|
|
job_id=job.id,
|
|
stage=stage,
|
|
status=str(status),
|
|
message=message,
|
|
payload_json=_dumps(payload) if payload is not None else None,
|
|
created_at=utc_now(),
|
|
)
|
|
)
|
|
job.heartbeat_at = utc_now()
|
|
db.commit()
|
|
|
|
|
|
def enqueue_job(background_tasks: BackgroundTasks, job_id: int) -> None:
|
|
"""把任务提交给 FastAPI BackgroundTasks。"""
|
|
background_tasks.add_task(run_job_by_id, job_id)
|
|
|
|
|
|
async def run_job_by_id(job_id: int) -> None:
|
|
"""使用独立 DB session 运行一个已创建的 job。"""
|
|
db = SessionLocal()
|
|
try:
|
|
await run_job(db, job_id)
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
async def run_job(db: Session, job_id: int) -> dict:
|
|
"""运行 job,并把状态/result/error 写回 jobs/job_events。"""
|
|
job = db.get(Job, job_id)
|
|
if not job:
|
|
raise ValueError(f"Job not found: {job_id}")
|
|
if job.status == JobStatus.RUNNING:
|
|
raise RuntimeError(f"Job already running: {job_id}")
|
|
|
|
payload = _loads(job.payload_json)
|
|
job.status = JobStatus.RUNNING
|
|
job.started_at = utc_now()
|
|
job.heartbeat_at = job.started_at
|
|
db.commit()
|
|
add_job_event(db, job, stage="run", status=JobEventStatus.STARTED)
|
|
|
|
try:
|
|
result = await _dispatch_job(db, job, payload)
|
|
except Exception as exc:
|
|
logger.exception("Job failed: id=%s type=%s", job.id, job.type)
|
|
error = truncate_error(exc, limit=4000)
|
|
job.status = JobStatus.FAILED
|
|
job.error = error
|
|
job.completed_at = utc_now()
|
|
db.commit()
|
|
add_job_event(db, job, stage="run", status=JobEventStatus.FAILED, message=error)
|
|
return {"status": "failed", "error": error}
|
|
|
|
job.status = JobStatus.SUCCESS
|
|
job.result_json = _dumps(result)
|
|
job.completed_at = utc_now()
|
|
job.error = None
|
|
db.commit()
|
|
add_job_event(
|
|
db,
|
|
job,
|
|
stage="run",
|
|
status=JobEventStatus.SUCCESS,
|
|
payload=result if isinstance(result, dict) else {"result": result},
|
|
)
|
|
return (
|
|
result if isinstance(result, dict) else {"status": "success", "result": result}
|
|
)
|
|
|
|
|
|
async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
|
|
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
|
|
from app.services.crawler import refresh_upvotes
|
|
from app.services.derived import reindex_chroma, reindex_fts
|
|
from app.services.pipeline import run_crawl, run_pipeline
|
|
from app.services.summarizer import summarize_batch, summarize_single
|
|
|
|
if job.type == "crawl_daily":
|
|
return await run_crawl(
|
|
db,
|
|
payload["target_date"],
|
|
owner=job.owner or f"job:{job.id}",
|
|
top_n=payload.get("top_n"),
|
|
)
|
|
if job.type == "pipeline_daily":
|
|
return await run_pipeline(
|
|
db,
|
|
payload["target_date"],
|
|
owner=job.owner or f"job:{job.id}",
|
|
)
|
|
if job.type == "summarize_batch":
|
|
return await summarize_batch(
|
|
db,
|
|
pdf_mode=payload.get("pdf_mode", settings.SUMMARY_PDF_MODE),
|
|
)
|
|
if job.type == "summarize_one":
|
|
return await summarize_single(
|
|
db,
|
|
payload["arxiv_id"],
|
|
force=payload.get("force", True),
|
|
pdf_mode=payload.get("pdf_mode", settings.SUMMARY_PDF_MODE),
|
|
)
|
|
if job.type == "refresh_upvotes":
|
|
return await refresh_upvotes(db, days=payload.get("days"))
|
|
if job.type == "delete_range":
|
|
return await delete_papers_by_date_range(
|
|
db,
|
|
date.fromisoformat(payload["date_start"]),
|
|
date.fromisoformat(payload["date_end"]),
|
|
include_notes=payload.get("include_notes", True),
|
|
)
|
|
if job.type == "cleanup_tmp":
|
|
return cleanup_tmp()
|
|
if job.type == "reindex_fts":
|
|
return reindex_fts(db)
|
|
if job.type == "reindex_chroma":
|
|
return reindex_chroma(db)
|
|
|
|
raise ValueError(f"Unsupported job type: {job.type}")
|
|
|
|
|
|
def recover_stale_jobs(db: Session) -> int:
|
|
"""启动时将过期 running job/lock 标记为 stale,避免永久卡住。"""
|
|
now = utc_now()
|
|
cutoff = now - STALE_JOB_AFTER
|
|
stale_jobs = (
|
|
db.execute(
|
|
select(Job).where(
|
|
Job.status == JobStatus.RUNNING,
|
|
or_(Job.heartbeat_at == None, Job.heartbeat_at < cutoff), # noqa: E711
|
|
)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
for job in stale_jobs:
|
|
job.status = JobStatus.STALE
|
|
job.error = "Marked stale after process restart or missed heartbeat"
|
|
job.completed_at = now
|
|
db.add(
|
|
JobEvent(
|
|
job_id=job.id,
|
|
stage="recovery",
|
|
status=JobEventStatus.FAILED,
|
|
message=job.error,
|
|
created_at=now,
|
|
)
|
|
)
|
|
|
|
stale_locks = (
|
|
db.execute(
|
|
select(TaskLock).where(
|
|
TaskLock.status == "running",
|
|
TaskLock.acquired_at < cutoff,
|
|
)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
for lock in stale_locks:
|
|
lock.status = "stale"
|
|
lock.released_at = now
|
|
|
|
db.commit()
|
|
recovered = len(stale_jobs) + len(stale_locks)
|
|
if recovered:
|
|
logger.warning("Recovered stale runtime records: %d", recovered)
|
|
return recovered
|