Files
daily-paper/app/services/admin.py
T
Rain-Bus 743d69efd0 refactor: extract admin business logic to services, introduce job queue, add derived index helpers
- Move DB operations from routes/admin.py to services/admin.py (get_logs_context, query_summary_statuses, retry_failed, delete/reset operations)
- Add services/jobs.py with Job/JobEvent-based async job queue (create_job, run_job, enqueue_job)
- Add services/derived.py with FTS5 reindex and paper index deletion helpers
- Refactor scheduler to use job queue instead of direct pipeline calls
- Add heartbeat_at/expires_at to TaskLock for lock health tracking
- Remove DESIGN_REVIEW.md
- Update tests: remove redundant integration tests, add unit tests for new services
2026-06-13 18:31:43 +08:00

514 lines
16 KiB
Python

"""管理后台服务 — 统计聚合、系统状态、管理操作。"""
from __future__ import annotations
import json
from datetime import date
from pathlib import Path
from typing import Callable
from sqlalchemy import func, select, text
from sqlalchemy.orm import Session
from app.config import settings
from app.models import (
CrawlLog,
DataDeleteJob,
Job,
JobEvent,
Paper,
PaperTag,
SummaryState,
SummaryStatus,
TaskLock,
)
from app.services.derived import delete_paper_indexes
from app.services.scheduler import get_scheduler
from app.utils import PAPERS_DIR, TMP_DIR, utc_now
# admin_papers 排序映射
SORT_MAP = {
"date_desc": Paper.paper_date.desc(),
"date_asc": Paper.paper_date.asc(),
"upvotes_desc": Paper.upvotes.desc(),
"title_asc": Paper.title_en.asc(),
}
def _dir_size(path: Path) -> int:
"""递归计算目录总字节数。"""
if not path.exists():
return 0
return sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
def _fmt_size(nbytes: int) -> str:
"""字节数 → 人类可读字符串。"""
for unit in ("B", "KB", "MB", "GB"):
if nbytes < 1024:
return f"{nbytes:.1f} {unit}"
nbytes /= 1024
return f"{nbytes:.1f} TB"
def get_admin_stats(db: Session) -> dict:
"""管理仪表盘统计数据。"""
today = date.today()
# ── 论文统计 ──────────────────────────────────────────────────────
total_papers = db.scalar(select(func.count(Paper.id)))
today_papers = db.scalar(
select(func.count(Paper.id)).where(Paper.paper_date == today)
)
# ── 总结状态分布 ──────────────────────────────────────────────────
summary_rows = db.execute(
text("""
SELECT COALESCE(ss.status, 'none') AS status, COUNT(*) AS cnt
FROM papers p
LEFT JOIN summary_status ss ON ss.paper_id = p.id
GROUP BY status
""")
).fetchall()
status_counts = {row[0]: row[1] for row in summary_rows}
# ── 存储概况 ──────────────────────────────────────────────────────
db_size = (
_fmt_size(settings.db_path.stat().st_size)
if settings.db_path.exists()
else "0 B"
)
papers_size = _fmt_size(_dir_size(PAPERS_DIR))
tmp_size = _fmt_size(_dir_size(TMP_DIR))
# ── 调度器状态 ────────────────────────────────────────────────────
scheduler = get_scheduler()
scheduler_enabled = scheduler is not None
next_run = None
if scheduler_enabled:
for job in scheduler.get_jobs():
if job.id == "daily_pipeline":
next_run = job.next_run_time
break
# ── 最近日志(5 条) ──────────────────────────────────────────────
recent_logs = (
db.execute(select(CrawlLog).order_by(CrawlLog.started_at.desc()).limit(5))
.scalars()
.all()
)
# ── 活跃锁 ────────────────────────────────────────────────────────
active_locks = (
db.execute(select(TaskLock).where(TaskLock.status == "running")).scalars().all()
)
return {
"total_papers": total_papers or 0,
"today_papers": today_papers or 0,
"pending_count": status_counts.get(SummaryState.PENDING, 0),
"failed_count": status_counts.get(SummaryState.FAILED, 0)
+ status_counts.get(SummaryState.PERMANENT_FAILURE, 0),
"done_count": status_counts.get(SummaryState.DONE, 0),
"running_count": status_counts.get("running", 0)
+ status_counts.get(SummaryState.PROCESSING, 0),
"none_count": status_counts.get("none", 0),
"status_counts": status_counts,
"db_size": db_size,
"papers_size": papers_size,
"tmp_size": tmp_size,
"scheduler_enabled": scheduler_enabled,
"schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}",
"timezone": settings.APP_TIMEZONE,
"next_run": next_run.isoformat() if next_run else None,
"recent_logs": recent_logs,
"active_locks": active_locks,
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
}
def query_papers(
db: Session,
*,
q: str = "",
date_from: str | None = None,
date_to: str | None = None,
tag: str = "",
summary_status: str = "all",
sort: str = "date_desc",
page: int = 1,
per_page: int = 20,
) -> tuple[list[Paper], int, dict[str, str]]:
"""论文管理查询 — 构建过滤、排序、分页。
Returns:
(papers, total, statuses) — 论文列表、总数、{arxiv_id: summary_status}
"""
query = select(Paper)
# 搜索
if q.strip():
query = query.where(
Paper.title_en.ilike(f"%{q}%")
| Paper.title_zh.ilike(f"%{q}%")
| Paper.abstract.ilike(f"%{q}%")
)
# 日期范围
if date_from:
query = query.where(Paper.paper_date >= date_from)
if date_to:
query = query.where(Paper.paper_date <= date_to)
# 标签筛选
if tag:
query = query.join(PaperTag, PaperTag.paper_id == Paper.id).where(
PaperTag.tag == tag
)
# 总结状态筛选
if summary_status != "all":
if summary_status == "none":
query = query.outerjoin(
SummaryStatus, SummaryStatus.paper_id == Paper.id
).where(SummaryStatus.paper_id == None) # noqa: E711
else:
query = query.join(SummaryStatus, SummaryStatus.paper_id == Paper.id).where(
SummaryStatus.status == summary_status
)
# 排序
order = SORT_MAP.get(sort, Paper.paper_date.desc())
query = query.order_by(order)
# 计数
total = db.scalar(select(func.count()).select_from(query.subquery()))
# 分页
papers = (
db.execute(query.offset((page - 1) * per_page).limit(per_page)).scalars().all()
)
# 每篇论文的总结状态
paper_ids = [p.id for p in papers]
statuses: dict[str, str] = {}
if paper_ids:
rows = db.execute(
select(SummaryStatus.paper_id, SummaryStatus.status).where(
SummaryStatus.paper_id.in_(paper_ids)
)
).all()
paper_id_to_arxiv = {p.id: p.arxiv_id for p in papers}
for pid, st in rows:
statuses[paper_id_to_arxiv.get(pid, "")] = st
return papers, total or 0, statuses
def get_scheduler_history(db: Session, limit: int = 10) -> list[CrawlLog]:
"""最近的调度器运行日志。"""
return (
db.execute(
select(CrawlLog)
.where(CrawlLog.task == "scheduler")
.order_by(CrawlLog.started_at.desc())
.limit(limit)
)
.scalars()
.all()
)
def get_scheduler_status() -> dict:
"""调度器运行状态。"""
scheduler = get_scheduler()
next_run = None
upvote_next_run = None
if scheduler:
for job in scheduler.get_jobs():
if job.id == "daily_pipeline":
next_run = job.next_run_time
elif job.id == "upvote_refresh":
upvote_next_run = job.next_run_time
return {
"enabled": scheduler is not None,
"schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}",
"timezone": settings.APP_TIMEZONE,
"next_run": next_run.isoformat() if next_run else None,
"upvote_next_run": upvote_next_run.isoformat() if upvote_next_run else None,
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
}
def run_cleanup_now(db: Session, cleanup_func: Callable[[], dict]) -> dict:
"""同步执行临时目录清理,并写入 CrawlLog。"""
log_entry = CrawlLog(task="cleanup", status="running", started_at=utc_now())
db.add(log_entry)
db.commit()
try:
result = cleanup_func()
log_entry.status = "success"
log_entry.completed_at = utc_now()
log_entry.details_json = json.dumps(
{
"scanned": result.get("scanned", 0),
"removed": result.get("removed", 0),
},
ensure_ascii=False,
)
if result.get("errors"):
log_entry.error = "; ".join(result["errors"])[:2000]
db.commit()
return result
except Exception as exc:
log_entry.status = "failed"
log_entry.error = str(exc)[:2000]
log_entry.completed_at = utc_now()
db.commit()
raise
def get_job_detail(db: Session, job_id: int) -> dict | None:
"""后台任务详情和阶段事件,返回可 JSON 序列化 dict。"""
job = db.get(Job, job_id)
if not job:
return None
events = (
db.execute(
select(JobEvent)
.where(JobEvent.job_id == job_id)
.order_by(JobEvent.created_at.asc())
)
.scalars()
.all()
)
return {
"id": job.id,
"type": job.type,
"status": job.status,
"owner": job.owner,
"payload": json.loads(job.payload_json or "{}"),
"result": json.loads(job.result_json or "{}") if job.result_json else None,
"error": job.error,
"created_at": job.created_at.isoformat(),
"started_at": job.started_at.isoformat() if job.started_at else None,
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
"events": [
{
"stage": event.stage,
"status": event.status,
"message": event.message,
"payload": json.loads(event.payload_json or "{}")
if event.payload_json
else None,
"created_at": event.created_at.isoformat(),
}
for event in events
],
}
def get_logs_context(db: Session, *, page: int, per_page: int) -> dict:
"""管理日志页上下文。"""
crawl_logs = (
db.execute(
select(CrawlLog)
.order_by(CrawlLog.started_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
)
.scalars()
.all()
)
delete_jobs = (
db.execute(
select(DataDeleteJob)
.order_by(DataDeleteJob.started_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
)
.scalars()
.all()
)
summary_total = db.scalar(select(func.count(Paper.id))) or 0
summary_done = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status == SummaryState.DONE
)
)
or 0
)
summary_pending = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_(
[SummaryState.PENDING, SummaryState.PROCESSING]
)
)
)
or 0
)
summary_failed = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
)
or 0
)
return {
"crawl_logs": crawl_logs,
"delete_jobs": delete_jobs,
"page": page,
"per_page": per_page,
"summary_total": summary_total,
"summary_done": summary_done,
"summary_pending": summary_pending,
"summary_failed": summary_failed,
}
def query_summary_statuses(
db: Session,
*,
status: str,
page: int,
per_page: int,
) -> tuple[list[tuple[Paper, SummaryStatus | None]], int]:
"""总结状态列表查询。"""
query = (
select(Paper, SummaryStatus)
.outerjoin(SummaryStatus, SummaryStatus.paper_id == Paper.id)
.order_by(Paper.paper_date.desc())
)
if status != "all":
if status == "none":
query = query.where(SummaryStatus.paper_id == None) # noqa: E711
else:
query = query.where(SummaryStatus.status == status)
total = db.scalar(select(func.count()).select_from(query.subquery())) or 0
results = db.execute(query.offset((page - 1) * per_page).limit(per_page)).all()
return results, total
def serialize_summary_statuses(
results: list[tuple[Paper, SummaryStatus | None]],
*,
total: int,
page: int,
per_page: int,
) -> dict:
"""总结状态列表 JSON 响应。"""
items = []
for paper, ss in results:
items.append(
{
"arxiv_id": paper.arxiv_id,
"title": paper.title_zh or paper.title_en,
"paper_date": str(paper.paper_date),
"summary_status": ss.status if ss else "none",
"retry_count": ss.retry_count if ss else 0,
"error_type": ss.error_type if ss else None,
"error": ss.error if ss else None,
}
)
return {"items": items, "total": total, "page": page, "per_page": per_page}
def retry_failed_summaries(db: Session) -> int:
"""将失败/永久失败的总结任务重置为 pending。"""
failed_ids = (
db.execute(
select(Paper.arxiv_id)
.join(SummaryStatus, SummaryStatus.paper_id == Paper.id)
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
)
.scalars()
.all()
)
if not failed_ids:
return 0
db.execute(
SummaryStatus.__table__.update()
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
.values(status=SummaryState.PENDING, error=None, error_type=None)
)
db.commit()
return len(failed_ids)
def delete_paper_by_arxiv(db: Session, arxiv_id: str) -> bool:
"""删除单篇论文和派生索引。"""
paper = db.scalar(select(Paper).where(Paper.arxiv_id == arxiv_id))
if not paper:
return False
paper_id = paper.id
db.delete(paper)
db.commit()
delete_paper_indexes(db, paper_id=paper_id, arxiv_id=arxiv_id)
db.commit()
return True
def delete_papers_by_arxiv_ids(db: Session, arxiv_ids: list[str]) -> int:
"""批量删除论文和派生索引。"""
papers = (
db.execute(select(Paper).where(Paper.arxiv_id.in_(arxiv_ids))).scalars().all()
)
deleted = [(paper.id, paper.arxiv_id) for paper in papers]
for paper in papers:
db.delete(paper)
db.commit()
for paper_id, arxiv_id in deleted:
delete_paper_indexes(db, paper_id=paper_id, arxiv_id=arxiv_id)
db.commit()
return len(deleted)
def reset_summaries_pending(db: Session, arxiv_ids: list[str]) -> int:
"""将指定论文的总结状态重置为 pending,没有状态则创建。"""
paper_ids = (
db.execute(select(Paper.id).where(Paper.arxiv_id.in_(arxiv_ids)))
.scalars()
.all()
)
if not paper_ids:
return 0
existing_statuses = (
db.execute(select(SummaryStatus).where(SummaryStatus.paper_id.in_(paper_ids)))
.scalars()
.all()
)
existing_ids = {status.paper_id for status in existing_statuses}
for status in existing_statuses:
status.status = SummaryState.PENDING
status.quality = None
status.error = None
status.error_type = None
status.raw_output_saved = False
status.started_at = None
status.completed_at = None
for paper_id in paper_ids:
if paper_id not in existing_ids:
db.add(SummaryStatus(paper_id=paper_id, status=SummaryState.PENDING))
db.commit()
return len(paper_ids)