"""ChromaDB 语义搜索服务 — embedding 生成、索引管理、相似查询。""" from __future__ import annotations import logging from pathlib import Path from sqlalchemy.orm import Session, joinedload from app.config import settings from app.models import Paper logger = logging.getLogger(__name__) # ── ChromaDB 管理器(替代全局可变状态)────────────────────────────────── class ChromaManager: """封装 ChromaDB 客户端和 collection 的生命周期。""" def __init__(self) -> None: self._client = None self._collection = None def init(self) -> None: """CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。""" if not settings.CHROMA_ENABLED: logger.debug("ChromaDB disabled, skip init") return if self._client is not None: return try: import chromadb chroma_path = Path(settings.CHROMA_DIR) chroma_path.mkdir(parents=True, exist_ok=True) self._client = chromadb.PersistentClient(path=str(chroma_path)) self._collection = self._get_or_create_collection() logger.info("ChromaDB initialized at %s", chroma_path) except Exception: logger.exception("Failed to initialize ChromaDB") self._client = None self._collection = None def _get_or_create_collection(self): """获取或创建 papers_embeddings collection。""" try: col = self._client.get_collection("papers_embeddings") logger.info( "ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count() ) return col except Exception: pass col = self._client.create_collection( name="papers_embeddings", metadata={"hnsw:space": "cosine"}, ) logger.info("ChromaDB collection 'papers_embeddings' created") return col def get_collection(self): """返回当前 collection,未初始化则自动初始化。""" if not settings.CHROMA_ENABLED: return None if self._collection is None: self.init() return self._collection def reset(self) -> None: """重置状态(供测试使用)。""" self._client = None self._collection = None # 模块级单例 _chroma = ChromaManager() def init_chroma() -> None: """初始化 ChromaDB(供 lifespan 调用)。""" _chroma.init() def get_collection(): """返回当前 collection,未初始化则返回 None。""" return _chroma.get_collection() # ── Embedding API 调用 ────────────────────────────────────────────────── def _get_embedding(text: str) -> list[float] | None: """调用 EMBED_API_BASE 的 embedding API 生成向量。 POST /v1/embeddings, model=EMBED_MODEL 校验返回向量长度 == EMBED_DIMENSIONS 失败时返回 None 并记录日志。 """ if not settings.EMBED_API_BASE or not settings.EMBED_MODEL: logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding") return None from app.utils import make_http_client url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings" headers = {"Content-Type": "application/json"} if settings.EMBED_API_KEY: headers["Authorization"] = f"Bearer {settings.EMBED_API_KEY}" payload = { "model": settings.EMBED_MODEL, "input": text, } try: with make_http_client(sync=True) as client: resp = client.post(url, json=payload, headers=headers) resp.raise_for_status() data = resp.json() vec = data["data"][0]["embedding"] if settings.EMBED_DIMENSIONS > 0 and len(vec) != settings.EMBED_DIMENSIONS: logger.warning( "Embedding dimension mismatch: expected=%d got=%d, skip", settings.EMBED_DIMENSIONS, len(vec), ) return None return vec except Exception: logger.exception("Embedding API call failed") return None # ── 索引文本拼接 ──────────────────────────────────────────────────────── def _build_index_text(paper: Paper, summary_dict: dict | None) -> str: """拼接高信号字段为一段文本用于 embedding。""" parts: list[str] = [] if paper.title_zh: parts.append(paper.title_zh) if paper.title_en: parts.append(paper.title_en) if paper.tags: tag_str = " ".join(t.tag for t in paper.tags) if tag_str: parts.append(tag_str) if summary_dict: for key in ("one_line", "motivation_problem", "method_key_idea"): val = summary_dict.get(key) if val: parts.append(val) return " ".join(parts) # ── 单篇索引 ──────────────────────────────────────────────────────────── def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool: """将单篇论文写入 ChromaDB 语义索引。 Args: paper_id: arxiv_id texts_dict: 可选的预拼接字段 dict,包含 title_zh, title_en, tags, one_line 等 如果为 None,则从 DB 加载。 Returns: True 表示成功,False 表示失败或跳过。 """ col = get_collection() if col is None: return False try: # 如果没传 texts_dict,从 DB 加载 if texts_dict is None: from app.database import SessionLocal db = SessionLocal() try: paper = ( db.query(Paper) .filter(Paper.arxiv_id == paper_id) .options(joinedload(Paper.tags), joinedload(Paper.summary)) .first() ) if not paper: logger.warning("Paper %s not found for indexing", paper_id) return False summary_dict = None if paper.summary: summary_dict = { "one_line": paper.summary.one_line or "", "motivation_problem": paper.summary.motivation_problem or "", "method_key_idea": paper.summary.method_key_idea or "", } index_text = _build_index_text(paper, summary_dict) arxiv_id = paper.arxiv_id title_zh = paper.title_zh or "" paper_date = paper.paper_date.isoformat() if paper.paper_date else "" finally: db.close() else: index_text = " ".join( v for v in texts_dict.values() if isinstance(v, str) and v ) arxiv_id = texts_dict.get("arxiv_id", paper_id) title_zh = texts_dict.get("title_zh", "") paper_date = texts_dict.get("paper_date", "") if not index_text.strip(): logger.warning("Empty index text for %s, skip", paper_id) return False vec = _get_embedding(index_text) if vec is None: return False col.upsert( ids=[arxiv_id], embeddings=[vec], metadatas=[ {"arxiv_id": arxiv_id, "title_zh": title_zh, "paper_date": paper_date} ], ) logger.info("Indexed paper %s in ChromaDB", arxiv_id) return True except Exception: logger.exception("Failed to index paper %s in ChromaDB", paper_id) return False # ── 批量索引 ──────────────────────────────────────────────────────────── def index_batch(paper_ids: list[str]) -> dict: """批量索引论文,单篇失败不影响其他。 Returns: {"total": int, "success": int, "failed": int} """ if not paper_ids: return {"total": 0, "success": 0, "failed": 0} col = get_collection() if col is None: return {"total": len(paper_ids), "success": 0, "failed": len(paper_ids)} success = 0 failed = 0 for pid in paper_ids: if index_paper(pid): success += 1 else: failed += 1 logger.info( "Batch index: total=%d success=%d failed=%d", len(paper_ids), success, failed ) return {"total": len(paper_ids), "success": success, "failed": failed} # ── 删除 ──────────────────────────────────────────────────────────────── def delete_paper(paper_id: str) -> bool: """从 ChromaDB 删除论文索引。 Args: paper_id: arxiv_id """ col = get_collection() if col is None: return False try: col.delete(ids=[paper_id]) logger.info("Deleted paper %s from ChromaDB", paper_id) return True except Exception: logger.exception("Failed to delete paper %s from ChromaDB", paper_id) return False # ── 相似查询 ──────────────────────────────────────────────────────────── def search_similar(query_text: str, top_k: int = 20) -> list[dict]: """语义相似搜索。 Args: query_text: 查询文本 top_k: 返回前 K 个结果 Returns: [{"arxiv_id": str, "distance": float}, ...] """ col = get_collection() if col is None: return [] try: vec = _get_embedding(query_text) if vec is None: return [] results = col.query( query_embeddings=[vec], n_results=min(top_k, col.count()) if col.count() > 0 else top_k, include=["metadatas", "distances"], ) if not results["ids"] or not results["ids"][0]: return [] items = [] for i, arxiv_id in enumerate(results["ids"][0]): dist = results["distances"][0][i] if results["distances"] else 0.0 items.append({"arxiv_id": arxiv_id, "distance": dist}) return items except Exception: logger.exception("ChromaDB search_similar failed") return []