diff --git a/app/database.py b/app/database.py
index 095e014..40e63cf 100644
--- a/app/database.py
+++ b/app/database.py
@@ -1,6 +1,6 @@
"""数据库引擎、会话工厂、初始化。"""
-from sqlalchemy import event, create_engine
+from sqlalchemy import event, create_engine, text
from sqlalchemy.orm import DeclarativeBase, sessionmaker
from app.config import settings
@@ -10,6 +10,27 @@ class Base(DeclarativeBase):
pass
+# ── FTS5 和索引 DDL(与 ORM 模型分开管理)───────────────────────────────
+
+FTS5_CREATE_SQL = """
+CREATE VIRTUAL TABLE IF NOT EXISTS papers_fts USING fts5(
+ title_en,
+ title_zh,
+ abstract,
+ authors,
+ tags,
+ summary_text,
+ tokenize='unicode61'
+);
+"""
+
+FTS5_TRIGGER_INDEX = """
+-- partial index for task_locks running
+CREATE UNIQUE INDEX IF NOT EXISTS uq_task_locks_running
+ON task_locks(task, lock_key) WHERE status = 'running';
+"""
+
+
def _make_engine():
"""创建 SQLite 引擎,启用 foreign_keys。"""
engine = create_engine(
@@ -39,3 +60,14 @@ def get_db():
yield db
finally:
db.close()
+
+
+def init_db(engine):
+ """创建所有 ORM 表 + FTS5 虚拟表。"""
+ from app.models import Base # noqa: F811 — 避免循环导入,延迟导入
+
+ Base.metadata.create_all(engine)
+ with engine.connect() as conn:
+ conn.execute(text(FTS5_CREATE_SQL))
+ conn.execute(text(FTS5_TRIGGER_INDEX))
+ conn.commit()
diff --git a/app/main.py b/app/main.py
index 0ab932b..33ceaaf 100644
--- a/app/main.py
+++ b/app/main.py
@@ -2,14 +2,13 @@
import logging
import os
+from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
-from starlette.staticfiles import StaticFiles as StarletteStaticFiles
from app.config import settings
-from app.database import engine
-from app.models import init_db
+from app.database import engine, init_db
from app.routes.admin import router as admin_router
from app.routes.compare import router as compare_router
from app.routes.pages import router as pages_router
@@ -24,11 +23,30 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ """应用生命周期管理:启动与关闭。"""
+ # ── startup ──
+ from app.services.scheduler import start_scheduler
+ from app.services.embedder import init_chroma
+
+ start_scheduler()
+ init_chroma()
+
+ yield
+
+ # ── shutdown ──
+ from app.services.scheduler import stop_scheduler
+
+ stop_scheduler()
+
+
def create_app() -> FastAPI:
app = FastAPI(
title="HF Daily Papers",
description="HuggingFace Daily Papers — 中文论文导览站",
version="0.1.0",
+ lifespan=lifespan,
)
# 确保数据目录存在
@@ -65,23 +83,6 @@ def create_app() -> FastAPI:
app.include_router(trends_router)
app.include_router(compare_router)
- # 调度器(Phase 4)
- @app.on_event("startup")
- async def _start_scheduler():
- from app.services.scheduler import start_scheduler
- start_scheduler()
-
- # Phase 5: 初始化 ChromaDB
- @app.on_event("startup")
- async def _init_chroma():
- from app.services.embedder import init_chroma
- init_chroma()
-
- @app.on_event("shutdown")
- async def _stop_scheduler():
- from app.services.scheduler import stop_scheduler
- stop_scheduler()
-
return app
diff --git a/app/models.py b/app/models.py
index 589c3f8..eb5006a 100644
--- a/app/models.py
+++ b/app/models.py
@@ -1,4 +1,4 @@
-"""SQLAlchemy ORM 模型 — papers, authors, tags, summaries, FTS5, logs, locks, user data。"""
+"""SQLAlchemy ORM 模型 — papers, authors, tags, summaries, user data, logs, locks。"""
from datetime import date, datetime
@@ -13,7 +13,6 @@ from sqlalchemy import (
String,
Text,
UniqueConstraint,
- text,
)
from sqlalchemy.orm import relationship
@@ -204,32 +203,3 @@ class DataDeleteJob(Base):
error = Column(Text)
started_at = Column(DateTime, nullable=False)
completed_at = Column(DateTime)
-
-
-# ── FTS5 索引初始化 SQL(普通虚拟表,由应用层维护)──────────────────────
-FTS5_CREATE_SQL = """
-CREATE VIRTUAL TABLE IF NOT EXISTS papers_fts USING fts5(
- title_en,
- title_zh,
- abstract,
- authors,
- tags,
- summary_text,
- tokenize='unicode61'
-);
-"""
-
-FTS5_TRIGGER_INDEX = """
--- partial index for task_locks running
-CREATE UNIQUE INDEX IF NOT EXISTS uq_task_locks_running
-ON task_locks(task, lock_key) WHERE status = 'running';
-"""
-
-
-def init_db(engine):
- """创建所有 ORM 表 + FTS5 虚拟表。"""
- Base.metadata.create_all(engine)
- with engine.connect() as conn:
- conn.execute(text(FTS5_CREATE_SQL))
- conn.execute(text(FTS5_TRIGGER_INDEX))
- conn.commit()
diff --git a/app/routes/admin.py b/app/routes/admin.py
index a1e20eb..bac16d2 100644
--- a/app/routes/admin.py
+++ b/app/routes/admin.py
@@ -6,7 +6,6 @@ from datetime import date, datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
-from fastapi.templating import Jinja2Templates
from pydantic import BaseModel, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -17,10 +16,10 @@ from app.models import CrawlLog, DataDeleteJob, TaskLock
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
from app.services.crawler import crawl_daily
from app.services.summarizer import summarize_batch, summarize_single
+from app.utils import release_lock, templates, today_str
router = APIRouter(prefix="/admin", tags=["admin"])
security = HTTPBearer()
-templates = Jinja2Templates(directory="app/templates")
async def verify_admin(
@@ -32,7 +31,7 @@ async def verify_admin(
return credentials.credentials
-# ── 请求模型 ──────────────────────────────────────────────────────────────
+# ── 请求模型 ──────────────────────────────────────────────────────────
class DeleteRequest(BaseModel):
@@ -49,7 +48,7 @@ class DeleteRequest(BaseModel):
return v
-# ── 抓取 ──────────────────────────────────────────────────────────────────
+# ── 抓取 ──────────────────────────────────────────────────────────────
@router.post("/crawl")
@@ -59,12 +58,7 @@ async def admin_crawl(
date: str | None = Query(None, description="YYYY-MM-DD,默认今天"),
):
"""手动抓取指定日期,默认今天。"""
- # 计算 target_date
- from zoneinfo import ZoneInfo
-
- tz = ZoneInfo(settings.APP_TIMEZONE)
- today = datetime.now(tz).strftime("%Y-%m-%d")
- target_date = date or today
+ target_date = date or today_str()
# TaskLock 防重入
now = datetime.now(timezone.utc)
@@ -88,10 +82,10 @@ async def admin_crawl(
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
finally:
- _release_lock(db, lock)
+ release_lock(db, lock)
-# ── 总结 ──────────────────────────────────────────────────────────────────
+# ── 总结 ──────────────────────────────────────────────────────────────
@router.post("/summarize")
@@ -119,7 +113,7 @@ async def admin_summarize_single(
return result
-# ── 清理 ──────────────────────────────────────────────────────────────────
+# ── 清理 ──────────────────────────────────────────────────────────────
@router.post("/cleanup")
@@ -155,7 +149,7 @@ async def admin_cleanup(
raise HTTPException(status_code=500, detail=str(exc))
-# ── 删除 ──────────────────────────────────────────────────────────────────
+# ── 删除 ──────────────────────────────────────────────────────────────
@router.post("/delete")
@@ -177,7 +171,7 @@ async def admin_delete(
return result
-# ── 日志 ──────────────────────────────────────────────────────────────────
+# ── 日志 ──────────────────────────────────────────────────────────────
@router.get("/logs")
@@ -189,7 +183,6 @@ async def admin_logs(
per_page: int = Query(20, ge=1, le=100),
):
"""查看任务日志(CrawlLog + DataDeleteJob)。"""
- # 查询 crawl_logs
crawl_logs = (
db.execute(
select(CrawlLog)
@@ -201,7 +194,6 @@ async def admin_logs(
.all()
)
- # 查询 delete_jobs
delete_jobs = (
db.execute(
select(DataDeleteJob)
@@ -223,16 +215,3 @@ async def admin_logs(
"per_page": per_page,
},
)
-
-
-# ── 工具函数 ──────────────────────────────────────────────────────────────
-
-
-def _release_lock(db: Session, lock: TaskLock) -> None:
- """释放 TaskLock。"""
- try:
- lock.status = "finished"
- lock.released_at = datetime.now(timezone.utc)
- db.commit()
- except Exception:
- db.rollback()
diff --git a/app/routes/compare.py b/app/routes/compare.py
index 8c418c3..dbca7f6 100644
--- a/app/routes/compare.py
+++ b/app/routes/compare.py
@@ -2,19 +2,14 @@
from __future__ import annotations
-import logging
-
from fastapi import APIRouter, Depends, HTTPException, Query, Request
-from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session, joinedload
from app.database import get_db
from app.models import Paper
-
-logger = logging.getLogger(__name__)
+from app.utils import templates
router = APIRouter()
-templates = Jinja2Templates(directory="app/templates")
@router.get("/compare")
diff --git a/app/routes/pages.py b/app/routes/pages.py
index a034433..b249046 100644
--- a/app/routes/pages.py
+++ b/app/routes/pages.py
@@ -3,34 +3,27 @@
from __future__ import annotations
import logging
-from datetime import date, datetime, timedelta
+from datetime import date, timedelta
from pathlib import Path
-from zoneinfo import ZoneInfo
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import RedirectResponse
-from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.database import get_db
from app.models import Paper
+from app.utils import templates, today_str
logger = logging.getLogger(__name__)
router = APIRouter()
-templates = Jinja2Templates(directory="app/templates")
-
-
-def _today() -> str:
- tz = ZoneInfo(settings.APP_TIMEZONE)
- return datetime.now(tz).strftime("%Y-%m-%d")
@router.get("/")
def index(request: Request):
"""重定向到 /day/{today}。"""
- return RedirectResponse(url=f"/day/{_today()}")
+ return RedirectResponse(url=f"/day/{today_str()}")
@router.get("/day/{date_str}")
@@ -43,7 +36,7 @@ def day_page(date_str: str, request: Request, db: Session = Depends(get_db)):
prev_day = (target - timedelta(days=1)).isoformat()
next_day = (target + timedelta(days=1)).isoformat()
- today_str = _today()
+ today = today_str()
papers = (
db.query(Paper)
@@ -74,7 +67,7 @@ def day_page(date_str: str, request: Request, db: Session = Depends(get_db)):
"current_date": date_str,
"prev_day": prev_day,
"next_day": next_day,
- "today": today_str,
+ "today": today,
"available_dates": available_dates,
"page_title": f"{date_str} 论文列表",
},
@@ -146,16 +139,9 @@ def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict
return []
try:
- from app.services.embedder import search_similar
-
- # 用论文的 arxiv_id 从 ChromaDB 查询
- col = None
- try:
- from app.services.embedder import get_collection
- col = get_collection()
- except Exception:
- return []
+ from app.services.embedder import get_collection
+ col = get_collection()
if col is None:
return []
diff --git a/app/routes/search.py b/app/routes/search.py
index cc23ece..0ec43be 100644
--- a/app/routes/search.py
+++ b/app/routes/search.py
@@ -2,14 +2,11 @@
from __future__ import annotations
-import math
-from datetime import date, datetime, timedelta, timezone
-from zoneinfo import ZoneInfo
+from datetime import date, timedelta
from xml.sax.saxutils import escape
from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import Response
-from fastapi.templating import Jinja2Templates
from sqlalchemy import text
from sqlalchemy.orm import Session, joinedload
@@ -17,9 +14,10 @@ from app.config import settings
from app.database import get_db
from app.models import Paper, PaperTag, UserReadingStatus
from app.services.searcher import get_all_tags, search_papers
+from app.services.user_data import query_reading_list
+from app.utils import templates, today_str
router = APIRouter()
-templates = Jinja2Templates(directory="app/templates")
# ── 搜索页 ────────────────────────────────────────────────────────────
@@ -56,7 +54,7 @@ def search_page(
"total_pages": result["total_pages"],
"all_tags": all_tags,
"page_title": f"搜索: {q}" if q else "搜索",
- "today": _today_str(),
+ "today": today_str(),
},
)
@@ -114,7 +112,7 @@ def reading_list_page(
db: Session = Depends(get_db),
):
"""阅读列表页面。"""
- papers = _query_reading_list(db, filter, tag or None)
+ papers = query_reading_list(db, filter, tag or None)
all_tags = get_all_tags(db)
return templates.TemplateResponse(
@@ -126,54 +124,11 @@ def reading_list_page(
"current_tag": tag,
"all_tags": all_tags,
"page_title": "阅读列表",
- "today": _today_str(),
+ "today": today_str(),
},
)
-def _query_reading_list(
- db: Session,
- filter_type: str,
- tag: str | None,
-) -> list[Paper]:
- """根据筛选条件查询阅读列表。"""
- from sqlalchemy import or_
-
- # 基础:有任意用户数据的论文
- base = db.query(Paper).filter(
- or_(
- Paper.bookmark.has(),
- Paper.reading_status.has(),
- Paper.note.has(),
- )
- )
-
- # 应用筛选
- if filter_type == "has_note":
- base = base.filter(Paper.note.has())
- elif filter_type in ("unread", "skimmed", "read_summary", "read_full"):
- base = base.filter(
- Paper.reading_status.has(UserReadingStatus.status == filter_type)
- )
-
- # 应用标签
- if tag:
- base = base.filter(Paper.tags.any(PaperTag.tag == tag))
-
- return (
- base.options(
- joinedload(Paper.authors),
- joinedload(Paper.tags),
- joinedload(Paper.summary_status),
- joinedload(Paper.bookmark),
- joinedload(Paper.reading_status),
- joinedload(Paper.note),
- )
- .order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
- .all()
- )
-
-
# ── RSS Feed ──────────────────────────────────────────────────────────
@@ -216,7 +171,7 @@ def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> st
lines.append(f"
{escape(channel_title)}")
lines.append(f" {escape(base_url)}")
lines.append(" HuggingFace Daily Papers — 中文论文导览站")
- lines.append(f" zh-CN")
+ lines.append(" zh-CN")
for paper in papers:
title_text = paper.title_zh or paper.title_en
@@ -245,12 +200,3 @@ def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> st
lines.append(" ")
lines.append("")
return "\n".join(lines)
-
-
-# ── 工具函数 ──────────────────────────────────────────────────────────
-
-
-def _today_str() -> str:
- """当前日期字符串(按 APP_TIMEZONE)。"""
- tz = ZoneInfo(settings.APP_TIMEZONE)
- return datetime.now(tz).strftime("%Y-%m-%d")
diff --git a/app/routes/trends.py b/app/routes/trends.py
index 80e6ce5..84e40f5 100644
--- a/app/routes/trends.py
+++ b/app/routes/trends.py
@@ -2,34 +2,27 @@
from __future__ import annotations
-import logging
-from datetime import date, timedelta
-
from fastapi import APIRouter, Depends, Request
-from fastapi.templating import Jinja2Templates
-from sqlalchemy import func, text
from sqlalchemy.orm import Session
-from app.config import settings
from app.database import get_db
-
-logger = logging.getLogger(__name__)
+from app.services.trends import get_trends_data
+from app.utils import templates, today_str
router = APIRouter()
-templates = Jinja2Templates(directory="app/templates")
@router.get("/trends")
def trends_page(request: Request, db: Session = Depends(get_db)):
"""趋势看板页面。"""
- stats = _get_trends_data(db)
+ stats = get_trends_data(db)
return templates.TemplateResponse(
request,
"trends.html",
{
"page_title": "趋势看板",
"stats": stats,
- "today": _today_str(),
+ "today": today_str(),
},
)
@@ -37,84 +30,4 @@ def trends_page(request: Request, db: Session = Depends(get_db)):
@router.get("/api/stats/trends")
def trends_api(db: Session = Depends(get_db)):
"""趋势数据 JSON API。"""
- return _get_trends_data(db)
-
-
-def _get_trends_data(db: Session) -> dict:
- """从 DB 聚合趋势数据。"""
- thirty_days_ago = (date.today() - timedelta(days=30)).isoformat()
-
- # 1. 按日论文数量(近 30 天)
- daily_rows = db.execute(text("""
- SELECT paper_date, COUNT(*) as cnt
- FROM papers
- WHERE paper_date >= :start_date
- GROUP BY paper_date
- ORDER BY paper_date ASC
- """), {"start_date": thirty_days_ago}).fetchall()
- daily_counts = [
- {"date": str(row[0]), "count": row[1]}
- for row in daily_rows
- ]
-
- # 2. 热门标签 Top 20
- tag_rows = db.execute(text("""
- SELECT tag, COUNT(*) as cnt
- FROM paper_tags
- GROUP BY tag
- ORDER BY cnt DESC
- LIMIT 20
- """)).fetchall()
- top_tags = [
- {"tag": row[0], "count": row[1]}
- for row in tag_rows
- ]
-
- # 3. Upvotes 分布
- upvote_rows = db.execute(text("""
- SELECT
- CASE
- WHEN upvotes >= 100 THEN '100+'
- WHEN upvotes >= 50 THEN '50-99'
- WHEN upvotes >= 20 THEN '20-49'
- WHEN upvotes >= 10 THEN '10-19'
- WHEN upvotes >= 5 THEN '5-9'
- ELSE '0-4'
- END as bucket,
- COUNT(*) as cnt
- FROM papers
- GROUP BY bucket
- ORDER BY MIN(upvotes) DESC
- """)).fetchall()
- upvotes_dist = [
- {"range": row[0], "count": row[1]}
- for row in upvote_rows
- ]
-
- # 4. 总结完成率
- summary_rows = db.execute(text("""
- SELECT
- COALESCE(ss.status, 'none') as status,
- COUNT(*) as cnt
- FROM papers p
- LEFT JOIN summary_status ss ON ss.paper_id = p.id
- GROUP BY status
- """)).fetchall()
- summary_completion = [
- {"status": row[0], "count": row[1]}
- for row in summary_rows
- ]
-
- return {
- "daily_counts": daily_counts,
- "top_tags": top_tags,
- "upvotes_dist": upvotes_dist,
- "summary_completion": summary_completion,
- }
-
-
-def _today_str() -> str:
- from datetime import datetime
- from zoneinfo import ZoneInfo
- tz = ZoneInfo(settings.APP_TIMEZONE)
- return datetime.now(tz).strftime("%Y-%m-%d")
+ return get_trends_data(db)
diff --git a/app/services/cleaner.py b/app/services/cleaner.py
index f5f4c9b..5916a43 100644
--- a/app/services/cleaner.py
+++ b/app/services/cleaner.py
@@ -16,13 +16,10 @@ from app.models import (
Paper,
TaskLock,
)
+from app.utils import PAPERS_DIR, TMP_DIR
logger = logging.getLogger(__name__)
-_DATA_DIR = Path("data")
-_TMP_DIR = _DATA_DIR / "tmp"
-_PAPERS_DIR = _DATA_DIR / "papers"
-
# 临时文件最大保留时间(小时)
_MAX_TMP_AGE_HOURS = 24
@@ -39,7 +36,7 @@ def cleanup_tmp(max_age_hours: int = _MAX_TMP_AGE_HOURS) -> dict:
Returns:
清理统计 {"scanned": int, "removed": int, "errors": list[str]}
"""
- if not _TMP_DIR.exists():
+ if not TMP_DIR.exists():
return {"scanned": 0, "removed": 0, "errors": []}
now = datetime.now(timezone.utc)
@@ -48,7 +45,7 @@ def cleanup_tmp(max_age_hours: int = _MAX_TMP_AGE_HOURS) -> dict:
removed = 0
errors: list[str] = []
- for entry in _TMP_DIR.iterdir():
+ for entry in TMP_DIR.iterdir():
if not entry.is_dir():
continue
scanned += 1
@@ -147,13 +144,13 @@ async def delete_papers_by_date_range(
logger.warning("Failed to delete %s from ChromaDB", arxiv_id, exc_info=True)
# 2. 删除本地文件 data/papers/{arxiv_id}/
- paper_dir = _PAPERS_DIR / arxiv_id
+ paper_dir = PAPERS_DIR / arxiv_id
if paper_dir.exists():
shutil.rmtree(paper_dir)
logger.debug("Removed paper dir: %s", paper_dir)
# 3. 删除临时文件 data/tmp/{arxiv_id}/
- tmp_dir = _TMP_DIR / arxiv_id
+ tmp_dir = TMP_DIR / arxiv_id
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
logger.debug("Removed tmp dir: %s", tmp_dir)
diff --git a/app/services/crawler.py b/app/services/crawler.py
index f0dd124..371c90e 100644
--- a/app/services/crawler.py
+++ b/app/services/crawler.py
@@ -16,6 +16,7 @@ from app.models import (
PaperTag,
SummaryStatus,
)
+from app.utils import make_http_client
logger = logging.getLogger(__name__)
@@ -34,15 +35,7 @@ async def fetch_daily(target_date: str, top_n: int | None = None) -> list[dict]:
url = f"{settings.HF_API_BASE}/daily_papers"
params = {"date": target_date}
- transport = None
- if settings.http_proxy:
- transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
-
- async with httpx.AsyncClient(
- timeout=settings.HTTP_TIMEOUT_SECONDS,
- headers={"User-Agent": settings.HTTP_USER_AGENT},
- transport=transport,
- ) as client:
+ async with make_http_client() as client:
for attempt in range(1, settings.HTTP_MAX_RETRIES + 1):
try:
logger.info("Fetching HF Daily Papers: date=%s attempt=%d", target_date, attempt)
diff --git a/app/services/embedder.py b/app/services/embedder.py
index f153251..015a288 100644
--- a/app/services/embedder.py
+++ b/app/services/embedder.py
@@ -5,8 +5,6 @@ from __future__ import annotations
import logging
from pathlib import Path
-import httpx
-from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from app.config import settings
@@ -14,66 +12,82 @@ from app.models import Paper
logger = logging.getLogger(__name__)
-# ── 单例客户端和 collection ─────────────────────────────────────────────
-_client = None
-_collection = None
+
+# ── ChromaDB 管理器(替代全局可变状态)──────────────────────────────────
-def _chroma_dir() -> Path:
- return Path(settings.CHROMA_DIR)
+class ChromaManager:
+ """封装 ChromaDB 客户端和 collection 的生命周期。"""
+
+ def __init__(self) -> None:
+ self._client = None
+ self._collection = None
+
+ def init(self) -> None:
+ """CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
+ if not settings.CHROMA_ENABLED:
+ logger.debug("ChromaDB disabled, skip init")
+ return
+
+ if self._client is not None:
+ return
+
+ try:
+ import chromadb
+
+ chroma_path = Path(settings.CHROMA_DIR)
+ chroma_path.mkdir(parents=True, exist_ok=True)
+
+ self._client = chromadb.PersistentClient(path=str(chroma_path))
+ self._collection = self._get_or_create_collection()
+ logger.info("ChromaDB initialized at %s", chroma_path)
+ except Exception:
+ logger.exception("Failed to initialize ChromaDB")
+ self._client = None
+ self._collection = None
+
+ def _get_or_create_collection(self):
+ """获取或创建 papers_embeddings collection。"""
+ try:
+ col = self._client.get_collection("papers_embeddings")
+ logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count())
+ return col
+ except Exception:
+ pass
+
+ col = self._client.create_collection(
+ name="papers_embeddings",
+ metadata={"hnsw:space": "cosine"},
+ )
+ logger.info("ChromaDB collection 'papers_embeddings' created")
+ return col
+
+ def get_collection(self):
+ """返回当前 collection,未初始化则自动初始化。"""
+ if not settings.CHROMA_ENABLED:
+ return None
+ if self._collection is None:
+ self.init()
+ return self._collection
+
+ def reset(self) -> None:
+ """重置状态(供测试使用)。"""
+ self._client = None
+ self._collection = None
+
+
+# 模块级单例
+_chroma = ChromaManager()
def init_chroma() -> None:
- """CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
- global _client, _collection
- if not settings.CHROMA_ENABLED:
- logger.debug("ChromaDB disabled, skip init")
- return
-
- if _client is not None:
- return
-
- try:
- import chromadb
-
- chroma_path = _chroma_dir()
- chroma_path.mkdir(parents=True, exist_ok=True)
-
- _client = chromadb.PersistentClient(path=str(chroma_path))
- _collection = _get_or_create_collection()
- logger.info("ChromaDB initialized at %s", chroma_path)
- except Exception:
- logger.exception("Failed to initialize ChromaDB")
- _client = None
- _collection = None
-
-
-def _get_or_create_collection():
- """获取或创建 papers_embeddings collection,维度不匹配时记录日志并跳过。"""
- import chromadb
-
- try:
- col = _client.get_collection("papers_embeddings")
- logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count())
- return col
- except Exception:
- pass
-
- col = _client.create_collection(
- name="papers_embeddings",
- metadata={"hnsw:space": "cosine"},
- )
- logger.info("ChromaDB collection 'papers_embeddings' created")
- return col
+ """初始化 ChromaDB(供 lifespan 调用)。"""
+ _chroma.init()
def get_collection():
"""返回当前 collection,未初始化则返回 None。"""
- if not settings.CHROMA_ENABLED:
- return None
- if _collection is None:
- init_chroma()
- return _collection
+ return _chroma.get_collection()
# ── Embedding API 调用 ──────────────────────────────────────────────────
@@ -90,6 +104,8 @@ def _get_embedding(text: str) -> list[float] | None:
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
return None
+ from app.utils import make_http_client
+
url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings"
headers = {"Content-Type": "application/json"}
if settings.EMBED_API_KEY:
@@ -101,7 +117,7 @@ def _get_embedding(text: str) -> list[float] | None:
}
try:
- with httpx.Client(timeout=settings.HTTP_TIMEOUT_SECONDS) as client:
+ with make_http_client(sync=True) as client:
resp = client.post(url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
diff --git a/app/services/image_extractor.py b/app/services/image_extractor.py
new file mode 100644
index 0000000..11028c6
--- /dev/null
+++ b/app/services/image_extractor.py
@@ -0,0 +1,83 @@
+"""LaTeX 图片提取 — 从 arXiv 源码中扫描 \\includegraphics 并提取图片文件。"""
+
+from __future__ import annotations
+
+import logging
+import re
+import shutil
+from pathlib import Path
+
+from app.services.pdf_downloader import download_source_zip, paper_dir, tmp_dir
+
+logger = logging.getLogger(__name__)
+
+_INCLUDEGRAPHICS_RE = re.compile(
+ r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
+)
+_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
+
+
+async def extract_images_from_source(arxiv_id: str) -> int:
+ """从 LaTeX 源码中提取图片文件。
+
+ 流程:
+ 1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
+ 2. 扫描 .tex 文件中的 \\includegraphics
+ 3. 复制图片到 data/papers/{arxiv_id}/images/
+ 4. 清理源码临时文件
+
+ Returns:
+ 提取的图片数量
+ """
+ tmp_source = tmp_dir(arxiv_id) / "source"
+ images_dest = paper_dir(arxiv_id) / "images"
+
+ try:
+ # 下载源码 zip(如果还没下载)
+ if not tmp_source.exists():
+ source_url = f"https://arxiv.org/e-print/{arxiv_id}"
+ await download_source_zip(arxiv_id, source_url, tmp_source)
+
+ if not tmp_source.exists():
+ return 0
+
+ # 扫描 .tex 文件,收集图片路径
+ image_paths: set[str] = set()
+ for tex_file in tmp_source.rglob("*.tex"):
+ try:
+ content = tex_file.read_text(encoding="utf-8", errors="replace")
+ for match in _INCLUDEGRAPHICS_RE.finditer(content):
+ img_path = match.group(1).strip()
+ image_paths.add(img_path)
+ except Exception:
+ continue
+
+ if not image_paths:
+ return 0
+
+ # 查找并复制图片
+ images_dest.mkdir(parents=True, exist_ok=True)
+ copied = 0
+ for img_rel in image_paths:
+ # 尝试在源码目录中找到文件
+ for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
+ candidate = tmp_source / (img_rel + ext)
+ if candidate.is_file():
+ dest_name = candidate.name
+ # 避免文件名冲突
+ dest = images_dest / dest_name
+ if dest.exists():
+ stem = dest.stem
+ suffix = dest.suffix
+ dest = images_dest / f"{stem}_{copied}{suffix}"
+ shutil.copy2(candidate, dest)
+ copied += 1
+ break
+
+ if copied > 0:
+ logger.info("Extracted %d images from source for %s", copied, arxiv_id)
+ return copied
+
+ except Exception:
+ logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
+ return 0
diff --git a/app/services/pdf_downloader.py b/app/services/pdf_downloader.py
new file mode 100644
index 0000000..cfb0d5a
--- /dev/null
+++ b/app/services/pdf_downloader.py
@@ -0,0 +1,105 @@
+"""PDF 下载与源码下载 — 从 arXiv 下载论文 PDF 和 LaTeX 源码包。"""
+
+from __future__ import annotations
+
+import logging
+import shutil
+import zipfile
+from pathlib import Path
+
+from app.utils import PAPERS_DIR, TMP_DIR, make_http_client
+
+logger = logging.getLogger(__name__)
+
+
+# ── 自定义异常 ──────────────────────────────────────────────────────────
+
+
+class PdfDownloadError(Exception):
+ pass
+
+
+# ── 路径工具 ────────────────────────────────────────────────────────────
+
+
+def paper_dir(arxiv_id: str) -> Path:
+ return PAPERS_DIR / arxiv_id
+
+
+def tmp_dir(arxiv_id: str) -> Path:
+ return TMP_DIR / arxiv_id
+
+
+# ── PDF 下载 ────────────────────────────────────────────────────────────
+
+
+async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
+ """下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
+ if not pdf_url:
+ raise PdfDownloadError(f"no pdf_url for {arxiv_id}")
+
+ dest_dir = tmp_dir(arxiv_id)
+ dest_dir.mkdir(parents=True, exist_ok=True)
+ dest = dest_dir / "paper.pdf"
+
+ try:
+ async with make_http_client(follow_redirects=True) as client:
+ resp = await client.get(pdf_url)
+ resp.raise_for_status()
+ dest.write_bytes(resp.content)
+ except Exception as exc:
+ raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
+
+ logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
+ return dest
+
+
+# ── 源码下载 ────────────────────────────────────────────────────────────
+
+
+async def download_source_zip(arxiv_id: str, source_url: str, dest_dir: Path) -> None:
+ """下载 arXiv 源码并解压。"""
+ dest_dir.mkdir(parents=True, exist_ok=True)
+ zip_path = tmp_dir(arxiv_id) / "source.zip"
+
+ try:
+ async with make_http_client(follow_redirects=True) as client:
+ resp = await client.get(source_url)
+ resp.raise_for_status()
+ zip_path.write_bytes(resp.content)
+ except Exception as exc:
+ logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
+ return
+
+ try:
+ with zipfile.ZipFile(zip_path, "r") as zf:
+ zf.extractall(dest_dir)
+ logger.debug("Extracted source for %s", arxiv_id)
+ except zipfile.BadZipFile:
+ # 可能是 tar.gz
+ import tarfile
+ try:
+ with tarfile.open(zip_path, "r:*") as tf:
+ tf.extractall(dest_dir, filter="data")
+ logger.debug("Extracted source (tar) for %s", arxiv_id)
+ except Exception:
+ logger.warning("Cannot extract source for %s", arxiv_id)
+ except Exception:
+ logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
+ finally:
+ if zip_path.exists():
+ zip_path.unlink()
+
+
+# ── 临时文件清理 ────────────────────────────────────────────────────────
+
+
+def cleanup_tmp(arxiv_id: str) -> None:
+ """清理 data/tmp/{arxiv_id}/ 目录。"""
+ td = tmp_dir(arxiv_id)
+ if td.exists():
+ try:
+ shutil.rmtree(td)
+ logger.debug("Cleaned tmp: %s", arxiv_id)
+ except Exception:
+ logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
diff --git a/app/services/pi_client.py b/app/services/pi_client.py
new file mode 100644
index 0000000..b7aae2b
--- /dev/null
+++ b/app/services/pi_client.py
@@ -0,0 +1,160 @@
+"""pi CLI 调用与 JSON 提取 — 调用 pi 生成总结,从输出中提取结构化 JSON。"""
+
+from __future__ import annotations
+
+import asyncio
+import json
+import logging
+import re
+from pathlib import Path
+
+from app.config import settings
+
+logger = logging.getLogger(__name__)
+
+
+# ── 自定义异常 ──────────────────────────────────────────────────────────
+
+
+class PiTimeoutError(Exception):
+ pass
+
+
+class PiProcessError(Exception):
+ def __init__(self, returncode: int, stderr: str):
+ self.returncode = returncode
+ self.stderr = stderr
+ super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
+
+
+class JsonNotFoundError(Exception):
+ pass
+
+
+# ── meta.json ───────────────────────────────────────────────────────────
+
+
+def write_meta_json(paper) -> Path:
+ """写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
+ from app.services.pdf_downloader import paper_dir
+
+ d = paper_dir(paper.arxiv_id)
+ d.mkdir(parents=True, exist_ok=True)
+ meta_path = d / "meta.json"
+
+ authors = [a.name for a in paper.authors]
+ tags = [t.tag for t in paper.tags]
+ meta = {
+ "arxiv_id": paper.arxiv_id,
+ "title_en": paper.title_en,
+ "abstract": paper.abstract or "",
+ "published_at": paper.published_at.isoformat() if paper.published_at else None,
+ "authors": authors,
+ "tags": tags,
+ "upvotes": paper.upvotes,
+ }
+ meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
+ return meta_path
+
+
+# ── pi CLI 调用 ────────────────────────────────────────────────────────
+
+
+async def call_pi(meta_path: Path, pdf_path: Path) -> str:
+ """调用 pi CLI 非交互模式,返回 stdout 文本。"""
+ arxiv_id = meta_path.parent.name
+ cmd = [
+ settings.PI_BIN,
+ "-p",
+ "--no-tools",
+ "--skill",
+ settings.SUMMARY_SKILL,
+ "请深度解读以下论文,并按指定 JSON schema 输出:",
+ f"@{meta_path}",
+ f"@{pdf_path}",
+ ]
+ logger.info("Calling pi for %s", arxiv_id)
+
+ proc = await asyncio.create_subprocess_exec(
+ *cmd,
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ )
+ try:
+ stdout, stderr = await asyncio.wait_for(
+ proc.communicate(),
+ timeout=settings.SUMMARY_TIMEOUT_SECONDS,
+ )
+ except asyncio.TimeoutError:
+ proc.kill()
+ await proc.wait()
+ raise PiTimeoutError(
+ f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s"
+ )
+
+ if proc.returncode != 0:
+ raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
+
+ return stdout.decode("utf-8", errors="replace")
+
+
+# ── JSON 提取 ──────────────────────────────────────────────────────────
+
+
+def extract_json(raw_output: str) -> dict:
+ """从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
+ # 策略 1:整体直接解析
+ stripped = raw_output.strip()
+ try:
+ result = json.loads(stripped)
+ if isinstance(result, dict) and "title_zh" in result:
+ return result
+ except json.JSONDecodeError:
+ pass
+
+ # 策略 2:提取 ```json ... ``` 代码块
+ fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
+ for match in fence_pattern.finditer(raw_output):
+ try:
+ result = json.loads(match.group(1).strip())
+ if isinstance(result, dict) and "title_zh" in result:
+ return result
+ except json.JSONDecodeError:
+ continue
+
+ # 策略 3:匹配包含 title_zh 的最大 {...} 块
+ brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
+ for match in brace_pattern.finditer(raw_output):
+ try:
+ return json.loads(match.group(0))
+ except json.JSONDecodeError:
+ continue
+
+ # 更宽松:找到最大的 { ... } 平衡块
+ best = None
+ best_len = 0
+ for i, ch in enumerate(raw_output):
+ if ch != "{":
+ continue
+ depth = 0
+ for j in range(i, len(raw_output)):
+ if raw_output[j] == "{":
+ depth += 1
+ elif raw_output[j] == "}":
+ depth -= 1
+ if depth == 0:
+ candidate = raw_output[i : j + 1]
+ if len(candidate) > best_len:
+ try:
+ parsed = json.loads(candidate)
+ if isinstance(parsed, dict):
+ best = parsed
+ best_len = len(candidate)
+ except json.JSONDecodeError:
+ pass
+ break
+
+ if best is not None:
+ return best
+
+ raise JsonNotFoundError("no JSON object found in pi output")
diff --git a/app/services/summarizer.py b/app/services/summarizer.py
index 523eb18..b3450d8 100644
--- a/app/services/summarizer.py
+++ b/app/services/summarizer.py
@@ -1,18 +1,15 @@
-"""AI 总结服务 — 调用 pi CLI 生成论文中文结构化总结。"""
+"""AI 总结编排服务 — 协调 PDF 下载、pi CLI 调用、JSON 校验、DB 写入、语义索引。"""
from __future__ import annotations
-import asyncio
import json
import logging
-import re
import shutil
from datetime import datetime, timezone
from pathlib import Path
-import httpx
from pydantic import ValidationError
-from sqlalchemy import select, text
+from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from app.config import settings
@@ -25,216 +22,31 @@ from app.models import (
SummaryStatus,
TaskLock,
)
+from app.services.image_extractor import extract_images_from_source
+from app.services.pdf_downloader import (
+ PdfDownloadError,
+ cleanup_tmp,
+ download_pdf,
+ paper_dir,
+)
+from app.services.pi_client import (
+ JsonNotFoundError,
+ PiProcessError,
+ PiTimeoutError,
+ call_pi,
+ extract_json,
+ write_meta_json,
+)
from app.services.schemas import (
SummarySchema,
assess_quality,
classify_validation_error,
flatten_for_db,
)
+from app.utils import PAPERS_DIR, release_lock
logger = logging.getLogger(__name__)
-# ── 自定义异常 ──────────────────────────────────────────────────────────
-
-
-class PdfDownloadError(Exception):
- pass
-
-
-class PiTimeoutError(Exception):
- pass
-
-
-class PiProcessError(Exception):
- def __init__(self, returncode: int, stderr: str):
- self.returncode = returncode
- self.stderr = stderr
- super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
-
-
-class JsonNotFoundError(Exception):
- pass
-
-
-# ── 路径工具 ────────────────────────────────────────────────────────────
-
-_DATA_DIR = Path("data")
-_PAPERS_DIR = _DATA_DIR / "papers"
-_TMP_DIR = _DATA_DIR / "tmp"
-
-
-def _paper_dir(arxiv_id: str) -> Path:
- return _PAPERS_DIR / arxiv_id
-
-
-def _tmp_dir(arxiv_id: str) -> Path:
- return _TMP_DIR / arxiv_id
-
-
-# ── PDF 下载 ────────────────────────────────────────────────────────────
-
-
-async def _download_pdf(arxiv_id: str, pdf_url: str) -> Path:
- """下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
- if not pdf_url:
- raise PdfDownloadError(f"no pdf_url for {arxiv_id}")
-
- tmp = _tmp_dir(arxiv_id)
- tmp.mkdir(parents=True, exist_ok=True)
- dest = tmp / "paper.pdf"
-
- transport = None
- if settings.http_proxy:
- transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
-
- try:
- async with httpx.AsyncClient(
- timeout=settings.HTTP_TIMEOUT_SECONDS,
- headers={"User-Agent": settings.HTTP_USER_AGENT},
- transport=transport,
- follow_redirects=True,
- ) as client:
- resp = await client.get(pdf_url)
- resp.raise_for_status()
- dest.write_bytes(resp.content)
- except Exception as exc:
- raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
-
- logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
- return dest
-
-
-# ── meta.json ───────────────────────────────────────────────────────────
-
-
-def _write_meta_json(paper: Paper) -> Path:
- """写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
- d = _paper_dir(paper.arxiv_id)
- d.mkdir(parents=True, exist_ok=True)
- meta_path = d / "meta.json"
-
- authors = [a.name for a in paper.authors]
- tags = [t.tag for t in paper.tags]
- meta = {
- "arxiv_id": paper.arxiv_id,
- "title_en": paper.title_en,
- "abstract": paper.abstract or "",
- "published_at": paper.published_at.isoformat() if paper.published_at else None,
- "authors": authors,
- "tags": tags,
- "upvotes": paper.upvotes,
- }
- meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
- return meta_path
-
-
-# ── pi CLI 调用 ────────────────────────────────────────────────────────
-
-
-async def _call_pi(meta_path: Path, pdf_path: Path) -> str:
- """调用 pi CLI 非交互模式,返回 stdout 文本。"""
- cmd = [
- settings.PI_BIN,
- "-p",
- "--no-tools",
- "--skill",
- settings.SUMMARY_SKILL,
- "请深度解读以下论文,并按指定 JSON schema 输出:",
- f"@{meta_path}",
- f"@{pdf_path}",
- ]
- logger.info("Calling pi: %s %s", paper_id_from_path(meta_path), " ".join(cmd[:4]))
-
- proc = await asyncio.create_subprocess_exec(
- *cmd,
- stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.PIPE,
- )
- try:
- stdout, stderr = await asyncio.wait_for(
- proc.communicate(),
- timeout=settings.SUMMARY_TIMEOUT_SECONDS,
- )
- except asyncio.TimeoutError:
- proc.kill()
- await proc.wait()
- raise PiTimeoutError(
- f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s"
- )
-
- if proc.returncode != 0:
- raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
-
- return stdout.decode("utf-8", errors="replace")
-
-
-def paper_id_from_path(meta_path: Path) -> str:
- """从 meta.json 路径反推 arxiv_id。"""
- return meta_path.parent.name
-
-
-# ── JSON 提取 ──────────────────────────────────────────────────────────
-
-
-def _extract_json(raw_output: str) -> dict:
- """从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
- # 策略 1:整体直接解析
- stripped = raw_output.strip()
- try:
- result = json.loads(stripped)
- if isinstance(result, dict) and "title_zh" in result:
- return result
- except json.JSONDecodeError:
- pass
-
- # 策略 2:提取 ```json ... ``` 代码块
- fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
- for match in fence_pattern.finditer(raw_output):
- try:
- result = json.loads(match.group(1).strip())
- if isinstance(result, dict) and "title_zh" in result:
- return result
- except json.JSONDecodeError:
- continue
-
- # 策略 3:匹配包含 title_zh 的最大 {...} 块
- brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
- # 先尝试一层嵌套;如果没命中再用更宽松的策略
- for match in brace_pattern.finditer(raw_output):
- try:
- return json.loads(match.group(0))
- except json.JSONDecodeError:
- continue
-
- # 更宽松:找到最大的 { ... } 平衡块
- best = None
- best_len = 0
- for i, ch in enumerate(raw_output):
- if ch != "{":
- continue
- depth = 0
- for j in range(i, len(raw_output)):
- if raw_output[j] == "{":
- depth += 1
- elif raw_output[j] == "}":
- depth -= 1
- if depth == 0:
- candidate = raw_output[i : j + 1]
- if len(candidate) > best_len:
- try:
- parsed = json.loads(candidate)
- if isinstance(parsed, dict):
- best = parsed
- best_len = len(candidate)
- except json.JSONDecodeError:
- pass
- break
-
- if best is not None:
- return best
-
- raise JsonNotFoundError("no JSON object found in pi output")
-
# ── 错误分类 ────────────────────────────────────────────────────────────
@@ -284,6 +96,8 @@ def _update_summary_in_db(
raw_output: str,
) -> None:
"""将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。"""
+ from sqlalchemy import text
+
now = datetime.now(timezone.utc)
# 1. paper_summaries:upsert
@@ -298,9 +112,9 @@ def _update_summary_in_db(
# 2. papers 表
paper.title_zh = schema.title_zh
paper.summary_quality = quality
- paper_dir = _paper_dir(paper.arxiv_id)
- paper.summary_path = str(paper_dir / "summary.json")
- paper.raw_output_path = str(paper_dir / "raw_output.txt")
+ p_dir = paper_dir(paper.arxiv_id)
+ paper.summary_path = str(p_dir / "summary.json")
+ paper.raw_output_path = str(p_dir / "raw_output.txt")
# 3. AI 标签
existing_tag_names = {t.tag for t in paper.tags}
@@ -332,7 +146,7 @@ def _update_summary_in_db(
def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
"""保存 summary.json 和 raw_output.txt。"""
- d = _paper_dir(arxiv_id)
+ d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
(d / "summary.json").write_text(
schema.model_dump_json(ensure_ascii=False, indent=2),
@@ -343,143 +157,11 @@ def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None:
"""仅保存 raw_output.txt(失败时)。"""
- d = _paper_dir(arxiv_id)
+ d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
-def _cleanup_tmp(arxiv_id: str) -> None:
- """清理 data/tmp/{arxiv_id}/ 目录。"""
- tmp = _tmp_dir(arxiv_id)
- if tmp.exists():
- try:
- shutil.rmtree(tmp)
- logger.debug("Cleaned tmp: %s", arxiv_id)
- except Exception:
- logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
-
-
-# ── LaTeX 图片提取(Phase 5)───────────────────────────────────────────
-
-_INCLUDEGRAPHICS_RE = re.compile(
- r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
-)
-_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
-
-
-async def _extract_images_from_source(arxiv_id: str, tmp_source: Path | None = None) -> int:
- """从 LaTeX 源码中提取图片文件。
-
- 流程:
- 1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
- 2. 扫描 .tex 文件中的 \\includegraphics
- 3. 复制图片到 data/papers/{arxiv_id}/images/
- 4. 清理源码临时文件
-
- Returns:
- 提取的图片数量
- """
- tmp_source = _tmp_dir(arxiv_id) / "source"
- images_dest = _paper_dir(arxiv_id) / "images"
-
- try:
- # 下载源码 zip(如果还没下载)
- if not tmp_source.exists():
- source_url = f"https://arxiv.org/e-print/{arxiv_id}"
- await _download_source_zip(arxiv_id, source_url, tmp_source)
-
- if not tmp_source.exists():
- return 0
-
- # 扫描 .tex 文件,收集图片路径
- image_paths: set[str] = set()
- for tex_file in tmp_source.rglob("*.tex"):
- try:
- content = tex_file.read_text(encoding="utf-8", errors="replace")
- for match in _INCLUDEGRAPHICS_RE.finditer(content):
- img_path = match.group(1).strip()
- image_paths.add(img_path)
- except Exception:
- continue
-
- if not image_paths:
- return 0
-
- # 查找并复制图片
- images_dest.mkdir(parents=True, exist_ok=True)
- copied = 0
- for img_rel in image_paths:
- # 尝试在源码目录中找到文件
- for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
- candidate = tmp_source / (img_rel + ext)
- if candidate.is_file():
- dest_name = candidate.name
- # 避免文件名冲突
- dest = images_dest / dest_name
- if dest.exists():
- stem = dest.stem
- suffix = dest.suffix
- dest = images_dest / f"{stem}_{copied}{suffix}"
- shutil.copy2(candidate, dest)
- copied += 1
- break
-
- if copied > 0:
- logger.info("Extracted %d images from source for %s", copied, arxiv_id)
- return copied
-
- except Exception:
- logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
- return 0
-
-
-async def _download_source_zip(
- arxiv_id: str, source_url: str, dest_dir: Path
-) -> None:
- """下载 arXiv 源码并解压。"""
- import zipfile
-
- dest_dir.mkdir(parents=True, exist_ok=True)
- zip_path = _tmp_dir(arxiv_id) / "source.zip"
-
- transport = None
- if settings.http_proxy:
- transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
-
- try:
- async with httpx.AsyncClient(
- timeout=settings.HTTP_TIMEOUT_SECONDS,
- headers={"User-Agent": settings.HTTP_USER_AGENT},
- transport=transport,
- follow_redirects=True,
- ) as client:
- resp = await client.get(source_url)
- resp.raise_for_status()
- zip_path.write_bytes(resp.content)
- except Exception as exc:
- logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
- return
-
- try:
- with zipfile.ZipFile(zip_path, "r") as zf:
- zf.extractall(dest_dir)
- logger.debug("Extracted source for %s", arxiv_id)
- except zipfile.BadZipFile:
- # 可能是 tar.gz
- import tarfile
- try:
- with tarfile.open(zip_path, "r:*") as tf:
- tf.extractall(dest_dir)
- logger.debug("Extracted source (tar) for %s", arxiv_id)
- except Exception:
- logger.warning("Cannot extract source for %s", arxiv_id)
- except Exception:
- logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
- finally:
- if zip_path.exists():
- zip_path.unlink()
-
-
# ── 单篇总结 ────────────────────────────────────────────────────────────
@@ -491,6 +173,8 @@ async def summarize_one(
force: bool = False,
) -> dict:
"""总结单篇论文的完整流程。"""
+ import asyncio
+
arxiv_id = paper.arxiv_id
# 获取或创建 summary_status
@@ -520,6 +204,8 @@ async def summarize_one(
async def _do_summarize_one(db: Session, paper: Paper) -> dict:
"""实际的单篇总结执行(在 semaphore 保护下)。"""
+ import asyncio
+
arxiv_id = paper.arxiv_id
status = paper.summary_status
now = datetime.now(timezone.utc)
@@ -532,16 +218,16 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
raw_output = ""
try:
# 写 meta.json
- meta_path = _write_meta_json(paper)
+ meta_path = write_meta_json(paper)
# 下载 PDF
- await _download_pdf(arxiv_id, paper.pdf_url)
+ await download_pdf(arxiv_id, paper.pdf_url)
# 调用 pi
- raw_output = await _call_pi(meta_path, _tmp_dir(arxiv_id) / "paper.pdf")
+ raw_output = await call_pi(meta_path, Path("data/tmp") / arxiv_id / "paper.pdf")
# 提取 JSON
- json_data = _extract_json(raw_output)
+ json_data = extract_json(raw_output)
# Pydantic 校验
schema = SummarySchema.model_validate(json_data)
@@ -564,7 +250,7 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
# Phase 5: LaTeX 图片提取(可选增强,失败不影响总结)
try:
- await _extract_images_from_source(arxiv_id)
+ await extract_images_from_source(arxiv_id)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
@@ -625,7 +311,7 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
}
finally:
- _cleanup_tmp(arxiv_id)
+ cleanup_tmp(arxiv_id)
# ── 单篇入口 ────────────────────────────────────────────────────────────
@@ -690,6 +376,8 @@ async def summarize_batch(
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
+ import asyncio
+
now = datetime.now(timezone.utc)
# TaskLock 防重入
@@ -741,7 +429,7 @@ async def summarize_batch(
log_entry.papers_found = 0
log_entry.papers_new = 0
log_entry.completed_at = datetime.now(timezone.utc)
- _release_lock(db, lock)
+ release_lock(db, lock)
return {"status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0}
# 并发控制
@@ -813,15 +501,4 @@ async def summarize_batch(
return {"status": "failed", "error": str(exc)}
finally:
- _release_lock(db, lock)
-
-
-def _release_lock(db: Session, lock: TaskLock) -> None:
- """释放 TaskLock。"""
- try:
- lock.status = "finished"
- lock.released_at = datetime.now(timezone.utc)
- db.commit()
- except Exception:
- db.rollback()
- logger.warning("Failed to release summarize lock", exc_info=True)
+ release_lock(db, lock)
diff --git a/app/services/trends.py b/app/services/trends.py
new file mode 100644
index 0000000..51e7824
--- /dev/null
+++ b/app/services/trends.py
@@ -0,0 +1,81 @@
+"""趋势统计服务 — 按日论文数量、热门标签、Upvotes 分布、总结完成率。"""
+
+from __future__ import annotations
+
+from datetime import date, timedelta
+
+from sqlalchemy import text
+from sqlalchemy.orm import Session
+
+
+def get_trends_data(db: Session) -> dict:
+ """从 DB 聚合趋势数据。"""
+ thirty_days_ago = (date.today() - timedelta(days=30)).isoformat()
+
+ # 1. 按日论文数量(近 30 天)
+ daily_rows = db.execute(text("""
+ SELECT paper_date, COUNT(*) as cnt
+ FROM papers
+ WHERE paper_date >= :start_date
+ GROUP BY paper_date
+ ORDER BY paper_date ASC
+ """), {"start_date": thirty_days_ago}).fetchall()
+ daily_counts = [
+ {"date": str(row[0]), "count": row[1]}
+ for row in daily_rows
+ ]
+
+ # 2. 热门标签 Top 20
+ tag_rows = db.execute(text("""
+ SELECT tag, COUNT(*) as cnt
+ FROM paper_tags
+ GROUP BY tag
+ ORDER BY cnt DESC
+ LIMIT 20
+ """)).fetchall()
+ top_tags = [
+ {"tag": row[0], "count": row[1]}
+ for row in tag_rows
+ ]
+
+ # 3. Upvotes 分布
+ upvote_rows = db.execute(text("""
+ SELECT
+ CASE
+ WHEN upvotes >= 100 THEN '100+'
+ WHEN upvotes >= 50 THEN '50-99'
+ WHEN upvotes >= 20 THEN '20-49'
+ WHEN upvotes >= 10 THEN '10-19'
+ WHEN upvotes >= 5 THEN '5-9'
+ ELSE '0-4'
+ END as bucket,
+ COUNT(*) as cnt
+ FROM papers
+ GROUP BY bucket
+ ORDER BY MIN(upvotes) DESC
+ """)).fetchall()
+ upvotes_dist = [
+ {"range": row[0], "count": row[1]}
+ for row in upvote_rows
+ ]
+
+ # 4. 总结完成率
+ summary_rows = db.execute(text("""
+ SELECT
+ COALESCE(ss.status, 'none') as status,
+ COUNT(*) as cnt
+ FROM papers p
+ LEFT JOIN summary_status ss ON ss.paper_id = p.id
+ GROUP BY status
+ """)).fetchall()
+ summary_completion = [
+ {"status": row[0], "count": row[1]}
+ for row in summary_rows
+ ]
+
+ return {
+ "daily_counts": daily_counts,
+ "top_tags": top_tags,
+ "upvotes_dist": upvotes_dist,
+ "summary_completion": summary_completion,
+ }
diff --git a/app/services/user_data.py b/app/services/user_data.py
index 77cb9c9..fd93b18 100644
--- a/app/services/user_data.py
+++ b/app/services/user_data.py
@@ -1,12 +1,13 @@
-"""用户数据服务 — 收藏、阅读状态、个人笔记。无账号体系,数据写入本地 SQLite。"""
+"""用户数据服务 — 收藏、阅读状态、个人笔记、阅读列表查询。无账号体系,数据写入本地 SQLite。"""
from __future__ import annotations
from datetime import datetime, timezone
-from sqlalchemy.orm import Session
+from sqlalchemy import or_
+from sqlalchemy.orm import Session, joinedload
-from app.models import Paper, UserBookmark, UserNote, UserReadingStatus
+from app.models import Paper, PaperTag, UserBookmark, UserNote, UserReadingStatus
# ── 收藏 ──────────────────────────────────────────────────────────────
@@ -113,3 +114,47 @@ def save_note(db: Session, arxiv_id: str, content: str) -> dict:
"content": content,
"updated_at": now.isoformat(),
}
+
+
+# ── 阅读列表 ──────────────────────────────────────────────────────────
+
+
+def query_reading_list(
+ db: Session,
+ filter_type: str,
+ tag: str | None,
+) -> list[Paper]:
+ """根据筛选条件查询阅读列表。"""
+ # 基础:有任意用户数据的论文
+ base = db.query(Paper).filter(
+ or_(
+ Paper.bookmark.has(),
+ Paper.reading_status.has(),
+ Paper.note.has(),
+ )
+ )
+
+ # 应用筛选
+ if filter_type == "has_note":
+ base = base.filter(Paper.note.has())
+ elif filter_type in ("unread", "skimmed", "read_summary", "read_full"):
+ base = base.filter(
+ Paper.reading_status.has(UserReadingStatus.status == filter_type)
+ )
+
+ # 应用标签
+ if tag:
+ base = base.filter(Paper.tags.any(PaperTag.tag == tag))
+
+ return (
+ base.options(
+ joinedload(Paper.authors),
+ joinedload(Paper.tags),
+ joinedload(Paper.summary_status),
+ joinedload(Paper.bookmark),
+ joinedload(Paper.reading_status),
+ joinedload(Paper.note),
+ )
+ .order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
+ .all()
+ )
diff --git a/app/utils.py b/app/utils.py
new file mode 100644
index 0000000..63c794a
--- /dev/null
+++ b/app/utils.py
@@ -0,0 +1,73 @@
+"""公共工具 — 消除各模块间的重复代码。"""
+
+from __future__ import annotations
+
+from datetime import datetime, timezone
+from pathlib import Path
+from zoneinfo import ZoneInfo
+
+import httpx
+from fastapi.templating import Jinja2Templates
+
+from app.config import settings
+
+# ── 路径常量 ──────────────────────────────────────────────────────────
+
+DATA_DIR = Path("data")
+PAPERS_DIR = DATA_DIR / "papers"
+TMP_DIR = DATA_DIR / "tmp"
+
+# ── 模板单例 ──────────────────────────────────────────────────────────
+
+templates = Jinja2Templates(directory="app/templates")
+
+
+# ── 时区工具 ──────────────────────────────────────────────────────────
+
+
+def today_str() -> str:
+ """当前日期字符串(按 APP_TIMEZONE)。"""
+ tz = ZoneInfo(settings.APP_TIMEZONE)
+ return datetime.now(tz).strftime("%Y-%m-%d")
+
+
+# ── 锁释放 ────────────────────────────────────────────────────────────
+
+
+def release_lock(db, lock) -> None:
+ """释放 TaskLock。"""
+ try:
+ lock.status = "finished"
+ lock.released_at = datetime.now(timezone.utc)
+ db.commit()
+ except Exception:
+ db.rollback()
+
+
+# ── HTTP 客户端工厂 ───────────────────────────────────────────────────
+
+
+def make_http_client(*, sync: bool = False, follow_redirects: bool = False, **kwargs) -> httpx.AsyncClient | httpx.Client:
+ """创建带 proxy 和默认配置的 httpx 客户端。
+
+ Args:
+ sync: True 返回同步 Client,False 返回 AsyncClient
+ follow_redirects: 是否跟随重定向
+ **kwargs: 覆盖默认参数
+ """
+ defaults: dict = {
+ "timeout": settings.HTTP_TIMEOUT_SECONDS,
+ "headers": {"User-Agent": settings.HTTP_USER_AGENT},
+ "follow_redirects": follow_redirects,
+ }
+ if settings.http_proxy:
+ defaults["transport"] = (
+ httpx.HTTPTransport(proxy=settings.http_proxy)
+ if sync
+ else httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
+ )
+ defaults.update(kwargs)
+
+ if sync:
+ return httpx.Client(**defaults)
+ return httpx.AsyncClient(**defaults)
diff --git a/tests/conftest.py b/tests/conftest.py
index 61ef8aa..ef4895b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -15,13 +15,13 @@ from sqlalchemy.pool import StaticPool
from app.database import get_db
from app.main import create_app
+from app.database import init_db
from app.models import (
Paper,
PaperAuthor,
PaperSummary,
PaperTag,
SummaryStatus,
- init_db,
)
diff --git a/tests/test_admin_phase4.py b/tests/test_admin_phase4.py
index f95702e..305e608 100644
--- a/tests/test_admin_phase4.py
+++ b/tests/test_admin_phase4.py
@@ -141,7 +141,7 @@ class TestCleanupTmp:
old_mtime = time.time() - 25 * 3600
os.utime(old_dir, (old_mtime, old_mtime))
- monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir)
+ monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
@@ -158,7 +158,7 @@ class TestCleanupTmp:
recent_dir.mkdir()
(recent_dir / "paper.pdf").write_text("fake pdf")
- monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir)
+ monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
@@ -168,7 +168,7 @@ class TestCleanupTmp:
def test_cleanup_empty_dir(self, tmp_path, monkeypatch):
"""data/tmp/ 不存在时安全返回。"""
- monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_path / "nonexistent")
+ monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_path / "nonexistent")
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 0
@@ -187,7 +187,7 @@ class TestCleanupTmp:
recent_dir = tmp_dir / "2401.new"
recent_dir.mkdir()
- monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir)
+ monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
@@ -318,7 +318,7 @@ class TestDeletePapersByDateRange:
(papers_dir / "2401.10001").mkdir()
(papers_dir / "2401.10001" / "meta.json").write_text("{}")
- monkeypatch.setattr("app.services.cleaner._PAPERS_DIR", papers_dir)
+ monkeypatch.setattr("app.services.cleaner.PAPERS_DIR", papers_dir)
result = await delete_papers_by_date_range(
db_session,
diff --git a/tests/test_phase5.py b/tests/test_phase5.py
index cf73132..1ed2f68 100644
--- a/tests/test_phase5.py
+++ b/tests/test_phase5.py
@@ -125,10 +125,9 @@ class TestEmbedderInit:
"""CHROMA_ENABLED=false 时不初始化。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
- emb._client = None
- emb._collection = None
+ emb._chroma.reset()
emb.init_chroma()
- assert emb._client is None
+ assert emb._chroma._client is None
def test_chroma_init_success(self, monkeypatch, tmp_path):
"""CHROMA_ENABLED=true 时初始化成功。"""
@@ -136,23 +135,20 @@ class TestEmbedderInit:
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
import app.services.embedder as emb
- emb._client = None
- emb._collection = None
+ emb._chroma.reset()
emb.init_chroma()
- assert emb._client is not None
- assert emb._collection is not None
+ assert emb._chroma._client is not None
+ assert emb._chroma._collection is not None
# 清理
- emb._client = None
- emb._collection = None
+ emb._chroma.reset()
def test_get_collection_returns_none_when_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 get_collection 返回 None。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
- emb._client = None
- emb._collection = None
+ emb._chroma.reset()
assert emb.get_collection() is None
@@ -163,8 +159,7 @@ class TestEmbedderIndexing:
"""CHROMA_ENABLED=false 时 index_paper 返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
- emb._client = None
- emb._collection = None
+ emb._chroma.reset()
assert emb.index_paper("test-id") is False
def test_index_paper_no_api_config(self, monkeypatch, tmp_path):
@@ -175,22 +170,19 @@ class TestEmbedderIndexing:
monkeypatch.setattr(settings, "EMBED_MODEL", "")
import app.services.embedder as emb
- emb._client = None
- emb._collection = None
+ emb._chroma.reset()
emb.init_chroma()
result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"})
assert result is False
- emb._client = None
- emb._collection = None
+ emb._chroma.reset()
def test_index_batch_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 index_batch 返回全失败。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
- emb._client = None
- emb._collection = None
+ emb._chroma.reset()
result = emb.index_batch(["a", "b"])
assert result["success"] == 0
assert result["failed"] == 2
@@ -206,16 +198,14 @@ class TestEmbedderIndexing:
"""CHROMA_ENABLED=false 时 delete_paper 返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
- emb._client = None
- emb._collection = None
+ emb._chroma.reset()
assert emb.delete_paper("test-id") is False
def test_search_similar_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 search_similar 返回空列表。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
- emb._client = None
- emb._collection = None
+ emb._chroma.reset()
assert emb.search_similar("test query") == []
@@ -427,8 +417,8 @@ class TestTrendsDashboard:
from unittest.mock import patch as upatch
import app.routes.trends as trends_mod
- # monkeypatch _get_trends_data 中的 date.today
- with upatch("app.routes.trends.date") as mock_date:
+ # monkeypatch get_trends_data 中的 date.today
+ with upatch("app.services.trends.date") as mock_date:
mock_date.today.return_value = date(2024, 1, 20)
mock_date.side_effect = lambda *a, **kw: date(*a, **kw)
@@ -528,15 +518,17 @@ class TestImageExtraction:
@pytest.mark.asyncio
async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path):
"""源码目录不存在时返回 0。"""
- monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x)
- monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x)
- from app.services.summarizer import _extract_images_from_source
- result = await _extract_images_from_source("2401.99999")
+ monkeypatch.setattr("app.services.pdf_downloader.tmp_dir", lambda x: tmp_path / "tmp" / x)
+ monkeypatch.setattr("app.services.pdf_downloader.paper_dir", lambda x: tmp_path / "papers" / x)
+ from app.services.image_extractor import extract_images_from_source
+ result = await extract_images_from_source("2401.99999")
assert result == 0
@pytest.mark.asyncio
async def test_extract_images_from_tex(self, monkeypatch, tmp_path):
"""从 .tex 文件中提取图片。"""
+ from app.services.image_extractor import extract_images_from_source
+
tmp_source = tmp_path / "tmp" / "2401.00001" / "source"
tmp_source.mkdir(parents=True)
@@ -559,11 +551,16 @@ class TestImageExtraction:
(tmp_source / "main.tex").write_text(tex_content)
papers_dir = tmp_path / "papers" / "2401.00001"
- monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x)
- monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x)
+ monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
+ monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
- from app.services.summarizer import _extract_images_from_source
- result = await _extract_images_from_source("2401.00001")
+ # Mock download_source_zip to avoid real network call (source dir already exists)
+ async def _noop_download(*args, **kwargs):
+ pass
+
+ monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
+
+ result = await extract_images_from_source("2401.00001")
assert result == 2
dest_images = papers_dir / "images"
@@ -574,15 +571,22 @@ class TestImageExtraction:
@pytest.mark.asyncio
async def test_extract_images_empty_tex(self, monkeypatch, tmp_path):
""".tex 文件无图片时返回 0。"""
+ from app.services.image_extractor import extract_images_from_source
+
tmp_source = tmp_path / "tmp" / "2401.00002" / "source"
tmp_source.mkdir(parents=True)
(tmp_source / "main.tex").write_text(r"\documentclass{article}\begin{document}Hello\end{document}")
- monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x)
- monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x)
+ monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
+ monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
- from app.services.summarizer import _extract_images_from_source
- result = await _extract_images_from_source("2401.00002")
+ # Mock download_source_zip to avoid real network call
+ async def _noop_download(*args, **kwargs):
+ pass
+
+ monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
+
+ result = await extract_images_from_source("2401.00002")
assert result == 0
@@ -644,8 +648,7 @@ class TestGracefulDegradation:
"""CHROMA 关闭时删除论文正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
- emb._client = None
- emb._collection = None
+ emb._chroma.reset()
from app.services.cleaner import delete_papers_by_date_range
result = await delete_papers_by_date_range(
diff --git a/tests/test_summarizer.py b/tests/test_summarizer.py
index ebe8b8f..35cdb79 100644
--- a/tests/test_summarizer.py
+++ b/tests/test_summarizer.py
@@ -27,14 +27,7 @@ from app.services.schemas import (
flatten_for_db,
)
from app.services.summarizer import (
- JsonNotFoundError,
- PdfDownloadError,
- PiProcessError,
- PiTimeoutError,
- _call_pi,
_classify_error,
- _cleanup_tmp,
- _extract_json,
_save_files,
_save_raw_output_only,
_update_summary_in_db,
@@ -42,6 +35,17 @@ from app.services.summarizer import (
summarize_one,
summarize_single,
)
+from app.services.pi_client import (
+ JsonNotFoundError,
+ PiProcessError,
+ PiTimeoutError,
+ call_pi as _call_pi,
+ extract_json as _extract_json,
+)
+from app.services.pdf_downloader import (
+ PdfDownloadError,
+ cleanup_tmp as _cleanup_tmp,
+)
# ═══════════════════════════════════════════════════════════════════════
@@ -287,7 +291,7 @@ class TestFileOperations:
def test_save_files(self, tmp_path, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
- with patch("app.services.summarizer._PAPERS_DIR", tmp_path):
+ with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
_save_files("2401.12345", schema, "raw output text")
paper_dir = tmp_path / "2401.12345"
@@ -297,7 +301,7 @@ class TestFileOperations:
assert saved["title_zh"] == "测试论文中文标题"
def test_save_raw_output_only(self, tmp_path):
- with patch("app.services.summarizer._PAPERS_DIR", tmp_path):
+ with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
_save_raw_output_only("2401.12345", "raw output")
paper_dir = tmp_path / "2401.12345"
assert (paper_dir / "raw_output.txt").exists()
@@ -307,13 +311,13 @@ class TestFileOperations:
tmp_paper = tmp_path / "2401.12345"
tmp_paper.mkdir()
(tmp_paper / "paper.pdf").write_bytes(b"%PDF-fake")
- with patch("app.services.summarizer._TMP_DIR", tmp_path):
+ with patch("app.services.pdf_downloader.TMP_DIR", tmp_path):
_cleanup_tmp("2401.12345")
assert not tmp_paper.exists()
def test_cleanup_tmp_nonexistent(self, tmp_path):
"""清理不存在的目录不报错。"""
- with patch("app.services.summarizer._TMP_DIR", tmp_path):
+ with patch("app.services.pdf_downloader.TMP_DIR", tmp_path):
_cleanup_tmp("nonexistent") # 不抛异常
@@ -329,9 +333,11 @@ class TestSummarizeOneFlow:
def _patch_paths(self, tmp_path):
"""将 data 目录重定向到 tmp_path。"""
with (
- patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"),
- patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"),
- patch("app.services.summarizer._DATA_DIR", tmp_path),
+ patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
+ patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
+ patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
+ patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
+ patch("app.utils.TMP_DIR", tmp_path / "tmp"),
):
yield
@@ -341,8 +347,8 @@ class TestSummarizeOneFlow:
):
"""pending → processing → done 全流程。"""
with (
- patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
- patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
+ patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
+ patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
result = await summarize_one(db_session, sample_paper)
@@ -374,7 +380,7 @@ class TestSummarizeOneFlow:
"""PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。"""
with (
patch(
- "app.services.summarizer._download_pdf",
+ "app.services.summarizer.download_pdf",
new_callable=AsyncMock,
side_effect=PdfDownloadError("network error"),
),
@@ -392,9 +398,9 @@ class TestSummarizeOneFlow:
async def test_pi_timeout(self, db_session, sample_paper, _patch_paths):
"""pi 超时 → timeout 错误,retry_count 递增。"""
with (
- patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
+ patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
- "app.services.summarizer._call_pi",
+ "app.services.summarizer.call_pi",
new_callable=AsyncMock,
side_effect=PiTimeoutError("timeout after 300s"),
),
@@ -409,9 +415,9 @@ class TestSummarizeOneFlow:
async def test_json_not_found(self, db_session, sample_paper, _patch_paths):
"""pi 输出无 JSON → json_not_found。"""
with (
- patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
+ patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
- "app.services.summarizer._call_pi",
+ "app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value="No JSON in this output at all.",
),
@@ -436,9 +442,9 @@ class TestSummarizeOneFlow:
bad_output = f"```json\n{bad_json}\n```"
with (
- patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
+ patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
- "app.services.summarizer._call_pi",
+ "app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value=bad_output,
),
@@ -464,9 +470,9 @@ class TestSummarizeOneFlow:
):
"""失败时仍保存 raw_output.txt。"""
with (
- patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
+ patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
- "app.services.summarizer._call_pi",
+ "app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value="Some output without JSON",
),
@@ -483,8 +489,8 @@ class TestSummarizeOneFlow:
):
"""成功后清理 tmp 目录。"""
with (
- patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
- patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
+ patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
+ patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
await summarize_one(db_session, sample_paper)
@@ -498,7 +504,7 @@ class TestSummarizeOneFlow:
"""失败后也清理 tmp 目录。"""
with (
patch(
- "app.services.summarizer._download_pdf",
+ "app.services.summarizer.download_pdf",
new_callable=AsyncMock,
side_effect=PdfDownloadError("fail"),
),
@@ -529,9 +535,11 @@ class TestBatchSummarize:
@pytest.fixture
def _patch_paths(self, tmp_path):
with (
- patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"),
- patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"),
- patch("app.services.summarizer._DATA_DIR", tmp_path),
+ patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
+ patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
+ patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
+ patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
+ patch("app.utils.TMP_DIR", tmp_path / "tmp"),
):
yield
@@ -561,8 +569,8 @@ class TestBatchSummarize:
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
with (
- patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
- patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
+ patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
+ patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
result = await summarize_batch(
db_session, _session_factory=_TestSession
@@ -612,8 +620,8 @@ class TestBatchSummarize:
return mock_pi_output
with (
- patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
- patch("app.services.summarizer._call_pi", side_effect=_mock_call_pi),
+ patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
+ patch("app.services.summarizer.call_pi", side_effect=_mock_call_pi),
):
result = await summarize_batch(
db_session, _session_factory=_TestSession
@@ -646,8 +654,8 @@ class TestBatchSummarize:
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
with (
- patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
- patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
+ patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
+ patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
await summarize_batch(
db_session, _session_factory=_TestSession