refactor: restructure services and add image/pdf extraction utilities

- Add image_extractor, pdf_downloader, pi_client, trends services
- Add shared utils module
- Refactor summarizer, embedder, routes for cleaner separation
- Update tests to match new service structure
This commit is contained in:
2026-06-06 00:00:55 +08:00
parent ba9afa212c
commit 85c4cfb9e8
22 changed files with 843 additions and 780 deletions
+70 -54
View File
@@ -5,8 +5,6 @@ from __future__ import annotations
import logging
from pathlib import Path
import httpx
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from app.config import settings
@@ -14,66 +12,82 @@ from app.models import Paper
logger = logging.getLogger(__name__)
# ── 单例客户端和 collection ─────────────────────────────────────────────
_client = None
_collection = None
# ── ChromaDB 管理器(替代全局可变状态)──────────────────────────────────
def _chroma_dir() -> Path:
return Path(settings.CHROMA_DIR)
class ChromaManager:
"""封装 ChromaDB 客户端和 collection 的生命周期。"""
def __init__(self) -> None:
self._client = None
self._collection = None
def init(self) -> None:
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
if not settings.CHROMA_ENABLED:
logger.debug("ChromaDB disabled, skip init")
return
if self._client is not None:
return
try:
import chromadb
chroma_path = Path(settings.CHROMA_DIR)
chroma_path.mkdir(parents=True, exist_ok=True)
self._client = chromadb.PersistentClient(path=str(chroma_path))
self._collection = self._get_or_create_collection()
logger.info("ChromaDB initialized at %s", chroma_path)
except Exception:
logger.exception("Failed to initialize ChromaDB")
self._client = None
self._collection = None
def _get_or_create_collection(self):
"""获取或创建 papers_embeddings collection。"""
try:
col = self._client.get_collection("papers_embeddings")
logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count())
return col
except Exception:
pass
col = self._client.create_collection(
name="papers_embeddings",
metadata={"hnsw:space": "cosine"},
)
logger.info("ChromaDB collection 'papers_embeddings' created")
return col
def get_collection(self):
"""返回当前 collection,未初始化则自动初始化。"""
if not settings.CHROMA_ENABLED:
return None
if self._collection is None:
self.init()
return self._collection
def reset(self) -> None:
"""重置状态(供测试使用)。"""
self._client = None
self._collection = None
# 模块级单例
_chroma = ChromaManager()
def init_chroma() -> None:
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection"""
global _client, _collection
if not settings.CHROMA_ENABLED:
logger.debug("ChromaDB disabled, skip init")
return
if _client is not None:
return
try:
import chromadb
chroma_path = _chroma_dir()
chroma_path.mkdir(parents=True, exist_ok=True)
_client = chromadb.PersistentClient(path=str(chroma_path))
_collection = _get_or_create_collection()
logger.info("ChromaDB initialized at %s", chroma_path)
except Exception:
logger.exception("Failed to initialize ChromaDB")
_client = None
_collection = None
def _get_or_create_collection():
"""获取或创建 papers_embeddings collection,维度不匹配时记录日志并跳过。"""
import chromadb
try:
col = _client.get_collection("papers_embeddings")
logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count())
return col
except Exception:
pass
col = _client.create_collection(
name="papers_embeddings",
metadata={"hnsw:space": "cosine"},
)
logger.info("ChromaDB collection 'papers_embeddings' created")
return col
"""初始化 ChromaDB(供 lifespan 调用)"""
_chroma.init()
def get_collection():
"""返回当前 collection,未初始化则返回 None。"""
if not settings.CHROMA_ENABLED:
return None
if _collection is None:
init_chroma()
return _collection
return _chroma.get_collection()
# ── Embedding API 调用 ──────────────────────────────────────────────────
@@ -90,6 +104,8 @@ def _get_embedding(text: str) -> list[float] | None:
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
return None
from app.utils import make_http_client
url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings"
headers = {"Content-Type": "application/json"}
if settings.EMBED_API_KEY:
@@ -101,7 +117,7 @@ def _get_embedding(text: str) -> list[float] | None:
}
try:
with httpx.Client(timeout=settings.HTTP_TIMEOUT_SECONDS) as client:
with make_http_client(sync=True) as client:
resp = client.post(url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()