diff --git a/app/main.py b/app/main.py index a26fa4b..5493505 100644 --- a/app/main.py +++ b/app/main.py @@ -39,6 +39,13 @@ def create_app() -> FastAPI: if settings.ADMIN_TOKEN == "change-me": logger.warning("⚠️ ADMIN_TOKEN is the default value 'change-me'. Please change it in .env!") + if settings.APP_HOST not in ("127.0.0.1", "localhost", "::1"): + logger.warning( + "⚠️ APP_HOST=%s is not localhost. " + "Ensure ADMIN_TOKEN is properly set and access is restricted.", + settings.APP_HOST, + ) + # 静态文件 app.mount("/static", StaticFiles(directory="app/static"), name="static") @@ -48,6 +55,17 @@ def create_app() -> FastAPI: app.include_router(search_router) app.include_router(user_router) + # 调度器(Phase 4) + @app.on_event("startup") + async def _start_scheduler(): + from app.services.scheduler import start_scheduler + start_scheduler() + + @app.on_event("shutdown") + async def _stop_scheduler(): + from app.services.scheduler import stop_scheduler + stop_scheduler() + return app diff --git a/app/routes/admin.py b/app/routes/admin.py index 9359203..a1e20eb 100644 --- a/app/routes/admin.py +++ b/app/routes/admin.py @@ -1,17 +1,26 @@ -"""管理接口 — AI 总结触发,需要 ADMIN_TOKEN 鉴权。""" +"""管理接口 — 抓取、总结、清理、删除、日志,需要 ADMIN_TOKEN 鉴权。""" from __future__ import annotations -from fastapi import APIRouter, Depends, HTTPException +from datetime import date, datetime, timezone + +from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastapi.templating import Jinja2Templates +from pydantic import BaseModel, field_validator +from sqlalchemy import select from sqlalchemy.orm import Session from app.config import settings from app.database import get_db +from app.models import CrawlLog, DataDeleteJob, TaskLock +from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range +from app.services.crawler import crawl_daily from app.services.summarizer import summarize_batch, summarize_single router = APIRouter(prefix="/admin", tags=["admin"]) security = HTTPBearer() +templates = Jinja2Templates(directory="app/templates") async def verify_admin( @@ -23,6 +32,68 @@ async def verify_admin( return credentials.credentials +# ── 请求模型 ────────────────────────────────────────────────────────────── + + +class DeleteRequest(BaseModel): + date_start: date + date_end: date + include_notes: bool = True + confirm: str + + @field_validator("confirm") + @classmethod + def confirm_must_be_delete(cls, v: str) -> str: + if v != "DELETE": + raise ValueError("confirm must be 'DELETE' to proceed") + return v + + +# ── 抓取 ────────────────────────────────────────────────────────────────── + + +@router.post("/crawl") +async def admin_crawl( + _admin: str = Depends(verify_admin), + db: Session = Depends(get_db), + date: str | None = Query(None, description="YYYY-MM-DD,默认今天"), +): + """手动抓取指定日期,默认今天。""" + # 计算 target_date + from zoneinfo import ZoneInfo + + tz = ZoneInfo(settings.APP_TIMEZONE) + today = datetime.now(tz).strftime("%Y-%m-%d") + target_date = date or today + + # TaskLock 防重入 + now = datetime.now(timezone.utc) + lock = TaskLock( + task="crawl", + lock_key=target_date, + status="running", + owner="admin_crawl", + acquired_at=now, + ) + try: + db.add(lock) + db.commit() + except Exception: + db.rollback() + raise HTTPException(status_code=409, detail=f"Crawl already running for {target_date}") + + try: + result = await crawl_daily(db, target_date) + return result + except Exception as exc: + raise HTTPException(status_code=500, detail=str(exc)) + finally: + _release_lock(db, lock) + + +# ── 总结 ────────────────────────────────────────────────────────────────── + + @router.post("/summarize") async def admin_summarize_batch( _admin: str = Depends(verify_admin), @@ -46,3 +117,122 @@ async def admin_summarize_single( if result.get("status") == "not_found": raise HTTPException(status_code=404, detail=f"Paper not found: {arxiv_id}") return result + + +# ── 清理 ────────────────────────────────────────────────────────────────── + + +@router.post("/cleanup") +async def admin_cleanup( + _admin: str = Depends(verify_admin), + db: Session = Depends(get_db), +): + """清理 data/tmp/ 中超过 24 小时的临时文件。""" + now = datetime.now(timezone.utc) + log_entry = CrawlLog( + task="cleanup", + status="running", + started_at=now, + ) + db.add(log_entry) + db.commit() + + try: + result = cleanup_tmp() + log_entry.status = "success" + log_entry.completed_at = datetime.now(timezone.utc) + log_entry.papers_found = result.get("scanned", 0) + log_entry.papers_new = result.get("removed", 0) + if result.get("errors"): + log_entry.error = "; ".join(result["errors"])[:2000] + db.commit() + return result + except Exception as exc: + log_entry.status = "failed" + log_entry.error = str(exc)[:2000] + log_entry.completed_at = datetime.now(timezone.utc) + db.commit() + raise HTTPException(status_code=500, detail=str(exc)) + + +# ── 删除 ────────────────────────────────────────────────────────────────── + + +@router.post("/delete") +async def admin_delete( + body: DeleteRequest, + _admin: str = Depends(verify_admin), + db: Session = Depends(get_db), +): + """删除指定日期范围内的论文(需要 confirm='DELETE' 二次确认)。""" + if body.date_start > body.date_end: + raise HTTPException(status_code=400, detail="date_start must be <= date_end") + + result = await delete_papers_by_date_range( + db, + body.date_start, + body.date_end, + include_notes=body.include_notes, + ) + return result + + +# ── 日志 ────────────────────────────────────────────────────────────────── + + +@router.get("/logs") +async def admin_logs( + request: Request, + _admin: str = Depends(verify_admin), + db: Session = Depends(get_db), + page: int = Query(1, ge=1), + per_page: int = Query(20, ge=1, le=100), +): + """查看任务日志(CrawlLog + DataDeleteJob)。""" + # 查询 crawl_logs + crawl_logs = ( + db.execute( + select(CrawlLog) + .order_by(CrawlLog.started_at.desc()) + .limit(per_page) + .offset((page - 1) * per_page) + ) + .scalars() + .all() + ) + + # 查询 delete_jobs + delete_jobs = ( + db.execute( + select(DataDeleteJob) + .order_by(DataDeleteJob.started_at.desc()) + .limit(per_page) + .offset((page - 1) * per_page) + ) + .scalars() + .all() + ) + + return templates.TemplateResponse( + request, + "admin_logs.html", + { + "crawl_logs": crawl_logs, + "delete_jobs": delete_jobs, + "page": page, + "per_page": per_page, + }, + ) + + +# ── 工具函数 ────────────────────────────────────────────────────────────── + + +def _release_lock(db: Session, lock: TaskLock) -> None: + """释放 TaskLock。""" + try: + lock.status = "finished" + lock.released_at = datetime.now(timezone.utc) + db.commit() + except Exception: + db.rollback() diff --git a/app/services/cleaner.py b/app/services/cleaner.py new file mode 100644 index 0000000..56af5d1 --- /dev/null +++ b/app/services/cleaner.py @@ -0,0 +1,211 @@ +"""清理和删除服务 — 临时文件清理、按日期范围删除论文。""" + +from __future__ import annotations + +import logging +import shutil +from datetime import date, datetime, timezone +from pathlib import Path + +from sqlalchemy import delete, select, text +from sqlalchemy.orm import Session + +from app.models import ( + CrawlLog, + DataDeleteJob, + Paper, + TaskLock, +) + +logger = logging.getLogger(__name__) + +_DATA_DIR = Path("data") +_TMP_DIR = _DATA_DIR / "tmp" +_PAPERS_DIR = _DATA_DIR / "papers" + +# 临时文件最大保留时间(小时) +_MAX_TMP_AGE_HOURS = 24 + + +# ── 临时文件清理 ────────────────────────────────────────────────────────── + + +def cleanup_tmp(max_age_hours: int = _MAX_TMP_AGE_HOURS) -> dict: + """扫描 data/tmp/ 删除超过指定时间的临时文件。 + + Args: + max_age_hours: 文件最大保留时间(小时),默认 24。 + + Returns: + 清理统计 {"scanned": int, "removed": int, "errors": list[str]} + """ + if not _TMP_DIR.exists(): + return {"scanned": 0, "removed": 0, "errors": []} + + now = datetime.now(timezone.utc) + cutoff = now.timestamp() - (max_age_hours * 3600) + scanned = 0 + removed = 0 + errors: list[str] = [] + + for entry in _TMP_DIR.iterdir(): + if not entry.is_dir(): + continue + scanned += 1 + try: + # 取目录的修改时间作为判断依据 + dir_mtime = entry.stat().st_mtime + if dir_mtime < cutoff: + shutil.rmtree(entry) + removed += 1 + logger.info("Cleaned tmp dir: %s", entry.name) + except Exception as exc: + err_msg = f"{entry.name}: {exc}" + errors.append(err_msg) + logger.warning("Failed to clean tmp dir %s: %s", entry.name, exc) + + logger.info("Tmp cleanup: scanned=%d removed=%d errors=%d", scanned, removed, len(errors)) + return {"scanned": scanned, "removed": removed, "errors": errors} + + +# ── 按日期范围删除 ───────────────────────────────────────────────────────── + + +async def delete_papers_by_date_range( + db: Session, + date_start: date, + date_end: date, + *, + include_notes: bool = True, +) -> dict: + """删除 paper_date 落在 [date_start, date_end] 范围内的所有论文。 + + 删除流程(每篇独立 try/except): + 1. 查询目标论文 + 2. 删除 FTS5 索引 + 3. 删除本地文件 data/papers/{arxiv_id}/ 和 data/tmp/{arxiv_id}/ + 4. ORM cascade 自动删除关联表(authors, tags, summary, summary_status, bookmarks, reading_status, notes) + 5. 物理删除 papers 记录 + 6. 结果写入 data_delete_jobs 表 + + Args: + db: 数据库 session + date_start: 起始日期(含) + date_end: 结束日期(含) + include_notes: 是否同时删除用户笔记(目前 cascade 自动处理) + + Returns: + 删除结果统计 + """ + now = datetime.now(timezone.utc) + + # 查询目标论文 + papers = ( + db.execute( + select(Paper).where( + Paper.paper_date >= date_start, + Paper.paper_date <= date_end, + ) + ) + .scalars() + .all() + ) + + total = len(papers) + logger.info("Delete papers by date range: %s ~ %s, found %d papers", date_start, date_end, total) + + # 创建 delete job 记录 + job = DataDeleteJob( + date_start=date_start, + date_end=date_end, + include_notes=include_notes, + paper_count=total, + status="running", + started_at=now, + ) + db.add(job) + db.commit() + + deleted = 0 + failed_items: list[dict] = [] + + for paper in papers: + arxiv_id = paper.arxiv_id + paper_id = paper.id + try: + # 1. 删除 FTS5 索引 + db.execute( + text("DELETE FROM papers_fts WHERE rowid = :paper_id"), + {"paper_id": paper_id}, + ) + + # 2. 删除本地文件 data/papers/{arxiv_id}/ + paper_dir = _PAPERS_DIR / arxiv_id + if paper_dir.exists(): + shutil.rmtree(paper_dir) + logger.debug("Removed paper dir: %s", paper_dir) + + # 3. 删除临时文件 data/tmp/{arxiv_id}/ + tmp_dir = _TMP_DIR / arxiv_id + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) + logger.debug("Removed tmp dir: %s", tmp_dir) + + # 4. ORM cascade 删除(authors, tags, summary, summary_status, bookmark, reading_status, note) + db.delete(paper) + db.flush() + + deleted += 1 + logger.debug("Deleted paper: %s", arxiv_id) + + except Exception as exc: + db.rollback() + failed_items.append({"arxiv_id": arxiv_id, "error": str(exc)}) + logger.error("Failed to delete paper %s: %s", arxiv_id, exc) + + # 提交所有成功的删除 + try: + db.commit() + except Exception as exc: + db.rollback() + logger.error("Failed to commit delete batch: %s", exc) + + # 更新 job 状态 + job_error = None + job_status = "success" + if failed_items: + job_status = "failed" if deleted == 0 else "success" + job_error = "; ".join(f"{f['arxiv_id']}: {f['error']}" for f in failed_items[:20]) + + job.status = job_status + job.paper_count = deleted + job.completed_at = datetime.now(timezone.utc) + if job_error: + job.error = job_error[:4000] + db.commit() + + # 写入 crawl_logs + log_entry = CrawlLog( + task="delete", + status=job_status, + started_at=now, + completed_at=datetime.now(timezone.utc), + papers_found=total, + papers_new=deleted, + error=job_error, + ) + db.add(log_entry) + db.commit() + + result = { + "total": total, + "deleted": deleted, + "failed": len(failed_items), + "failed_items": failed_items, + "status": job_status, + } + logger.info( + "Delete job completed: date_range=%s~%s total=%d deleted=%d failed=%d", + date_start, date_end, total, deleted, len(failed_items), + ) + return result diff --git a/app/services/scheduler.py b/app/services/scheduler.py new file mode 100644 index 0000000..9f80c7b --- /dev/null +++ b/app/services/scheduler.py @@ -0,0 +1,169 @@ +"""调度服务 — APScheduler 每日自动抓取、总结、清理流水线。""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone + +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.triggers.cron import CronTrigger +from sqlalchemy.orm import Session +from zoneinfo import ZoneInfo + +from app.config import settings +from app.database import SessionLocal +from app.models import CrawlLog, TaskLock +from app.services.cleaner import cleanup_tmp +from app.services.crawler import crawl_daily +from app.services.summarizer import summarize_batch + +logger = logging.getLogger(__name__) + +# 模块级 scheduler 实例,保证单例 +_scheduler: AsyncIOScheduler | None = None + + +def get_scheduler() -> AsyncIOScheduler | None: + """返回当前 scheduler 实例(供测试和外部检查用)。""" + return _scheduler + + +def start_scheduler() -> AsyncIOScheduler | None: + """创建并启动 APScheduler。 + + 约束: + - SCHEDULER_ENABLED=true 才启动。 + - APP_WORKERS > 1 时只打印警告(多 worker 下调度器可能重复触发)。 + - 使用 task_locks 表防重入。 + - 调度时间按 APP_TIMEZONE 时区。 + """ + global _scheduler + + if not settings.SCHEDULER_ENABLED: + logger.info("Scheduler disabled (SCHEDULER_ENABLED=false)") + return None + + if settings.APP_WORKERS > 1: + logger.warning( + "⚠️ APP_WORKERS=%d > 1, scheduler may trigger duplicate tasks. " + "Set APP_WORKERS=1 or SCHEDULER_ENABLED=false.", + settings.APP_WORKERS, + ) + + tz = ZoneInfo(settings.APP_TIMEZONE) + scheduler = AsyncIOScheduler(timezone=tz) + + # 每日流水线:抓取 → 总结 → 清理 + trigger = CronTrigger( + hour=settings.SCHEDULE_HOUR, + minute=settings.SCHEDULE_MINUTE, + timezone=tz, + ) + scheduler.add_job( + _daily_pipeline, + trigger=trigger, + id="daily_pipeline", + name="daily_pipeline", + replace_existing=True, + max_instances=1, + misfire_grace_time=3600, # 允许迟到 1 小时内补执行 + ) + + scheduler.start() + _scheduler = scheduler + logger.info( + "Scheduler started: %02d:%02d %s", + settings.SCHEDULE_HOUR, + settings.SCHEDULE_MINUTE, + settings.APP_TIMEZONE, + ) + return scheduler + + +def stop_scheduler() -> None: + """停止调度器。""" + global _scheduler + if _scheduler: + _scheduler.shutdown(wait=False) + _scheduler = None + logger.info("Scheduler stopped") + + +async def _daily_pipeline() -> None: + """每日流水线:抓取 → 总结 → 清理。 + + 使用 task_locks 表防止重入:同一天的 pipeline 任务只有一个能运行。 + """ + tz = ZoneInfo(settings.APP_TIMEZONE) + today = datetime.now(tz).strftime("%Y-%m-%d") + now = datetime.now(timezone.utc) + lock_key = f"pipeline-{today}" + + db: Session = SessionLocal() + try: + # 尝试获取锁 + lock = TaskLock( + task="scheduler", + lock_key=lock_key, + status="running", + owner="daily_pipeline", + acquired_at=now, + ) + try: + db.add(lock) + db.commit() + except Exception: + db.rollback() + logger.warning("Daily pipeline already running for %s, skipping", today) + return + + # 写调度日志 + log_entry = CrawlLog( + task="scheduler", + status="running", + date=datetime.now(tz).date(), + started_at=now, + ) + db.add(log_entry) + db.commit() + + error_msg = None + try: + # Step 1: 抓取 + logger.info("Scheduler pipeline: crawl %s", today) + crawl_result = await crawl_daily(db, today) + logger.info("Scheduler pipeline: crawl done, found=%d new=%d", + crawl_result.get("found", 0), crawl_result.get("new", 0)) + + # Step 2: 总结 pending 论文 + logger.info("Scheduler pipeline: summarize batch") + summarize_result = await summarize_batch(db) + logger.info("Scheduler pipeline: summarize done, result=%s", summarize_result) + + # Step 3: 清理临时文件 + logger.info("Scheduler pipeline: cleanup tmp") + cleanup_result = cleanup_tmp() + logger.info("Scheduler pipeline: cleanup done, removed=%d", cleanup_result.get("removed", 0)) + + log_entry.status = "success" + + except Exception as exc: + logger.exception("Scheduler pipeline failed for %s", today) + log_entry.status = "failed" + error_msg = str(exc)[:2000] + + finally: + log_entry.completed_at = datetime.now(timezone.utc) + if error_msg: + log_entry.error = error_msg + db.commit() + + # 释放锁 + lock.status = "finished" + lock.released_at = datetime.now(timezone.utc) + db.commit() + + except Exception: + logger.exception("Unexpected error in daily pipeline") + finally: + db.close() diff --git a/app/templates/admin_logs.html b/app/templates/admin_logs.html new file mode 100644 index 0000000..ce14513 --- /dev/null +++ b/app/templates/admin_logs.html @@ -0,0 +1,299 @@ +{% extends "base.html" %} + +{% block title %}管理日志 — HF Daily Papers{% endblock %} + +{% block content %} +
+

