Files
daily-paper/app/services/admin.py
T

641 lines
21 KiB
Python

"""管理后台服务 — 统计聚合、系统状态、管理操作。"""
from __future__ import annotations
import json
from datetime import date
from pathlib import Path
from typing import Callable
from sqlalchemy import func, select, text, update
from sqlalchemy.orm import Session
from app.config import settings
from app.models import (
CrawlLog,
DataDeleteJob,
Job,
JobEvent,
Paper,
PaperTag,
SummaryState,
SummaryStatus,
TaskLock,
)
from app.services.derived import delete_paper_indexes
from app.services.scheduler import get_scheduler
from app.utils import PAPERS_DIR, TMP_DIR, utc_now
# admin_papers 排序映射
SORT_MAP = {
"date_desc": Paper.paper_date.desc(),
"date_asc": Paper.paper_date.asc(),
"upvotes_desc": Paper.upvotes.desc(),
"title_asc": Paper.title_en.asc(),
}
def _dir_size(path: Path) -> int:
"""递归计算目录总字节数。"""
if not path.exists():
return 0
return sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
def _fmt_size(nbytes: int) -> str:
"""字节数 → 人类可读字符串。"""
for unit in ("B", "KB", "MB", "GB"):
if nbytes < 1024:
return f"{nbytes:.1f} {unit}"
nbytes /= 1024
return f"{nbytes:.1f} TB"
def get_admin_stats(db: Session) -> dict:
"""管理仪表盘统计数据。"""
today = date.today()
# ── 论文统计 ──────────────────────────────────────────────────────
total_papers = db.scalar(select(func.count(Paper.id)))
today_papers = db.scalar(
select(func.count(Paper.id)).where(Paper.paper_date == today)
)
# ── 总结状态分布 ──────────────────────────────────────────────────
summary_rows = db.execute(
text("""
SELECT COALESCE(ss.status, 'none') AS status, COUNT(*) AS cnt
FROM papers p
LEFT JOIN summary_status ss ON ss.paper_id = p.id
GROUP BY status
""")
).fetchall()
status_counts = {row[0]: row[1] for row in summary_rows}
# ── 存储概况 ──────────────────────────────────────────────────────
db_size = (
_fmt_size(settings.db_path.stat().st_size)
if settings.db_path.exists()
else "0 B"
)
papers_size = _fmt_size(_dir_size(PAPERS_DIR))
tmp_size = _fmt_size(_dir_size(TMP_DIR))
# ── 调度器状态 ────────────────────────────────────────────────────
scheduler = get_scheduler()
scheduler_enabled = scheduler is not None
next_run = None
if scheduler_enabled:
for job in scheduler.get_jobs():
if job.id == "daily_pipeline":
next_run = job.next_run_time
break
# ── 最近日志(5 条) ──────────────────────────────────────────────
recent_logs = (
db.execute(select(CrawlLog).order_by(CrawlLog.started_at.desc()).limit(5))
.scalars()
.all()
)
# ── 活跃锁 ────────────────────────────────────────────────────────
active_locks = (
db.execute(select(TaskLock).where(TaskLock.status.in_(["running", "stale"])))
.scalars()
.all()
)
return {
"total_papers": total_papers or 0,
"today_papers": today_papers or 0,
"pending_count": status_counts.get(SummaryState.PENDING, 0),
"failed_count": status_counts.get(SummaryState.FAILED, 0)
+ status_counts.get(SummaryState.PERMANENT_FAILURE, 0),
"done_count": status_counts.get(SummaryState.DONE, 0),
"running_count": status_counts.get("running", 0)
+ status_counts.get(SummaryState.PROCESSING, 0),
"none_count": status_counts.get("none", 0),
"status_counts": status_counts,
"db_size": db_size,
"papers_size": papers_size,
"tmp_size": tmp_size,
"scheduler_enabled": scheduler_enabled,
"schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}",
"timezone": settings.APP_TIMEZONE,
"next_run": next_run.isoformat() if next_run else None,
"recent_logs": recent_logs,
"active_locks": active_locks,
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
"config_overview": get_config_overview(),
}
def query_papers(
db: Session,
*,
q: str = "",
date_from: str | None = None,
date_to: str | None = None,
tag: str = "",
summary_status: str = "all",
sort: str = "date_desc",
page: int = 1,
per_page: int = 20,
) -> tuple[list[Paper], int, dict[str, str]]:
"""论文管理查询 — 构建过滤、排序、分页。
Returns:
(papers, total, statuses) — 论文列表、总数、{arxiv_id: summary_status}
"""
query = select(Paper)
# 搜索
if q.strip():
query = query.where(
Paper.title_en.ilike(f"%{q}%")
| Paper.title_zh.ilike(f"%{q}%")
| Paper.abstract.ilike(f"%{q}%")
)
# 日期范围
if date_from:
query = query.where(Paper.paper_date >= date_from)
if date_to:
query = query.where(Paper.paper_date <= date_to)
# 标签筛选
if tag:
query = query.join(PaperTag, PaperTag.paper_id == Paper.id).where(
PaperTag.tag == tag
)
# 总结状态筛选
if summary_status != "all":
if summary_status == "none":
query = query.outerjoin(
SummaryStatus, SummaryStatus.paper_id == Paper.id
).where(SummaryStatus.paper_id == None) # noqa: E711
else:
query = query.join(SummaryStatus, SummaryStatus.paper_id == Paper.id).where(
SummaryStatus.status == summary_status
)
# 排序
order = SORT_MAP.get(sort, Paper.paper_date.desc())
query = query.order_by(order)
# 计数
total = db.scalar(select(func.count()).select_from(query.subquery()))
# 分页
papers = (
db.execute(query.offset((page - 1) * per_page).limit(per_page)).scalars().all()
)
# 每篇论文的总结状态
paper_ids = [p.id for p in papers]
statuses: dict[str, str] = {}
if paper_ids:
rows = db.execute(
select(SummaryStatus.paper_id, SummaryStatus.status).where(
SummaryStatus.paper_id.in_(paper_ids)
)
).all()
paper_id_to_arxiv = {p.id: p.arxiv_id for p in papers}
for pid, st in rows:
statuses[paper_id_to_arxiv.get(pid, "")] = st
return papers, total or 0, statuses
def get_scheduler_history(db: Session, limit: int = 10) -> list[CrawlLog]:
"""最近的调度器运行日志。"""
return (
db.execute(
select(CrawlLog)
.where(CrawlLog.task == "scheduler")
.order_by(CrawlLog.started_at.desc())
.limit(limit)
)
.scalars()
.all()
)
def get_scheduler_status() -> dict:
"""调度器运行状态。"""
scheduler = get_scheduler()
next_run = None
upvote_next_run = None
if scheduler:
for job in scheduler.get_jobs():
if job.id == "daily_pipeline":
next_run = job.next_run_time
elif job.id == "upvote_refresh":
upvote_next_run = job.next_run_time
return {
"enabled": scheduler is not None,
"schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}",
"timezone": settings.APP_TIMEZONE,
"next_run": next_run.isoformat() if next_run else None,
"upvote_next_run": upvote_next_run.isoformat() if upvote_next_run else None,
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
}
def run_cleanup_now(db: Session, cleanup_func: Callable[[], dict]) -> dict:
"""同步执行临时目录清理,并写入 CrawlLog。"""
log_entry = CrawlLog(task="cleanup", status="running", started_at=utc_now())
db.add(log_entry)
db.commit()
try:
result = cleanup_func()
log_entry.status = "success"
log_entry.completed_at = utc_now()
log_entry.details_json = json.dumps(
{
"scanned": result.get("scanned", 0),
"removed": result.get("removed", 0),
},
ensure_ascii=False,
)
if result.get("errors"):
log_entry.error = "; ".join(result["errors"])[:2000]
db.commit()
return result
except Exception as exc:
log_entry.status = "failed"
log_entry.error = str(exc)[:2000]
log_entry.completed_at = utc_now()
db.commit()
raise
def get_job_detail(db: Session, job_id: int) -> dict | None:
"""后台任务详情和阶段事件,返回可 JSON 序列化 dict。"""
job = db.get(Job, job_id)
if not job:
return None
events = (
db.execute(
select(JobEvent)
.where(JobEvent.job_id == job_id)
.order_by(JobEvent.created_at.asc())
)
.scalars()
.all()
)
return {
"id": job.id,
"type": job.type,
"status": job.status,
"owner": job.owner,
"payload": json.loads(job.payload_json or "{}"),
"result": json.loads(job.result_json or "{}") if job.result_json else None,
"error": job.error,
"created_at": job.created_at.isoformat(),
"started_at": job.started_at.isoformat() if job.started_at else None,
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
"events": [
{
"stage": event.stage,
"status": event.status,
"message": event.message,
"payload": json.loads(event.payload_json or "{}")
if event.payload_json
else None,
"created_at": event.created_at.isoformat(),
}
for event in events
],
}
def get_logs_context(db: Session, *, page: int, per_page: int) -> dict:
"""管理日志页上下文。"""
crawl_logs = (
db.execute(
select(CrawlLog)
.order_by(CrawlLog.started_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
)
.scalars()
.all()
)
delete_jobs = (
db.execute(
select(DataDeleteJob)
.order_by(DataDeleteJob.started_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
)
.scalars()
.all()
)
summary_total = db.scalar(select(func.count(Paper.id))) or 0
summary_done = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status == SummaryState.DONE
)
)
or 0
)
summary_pending = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_(
[SummaryState.PENDING, SummaryState.PROCESSING]
)
)
)
or 0
)
summary_failed = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
)
or 0
)
return {
"crawl_logs": crawl_logs,
"delete_jobs": delete_jobs,
"page": page,
"per_page": per_page,
"summary_total": summary_total,
"summary_done": summary_done,
"summary_pending": summary_pending,
"summary_failed": summary_failed,
"failure_breakdown": get_failure_breakdown(db),
}
def query_summary_statuses(
db: Session,
*,
status: str,
page: int,
per_page: int,
) -> tuple[list[tuple[Paper, SummaryStatus | None]], int]:
"""总结状态列表查询。"""
query = (
select(Paper, SummaryStatus)
.outerjoin(SummaryStatus, SummaryStatus.paper_id == Paper.id)
.order_by(Paper.paper_date.desc())
)
if status != "all":
if status == "none":
query = query.where(SummaryStatus.paper_id == None) # noqa: E711
else:
query = query.where(SummaryStatus.status == status)
total = db.scalar(select(func.count()).select_from(query.subquery())) or 0
results = db.execute(query.offset((page - 1) * per_page).limit(per_page)).all()
return results, total
def serialize_summary_statuses(
results: list[tuple[Paper, SummaryStatus | None]],
*,
total: int,
page: int,
per_page: int,
) -> dict:
"""总结状态列表 JSON 响应。"""
items = []
for paper, ss in results:
items.append(
{
"arxiv_id": paper.arxiv_id,
"title": paper.title_zh or paper.title_en,
"paper_date": str(paper.paper_date),
"summary_status": ss.status if ss else "none",
"retry_count": ss.retry_count if ss else 0,
"error_type": ss.error_type if ss else None,
"error": ss.error if ss else None,
}
)
return {"items": items, "total": total, "page": page, "per_page": per_page}
def retry_failed_summaries(db: Session) -> int:
"""将失败/永久失败的总结任务重置为 pending。"""
failed_ids = (
db.execute(
select(Paper.arxiv_id)
.join(SummaryStatus, SummaryStatus.paper_id == Paper.id)
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
)
.scalars()
.all()
)
if not failed_ids:
return 0
db.execute(
SummaryStatus.__table__.update()
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
.values(status=SummaryState.PENDING, error=None, error_type=None)
)
db.commit()
return len(failed_ids)
def delete_paper_by_arxiv(db: Session, arxiv_id: str) -> bool:
"""删除单篇论文和派生索引。"""
paper = db.scalar(select(Paper).where(Paper.arxiv_id == arxiv_id))
if not paper:
return False
paper_id = paper.id
db.delete(paper)
db.commit()
delete_paper_indexes(db, paper_id=paper_id, arxiv_id=arxiv_id)
db.commit()
return True
def delete_papers_by_arxiv_ids(db: Session, arxiv_ids: list[str]) -> int:
"""批量删除论文和派生索引。"""
papers = (
db.execute(select(Paper).where(Paper.arxiv_id.in_(arxiv_ids))).scalars().all()
)
deleted = [(paper.id, paper.arxiv_id) for paper in papers]
for paper in papers:
db.delete(paper)
db.commit()
for paper_id, arxiv_id in deleted:
delete_paper_indexes(db, paper_id=paper_id, arxiv_id=arxiv_id)
db.commit()
return len(deleted)
def reset_summaries_pending(db: Session, arxiv_ids: list[str]) -> int:
"""将指定论文的总结状态重置为 pending,没有状态则创建。"""
paper_ids = (
db.execute(select(Paper.id).where(Paper.arxiv_id.in_(arxiv_ids)))
.scalars()
.all()
)
if not paper_ids:
return 0
existing_statuses = (
db.execute(select(SummaryStatus).where(SummaryStatus.paper_id.in_(paper_ids)))
.scalars()
.all()
)
existing_ids = {status.paper_id for status in existing_statuses}
for status in existing_statuses:
status.status = SummaryState.PENDING
status.quality = None
status.error = None
status.error_type = None
status.raw_output_saved = False
status.started_at = None
status.completed_at = None
for paper_id in paper_ids:
if paper_id not in existing_ids:
db.add(SummaryStatus(paper_id=paper_id, status=SummaryState.PENDING))
db.commit()
return len(paper_ids)
# ── 任务监控 ──────────────────────────────────────────────────────────
def query_jobs(
db: Session,
*,
status: str | None = None,
job_type: str | None = None,
page: int = 1,
per_page: int = 20,
) -> tuple[list[dict], int]:
"""后台任务列表查询 — 支持 status/type 过滤 + 分页,返回已 enrich 的 dict 列表。"""
query = select(Job)
if status and status != "all":
query = query.where(Job.status == status)
if job_type and job_type != "all":
query = query.where(Job.type == job_type)
total = db.scalar(select(func.count()).select_from(query.subquery())) or 0
jobs = (
db.execute(
query.order_by(Job.created_at.desc())
.offset((page - 1) * per_page)
.limit(per_page)
)
.scalars()
.all()
)
return [serialize_job(j) for j in jobs], total
def _as_naive(dt):
"""去掉 tzinfo — SQLite 读回的 datetime 是 naive UTC,与 utc_now() 运算前需统一。"""
if dt is not None and getattr(dt, "tzinfo", None) is not None:
return dt.replace(tzinfo=None)
return dt
def serialize_job(job: Job) -> dict:
"""单条 job 序列化为展示用 dict(含耗时)。"""
duration = None
started = _as_naive(job.started_at)
if started:
end = _as_naive(job.completed_at) or _as_naive(utc_now())
duration = round((end - started).total_seconds(), 1)
return {
"id": job.id,
"type": job.type,
"status": job.status,
"owner": job.owner,
"created_at": job.created_at,
"started_at": job.started_at,
"completed_at": job.completed_at,
"duration_seconds": duration,
"error": job.error,
}
def get_job_status_counts(db: Session) -> dict:
"""按 status 聚合 job 计数,供任务页顶部小统计行用。"""
rows = db.execute(
select(Job.status, func.count(Job.id)).group_by(Job.status)
).fetchall()
return {row[0]: row[1] for row in rows}
# ── 锁管理 ────────────────────────────────────────────────────────────
def force_release_lock(db: Session, lock_id: int) -> bool:
"""强制释放一个卡死的 TaskLock(仅对 running/stale 生效)。"""
result = db.execute(
update(TaskLock)
.where(TaskLock.id == lock_id, TaskLock.status.in_(["running", "stale"]))
.values(status="finished", released_at=utc_now())
)
db.commit()
return (result.rowcount or 0) > 0
# ── 失败原因分布 ──────────────────────────────────────────────────────
def get_failure_breakdown(db: Session) -> list[dict]:
"""按 error_type 聚合失败/永久失败的总结,按数量降序。NULL 归 unknown。"""
error_expr = func.coalesce(SummaryStatus.error_type, "unknown")
rows = db.execute(
select(error_expr, func.count(SummaryStatus.id))
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
.group_by(error_expr)
.order_by(func.count(SummaryStatus.id).desc())
).fetchall()
return [{"error_type": row[0], "count": row[1]} for row in rows]
# ── 运行配置概览 ──────────────────────────────────────────────────────
def get_config_overview() -> dict:
"""聚合非敏感配置,供仪表盘展示。敏感字段只标是否已配置,不显示值。"""
return {
"summary_backend": settings.SUMMARY_BACKEND,
"summary_pdf_mode": settings.SUMMARY_PDF_MODE,
"summary_concurrency": settings.SUMMARY_CONCURRENCY,
"summary_timeout_seconds": settings.SUMMARY_TIMEOUT_SECONDS,
"summary_max_retries": settings.SUMMARY_MAX_RETRIES,
"scheduler_enabled": settings.SCHEDULER_ENABLED,
"schedule_time": f"{settings.SCHEDULE_HOUR:02d}:{settings.SCHEDULE_MINUTE:02d}",
"chroma_enabled": settings.CHROMA_ENABLED,
"embed_model": settings.EMBED_MODEL or "(未配置)",
"top_n": settings.TOP_N,
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
"app_workers": settings.APP_WORKERS,
"layout_model": Path(settings.LAYOUT_MODEL_PATH).name,
"database_url": settings.DATABASE_URL,
"api_key_configured": bool(settings.EMBED_API_KEY),
}