"""统一后台任务系统 — 创建、运行、事件记录、失败恢复。""" from __future__ import annotations import json import logging from datetime import date, timedelta from typing import Any from fastapi import BackgroundTasks from sqlalchemy import or_, select from sqlalchemy.orm import Session from app.config import settings from app.database import SessionLocal from app.models import Job, JobEvent, JobEventStatus, JobStatus, TaskLock from app.utils import truncate_error, utc_now logger = logging.getLogger(__name__) STALE_JOB_AFTER = timedelta(hours=6) def _dumps(value: Any) -> str: return json.dumps(value, ensure_ascii=False, default=str) def _loads(value: str | None) -> dict: if not value: return {} try: data = json.loads(value) return data if isinstance(data, dict) else {} except json.JSONDecodeError: return {} def create_job( db: Session, job_type: str, *, owner: str, payload: dict | None = None, ) -> Job: """创建后台任务主记录。""" job = Job( type=job_type, status=JobStatus.QUEUED, owner=owner, payload_json=_dumps(payload or {}), created_at=utc_now(), ) db.add(job) db.commit() db.refresh(job) add_job_event( db, job, stage="created", status=JobEventStatus.INFO, message=f"Job queued: {job_type}", payload=payload or {}, ) return job def add_job_event( db: Session, job: Job, *, stage: str, status: str, message: str | None = None, payload: dict | None = None, ) -> None: """追加一条任务阶段事件。""" db.add( JobEvent( job_id=job.id, stage=stage, status=str(status), message=message, payload_json=_dumps(payload) if payload is not None else None, created_at=utc_now(), ) ) job.heartbeat_at = utc_now() db.commit() def enqueue_job(background_tasks: BackgroundTasks, job_id: int) -> None: """把任务提交给 FastAPI BackgroundTasks。""" background_tasks.add_task(run_job_by_id, job_id) async def run_job_by_id(job_id: int) -> None: """使用独立 DB session 运行一个已创建的 job。""" db = SessionLocal() try: await run_job(db, job_id) finally: db.close() async def run_job(db: Session, job_id: int) -> dict: """运行 job,并把状态/result/error 写回 jobs/job_events。""" job = db.get(Job, job_id) if not job: raise ValueError(f"Job not found: {job_id}") if job.status == JobStatus.RUNNING: raise RuntimeError(f"Job already running: {job_id}") payload = _loads(job.payload_json) job.status = JobStatus.RUNNING job.started_at = utc_now() job.heartbeat_at = job.started_at db.commit() add_job_event(db, job, stage="run", status=JobEventStatus.STARTED) try: result = await _dispatch_job(db, job, payload) except Exception as exc: logger.exception("Job failed: id=%s type=%s", job.id, job.type) error = truncate_error(exc, limit=4000) job.status = JobStatus.FAILED job.error = error job.completed_at = utc_now() db.commit() add_job_event(db, job, stage="run", status=JobEventStatus.FAILED, message=error) return {"status": "failed", "error": error} job.status = JobStatus.SUCCESS job.result_json = _dumps(result) job.completed_at = utc_now() job.error = None db.commit() add_job_event( db, job, stage="run", status=JobEventStatus.SUCCESS, payload=result if isinstance(result, dict) else {"result": result}, ) return ( result if isinstance(result, dict) else {"status": "success", "result": result} ) async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict: from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range from app.services.crawler import refresh_upvotes from app.services.derived import reindex_chroma, reindex_fts from app.services.pipeline import run_crawl, run_pipeline from app.services.summarizer import summarize_batch, summarize_single if job.type == "crawl_daily": return await run_crawl( db, payload["target_date"], owner=job.owner or f"job:{job.id}", top_n=payload.get("top_n"), ) if job.type == "pipeline_daily": return await run_pipeline( db, payload["target_date"], owner=job.owner or f"job:{job.id}", ) if job.type == "summarize_batch": return await summarize_batch( db, pdf_mode=payload.get("pdf_mode", settings.SUMMARY_PDF_MODE), ) if job.type == "summarize_one": return await summarize_single( db, payload["arxiv_id"], force=payload.get("force", True), pdf_mode=payload.get("pdf_mode", settings.SUMMARY_PDF_MODE), ) if job.type == "refresh_upvotes": return await refresh_upvotes(db, days=payload.get("days")) if job.type == "delete_range": return await delete_papers_by_date_range( db, date.fromisoformat(payload["date_start"]), date.fromisoformat(payload["date_end"]), include_notes=payload.get("include_notes", True), ) if job.type == "cleanup_tmp": return cleanup_tmp() if job.type == "reindex_fts": return reindex_fts(db) if job.type == "reindex_chroma": return reindex_chroma(db) raise ValueError(f"Unsupported job type: {job.type}") def recover_stale_jobs(db: Session) -> int: """启动时将过期 running job/lock 标记为 stale,避免永久卡住。""" now = utc_now() cutoff = now - STALE_JOB_AFTER stale_jobs = ( db.execute( select(Job).where( Job.status == JobStatus.RUNNING, or_(Job.heartbeat_at == None, Job.heartbeat_at < cutoff), # noqa: E711 ) ) .scalars() .all() ) for job in stale_jobs: job.status = JobStatus.STALE job.error = "Marked stale after process restart or missed heartbeat" job.completed_at = now db.add( JobEvent( job_id=job.id, stage="recovery", status=JobEventStatus.FAILED, message=job.error, created_at=now, ) ) stale_locks = ( db.execute( select(TaskLock).where( TaskLock.status == "running", TaskLock.acquired_at < cutoff, ) ) .scalars() .all() ) for lock in stale_locks: lock.status = "stale" lock.released_at = now db.commit() recovered = len(stale_jobs) + len(stale_locks) if recovered: logger.warning("Recovered stale runtime records: %d", recovered) return recovered