Files
daily-paper/app/services/jobs.py
T

261 lines
7.5 KiB
Python

"""统一后台任务系统 — 创建、运行、事件记录、失败恢复。"""
from __future__ import annotations
import json
import logging
from datetime import date, timedelta
from typing import Any
from fastapi import BackgroundTasks
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
from app.config import settings
from app.database import SessionLocal
from app.models import Job, JobEvent, JobEventStatus, JobStatus, TaskLock
from app.utils import truncate_error, utc_now
logger = logging.getLogger(__name__)
STALE_JOB_AFTER = timedelta(hours=6)
def _dumps(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, default=str)
def _loads(value: str | None) -> dict:
if not value:
return {}
try:
data = json.loads(value)
return data if isinstance(data, dict) else {}
except json.JSONDecodeError:
return {}
def create_job(
db: Session,
job_type: str,
*,
owner: str,
payload: dict | None = None,
) -> Job:
"""创建后台任务主记录。"""
job = Job(
type=job_type,
status=JobStatus.QUEUED,
owner=owner,
payload_json=_dumps(payload or {}),
created_at=utc_now(),
)
db.add(job)
db.commit()
db.refresh(job)
add_job_event(
db,
job,
stage="created",
status=JobEventStatus.INFO,
message=f"Job queued: {job_type}",
payload=payload or {},
)
return job
def add_job_event(
db: Session,
job: Job,
*,
stage: str,
status: str,
message: str | None = None,
payload: dict | None = None,
) -> None:
"""追加一条任务阶段事件。"""
db.add(
JobEvent(
job_id=job.id,
stage=stage,
status=str(status),
message=message,
payload_json=_dumps(payload) if payload is not None else None,
created_at=utc_now(),
)
)
job.heartbeat_at = utc_now()
db.commit()
def enqueue_job(background_tasks: BackgroundTasks, job_id: int) -> None:
"""把任务提交给 FastAPI BackgroundTasks。"""
background_tasks.add_task(run_job_by_id, job_id)
async def run_job_by_id(job_id: int) -> None:
"""使用独立 DB session 运行一个已创建的 job。"""
db = SessionLocal()
try:
await run_job(db, job_id)
finally:
db.close()
async def run_job(db: Session, job_id: int) -> dict:
"""运行 job,并把状态/result/error 写回 jobs/job_events。"""
job = db.get(Job, job_id)
if not job:
raise ValueError(f"Job not found: {job_id}")
if job.status == JobStatus.RUNNING:
raise RuntimeError(f"Job already running: {job_id}")
payload = _loads(job.payload_json)
job.status = JobStatus.RUNNING
job.started_at = utc_now()
job.heartbeat_at = job.started_at
db.commit()
add_job_event(db, job, stage="run", status=JobEventStatus.STARTED)
try:
result = await _dispatch_job(db, job, payload)
except Exception as exc:
logger.exception("Job failed: id=%s type=%s", job.id, job.type)
error = truncate_error(exc, limit=4000)
job.status = JobStatus.FAILED
job.error = error
job.completed_at = utc_now()
db.commit()
add_job_event(db, job, stage="run", status=JobEventStatus.FAILED, message=error)
return {"status": "failed", "error": error}
job.status = JobStatus.SUCCESS
job.result_json = _dumps(result)
job.completed_at = utc_now()
job.error = None
db.commit()
add_job_event(
db,
job,
stage="run",
status=JobEventStatus.SUCCESS,
payload=result if isinstance(result, dict) else {"result": result},
)
return (
result if isinstance(result, dict) else {"status": "success", "result": result}
)
async def _dispatch_job(db: Session, job: Job, payload: dict) -> dict:
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
from app.services.crawler import recrawl_single, refresh_upvotes
from app.services.derived import reindex_chroma, reindex_fts
from app.services.pipeline import run_crawl, run_pipeline
from app.services.summarizer import summarize_batch, summarize_single
if job.type == "crawl_daily":
return await run_crawl(
db,
payload["target_date"],
owner=job.owner or f"job:{job.id}",
top_n=payload.get("top_n"),
)
if job.type == "pipeline_daily":
return await run_pipeline(
db,
payload["target_date"],
owner=job.owner or f"job:{job.id}",
)
if job.type == "summarize_batch":
return await summarize_batch(
db,
pdf_mode=payload.get("pdf_mode", settings.SUMMARY_PDF_MODE),
)
if job.type == "summarize_one":
return await summarize_single(
db,
payload["arxiv_id"],
force=payload.get("force", True),
pdf_mode=payload.get("pdf_mode", settings.SUMMARY_PDF_MODE),
)
if job.type == "refresh_upvotes":
return await refresh_upvotes(db, days=payload.get("days"))
if job.type == "delete_range":
return await delete_papers_by_date_range(
db,
date.fromisoformat(payload["date_start"]),
date.fromisoformat(payload["date_end"]),
include_notes=payload.get("include_notes", True),
)
if job.type == "cleanup_tmp":
return cleanup_tmp()
if job.type == "reindex_fts":
return reindex_fts(db)
if job.type == "reindex_chroma":
return reindex_chroma(db)
if job.type == "recrawl_one":
return await recrawl_single(db, payload["arxiv_id"])
if job.type == "recrawl_batch":
updated = 0
skipped = 0
results = []
for arxiv_id in payload.get("arxiv_ids", []):
res = await recrawl_single(db, arxiv_id)
results.append(res)
if res.get("updated"):
updated += 1
else:
skipped += 1
return {"updated": updated, "skipped": skipped, "results": results}
raise ValueError(f"Unsupported job type: {job.type}")
def recover_stale_jobs(db: Session) -> int:
"""启动时将过期 running job/lock 标记为 stale,避免永久卡住。"""
now = utc_now()
cutoff = now - STALE_JOB_AFTER
stale_jobs = (
db.execute(
select(Job).where(
Job.status == JobStatus.RUNNING,
or_(Job.heartbeat_at == None, Job.heartbeat_at < cutoff), # noqa: E711
)
)
.scalars()
.all()
)
for job in stale_jobs:
job.status = JobStatus.STALE
job.error = "Marked stale after process restart or missed heartbeat"
job.completed_at = now
db.add(
JobEvent(
job_id=job.id,
stage="recovery",
status=JobEventStatus.FAILED,
message=job.error,
created_at=now,
)
)
stale_locks = (
db.execute(
select(TaskLock).where(
TaskLock.status == "running",
TaskLock.acquired_at < cutoff,
)
)
.scalars()
.all()
)
for lock in stale_locks:
lock.status = "stale"
lock.released_at = now
db.commit()
recovered = len(stale_jobs) + len(stale_locks)
if recovered:
logger.warning("Recovered stale runtime records: %d", recovered)
return recovered