diff --git a/app/database.py b/app/database.py index 095e014..40e63cf 100644 --- a/app/database.py +++ b/app/database.py @@ -1,6 +1,6 @@ """数据库引擎、会话工厂、初始化。""" -from sqlalchemy import event, create_engine +from sqlalchemy import event, create_engine, text from sqlalchemy.orm import DeclarativeBase, sessionmaker from app.config import settings @@ -10,6 +10,27 @@ class Base(DeclarativeBase): pass +# ── FTS5 和索引 DDL(与 ORM 模型分开管理)─────────────────────────────── + +FTS5_CREATE_SQL = """ +CREATE VIRTUAL TABLE IF NOT EXISTS papers_fts USING fts5( + title_en, + title_zh, + abstract, + authors, + tags, + summary_text, + tokenize='unicode61' +); +""" + +FTS5_TRIGGER_INDEX = """ +-- partial index for task_locks running +CREATE UNIQUE INDEX IF NOT EXISTS uq_task_locks_running +ON task_locks(task, lock_key) WHERE status = 'running'; +""" + + def _make_engine(): """创建 SQLite 引擎,启用 foreign_keys。""" engine = create_engine( @@ -39,3 +60,14 @@ def get_db(): yield db finally: db.close() + + +def init_db(engine): + """创建所有 ORM 表 + FTS5 虚拟表。""" + from app.models import Base # noqa: F811 — 避免循环导入,延迟导入 + + Base.metadata.create_all(engine) + with engine.connect() as conn: + conn.execute(text(FTS5_CREATE_SQL)) + conn.execute(text(FTS5_TRIGGER_INDEX)) + conn.commit() diff --git a/app/main.py b/app/main.py index 0ab932b..33ceaaf 100644 --- a/app/main.py +++ b/app/main.py @@ -2,14 +2,13 @@ import logging import os +from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.staticfiles import StaticFiles -from starlette.staticfiles import StaticFiles as StarletteStaticFiles from app.config import settings -from app.database import engine -from app.models import init_db +from app.database import engine, init_db from app.routes.admin import router as admin_router from app.routes.compare import router as compare_router from app.routes.pages import router as pages_router @@ -24,11 +23,30 @@ logging.basicConfig( logger = logging.getLogger(__name__) +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理:启动与关闭。""" + # ── startup ── + from app.services.scheduler import start_scheduler + from app.services.embedder import init_chroma + + start_scheduler() + init_chroma() + + yield + + # ── shutdown ── + from app.services.scheduler import stop_scheduler + + stop_scheduler() + + def create_app() -> FastAPI: app = FastAPI( title="HF Daily Papers", description="HuggingFace Daily Papers — 中文论文导览站", version="0.1.0", + lifespan=lifespan, ) # 确保数据目录存在 @@ -65,23 +83,6 @@ def create_app() -> FastAPI: app.include_router(trends_router) app.include_router(compare_router) - # 调度器(Phase 4) - @app.on_event("startup") - async def _start_scheduler(): - from app.services.scheduler import start_scheduler - start_scheduler() - - # Phase 5: 初始化 ChromaDB - @app.on_event("startup") - async def _init_chroma(): - from app.services.embedder import init_chroma - init_chroma() - - @app.on_event("shutdown") - async def _stop_scheduler(): - from app.services.scheduler import stop_scheduler - stop_scheduler() - return app diff --git a/app/models.py b/app/models.py index 589c3f8..eb5006a 100644 --- a/app/models.py +++ b/app/models.py @@ -1,4 +1,4 @@ -"""SQLAlchemy ORM 模型 — papers, authors, tags, summaries, FTS5, logs, locks, user data。""" +"""SQLAlchemy ORM 模型 — papers, authors, tags, summaries, user data, logs, locks。""" from datetime import date, datetime @@ -13,7 +13,6 @@ from sqlalchemy import ( String, Text, UniqueConstraint, - text, ) from sqlalchemy.orm import relationship @@ -204,32 +203,3 @@ class DataDeleteJob(Base): error = Column(Text) started_at = Column(DateTime, nullable=False) completed_at = Column(DateTime) - - -# ── FTS5 索引初始化 SQL(普通虚拟表,由应用层维护)────────────────────── -FTS5_CREATE_SQL = """ -CREATE VIRTUAL TABLE IF NOT EXISTS papers_fts USING fts5( - title_en, - title_zh, - abstract, - authors, - tags, - summary_text, - tokenize='unicode61' -); -""" - -FTS5_TRIGGER_INDEX = """ --- partial index for task_locks running -CREATE UNIQUE INDEX IF NOT EXISTS uq_task_locks_running -ON task_locks(task, lock_key) WHERE status = 'running'; -""" - - -def init_db(engine): - """创建所有 ORM 表 + FTS5 虚拟表。""" - Base.metadata.create_all(engine) - with engine.connect() as conn: - conn.execute(text(FTS5_CREATE_SQL)) - conn.execute(text(FTS5_TRIGGER_INDEX)) - conn.commit() diff --git a/app/routes/admin.py b/app/routes/admin.py index a1e20eb..bac16d2 100644 --- a/app/routes/admin.py +++ b/app/routes/admin.py @@ -6,7 +6,6 @@ from datetime import date, datetime, timezone from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from fastapi.templating import Jinja2Templates from pydantic import BaseModel, field_validator from sqlalchemy import select from sqlalchemy.orm import Session @@ -17,10 +16,10 @@ from app.models import CrawlLog, DataDeleteJob, TaskLock from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range from app.services.crawler import crawl_daily from app.services.summarizer import summarize_batch, summarize_single +from app.utils import release_lock, templates, today_str router = APIRouter(prefix="/admin", tags=["admin"]) security = HTTPBearer() -templates = Jinja2Templates(directory="app/templates") async def verify_admin( @@ -32,7 +31,7 @@ async def verify_admin( return credentials.credentials -# ── 请求模型 ────────────────────────────────────────────────────────────── +# ── 请求模型 ────────────────────────────────────────────────────────── class DeleteRequest(BaseModel): @@ -49,7 +48,7 @@ class DeleteRequest(BaseModel): return v -# ── 抓取 ────────────────────────────────────────────────────────────────── +# ── 抓取 ────────────────────────────────────────────────────────────── @router.post("/crawl") @@ -59,12 +58,7 @@ async def admin_crawl( date: str | None = Query(None, description="YYYY-MM-DD,默认今天"), ): """手动抓取指定日期,默认今天。""" - # 计算 target_date - from zoneinfo import ZoneInfo - - tz = ZoneInfo(settings.APP_TIMEZONE) - today = datetime.now(tz).strftime("%Y-%m-%d") - target_date = date or today + target_date = date or today_str() # TaskLock 防重入 now = datetime.now(timezone.utc) @@ -88,10 +82,10 @@ async def admin_crawl( except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) finally: - _release_lock(db, lock) + release_lock(db, lock) -# ── 总结 ────────────────────────────────────────────────────────────────── +# ── 总结 ────────────────────────────────────────────────────────────── @router.post("/summarize") @@ -119,7 +113,7 @@ async def admin_summarize_single( return result -# ── 清理 ────────────────────────────────────────────────────────────────── +# ── 清理 ────────────────────────────────────────────────────────────── @router.post("/cleanup") @@ -155,7 +149,7 @@ async def admin_cleanup( raise HTTPException(status_code=500, detail=str(exc)) -# ── 删除 ────────────────────────────────────────────────────────────────── +# ── 删除 ────────────────────────────────────────────────────────────── @router.post("/delete") @@ -177,7 +171,7 @@ async def admin_delete( return result -# ── 日志 ────────────────────────────────────────────────────────────────── +# ── 日志 ────────────────────────────────────────────────────────────── @router.get("/logs") @@ -189,7 +183,6 @@ async def admin_logs( per_page: int = Query(20, ge=1, le=100), ): """查看任务日志(CrawlLog + DataDeleteJob)。""" - # 查询 crawl_logs crawl_logs = ( db.execute( select(CrawlLog) @@ -201,7 +194,6 @@ async def admin_logs( .all() ) - # 查询 delete_jobs delete_jobs = ( db.execute( select(DataDeleteJob) @@ -223,16 +215,3 @@ async def admin_logs( "per_page": per_page, }, ) - - -# ── 工具函数 ────────────────────────────────────────────────────────────── - - -def _release_lock(db: Session, lock: TaskLock) -> None: - """释放 TaskLock。""" - try: - lock.status = "finished" - lock.released_at = datetime.now(timezone.utc) - db.commit() - except Exception: - db.rollback() diff --git a/app/routes/compare.py b/app/routes/compare.py index 8c418c3..dbca7f6 100644 --- a/app/routes/compare.py +++ b/app/routes/compare.py @@ -2,19 +2,14 @@ from __future__ import annotations -import logging - from fastapi import APIRouter, Depends, HTTPException, Query, Request -from fastapi.templating import Jinja2Templates from sqlalchemy.orm import Session, joinedload from app.database import get_db from app.models import Paper - -logger = logging.getLogger(__name__) +from app.utils import templates router = APIRouter() -templates = Jinja2Templates(directory="app/templates") @router.get("/compare") diff --git a/app/routes/pages.py b/app/routes/pages.py index a034433..b249046 100644 --- a/app/routes/pages.py +++ b/app/routes/pages.py @@ -3,34 +3,27 @@ from __future__ import annotations import logging -from datetime import date, datetime, timedelta +from datetime import date, timedelta from pathlib import Path -from zoneinfo import ZoneInfo from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.responses import RedirectResponse -from fastapi.templating import Jinja2Templates from sqlalchemy.orm import Session, joinedload from app.config import settings from app.database import get_db from app.models import Paper +from app.utils import templates, today_str logger = logging.getLogger(__name__) router = APIRouter() -templates = Jinja2Templates(directory="app/templates") - - -def _today() -> str: - tz = ZoneInfo(settings.APP_TIMEZONE) - return datetime.now(tz).strftime("%Y-%m-%d") @router.get("/") def index(request: Request): """重定向到 /day/{today}。""" - return RedirectResponse(url=f"/day/{_today()}") + return RedirectResponse(url=f"/day/{today_str()}") @router.get("/day/{date_str}") @@ -43,7 +36,7 @@ def day_page(date_str: str, request: Request, db: Session = Depends(get_db)): prev_day = (target - timedelta(days=1)).isoformat() next_day = (target + timedelta(days=1)).isoformat() - today_str = _today() + today = today_str() papers = ( db.query(Paper) @@ -74,7 +67,7 @@ def day_page(date_str: str, request: Request, db: Session = Depends(get_db)): "current_date": date_str, "prev_day": prev_day, "next_day": next_day, - "today": today_str, + "today": today, "available_dates": available_dates, "page_title": f"{date_str} 论文列表", }, @@ -146,16 +139,9 @@ def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict return [] try: - from app.services.embedder import search_similar - - # 用论文的 arxiv_id 从 ChromaDB 查询 - col = None - try: - from app.services.embedder import get_collection - col = get_collection() - except Exception: - return [] + from app.services.embedder import get_collection + col = get_collection() if col is None: return [] diff --git a/app/routes/search.py b/app/routes/search.py index cc23ece..0ec43be 100644 --- a/app/routes/search.py +++ b/app/routes/search.py @@ -2,14 +2,11 @@ from __future__ import annotations -import math -from datetime import date, datetime, timedelta, timezone -from zoneinfo import ZoneInfo +from datetime import date, timedelta from xml.sax.saxutils import escape from fastapi import APIRouter, Depends, Query, Request from fastapi.responses import Response -from fastapi.templating import Jinja2Templates from sqlalchemy import text from sqlalchemy.orm import Session, joinedload @@ -17,9 +14,10 @@ from app.config import settings from app.database import get_db from app.models import Paper, PaperTag, UserReadingStatus from app.services.searcher import get_all_tags, search_papers +from app.services.user_data import query_reading_list +from app.utils import templates, today_str router = APIRouter() -templates = Jinja2Templates(directory="app/templates") # ── 搜索页 ──────────────────────────────────────────────────────────── @@ -56,7 +54,7 @@ def search_page( "total_pages": result["total_pages"], "all_tags": all_tags, "page_title": f"搜索: {q}" if q else "搜索", - "today": _today_str(), + "today": today_str(), }, ) @@ -114,7 +112,7 @@ def reading_list_page( db: Session = Depends(get_db), ): """阅读列表页面。""" - papers = _query_reading_list(db, filter, tag or None) + papers = query_reading_list(db, filter, tag or None) all_tags = get_all_tags(db) return templates.TemplateResponse( @@ -126,54 +124,11 @@ def reading_list_page( "current_tag": tag, "all_tags": all_tags, "page_title": "阅读列表", - "today": _today_str(), + "today": today_str(), }, ) -def _query_reading_list( - db: Session, - filter_type: str, - tag: str | None, -) -> list[Paper]: - """根据筛选条件查询阅读列表。""" - from sqlalchemy import or_ - - # 基础:有任意用户数据的论文 - base = db.query(Paper).filter( - or_( - Paper.bookmark.has(), - Paper.reading_status.has(), - Paper.note.has(), - ) - ) - - # 应用筛选 - if filter_type == "has_note": - base = base.filter(Paper.note.has()) - elif filter_type in ("unread", "skimmed", "read_summary", "read_full"): - base = base.filter( - Paper.reading_status.has(UserReadingStatus.status == filter_type) - ) - - # 应用标签 - if tag: - base = base.filter(Paper.tags.any(PaperTag.tag == tag)) - - return ( - base.options( - joinedload(Paper.authors), - joinedload(Paper.tags), - joinedload(Paper.summary_status), - joinedload(Paper.bookmark), - joinedload(Paper.reading_status), - joinedload(Paper.note), - ) - .order_by(Paper.paper_date.desc(), Paper.upvotes.desc()) - .all() - ) - - # ── RSS Feed ────────────────────────────────────────────────────────── @@ -216,7 +171,7 @@ def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> st lines.append(f" {escape(channel_title)}") lines.append(f" {escape(base_url)}") lines.append(" HuggingFace Daily Papers — 中文论文导览站") - lines.append(f" zh-CN") + lines.append(" zh-CN") for paper in papers: title_text = paper.title_zh or paper.title_en @@ -245,12 +200,3 @@ def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> st lines.append(" ") lines.append("") return "\n".join(lines) - - -# ── 工具函数 ────────────────────────────────────────────────────────── - - -def _today_str() -> str: - """当前日期字符串(按 APP_TIMEZONE)。""" - tz = ZoneInfo(settings.APP_TIMEZONE) - return datetime.now(tz).strftime("%Y-%m-%d") diff --git a/app/routes/trends.py b/app/routes/trends.py index 80e6ce5..84e40f5 100644 --- a/app/routes/trends.py +++ b/app/routes/trends.py @@ -2,34 +2,27 @@ from __future__ import annotations -import logging -from datetime import date, timedelta - from fastapi import APIRouter, Depends, Request -from fastapi.templating import Jinja2Templates -from sqlalchemy import func, text from sqlalchemy.orm import Session -from app.config import settings from app.database import get_db - -logger = logging.getLogger(__name__) +from app.services.trends import get_trends_data +from app.utils import templates, today_str router = APIRouter() -templates = Jinja2Templates(directory="app/templates") @router.get("/trends") def trends_page(request: Request, db: Session = Depends(get_db)): """趋势看板页面。""" - stats = _get_trends_data(db) + stats = get_trends_data(db) return templates.TemplateResponse( request, "trends.html", { "page_title": "趋势看板", "stats": stats, - "today": _today_str(), + "today": today_str(), }, ) @@ -37,84 +30,4 @@ def trends_page(request: Request, db: Session = Depends(get_db)): @router.get("/api/stats/trends") def trends_api(db: Session = Depends(get_db)): """趋势数据 JSON API。""" - return _get_trends_data(db) - - -def _get_trends_data(db: Session) -> dict: - """从 DB 聚合趋势数据。""" - thirty_days_ago = (date.today() - timedelta(days=30)).isoformat() - - # 1. 按日论文数量(近 30 天) - daily_rows = db.execute(text(""" - SELECT paper_date, COUNT(*) as cnt - FROM papers - WHERE paper_date >= :start_date - GROUP BY paper_date - ORDER BY paper_date ASC - """), {"start_date": thirty_days_ago}).fetchall() - daily_counts = [ - {"date": str(row[0]), "count": row[1]} - for row in daily_rows - ] - - # 2. 热门标签 Top 20 - tag_rows = db.execute(text(""" - SELECT tag, COUNT(*) as cnt - FROM paper_tags - GROUP BY tag - ORDER BY cnt DESC - LIMIT 20 - """)).fetchall() - top_tags = [ - {"tag": row[0], "count": row[1]} - for row in tag_rows - ] - - # 3. Upvotes 分布 - upvote_rows = db.execute(text(""" - SELECT - CASE - WHEN upvotes >= 100 THEN '100+' - WHEN upvotes >= 50 THEN '50-99' - WHEN upvotes >= 20 THEN '20-49' - WHEN upvotes >= 10 THEN '10-19' - WHEN upvotes >= 5 THEN '5-9' - ELSE '0-4' - END as bucket, - COUNT(*) as cnt - FROM papers - GROUP BY bucket - ORDER BY MIN(upvotes) DESC - """)).fetchall() - upvotes_dist = [ - {"range": row[0], "count": row[1]} - for row in upvote_rows - ] - - # 4. 总结完成率 - summary_rows = db.execute(text(""" - SELECT - COALESCE(ss.status, 'none') as status, - COUNT(*) as cnt - FROM papers p - LEFT JOIN summary_status ss ON ss.paper_id = p.id - GROUP BY status - """)).fetchall() - summary_completion = [ - {"status": row[0], "count": row[1]} - for row in summary_rows - ] - - return { - "daily_counts": daily_counts, - "top_tags": top_tags, - "upvotes_dist": upvotes_dist, - "summary_completion": summary_completion, - } - - -def _today_str() -> str: - from datetime import datetime - from zoneinfo import ZoneInfo - tz = ZoneInfo(settings.APP_TIMEZONE) - return datetime.now(tz).strftime("%Y-%m-%d") + return get_trends_data(db) diff --git a/app/services/cleaner.py b/app/services/cleaner.py index f5f4c9b..5916a43 100644 --- a/app/services/cleaner.py +++ b/app/services/cleaner.py @@ -16,13 +16,10 @@ from app.models import ( Paper, TaskLock, ) +from app.utils import PAPERS_DIR, TMP_DIR logger = logging.getLogger(__name__) -_DATA_DIR = Path("data") -_TMP_DIR = _DATA_DIR / "tmp" -_PAPERS_DIR = _DATA_DIR / "papers" - # 临时文件最大保留时间(小时) _MAX_TMP_AGE_HOURS = 24 @@ -39,7 +36,7 @@ def cleanup_tmp(max_age_hours: int = _MAX_TMP_AGE_HOURS) -> dict: Returns: 清理统计 {"scanned": int, "removed": int, "errors": list[str]} """ - if not _TMP_DIR.exists(): + if not TMP_DIR.exists(): return {"scanned": 0, "removed": 0, "errors": []} now = datetime.now(timezone.utc) @@ -48,7 +45,7 @@ def cleanup_tmp(max_age_hours: int = _MAX_TMP_AGE_HOURS) -> dict: removed = 0 errors: list[str] = [] - for entry in _TMP_DIR.iterdir(): + for entry in TMP_DIR.iterdir(): if not entry.is_dir(): continue scanned += 1 @@ -147,13 +144,13 @@ async def delete_papers_by_date_range( logger.warning("Failed to delete %s from ChromaDB", arxiv_id, exc_info=True) # 2. 删除本地文件 data/papers/{arxiv_id}/ - paper_dir = _PAPERS_DIR / arxiv_id + paper_dir = PAPERS_DIR / arxiv_id if paper_dir.exists(): shutil.rmtree(paper_dir) logger.debug("Removed paper dir: %s", paper_dir) # 3. 删除临时文件 data/tmp/{arxiv_id}/ - tmp_dir = _TMP_DIR / arxiv_id + tmp_dir = TMP_DIR / arxiv_id if tmp_dir.exists(): shutil.rmtree(tmp_dir) logger.debug("Removed tmp dir: %s", tmp_dir) diff --git a/app/services/crawler.py b/app/services/crawler.py index f0dd124..371c90e 100644 --- a/app/services/crawler.py +++ b/app/services/crawler.py @@ -16,6 +16,7 @@ from app.models import ( PaperTag, SummaryStatus, ) +from app.utils import make_http_client logger = logging.getLogger(__name__) @@ -34,15 +35,7 @@ async def fetch_daily(target_date: str, top_n: int | None = None) -> list[dict]: url = f"{settings.HF_API_BASE}/daily_papers" params = {"date": target_date} - transport = None - if settings.http_proxy: - transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy) - - async with httpx.AsyncClient( - timeout=settings.HTTP_TIMEOUT_SECONDS, - headers={"User-Agent": settings.HTTP_USER_AGENT}, - transport=transport, - ) as client: + async with make_http_client() as client: for attempt in range(1, settings.HTTP_MAX_RETRIES + 1): try: logger.info("Fetching HF Daily Papers: date=%s attempt=%d", target_date, attempt) diff --git a/app/services/embedder.py b/app/services/embedder.py index f153251..015a288 100644 --- a/app/services/embedder.py +++ b/app/services/embedder.py @@ -5,8 +5,6 @@ from __future__ import annotations import logging from pathlib import Path -import httpx -from sqlalchemy import select from sqlalchemy.orm import Session, joinedload from app.config import settings @@ -14,66 +12,82 @@ from app.models import Paper logger = logging.getLogger(__name__) -# ── 单例客户端和 collection ───────────────────────────────────────────── -_client = None -_collection = None + +# ── ChromaDB 管理器(替代全局可变状态)────────────────────────────────── -def _chroma_dir() -> Path: - return Path(settings.CHROMA_DIR) +class ChromaManager: + """封装 ChromaDB 客户端和 collection 的生命周期。""" + + def __init__(self) -> None: + self._client = None + self._collection = None + + def init(self) -> None: + """CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。""" + if not settings.CHROMA_ENABLED: + logger.debug("ChromaDB disabled, skip init") + return + + if self._client is not None: + return + + try: + import chromadb + + chroma_path = Path(settings.CHROMA_DIR) + chroma_path.mkdir(parents=True, exist_ok=True) + + self._client = chromadb.PersistentClient(path=str(chroma_path)) + self._collection = self._get_or_create_collection() + logger.info("ChromaDB initialized at %s", chroma_path) + except Exception: + logger.exception("Failed to initialize ChromaDB") + self._client = None + self._collection = None + + def _get_or_create_collection(self): + """获取或创建 papers_embeddings collection。""" + try: + col = self._client.get_collection("papers_embeddings") + logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count()) + return col + except Exception: + pass + + col = self._client.create_collection( + name="papers_embeddings", + metadata={"hnsw:space": "cosine"}, + ) + logger.info("ChromaDB collection 'papers_embeddings' created") + return col + + def get_collection(self): + """返回当前 collection,未初始化则自动初始化。""" + if not settings.CHROMA_ENABLED: + return None + if self._collection is None: + self.init() + return self._collection + + def reset(self) -> None: + """重置状态(供测试使用)。""" + self._client = None + self._collection = None + + +# 模块级单例 +_chroma = ChromaManager() def init_chroma() -> None: - """CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。""" - global _client, _collection - if not settings.CHROMA_ENABLED: - logger.debug("ChromaDB disabled, skip init") - return - - if _client is not None: - return - - try: - import chromadb - - chroma_path = _chroma_dir() - chroma_path.mkdir(parents=True, exist_ok=True) - - _client = chromadb.PersistentClient(path=str(chroma_path)) - _collection = _get_or_create_collection() - logger.info("ChromaDB initialized at %s", chroma_path) - except Exception: - logger.exception("Failed to initialize ChromaDB") - _client = None - _collection = None - - -def _get_or_create_collection(): - """获取或创建 papers_embeddings collection,维度不匹配时记录日志并跳过。""" - import chromadb - - try: - col = _client.get_collection("papers_embeddings") - logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count()) - return col - except Exception: - pass - - col = _client.create_collection( - name="papers_embeddings", - metadata={"hnsw:space": "cosine"}, - ) - logger.info("ChromaDB collection 'papers_embeddings' created") - return col + """初始化 ChromaDB(供 lifespan 调用)。""" + _chroma.init() def get_collection(): """返回当前 collection,未初始化则返回 None。""" - if not settings.CHROMA_ENABLED: - return None - if _collection is None: - init_chroma() - return _collection + return _chroma.get_collection() # ── Embedding API 调用 ────────────────────────────────────────────────── @@ -90,6 +104,8 @@ def _get_embedding(text: str) -> list[float] | None: logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding") return None + from app.utils import make_http_client + url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings" headers = {"Content-Type": "application/json"} if settings.EMBED_API_KEY: @@ -101,7 +117,7 @@ def _get_embedding(text: str) -> list[float] | None: } try: - with httpx.Client(timeout=settings.HTTP_TIMEOUT_SECONDS) as client: + with make_http_client(sync=True) as client: resp = client.post(url, json=payload, headers=headers) resp.raise_for_status() data = resp.json() diff --git a/app/services/image_extractor.py b/app/services/image_extractor.py new file mode 100644 index 0000000..11028c6 --- /dev/null +++ b/app/services/image_extractor.py @@ -0,0 +1,83 @@ +"""LaTeX 图片提取 — 从 arXiv 源码中扫描 \\includegraphics 并提取图片文件。""" + +from __future__ import annotations + +import logging +import re +import shutil +from pathlib import Path + +from app.services.pdf_downloader import download_source_zip, paper_dir, tmp_dir + +logger = logging.getLogger(__name__) + +_INCLUDEGRAPHICS_RE = re.compile( + r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE +) +_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"} + + +async def extract_images_from_source(arxiv_id: str) -> int: + """从 LaTeX 源码中提取图片文件。 + + 流程: + 1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/ + 2. 扫描 .tex 文件中的 \\includegraphics + 3. 复制图片到 data/papers/{arxiv_id}/images/ + 4. 清理源码临时文件 + + Returns: + 提取的图片数量 + """ + tmp_source = tmp_dir(arxiv_id) / "source" + images_dest = paper_dir(arxiv_id) / "images" + + try: + # 下载源码 zip(如果还没下载) + if not tmp_source.exists(): + source_url = f"https://arxiv.org/e-print/{arxiv_id}" + await download_source_zip(arxiv_id, source_url, tmp_source) + + if not tmp_source.exists(): + return 0 + + # 扫描 .tex 文件,收集图片路径 + image_paths: set[str] = set() + for tex_file in tmp_source.rglob("*.tex"): + try: + content = tex_file.read_text(encoding="utf-8", errors="replace") + for match in _INCLUDEGRAPHICS_RE.finditer(content): + img_path = match.group(1).strip() + image_paths.add(img_path) + except Exception: + continue + + if not image_paths: + return 0 + + # 查找并复制图片 + images_dest.mkdir(parents=True, exist_ok=True) + copied = 0 + for img_rel in image_paths: + # 尝试在源码目录中找到文件 + for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"): + candidate = tmp_source / (img_rel + ext) + if candidate.is_file(): + dest_name = candidate.name + # 避免文件名冲突 + dest = images_dest / dest_name + if dest.exists(): + stem = dest.stem + suffix = dest.suffix + dest = images_dest / f"{stem}_{copied}{suffix}" + shutil.copy2(candidate, dest) + copied += 1 + break + + if copied > 0: + logger.info("Extracted %d images from source for %s", copied, arxiv_id) + return copied + + except Exception: + logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True) + return 0 diff --git a/app/services/pdf_downloader.py b/app/services/pdf_downloader.py new file mode 100644 index 0000000..cfb0d5a --- /dev/null +++ b/app/services/pdf_downloader.py @@ -0,0 +1,105 @@ +"""PDF 下载与源码下载 — 从 arXiv 下载论文 PDF 和 LaTeX 源码包。""" + +from __future__ import annotations + +import logging +import shutil +import zipfile +from pathlib import Path + +from app.utils import PAPERS_DIR, TMP_DIR, make_http_client + +logger = logging.getLogger(__name__) + + +# ── 自定义异常 ────────────────────────────────────────────────────────── + + +class PdfDownloadError(Exception): + pass + + +# ── 路径工具 ──────────────────────────────────────────────────────────── + + +def paper_dir(arxiv_id: str) -> Path: + return PAPERS_DIR / arxiv_id + + +def tmp_dir(arxiv_id: str) -> Path: + return TMP_DIR / arxiv_id + + +# ── PDF 下载 ──────────────────────────────────────────────────────────── + + +async def download_pdf(arxiv_id: str, pdf_url: str) -> Path: + """下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。""" + if not pdf_url: + raise PdfDownloadError(f"no pdf_url for {arxiv_id}") + + dest_dir = tmp_dir(arxiv_id) + dest_dir.mkdir(parents=True, exist_ok=True) + dest = dest_dir / "paper.pdf" + + try: + async with make_http_client(follow_redirects=True) as client: + resp = await client.get(pdf_url) + resp.raise_for_status() + dest.write_bytes(resp.content) + except Exception as exc: + raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc + + logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size) + return dest + + +# ── 源码下载 ──────────────────────────────────────────────────────────── + + +async def download_source_zip(arxiv_id: str, source_url: str, dest_dir: Path) -> None: + """下载 arXiv 源码并解压。""" + dest_dir.mkdir(parents=True, exist_ok=True) + zip_path = tmp_dir(arxiv_id) / "source.zip" + + try: + async with make_http_client(follow_redirects=True) as client: + resp = await client.get(source_url) + resp.raise_for_status() + zip_path.write_bytes(resp.content) + except Exception as exc: + logger.debug("Failed to download source for %s: %s", arxiv_id, exc) + return + + try: + with zipfile.ZipFile(zip_path, "r") as zf: + zf.extractall(dest_dir) + logger.debug("Extracted source for %s", arxiv_id) + except zipfile.BadZipFile: + # 可能是 tar.gz + import tarfile + try: + with tarfile.open(zip_path, "r:*") as tf: + tf.extractall(dest_dir, filter="data") + logger.debug("Extracted source (tar) for %s", arxiv_id) + except Exception: + logger.warning("Cannot extract source for %s", arxiv_id) + except Exception: + logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True) + finally: + if zip_path.exists(): + zip_path.unlink() + + +# ── 临时文件清理 ──────────────────────────────────────────────────────── + + +def cleanup_tmp(arxiv_id: str) -> None: + """清理 data/tmp/{arxiv_id}/ 目录。""" + td = tmp_dir(arxiv_id) + if td.exists(): + try: + shutil.rmtree(td) + logger.debug("Cleaned tmp: %s", arxiv_id) + except Exception: + logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True) diff --git a/app/services/pi_client.py b/app/services/pi_client.py new file mode 100644 index 0000000..b7aae2b --- /dev/null +++ b/app/services/pi_client.py @@ -0,0 +1,160 @@ +"""pi CLI 调用与 JSON 提取 — 调用 pi 生成总结,从输出中提取结构化 JSON。""" + +from __future__ import annotations + +import asyncio +import json +import logging +import re +from pathlib import Path + +from app.config import settings + +logger = logging.getLogger(__name__) + + +# ── 自定义异常 ────────────────────────────────────────────────────────── + + +class PiTimeoutError(Exception): + pass + + +class PiProcessError(Exception): + def __init__(self, returncode: int, stderr: str): + self.returncode = returncode + self.stderr = stderr + super().__init__(f"pi exited with code {returncode}: {stderr[:500]}") + + +class JsonNotFoundError(Exception): + pass + + +# ── meta.json ─────────────────────────────────────────────────────────── + + +def write_meta_json(paper) -> Path: + """写入 data/papers/{arxiv_id}/meta.json,返回路径。""" + from app.services.pdf_downloader import paper_dir + + d = paper_dir(paper.arxiv_id) + d.mkdir(parents=True, exist_ok=True) + meta_path = d / "meta.json" + + authors = [a.name for a in paper.authors] + tags = [t.tag for t in paper.tags] + meta = { + "arxiv_id": paper.arxiv_id, + "title_en": paper.title_en, + "abstract": paper.abstract or "", + "published_at": paper.published_at.isoformat() if paper.published_at else None, + "authors": authors, + "tags": tags, + "upvotes": paper.upvotes, + } + meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8") + return meta_path + + +# ── pi CLI 调用 ──────────────────────────────────────────────────────── + + +async def call_pi(meta_path: Path, pdf_path: Path) -> str: + """调用 pi CLI 非交互模式,返回 stdout 文本。""" + arxiv_id = meta_path.parent.name + cmd = [ + settings.PI_BIN, + "-p", + "--no-tools", + "--skill", + settings.SUMMARY_SKILL, + "请深度解读以下论文,并按指定 JSON schema 输出:", + f"@{meta_path}", + f"@{pdf_path}", + ] + logger.info("Calling pi for %s", arxiv_id) + + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + try: + stdout, stderr = await asyncio.wait_for( + proc.communicate(), + timeout=settings.SUMMARY_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + raise PiTimeoutError( + f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s" + ) + + if proc.returncode != 0: + raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace")) + + return stdout.decode("utf-8", errors="replace") + + +# ── JSON 提取 ────────────────────────────────────────────────────────── + + +def extract_json(raw_output: str) -> dict: + """从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。""" + # 策略 1:整体直接解析 + stripped = raw_output.strip() + try: + result = json.loads(stripped) + if isinstance(result, dict) and "title_zh" in result: + return result + except json.JSONDecodeError: + pass + + # 策略 2:提取 ```json ... ``` 代码块 + fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL) + for match in fence_pattern.finditer(raw_output): + try: + result = json.loads(match.group(1).strip()) + if isinstance(result, dict) and "title_zh" in result: + return result + except json.JSONDecodeError: + continue + + # 策略 3:匹配包含 title_zh 的最大 {...} 块 + brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL) + for match in brace_pattern.finditer(raw_output): + try: + return json.loads(match.group(0)) + except json.JSONDecodeError: + continue + + # 更宽松:找到最大的 { ... } 平衡块 + best = None + best_len = 0 + for i, ch in enumerate(raw_output): + if ch != "{": + continue + depth = 0 + for j in range(i, len(raw_output)): + if raw_output[j] == "{": + depth += 1 + elif raw_output[j] == "}": + depth -= 1 + if depth == 0: + candidate = raw_output[i : j + 1] + if len(candidate) > best_len: + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + best = parsed + best_len = len(candidate) + except json.JSONDecodeError: + pass + break + + if best is not None: + return best + + raise JsonNotFoundError("no JSON object found in pi output") diff --git a/app/services/summarizer.py b/app/services/summarizer.py index 523eb18..b3450d8 100644 --- a/app/services/summarizer.py +++ b/app/services/summarizer.py @@ -1,18 +1,15 @@ -"""AI 总结服务 — 调用 pi CLI 生成论文中文结构化总结。""" +"""AI 总结编排服务 — 协调 PDF 下载、pi CLI 调用、JSON 校验、DB 写入、语义索引。""" from __future__ import annotations -import asyncio import json import logging -import re import shutil from datetime import datetime, timezone from pathlib import Path -import httpx from pydantic import ValidationError -from sqlalchemy import select, text +from sqlalchemy import select from sqlalchemy.orm import Session, joinedload from app.config import settings @@ -25,216 +22,31 @@ from app.models import ( SummaryStatus, TaskLock, ) +from app.services.image_extractor import extract_images_from_source +from app.services.pdf_downloader import ( + PdfDownloadError, + cleanup_tmp, + download_pdf, + paper_dir, +) +from app.services.pi_client import ( + JsonNotFoundError, + PiProcessError, + PiTimeoutError, + call_pi, + extract_json, + write_meta_json, +) from app.services.schemas import ( SummarySchema, assess_quality, classify_validation_error, flatten_for_db, ) +from app.utils import PAPERS_DIR, release_lock logger = logging.getLogger(__name__) -# ── 自定义异常 ────────────────────────────────────────────────────────── - - -class PdfDownloadError(Exception): - pass - - -class PiTimeoutError(Exception): - pass - - -class PiProcessError(Exception): - def __init__(self, returncode: int, stderr: str): - self.returncode = returncode - self.stderr = stderr - super().__init__(f"pi exited with code {returncode}: {stderr[:500]}") - - -class JsonNotFoundError(Exception): - pass - - -# ── 路径工具 ──────────────────────────────────────────────────────────── - -_DATA_DIR = Path("data") -_PAPERS_DIR = _DATA_DIR / "papers" -_TMP_DIR = _DATA_DIR / "tmp" - - -def _paper_dir(arxiv_id: str) -> Path: - return _PAPERS_DIR / arxiv_id - - -def _tmp_dir(arxiv_id: str) -> Path: - return _TMP_DIR / arxiv_id - - -# ── PDF 下载 ──────────────────────────────────────────────────────────── - - -async def _download_pdf(arxiv_id: str, pdf_url: str) -> Path: - """下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。""" - if not pdf_url: - raise PdfDownloadError(f"no pdf_url for {arxiv_id}") - - tmp = _tmp_dir(arxiv_id) - tmp.mkdir(parents=True, exist_ok=True) - dest = tmp / "paper.pdf" - - transport = None - if settings.http_proxy: - transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy) - - try: - async with httpx.AsyncClient( - timeout=settings.HTTP_TIMEOUT_SECONDS, - headers={"User-Agent": settings.HTTP_USER_AGENT}, - transport=transport, - follow_redirects=True, - ) as client: - resp = await client.get(pdf_url) - resp.raise_for_status() - dest.write_bytes(resp.content) - except Exception as exc: - raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc - - logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size) - return dest - - -# ── meta.json ─────────────────────────────────────────────────────────── - - -def _write_meta_json(paper: Paper) -> Path: - """写入 data/papers/{arxiv_id}/meta.json,返回路径。""" - d = _paper_dir(paper.arxiv_id) - d.mkdir(parents=True, exist_ok=True) - meta_path = d / "meta.json" - - authors = [a.name for a in paper.authors] - tags = [t.tag for t in paper.tags] - meta = { - "arxiv_id": paper.arxiv_id, - "title_en": paper.title_en, - "abstract": paper.abstract or "", - "published_at": paper.published_at.isoformat() if paper.published_at else None, - "authors": authors, - "tags": tags, - "upvotes": paper.upvotes, - } - meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8") - return meta_path - - -# ── pi CLI 调用 ──────────────────────────────────────────────────────── - - -async def _call_pi(meta_path: Path, pdf_path: Path) -> str: - """调用 pi CLI 非交互模式,返回 stdout 文本。""" - cmd = [ - settings.PI_BIN, - "-p", - "--no-tools", - "--skill", - settings.SUMMARY_SKILL, - "请深度解读以下论文,并按指定 JSON schema 输出:", - f"@{meta_path}", - f"@{pdf_path}", - ] - logger.info("Calling pi: %s %s", paper_id_from_path(meta_path), " ".join(cmd[:4])) - - proc = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - try: - stdout, stderr = await asyncio.wait_for( - proc.communicate(), - timeout=settings.SUMMARY_TIMEOUT_SECONDS, - ) - except asyncio.TimeoutError: - proc.kill() - await proc.wait() - raise PiTimeoutError( - f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s" - ) - - if proc.returncode != 0: - raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace")) - - return stdout.decode("utf-8", errors="replace") - - -def paper_id_from_path(meta_path: Path) -> str: - """从 meta.json 路径反推 arxiv_id。""" - return meta_path.parent.name - - -# ── JSON 提取 ────────────────────────────────────────────────────────── - - -def _extract_json(raw_output: str) -> dict: - """从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。""" - # 策略 1:整体直接解析 - stripped = raw_output.strip() - try: - result = json.loads(stripped) - if isinstance(result, dict) and "title_zh" in result: - return result - except json.JSONDecodeError: - pass - - # 策略 2:提取 ```json ... ``` 代码块 - fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL) - for match in fence_pattern.finditer(raw_output): - try: - result = json.loads(match.group(1).strip()) - if isinstance(result, dict) and "title_zh" in result: - return result - except json.JSONDecodeError: - continue - - # 策略 3:匹配包含 title_zh 的最大 {...} 块 - brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL) - # 先尝试一层嵌套;如果没命中再用更宽松的策略 - for match in brace_pattern.finditer(raw_output): - try: - return json.loads(match.group(0)) - except json.JSONDecodeError: - continue - - # 更宽松:找到最大的 { ... } 平衡块 - best = None - best_len = 0 - for i, ch in enumerate(raw_output): - if ch != "{": - continue - depth = 0 - for j in range(i, len(raw_output)): - if raw_output[j] == "{": - depth += 1 - elif raw_output[j] == "}": - depth -= 1 - if depth == 0: - candidate = raw_output[i : j + 1] - if len(candidate) > best_len: - try: - parsed = json.loads(candidate) - if isinstance(parsed, dict): - best = parsed - best_len = len(candidate) - except json.JSONDecodeError: - pass - break - - if best is not None: - return best - - raise JsonNotFoundError("no JSON object found in pi output") - # ── 错误分类 ──────────────────────────────────────────────────────────── @@ -284,6 +96,8 @@ def _update_summary_in_db( raw_output: str, ) -> None: """将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。""" + from sqlalchemy import text + now = datetime.now(timezone.utc) # 1. paper_summaries:upsert @@ -298,9 +112,9 @@ def _update_summary_in_db( # 2. papers 表 paper.title_zh = schema.title_zh paper.summary_quality = quality - paper_dir = _paper_dir(paper.arxiv_id) - paper.summary_path = str(paper_dir / "summary.json") - paper.raw_output_path = str(paper_dir / "raw_output.txt") + p_dir = paper_dir(paper.arxiv_id) + paper.summary_path = str(p_dir / "summary.json") + paper.raw_output_path = str(p_dir / "raw_output.txt") # 3. AI 标签 existing_tag_names = {t.tag for t in paper.tags} @@ -332,7 +146,7 @@ def _update_summary_in_db( def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None: """保存 summary.json 和 raw_output.txt。""" - d = _paper_dir(arxiv_id) + d = paper_dir(arxiv_id) d.mkdir(parents=True, exist_ok=True) (d / "summary.json").write_text( schema.model_dump_json(ensure_ascii=False, indent=2), @@ -343,143 +157,11 @@ def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None: def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None: """仅保存 raw_output.txt(失败时)。""" - d = _paper_dir(arxiv_id) + d = paper_dir(arxiv_id) d.mkdir(parents=True, exist_ok=True) (d / "raw_output.txt").write_text(raw_output, encoding="utf-8") -def _cleanup_tmp(arxiv_id: str) -> None: - """清理 data/tmp/{arxiv_id}/ 目录。""" - tmp = _tmp_dir(arxiv_id) - if tmp.exists(): - try: - shutil.rmtree(tmp) - logger.debug("Cleaned tmp: %s", arxiv_id) - except Exception: - logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True) - - -# ── LaTeX 图片提取(Phase 5)─────────────────────────────────────────── - -_INCLUDEGRAPHICS_RE = re.compile( - r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE -) -_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"} - - -async def _extract_images_from_source(arxiv_id: str, tmp_source: Path | None = None) -> int: - """从 LaTeX 源码中提取图片文件。 - - 流程: - 1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/ - 2. 扫描 .tex 文件中的 \\includegraphics - 3. 复制图片到 data/papers/{arxiv_id}/images/ - 4. 清理源码临时文件 - - Returns: - 提取的图片数量 - """ - tmp_source = _tmp_dir(arxiv_id) / "source" - images_dest = _paper_dir(arxiv_id) / "images" - - try: - # 下载源码 zip(如果还没下载) - if not tmp_source.exists(): - source_url = f"https://arxiv.org/e-print/{arxiv_id}" - await _download_source_zip(arxiv_id, source_url, tmp_source) - - if not tmp_source.exists(): - return 0 - - # 扫描 .tex 文件,收集图片路径 - image_paths: set[str] = set() - for tex_file in tmp_source.rglob("*.tex"): - try: - content = tex_file.read_text(encoding="utf-8", errors="replace") - for match in _INCLUDEGRAPHICS_RE.finditer(content): - img_path = match.group(1).strip() - image_paths.add(img_path) - except Exception: - continue - - if not image_paths: - return 0 - - # 查找并复制图片 - images_dest.mkdir(parents=True, exist_ok=True) - copied = 0 - for img_rel in image_paths: - # 尝试在源码目录中找到文件 - for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"): - candidate = tmp_source / (img_rel + ext) - if candidate.is_file(): - dest_name = candidate.name - # 避免文件名冲突 - dest = images_dest / dest_name - if dest.exists(): - stem = dest.stem - suffix = dest.suffix - dest = images_dest / f"{stem}_{copied}{suffix}" - shutil.copy2(candidate, dest) - copied += 1 - break - - if copied > 0: - logger.info("Extracted %d images from source for %s", copied, arxiv_id) - return copied - - except Exception: - logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True) - return 0 - - -async def _download_source_zip( - arxiv_id: str, source_url: str, dest_dir: Path -) -> None: - """下载 arXiv 源码并解压。""" - import zipfile - - dest_dir.mkdir(parents=True, exist_ok=True) - zip_path = _tmp_dir(arxiv_id) / "source.zip" - - transport = None - if settings.http_proxy: - transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy) - - try: - async with httpx.AsyncClient( - timeout=settings.HTTP_TIMEOUT_SECONDS, - headers={"User-Agent": settings.HTTP_USER_AGENT}, - transport=transport, - follow_redirects=True, - ) as client: - resp = await client.get(source_url) - resp.raise_for_status() - zip_path.write_bytes(resp.content) - except Exception as exc: - logger.debug("Failed to download source for %s: %s", arxiv_id, exc) - return - - try: - with zipfile.ZipFile(zip_path, "r") as zf: - zf.extractall(dest_dir) - logger.debug("Extracted source for %s", arxiv_id) - except zipfile.BadZipFile: - # 可能是 tar.gz - import tarfile - try: - with tarfile.open(zip_path, "r:*") as tf: - tf.extractall(dest_dir) - logger.debug("Extracted source (tar) for %s", arxiv_id) - except Exception: - logger.warning("Cannot extract source for %s", arxiv_id) - except Exception: - logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True) - finally: - if zip_path.exists(): - zip_path.unlink() - - # ── 单篇总结 ──────────────────────────────────────────────────────────── @@ -491,6 +173,8 @@ async def summarize_one( force: bool = False, ) -> dict: """总结单篇论文的完整流程。""" + import asyncio + arxiv_id = paper.arxiv_id # 获取或创建 summary_status @@ -520,6 +204,8 @@ async def summarize_one( async def _do_summarize_one(db: Session, paper: Paper) -> dict: """实际的单篇总结执行(在 semaphore 保护下)。""" + import asyncio + arxiv_id = paper.arxiv_id status = paper.summary_status now = datetime.now(timezone.utc) @@ -532,16 +218,16 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict: raw_output = "" try: # 写 meta.json - meta_path = _write_meta_json(paper) + meta_path = write_meta_json(paper) # 下载 PDF - await _download_pdf(arxiv_id, paper.pdf_url) + await download_pdf(arxiv_id, paper.pdf_url) # 调用 pi - raw_output = await _call_pi(meta_path, _tmp_dir(arxiv_id) / "paper.pdf") + raw_output = await call_pi(meta_path, Path("data/tmp") / arxiv_id / "paper.pdf") # 提取 JSON - json_data = _extract_json(raw_output) + json_data = extract_json(raw_output) # Pydantic 校验 schema = SummarySchema.model_validate(json_data) @@ -564,7 +250,7 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict: # Phase 5: LaTeX 图片提取(可选增强,失败不影响总结) try: - await _extract_images_from_source(arxiv_id) + await extract_images_from_source(arxiv_id) except Exception: logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True) @@ -625,7 +311,7 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict: } finally: - _cleanup_tmp(arxiv_id) + cleanup_tmp(arxiv_id) # ── 单篇入口 ──────────────────────────────────────────────────────────── @@ -690,6 +376,8 @@ async def summarize_batch( _session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。 """ + import asyncio + now = datetime.now(timezone.utc) # TaskLock 防重入 @@ -741,7 +429,7 @@ async def summarize_batch( log_entry.papers_found = 0 log_entry.papers_new = 0 log_entry.completed_at = datetime.now(timezone.utc) - _release_lock(db, lock) + release_lock(db, lock) return {"status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0} # 并发控制 @@ -813,15 +501,4 @@ async def summarize_batch( return {"status": "failed", "error": str(exc)} finally: - _release_lock(db, lock) - - -def _release_lock(db: Session, lock: TaskLock) -> None: - """释放 TaskLock。""" - try: - lock.status = "finished" - lock.released_at = datetime.now(timezone.utc) - db.commit() - except Exception: - db.rollback() - logger.warning("Failed to release summarize lock", exc_info=True) + release_lock(db, lock) diff --git a/app/services/trends.py b/app/services/trends.py new file mode 100644 index 0000000..51e7824 --- /dev/null +++ b/app/services/trends.py @@ -0,0 +1,81 @@ +"""趋势统计服务 — 按日论文数量、热门标签、Upvotes 分布、总结完成率。""" + +from __future__ import annotations + +from datetime import date, timedelta + +from sqlalchemy import text +from sqlalchemy.orm import Session + + +def get_trends_data(db: Session) -> dict: + """从 DB 聚合趋势数据。""" + thirty_days_ago = (date.today() - timedelta(days=30)).isoformat() + + # 1. 按日论文数量(近 30 天) + daily_rows = db.execute(text(""" + SELECT paper_date, COUNT(*) as cnt + FROM papers + WHERE paper_date >= :start_date + GROUP BY paper_date + ORDER BY paper_date ASC + """), {"start_date": thirty_days_ago}).fetchall() + daily_counts = [ + {"date": str(row[0]), "count": row[1]} + for row in daily_rows + ] + + # 2. 热门标签 Top 20 + tag_rows = db.execute(text(""" + SELECT tag, COUNT(*) as cnt + FROM paper_tags + GROUP BY tag + ORDER BY cnt DESC + LIMIT 20 + """)).fetchall() + top_tags = [ + {"tag": row[0], "count": row[1]} + for row in tag_rows + ] + + # 3. Upvotes 分布 + upvote_rows = db.execute(text(""" + SELECT + CASE + WHEN upvotes >= 100 THEN '100+' + WHEN upvotes >= 50 THEN '50-99' + WHEN upvotes >= 20 THEN '20-49' + WHEN upvotes >= 10 THEN '10-19' + WHEN upvotes >= 5 THEN '5-9' + ELSE '0-4' + END as bucket, + COUNT(*) as cnt + FROM papers + GROUP BY bucket + ORDER BY MIN(upvotes) DESC + """)).fetchall() + upvotes_dist = [ + {"range": row[0], "count": row[1]} + for row in upvote_rows + ] + + # 4. 总结完成率 + summary_rows = db.execute(text(""" + SELECT + COALESCE(ss.status, 'none') as status, + COUNT(*) as cnt + FROM papers p + LEFT JOIN summary_status ss ON ss.paper_id = p.id + GROUP BY status + """)).fetchall() + summary_completion = [ + {"status": row[0], "count": row[1]} + for row in summary_rows + ] + + return { + "daily_counts": daily_counts, + "top_tags": top_tags, + "upvotes_dist": upvotes_dist, + "summary_completion": summary_completion, + } diff --git a/app/services/user_data.py b/app/services/user_data.py index 77cb9c9..fd93b18 100644 --- a/app/services/user_data.py +++ b/app/services/user_data.py @@ -1,12 +1,13 @@ -"""用户数据服务 — 收藏、阅读状态、个人笔记。无账号体系,数据写入本地 SQLite。""" +"""用户数据服务 — 收藏、阅读状态、个人笔记、阅读列表查询。无账号体系,数据写入本地 SQLite。""" from __future__ import annotations from datetime import datetime, timezone -from sqlalchemy.orm import Session +from sqlalchemy import or_ +from sqlalchemy.orm import Session, joinedload -from app.models import Paper, UserBookmark, UserNote, UserReadingStatus +from app.models import Paper, PaperTag, UserBookmark, UserNote, UserReadingStatus # ── 收藏 ────────────────────────────────────────────────────────────── @@ -113,3 +114,47 @@ def save_note(db: Session, arxiv_id: str, content: str) -> dict: "content": content, "updated_at": now.isoformat(), } + + +# ── 阅读列表 ────────────────────────────────────────────────────────── + + +def query_reading_list( + db: Session, + filter_type: str, + tag: str | None, +) -> list[Paper]: + """根据筛选条件查询阅读列表。""" + # 基础:有任意用户数据的论文 + base = db.query(Paper).filter( + or_( + Paper.bookmark.has(), + Paper.reading_status.has(), + Paper.note.has(), + ) + ) + + # 应用筛选 + if filter_type == "has_note": + base = base.filter(Paper.note.has()) + elif filter_type in ("unread", "skimmed", "read_summary", "read_full"): + base = base.filter( + Paper.reading_status.has(UserReadingStatus.status == filter_type) + ) + + # 应用标签 + if tag: + base = base.filter(Paper.tags.any(PaperTag.tag == tag)) + + return ( + base.options( + joinedload(Paper.authors), + joinedload(Paper.tags), + joinedload(Paper.summary_status), + joinedload(Paper.bookmark), + joinedload(Paper.reading_status), + joinedload(Paper.note), + ) + .order_by(Paper.paper_date.desc(), Paper.upvotes.desc()) + .all() + ) diff --git a/app/utils.py b/app/utils.py new file mode 100644 index 0000000..63c794a --- /dev/null +++ b/app/utils.py @@ -0,0 +1,73 @@ +"""公共工具 — 消除各模块间的重复代码。""" + +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +from zoneinfo import ZoneInfo + +import httpx +from fastapi.templating import Jinja2Templates + +from app.config import settings + +# ── 路径常量 ────────────────────────────────────────────────────────── + +DATA_DIR = Path("data") +PAPERS_DIR = DATA_DIR / "papers" +TMP_DIR = DATA_DIR / "tmp" + +# ── 模板单例 ────────────────────────────────────────────────────────── + +templates = Jinja2Templates(directory="app/templates") + + +# ── 时区工具 ────────────────────────────────────────────────────────── + + +def today_str() -> str: + """当前日期字符串(按 APP_TIMEZONE)。""" + tz = ZoneInfo(settings.APP_TIMEZONE) + return datetime.now(tz).strftime("%Y-%m-%d") + + +# ── 锁释放 ──────────────────────────────────────────────────────────── + + +def release_lock(db, lock) -> None: + """释放 TaskLock。""" + try: + lock.status = "finished" + lock.released_at = datetime.now(timezone.utc) + db.commit() + except Exception: + db.rollback() + + +# ── HTTP 客户端工厂 ─────────────────────────────────────────────────── + + +def make_http_client(*, sync: bool = False, follow_redirects: bool = False, **kwargs) -> httpx.AsyncClient | httpx.Client: + """创建带 proxy 和默认配置的 httpx 客户端。 + + Args: + sync: True 返回同步 Client,False 返回 AsyncClient + follow_redirects: 是否跟随重定向 + **kwargs: 覆盖默认参数 + """ + defaults: dict = { + "timeout": settings.HTTP_TIMEOUT_SECONDS, + "headers": {"User-Agent": settings.HTTP_USER_AGENT}, + "follow_redirects": follow_redirects, + } + if settings.http_proxy: + defaults["transport"] = ( + httpx.HTTPTransport(proxy=settings.http_proxy) + if sync + else httpx.AsyncHTTPTransport(proxy=settings.http_proxy) + ) + defaults.update(kwargs) + + if sync: + return httpx.Client(**defaults) + return httpx.AsyncClient(**defaults) diff --git a/tests/conftest.py b/tests/conftest.py index 61ef8aa..ef4895b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,13 +15,13 @@ from sqlalchemy.pool import StaticPool from app.database import get_db from app.main import create_app +from app.database import init_db from app.models import ( Paper, PaperAuthor, PaperSummary, PaperTag, SummaryStatus, - init_db, ) diff --git a/tests/test_admin_phase4.py b/tests/test_admin_phase4.py index f95702e..305e608 100644 --- a/tests/test_admin_phase4.py +++ b/tests/test_admin_phase4.py @@ -141,7 +141,7 @@ class TestCleanupTmp: old_mtime = time.time() - 25 * 3600 os.utime(old_dir, (old_mtime, old_mtime)) - monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir) + monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir) from app.services.cleaner import cleanup_tmp result = cleanup_tmp() @@ -158,7 +158,7 @@ class TestCleanupTmp: recent_dir.mkdir() (recent_dir / "paper.pdf").write_text("fake pdf") - monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir) + monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir) from app.services.cleaner import cleanup_tmp result = cleanup_tmp() @@ -168,7 +168,7 @@ class TestCleanupTmp: def test_cleanup_empty_dir(self, tmp_path, monkeypatch): """data/tmp/ 不存在时安全返回。""" - monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_path / "nonexistent") + monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_path / "nonexistent") from app.services.cleaner import cleanup_tmp result = cleanup_tmp() assert result["scanned"] == 0 @@ -187,7 +187,7 @@ class TestCleanupTmp: recent_dir = tmp_dir / "2401.new" recent_dir.mkdir() - monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir) + monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir) from app.services.cleaner import cleanup_tmp result = cleanup_tmp() @@ -318,7 +318,7 @@ class TestDeletePapersByDateRange: (papers_dir / "2401.10001").mkdir() (papers_dir / "2401.10001" / "meta.json").write_text("{}") - monkeypatch.setattr("app.services.cleaner._PAPERS_DIR", papers_dir) + monkeypatch.setattr("app.services.cleaner.PAPERS_DIR", papers_dir) result = await delete_papers_by_date_range( db_session, diff --git a/tests/test_phase5.py b/tests/test_phase5.py index cf73132..1ed2f68 100644 --- a/tests/test_phase5.py +++ b/tests/test_phase5.py @@ -125,10 +125,9 @@ class TestEmbedderInit: """CHROMA_ENABLED=false 时不初始化。""" monkeypatch.setattr(settings, "CHROMA_ENABLED", False) import app.services.embedder as emb - emb._client = None - emb._collection = None + emb._chroma.reset() emb.init_chroma() - assert emb._client is None + assert emb._chroma._client is None def test_chroma_init_success(self, monkeypatch, tmp_path): """CHROMA_ENABLED=true 时初始化成功。""" @@ -136,23 +135,20 @@ class TestEmbedderInit: monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma")) import app.services.embedder as emb - emb._client = None - emb._collection = None + emb._chroma.reset() emb.init_chroma() - assert emb._client is not None - assert emb._collection is not None + assert emb._chroma._client is not None + assert emb._chroma._collection is not None # 清理 - emb._client = None - emb._collection = None + emb._chroma.reset() def test_get_collection_returns_none_when_disabled(self, monkeypatch): """CHROMA_ENABLED=false 时 get_collection 返回 None。""" monkeypatch.setattr(settings, "CHROMA_ENABLED", False) import app.services.embedder as emb - emb._client = None - emb._collection = None + emb._chroma.reset() assert emb.get_collection() is None @@ -163,8 +159,7 @@ class TestEmbedderIndexing: """CHROMA_ENABLED=false 时 index_paper 返回 False。""" monkeypatch.setattr(settings, "CHROMA_ENABLED", False) import app.services.embedder as emb - emb._client = None - emb._collection = None + emb._chroma.reset() assert emb.index_paper("test-id") is False def test_index_paper_no_api_config(self, monkeypatch, tmp_path): @@ -175,22 +170,19 @@ class TestEmbedderIndexing: monkeypatch.setattr(settings, "EMBED_MODEL", "") import app.services.embedder as emb - emb._client = None - emb._collection = None + emb._chroma.reset() emb.init_chroma() result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"}) assert result is False - emb._client = None - emb._collection = None + emb._chroma.reset() def test_index_batch_disabled(self, monkeypatch): """CHROMA_ENABLED=false 时 index_batch 返回全失败。""" monkeypatch.setattr(settings, "CHROMA_ENABLED", False) import app.services.embedder as emb - emb._client = None - emb._collection = None + emb._chroma.reset() result = emb.index_batch(["a", "b"]) assert result["success"] == 0 assert result["failed"] == 2 @@ -206,16 +198,14 @@ class TestEmbedderIndexing: """CHROMA_ENABLED=false 时 delete_paper 返回 False。""" monkeypatch.setattr(settings, "CHROMA_ENABLED", False) import app.services.embedder as emb - emb._client = None - emb._collection = None + emb._chroma.reset() assert emb.delete_paper("test-id") is False def test_search_similar_disabled(self, monkeypatch): """CHROMA_ENABLED=false 时 search_similar 返回空列表。""" monkeypatch.setattr(settings, "CHROMA_ENABLED", False) import app.services.embedder as emb - emb._client = None - emb._collection = None + emb._chroma.reset() assert emb.search_similar("test query") == [] @@ -427,8 +417,8 @@ class TestTrendsDashboard: from unittest.mock import patch as upatch import app.routes.trends as trends_mod - # monkeypatch _get_trends_data 中的 date.today - with upatch("app.routes.trends.date") as mock_date: + # monkeypatch get_trends_data 中的 date.today + with upatch("app.services.trends.date") as mock_date: mock_date.today.return_value = date(2024, 1, 20) mock_date.side_effect = lambda *a, **kw: date(*a, **kw) @@ -528,15 +518,17 @@ class TestImageExtraction: @pytest.mark.asyncio async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path): """源码目录不存在时返回 0。""" - monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x) - monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x) - from app.services.summarizer import _extract_images_from_source - result = await _extract_images_from_source("2401.99999") + monkeypatch.setattr("app.services.pdf_downloader.tmp_dir", lambda x: tmp_path / "tmp" / x) + monkeypatch.setattr("app.services.pdf_downloader.paper_dir", lambda x: tmp_path / "papers" / x) + from app.services.image_extractor import extract_images_from_source + result = await extract_images_from_source("2401.99999") assert result == 0 @pytest.mark.asyncio async def test_extract_images_from_tex(self, monkeypatch, tmp_path): """从 .tex 文件中提取图片。""" + from app.services.image_extractor import extract_images_from_source + tmp_source = tmp_path / "tmp" / "2401.00001" / "source" tmp_source.mkdir(parents=True) @@ -559,11 +551,16 @@ class TestImageExtraction: (tmp_source / "main.tex").write_text(tex_content) papers_dir = tmp_path / "papers" / "2401.00001" - monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x) - monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x) + monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x) + monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x) - from app.services.summarizer import _extract_images_from_source - result = await _extract_images_from_source("2401.00001") + # Mock download_source_zip to avoid real network call (source dir already exists) + async def _noop_download(*args, **kwargs): + pass + + monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download) + + result = await extract_images_from_source("2401.00001") assert result == 2 dest_images = papers_dir / "images" @@ -574,15 +571,22 @@ class TestImageExtraction: @pytest.mark.asyncio async def test_extract_images_empty_tex(self, monkeypatch, tmp_path): """.tex 文件无图片时返回 0。""" + from app.services.image_extractor import extract_images_from_source + tmp_source = tmp_path / "tmp" / "2401.00002" / "source" tmp_source.mkdir(parents=True) (tmp_source / "main.tex").write_text(r"\documentclass{article}\begin{document}Hello\end{document}") - monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x) - monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x) + monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x) + monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x) - from app.services.summarizer import _extract_images_from_source - result = await _extract_images_from_source("2401.00002") + # Mock download_source_zip to avoid real network call + async def _noop_download(*args, **kwargs): + pass + + monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download) + + result = await extract_images_from_source("2401.00002") assert result == 0 @@ -644,8 +648,7 @@ class TestGracefulDegradation: """CHROMA 关闭时删除论文正常工作。""" monkeypatch.setattr(settings, "CHROMA_ENABLED", False) import app.services.embedder as emb - emb._client = None - emb._collection = None + emb._chroma.reset() from app.services.cleaner import delete_papers_by_date_range result = await delete_papers_by_date_range( diff --git a/tests/test_summarizer.py b/tests/test_summarizer.py index ebe8b8f..35cdb79 100644 --- a/tests/test_summarizer.py +++ b/tests/test_summarizer.py @@ -27,14 +27,7 @@ from app.services.schemas import ( flatten_for_db, ) from app.services.summarizer import ( - JsonNotFoundError, - PdfDownloadError, - PiProcessError, - PiTimeoutError, - _call_pi, _classify_error, - _cleanup_tmp, - _extract_json, _save_files, _save_raw_output_only, _update_summary_in_db, @@ -42,6 +35,17 @@ from app.services.summarizer import ( summarize_one, summarize_single, ) +from app.services.pi_client import ( + JsonNotFoundError, + PiProcessError, + PiTimeoutError, + call_pi as _call_pi, + extract_json as _extract_json, +) +from app.services.pdf_downloader import ( + PdfDownloadError, + cleanup_tmp as _cleanup_tmp, +) # ═══════════════════════════════════════════════════════════════════════ @@ -287,7 +291,7 @@ class TestFileOperations: def test_save_files(self, tmp_path, sample_summary_dict): schema = SummarySchema.model_validate(sample_summary_dict) - with patch("app.services.summarizer._PAPERS_DIR", tmp_path): + with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid): _save_files("2401.12345", schema, "raw output text") paper_dir = tmp_path / "2401.12345" @@ -297,7 +301,7 @@ class TestFileOperations: assert saved["title_zh"] == "测试论文中文标题" def test_save_raw_output_only(self, tmp_path): - with patch("app.services.summarizer._PAPERS_DIR", tmp_path): + with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid): _save_raw_output_only("2401.12345", "raw output") paper_dir = tmp_path / "2401.12345" assert (paper_dir / "raw_output.txt").exists() @@ -307,13 +311,13 @@ class TestFileOperations: tmp_paper = tmp_path / "2401.12345" tmp_paper.mkdir() (tmp_paper / "paper.pdf").write_bytes(b"%PDF-fake") - with patch("app.services.summarizer._TMP_DIR", tmp_path): + with patch("app.services.pdf_downloader.TMP_DIR", tmp_path): _cleanup_tmp("2401.12345") assert not tmp_paper.exists() def test_cleanup_tmp_nonexistent(self, tmp_path): """清理不存在的目录不报错。""" - with patch("app.services.summarizer._TMP_DIR", tmp_path): + with patch("app.services.pdf_downloader.TMP_DIR", tmp_path): _cleanup_tmp("nonexistent") # 不抛异常 @@ -329,9 +333,11 @@ class TestSummarizeOneFlow: def _patch_paths(self, tmp_path): """将 data 目录重定向到 tmp_path。""" with ( - patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"), - patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"), - patch("app.services.summarizer._DATA_DIR", tmp_path), + patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid), + patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"), + patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"), + patch("app.utils.PAPERS_DIR", tmp_path / "papers"), + patch("app.utils.TMP_DIR", tmp_path / "tmp"), ): yield @@ -341,8 +347,8 @@ class TestSummarizeOneFlow: ): """pending → processing → done 全流程。""" with ( - patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), - patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output), + patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), + patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output), ): result = await summarize_one(db_session, sample_paper) @@ -374,7 +380,7 @@ class TestSummarizeOneFlow: """PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。""" with ( patch( - "app.services.summarizer._download_pdf", + "app.services.summarizer.download_pdf", new_callable=AsyncMock, side_effect=PdfDownloadError("network error"), ), @@ -392,9 +398,9 @@ class TestSummarizeOneFlow: async def test_pi_timeout(self, db_session, sample_paper, _patch_paths): """pi 超时 → timeout 错误,retry_count 递增。""" with ( - patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), + patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), patch( - "app.services.summarizer._call_pi", + "app.services.summarizer.call_pi", new_callable=AsyncMock, side_effect=PiTimeoutError("timeout after 300s"), ), @@ -409,9 +415,9 @@ class TestSummarizeOneFlow: async def test_json_not_found(self, db_session, sample_paper, _patch_paths): """pi 输出无 JSON → json_not_found。""" with ( - patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), + patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), patch( - "app.services.summarizer._call_pi", + "app.services.summarizer.call_pi", new_callable=AsyncMock, return_value="No JSON in this output at all.", ), @@ -436,9 +442,9 @@ class TestSummarizeOneFlow: bad_output = f"```json\n{bad_json}\n```" with ( - patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), + patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), patch( - "app.services.summarizer._call_pi", + "app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=bad_output, ), @@ -464,9 +470,9 @@ class TestSummarizeOneFlow: ): """失败时仍保存 raw_output.txt。""" with ( - patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), + patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), patch( - "app.services.summarizer._call_pi", + "app.services.summarizer.call_pi", new_callable=AsyncMock, return_value="Some output without JSON", ), @@ -483,8 +489,8 @@ class TestSummarizeOneFlow: ): """成功后清理 tmp 目录。""" with ( - patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), - patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output), + patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), + patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output), ): await summarize_one(db_session, sample_paper) @@ -498,7 +504,7 @@ class TestSummarizeOneFlow: """失败后也清理 tmp 目录。""" with ( patch( - "app.services.summarizer._download_pdf", + "app.services.summarizer.download_pdf", new_callable=AsyncMock, side_effect=PdfDownloadError("fail"), ), @@ -529,9 +535,11 @@ class TestBatchSummarize: @pytest.fixture def _patch_paths(self, tmp_path): with ( - patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"), - patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"), - patch("app.services.summarizer._DATA_DIR", tmp_path), + patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid), + patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"), + patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"), + patch("app.utils.PAPERS_DIR", tmp_path / "papers"), + patch("app.utils.TMP_DIR", tmp_path / "tmp"), ): yield @@ -561,8 +569,8 @@ class TestBatchSummarize: _TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False) with ( - patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), - patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output), + patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), + patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output), ): result = await summarize_batch( db_session, _session_factory=_TestSession @@ -612,8 +620,8 @@ class TestBatchSummarize: return mock_pi_output with ( - patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), - patch("app.services.summarizer._call_pi", side_effect=_mock_call_pi), + patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), + patch("app.services.summarizer.call_pi", side_effect=_mock_call_pi), ): result = await summarize_batch( db_session, _session_factory=_TestSession @@ -646,8 +654,8 @@ class TestBatchSummarize: _TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False) with ( - patch("app.services.summarizer._download_pdf", new_callable=AsyncMock), - patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output), + patch("app.services.summarizer.download_pdf", new_callable=AsyncMock), + patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output), ): await summarize_batch( db_session, _session_factory=_TestSession