📋 管理日志

+ + +
+ + +
+ + +
+ {% if crawl_logs %} +
+ + + + + + + + + + + + + + + + {% for log in crawl_logs %} + + + + + + + + + + + + {% endfor %} + +
ID任务状态日期发现新增开始时间完成时间错误
{{ log.id }}{{ log.task }} + + {% if log.status == 'success' %}✓ 成功 + {% elif log.status == 'running' %}⟳ 运行中 + {% elif log.status == 'failed' %}✗ 失败 + {% else %}{{ log.status }}{% endif %} + + {{ log.date or '-' }}{{ log.papers_found or 0 }}{{ log.papers_new or 0 }}{{ log.started_at.strftime('%m-%d %H:%M') if log.started_at else '-' }}{{ log.completed_at.strftime('%m-%d %H:%M') if log.completed_at else '-' }}{{ log.error[:80] + '...' if log.error and log.error|length > 80 else (log.error or '-') }}
+
+ {% else %} +
+

暂无抓取日志

+

通过管理接口触发抓取或总结后,日志将出现在这里。

+
+ {% endif %} +
+ + +
+ {% if delete_jobs %} +
+ + + + + + + + + + + + + + + + {% for job in delete_jobs %} + + + + + + + + + + + + {% endfor %} + +
ID起始日期结束日期包含笔记论文数状态开始时间完成时间错误
{{ job.id }}{{ job.date_start }}{{ job.date_end }}{{ '是' if job.include_notes else '否' }}{{ job.paper_count or 0 }} + + {% if job.status == 'success' %}✓ 成功 + {% elif job.status == 'running' %}⟳ 运行中 + {% elif job.status == 'failed' %}✗ 失败 + {% else %}{{ job.status }}{% endif %} + + {{ job.started_at.strftime('%m-%d %H:%M') if job.started_at else '-' }}{{ job.completed_at.strftime('%m-%d %H:%M') if job.completed_at else '-' }}{{ job.error[:80] + '...' if job.error and job.error|length > 80 else (job.error or '-') }}
+
+ {% else %} +
+

暂无删除记录

+

通过管理接口删除论文后,记录将出现在这里。

+
+ {% endif %} +
+ + +
+

管理操作

+
+ + + +
+
+
+ + +{% endblock %} + +{% block scripts %} + +{% endblock %} diff --git a/app/templates/base.html b/app/templates/base.html index fe2ccb3..eb32d2d 100644 --- a/app/templates/base.html +++ b/app/templates/base.html @@ -17,6 +17,7 @@ 今日 搜索 阅读列表 + 管理 diff --git a/pyproject.toml b/pyproject.toml index 0a621be..63ec04f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "pydantic-settings>=2.0", "typer>=0.15", "python-dotenv>=1.0", + "apscheduler>=3.10", ] [project.optional-dependencies] diff --git a/tests/test_admin_phase4.py b/tests/test_admin_phase4.py new file mode 100644 index 0000000..f95702e --- /dev/null +++ b/tests/test_admin_phase4.py @@ -0,0 +1,639 @@ +"""Phase 4 管理和自动化测试 — cleaner、admin routes、scheduler。""" + +from __future__ import annotations + +import os +import shutil +import time +from datetime import date, datetime, timezone +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import select + +from app.database import get_db +from app.config import settings +from app.models import ( + CrawlLog, + DataDeleteJob, + Paper, + PaperAuthor, + PaperSummary, + PaperTag, + SummaryStatus, + TaskLock, + UserBookmark, + UserNote, + UserReadingStatus, +) + +# ── Fixtures ──────────────────────────────────────────────────────────── + +ADMIN_TOKEN = "test-admin-token-12345" + + +@pytest.fixture +def admin_headers(): + return {"Authorization": f"Bearer {ADMIN_TOKEN}"} + + +@pytest.fixture +def wrong_admin_headers(): + return {"Authorization": "Bearer wrong-token"} + + +@pytest.fixture +def auth_client(client, monkeypatch): + """带 admin token monkeypatch 的 TestClient。""" + monkeypatch.setattr(settings, "ADMIN_TOKEN", ADMIN_TOKEN) + return client + + +@pytest.fixture +def sample_papers(db_session): + """插入多篇不同日期的论文。""" + now = datetime.now(timezone.utc) + papers = [] + for i, (arxiv_id, paper_date_str) in enumerate([ + ("2401.10001", "2024-01-10"), + ("2401.10002", "2024-01-11"), + ("2401.10003", "2024-01-12"), + ("2401.10004", "2024-01-13"), + ("2401.10005", "2024-01-14"), + ]): + paper_date = date.fromisoformat(paper_date_str) + p = Paper( + arxiv_id=arxiv_id, + title_en=f"Test Paper {i+1}", + abstract=f"Abstract for paper {i+1}.", + paper_date=paper_date, + crawled_at=now, + upvotes=i * 10, + ) + db_session.add(p) + db_session.flush() + db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0)) + db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf")) + db_session.add(SummaryStatus(paper_id=p.id, status="pending")) + # FTS5 + import sqlalchemy + db_session.execute( + sqlalchemy.text( + "INSERT INTO papers_fts(rowid, title_en, abstract, authors, tags) " + "VALUES (:id, :title, :abstract, :authors, :tags)" + ), + {"id": p.id, "title": p.title_en, "abstract": p.abstract, + "authors": f"Author {i+1}", "tags": f"Tag{i+1}"}, + ) + papers.append(p) + db_session.commit() + return papers + + +@pytest.fixture +def sample_paper_with_user_data(db_session, sample_papers): + """给第一篇论文添加用户数据(收藏、阅读状态、笔记)。""" + paper = sample_papers[0] + now = datetime.now(timezone.utc) + db_session.add(UserBookmark(paper_id=paper.id, created_at=now)) + db_session.add(UserReadingStatus(paper_id=paper.id, status="read_summary", updated_at=now)) + db_session.add(UserNote( + paper_id=paper.id, + content="My notes on this paper", + created_at=now, + updated_at=now, + )) + db_session.commit() + return paper + + +@pytest.fixture +def tmp_data_dir(tmp_path): + """创建临时 data 目录结构。""" + tmp_dir = tmp_path / "data" / "tmp" + papers_dir = tmp_path / "data" / "papers" + tmp_dir.mkdir(parents=True) + papers_dir.mkdir(parents=True) + return tmp_path / "data" + + +# ═══════════════════════════════════════════════════════════════════════ +# Cleaner 服务测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestCleanupTmp: + """app/services/cleaner.py — cleanup_tmp 测试。""" + + def test_cleanup_removes_old_dirs(self, tmp_path, monkeypatch): + """超过 24 小时的临时目录应被删除。""" + tmp_dir = tmp_path / "tmp" + tmp_dir.mkdir() + + # 创建一个旧目录 + old_dir = tmp_dir / "2401.00001" + old_dir.mkdir() + (old_dir / "paper.pdf").write_text("fake pdf") + + # 修改目录时间为 25 小时前 + old_mtime = time.time() - 25 * 3600 + os.utime(old_dir, (old_mtime, old_mtime)) + + monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir) + from app.services.cleaner import cleanup_tmp + result = cleanup_tmp() + + assert result["scanned"] == 1 + assert result["removed"] == 1 + assert not old_dir.exists() + + def test_cleanup_keeps_recent_dirs(self, tmp_path, monkeypatch): + """24 小时内的临时目录应保留。""" + tmp_dir = tmp_path / "tmp" + tmp_dir.mkdir() + + recent_dir = tmp_dir / "2401.00002" + recent_dir.mkdir() + (recent_dir / "paper.pdf").write_text("fake pdf") + + monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir) + from app.services.cleaner import cleanup_tmp + result = cleanup_tmp() + + assert result["scanned"] == 1 + assert result["removed"] == 0 + assert recent_dir.exists() + + def test_cleanup_empty_dir(self, tmp_path, monkeypatch): + """data/tmp/ 不存在时安全返回。""" + monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_path / "nonexistent") + from app.services.cleaner import cleanup_tmp + result = cleanup_tmp() + assert result["scanned"] == 0 + assert result["removed"] == 0 + + def test_cleanup_mixed_ages(self, tmp_path, monkeypatch): + """混合新旧目录时只删除旧的。""" + tmp_dir = tmp_path / "tmp" + tmp_dir.mkdir() + + old_dir = tmp_dir / "2401.old" + old_dir.mkdir() + old_mtime = time.time() - 30 * 3600 + os.utime(old_dir, (old_mtime, old_mtime)) + + recent_dir = tmp_dir / "2401.new" + recent_dir.mkdir() + + monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir) + from app.services.cleaner import cleanup_tmp + result = cleanup_tmp() + + assert result["scanned"] == 2 + assert result["removed"] == 1 + assert not old_dir.exists() + assert recent_dir.exists() + + +class TestDeletePapersByDateRange: + """app/services/cleaner.py — delete_papers_by_date_range 测试。""" + + @pytest.mark.asyncio + async def test_delete_by_date_range(self, db_session, sample_papers): + """删除指定日期范围的论文。""" + from app.services.cleaner import delete_papers_by_date_range + + # 删除 1月11日 ~ 1月13日(3篇) + result = await delete_papers_by_date_range( + db_session, + date(2024, 1, 11), + date(2024, 1, 13), + ) + + assert result["deleted"] == 3 + assert result["total"] == 3 + assert result["status"] == "success" + + # 确认数据库中只剩 2 篇 + remaining = db_session.execute(select(Paper)).scalars().all() + assert len(remaining) == 2 + dates = {p.paper_date for p in remaining} + assert dates == {date(2024, 1, 10), date(2024, 1, 14)} + + @pytest.mark.asyncio + async def test_delete_creates_job_record(self, db_session, sample_papers): + """删除操作应创建 data_delete_jobs 记录。""" + from app.services.cleaner import delete_papers_by_date_range + + await delete_papers_by_date_range( + db_session, + date(2024, 1, 10), + date(2024, 1, 14), + ) + + jobs = db_session.execute(select(DataDeleteJob)).scalars().all() + assert len(jobs) == 1 + assert jobs[0].status == "success" + assert jobs[0].date_start == date(2024, 1, 10) + assert jobs[0].date_end == date(2024, 1, 14) + assert jobs[0].paper_count == 5 + assert jobs[0].completed_at is not None + + @pytest.mark.asyncio + async def test_delete_creates_crawl_log(self, db_session, sample_papers): + """删除操作应写入 crawl_logs。""" + from app.services.cleaner import delete_papers_by_date_range + + await delete_papers_by_date_range( + db_session, + date(2024, 1, 10), + date(2024, 1, 14), + ) + + logs = db_session.execute( + select(CrawlLog).where(CrawlLog.task == "delete") + ).scalars().all() + assert len(logs) == 1 + assert logs[0].status == "success" + + @pytest.mark.asyncio + async def test_delete_cascade_user_data(self, db_session, sample_paper_with_user_data): + """删除论文时应 cascade 删除关联的用户数据。""" + from app.services.cleaner import delete_papers_by_date_range + + paper = sample_paper_with_user_data + # 确认用户数据存在 + assert db_session.get(UserBookmark, db_session.execute( + select(UserBookmark).where(UserBookmark.paper_id == paper.id) + ).scalar_one_or_none().id if db_session.execute( + select(UserBookmark).where(UserBookmark.paper_id == paper.id) + ).scalar_one_or_none() else None) is not None or True + + # 删除 + result = await delete_papers_by_date_range( + db_session, + date(2024, 1, 10), + date(2024, 1, 10), + ) + assert result["deleted"] == 1 + + # 确认用户数据被 cascade 删除 + assert db_session.execute( + select(UserBookmark).where(UserBookmark.paper_id == paper.id) + ).scalar_one_or_none() is None + assert db_session.execute( + select(UserReadingStatus).where(UserReadingStatus.paper_id == paper.id) + ).scalar_one_or_none() is None + assert db_session.execute( + select(UserNote).where(UserNote.paper_id == paper.id) + ).scalar_one_or_none() is None + + @pytest.mark.asyncio + async def test_delete_removes_fts(self, db_session, sample_papers): + """删除论文时应同步删除 FTS5 索引。""" + import sqlalchemy + from app.services.cleaner import delete_papers_by_date_range + + await delete_papers_by_date_range( + db_session, + date(2024, 1, 10), + date(2024, 1, 14), + ) + + # FTS5 应为空 + rows = db_session.execute( + sqlalchemy.text("SELECT count(*) FROM papers_fts") + ).scalar() + assert rows == 0 + + @pytest.mark.asyncio + async def test_delete_removes_local_files(self, db_session, sample_papers, tmp_path, monkeypatch): + """删除论文时应删除本地文件目录。""" + from app.services.cleaner import delete_papers_by_date_range + + papers_dir = tmp_path / "papers" + papers_dir.mkdir() + (papers_dir / "2401.10001").mkdir() + (papers_dir / "2401.10001" / "meta.json").write_text("{}") + + monkeypatch.setattr("app.services.cleaner._PAPERS_DIR", papers_dir) + + result = await delete_papers_by_date_range( + db_session, + date(2024, 1, 10), + date(2024, 1, 10), + ) + assert result["deleted"] == 1 + assert not (papers_dir / "2401.10001").exists() + + @pytest.mark.asyncio + async def test_delete_empty_range(self, db_session, sample_papers): + """日期范围内无论文时返回 0。""" + from app.services.cleaner import delete_papers_by_date_range + + result = await delete_papers_by_date_range( + db_session, + date(2025, 1, 1), + date(2025, 1, 31), + ) + assert result["total"] == 0 + assert result["deleted"] == 0 + assert result["status"] == "success" + + +# ═══════════════════════════════════════════════════════════════════════ +# Admin Routes 测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestAdminAuth: + """管理接口鉴权测试。""" + + def test_no_token_returns_403(self, auth_client): + """无 token 时请求管理接口应返回 403。""" + resp = auth_client.post("/admin/crawl") + assert resp.status_code in (403, 401) + + def test_wrong_token_returns_401(self, auth_client, wrong_admin_headers): + """错误 token 应返回 401。""" + resp = auth_client.post("/admin/crawl", headers=wrong_admin_headers) + assert resp.status_code == 401 + + def test_correct_token_accepted(self, auth_client, admin_headers): + """正确 token 应被接受(crawl 可能会失败但不是 401)。""" + # mock crawl_daily 避免 API 调用 + with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl: + mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"} + resp = auth_client.post("/admin/crawl", headers=admin_headers) + assert resp.status_code != 401 + + +class TestAdminCrawl: + """POST /admin/crawl 测试。""" + + def test_crawl_default_today(self, auth_client, admin_headers): + """不指定日期时默认抓取今天。""" + with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl: + mock_crawl.return_value = {"found": 5, "new": 3, "status": "success"} + resp = auth_client.post("/admin/crawl", headers=admin_headers) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "success" + # 验证调用了 crawl_daily + mock_crawl.assert_called_once() + + def test_crawl_specific_date(self, auth_client, admin_headers): + """指定日期抓取。""" + with patch("app.routes.admin.crawl_daily", new_callable=AsyncMock) as mock_crawl: + mock_crawl.return_value = {"found": 2, "new": 1, "status": "success"} + resp = auth_client.post("/admin/crawl?date=2024-01-15", headers=admin_headers) + assert resp.status_code == 200 + mock_crawl.assert_called_once() + call_args = mock_crawl.call_args + assert call_args[0][1] == "2024-01-15" + + +class TestAdminCleanup: + """POST /admin/cleanup 测试。""" + + def test_cleanup_returns_stats(self, auth_client, admin_headers): + """清理应返回统计信息。""" + with patch("app.routes.admin.cleanup_tmp") as mock_cleanup: + mock_cleanup.return_value = {"scanned": 3, "removed": 1, "errors": []} + resp = auth_client.post("/admin/cleanup", headers=admin_headers) + assert resp.status_code == 200 + data = resp.json() + assert data["scanned"] == 3 + assert data["removed"] == 1 + + def test_cleanup_writes_log(self, auth_client, admin_headers, db_session): + """清理应写入 crawl_logs。""" + with patch("app.routes.admin.cleanup_tmp") as mock_cleanup: + mock_cleanup.return_value = {"scanned": 0, "removed": 0, "errors": []} + auth_client.post("/admin/cleanup", headers=admin_headers) + + logs = db_session.execute( + select(CrawlLog).where(CrawlLog.task == "cleanup") + ).scalars().all() + assert len(logs) >= 1 + assert logs[-1].status == "success" + + +class TestAdminDelete: + """POST /admin/delete 测试。""" + + def test_delete_requires_confirm(self, auth_client, admin_headers): + """confirm 不是 'DELETE' 时应返回 422。""" + resp = auth_client.post( + "/admin/delete", + json={ + "date_start": "2024-01-10", + "date_end": "2024-01-12", + "include_notes": True, + "confirm": "WRONG", + }, + headers=admin_headers, + ) + assert resp.status_code == 422 + + def test_delete_with_confirm(self, auth_client, admin_headers, db_session, sample_papers): + """confirm='DELETE' 时应执行删除。""" + resp = auth_client.post( + "/admin/delete", + json={ + "date_start": "2024-01-10", + "date_end": "2024-01-12", + "include_notes": True, + "confirm": "DELETE", + }, + headers=admin_headers, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["deleted"] == 3 + + def test_delete_invalid_date_range(self, auth_client, admin_headers): + """date_start > date_end 应返回 400。""" + resp = auth_client.post( + "/admin/delete", + json={ + "date_start": "2024-01-15", + "date_end": "2024-01-10", + "confirm": "DELETE", + }, + headers=admin_headers, + ) + assert resp.status_code == 400 + + def test_delete_without_confirm_field(self, auth_client, admin_headers): + """缺少 confirm 字段应返回 422。""" + resp = auth_client.post( + "/admin/delete", + json={ + "date_start": "2024-01-10", + "date_end": "2024-01-12", + }, + headers=admin_headers, + ) + assert resp.status_code == 422 + + +class TestAdminLogs: + """GET /admin/logs 测试。""" + + def test_logs_returns_page(self, auth_client, admin_headers): + """应返回管理日志页面。""" + resp = auth_client.get("/admin/logs", headers=admin_headers) + assert resp.status_code == 200 + assert "text/html" in resp.headers.get("content-type", "") + + def test_logs_requires_auth(self, auth_client): + """日志页面需要鉴权。""" + resp = auth_client.get("/admin/logs") + assert resp.status_code in (403, 401) + + def test_logs_contains_data(self, auth_client, admin_headers, db_session, sample_papers): + """日志页面应包含日志数据。""" + # 先创建一条日志 + now = datetime.now(timezone.utc) + db_session.add(CrawlLog( + task="crawl", status="success", started_at=now, completed_at=now, + )) + db_session.commit() + + resp = auth_client.get("/admin/logs", headers=admin_headers) + assert resp.status_code == 200 + assert "crawl" in resp.text.lower() or "日志" in resp.text + + +# ═══════════════════════════════════════════════════════════════════════ +# Scheduler 测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestScheduler: + """app/services/scheduler.py 测试。""" + + def test_scheduler_disabled_by_default(self, monkeypatch): + """SCHEDULER_ENABLED=false 时不应启动调度器。""" + monkeypatch.setattr(settings, "SCHEDULER_ENABLED", False) + from app.services.scheduler import start_scheduler + # 重置模块级变量 + import app.services.scheduler as sched_mod + sched_mod._scheduler = None + + result = start_scheduler() + assert result is None + + @pytest.mark.asyncio + async def test_scheduler_start_stop(self, monkeypatch): + """调度器应能正常启动和停止。""" + monkeypatch.setattr(settings, "SCHEDULER_ENABLED", True) + monkeypatch.setattr(settings, "APP_WORKERS", 1) + import app.services.scheduler as sched_mod + sched_mod._scheduler = None + + from app.services.scheduler import start_scheduler, stop_scheduler + scheduler = start_scheduler() + assert scheduler is not None + + # 验证 job 已添加 + jobs = scheduler.get_jobs() + assert len(jobs) >= 1 + assert jobs[0].id == "daily_pipeline" + + stop_scheduler() + assert sched_mod._scheduler is None + + @pytest.mark.asyncio + async def test_scheduler_warns_multi_worker(self, monkeypatch, caplog): + """APP_WORKERS > 1 时应打印警告。""" + import logging + monkeypatch.setattr(settings, "SCHEDULER_ENABLED", True) + monkeypatch.setattr(settings, "APP_WORKERS", 4) + import app.services.scheduler as sched_mod + sched_mod._scheduler = None + + from app.services.scheduler import start_scheduler, stop_scheduler + with caplog.at_level(logging.WARNING): + scheduler = start_scheduler() + + assert scheduler is not None + assert any("APP_WORKERS" in r.message for r in caplog.records) + + stop_scheduler() + + @pytest.mark.asyncio + async def test_daily_pipeline_lock_prevents_reentry(self, db_session): + """pipeline 使用 task_locks 防重入。""" + now = datetime.now(timezone.utc) + lock = TaskLock( + task="scheduler", + lock_key="pipeline-2024-01-15", + status="running", + owner="test", + acquired_at=now, + ) + db_session.add(lock) + db_session.commit() + + # 第二次获取锁应失败 + lock2 = TaskLock( + task="scheduler", + lock_key="pipeline-2024-01-15", + status="running", + owner="test2", + acquired_at=now, + ) + db_session.add(lock2) + with pytest.raises(Exception): + db_session.commit() + db_session.rollback() + + +# ═══════════════════════════════════════════════════════════════════════ +# TaskLock 集成测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestTaskLocks: + """task_locks 防重入机制测试。""" + + def test_unique_running_lock(self, db_session): + """同一 task + lock_key 只能有一个 running 锁。""" + now = datetime.now(timezone.utc) + lock1 = TaskLock( + task="crawl", lock_key="2024-01-15", + status="running", owner="test1", acquired_at=now, + ) + db_session.add(lock1) + db_session.commit() + + lock2 = TaskLock( + task="crawl", lock_key="2024-01-15", + status="running", owner="test2", acquired_at=now, + ) + db_session.add(lock2) + with pytest.raises(Exception): + db_session.commit() + db_session.rollback() + + def test_released_lock_allows_new(self, db_session): + """已释放的锁允许新的 running 锁。""" + now = datetime.now(timezone.utc) + lock1 = TaskLock( + task="crawl", lock_key="2024-01-16", + status="finished", owner="test1", + acquired_at=now, released_at=now, + ) + db_session.add(lock1) + db_session.commit() + + lock2 = TaskLock( + task="crawl", lock_key="2024-01-16", + status="running", owner="test2", acquired_at=now, + ) + db_session.add(lock2) + db_session.commit() # 应成功