349 lines
12 KiB
Python
349 lines
12 KiB
Python
"""ChromaDB 语义搜索服务 — embedding 生成、索引管理、相似查询。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import threading
|
||
from pathlib import Path
|
||
|
||
from sqlalchemy import select
|
||
from sqlalchemy.orm import joinedload
|
||
|
||
from app.config import settings
|
||
from app.models import Paper
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ── ChromaDB 管理器(替代全局可变状态)──────────────────────────────────
|
||
|
||
|
||
class ChromaManager:
|
||
"""封装 ChromaDB 客户端和 collection 的生命周期。
|
||
|
||
所有客户端/集合访问经 ``self._lock`` 串行化:后处理经 ``asyncio.to_thread``
|
||
多 worker 并发调 ``index_paper``,若不串行化会并发建连,触发 chromadb 1.5.x
|
||
``SharedSystemClient`` 类级缓存的并发竞争(``_create_system_if_not_exists``
|
||
无锁 + refcount release 弹 key)→ ``KeyError: '<persist_dir>'``。
|
||
锁用 RLock:``index_paper`` 持锁后经 ``get_collection()`` 间接再调 ``init()``,
|
||
同线程可重入,不死锁。
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
self._lock = threading.RLock()
|
||
self._client = None
|
||
self._collection = None
|
||
|
||
def init(self) -> None:
|
||
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。
|
||
|
||
双重检查锁串行化首次建连 —— 外层快判已建好就直接返回(快路径),抢到锁后再
|
||
判一次(防并发下另一线程已建好),确保 ``PersistentClient`` 全进程只调一次。
|
||
"""
|
||
if not settings.CHROMA_ENABLED:
|
||
logger.debug("ChromaDB disabled, skip init")
|
||
return
|
||
|
||
if self._client is not None:
|
||
return
|
||
|
||
with self._lock:
|
||
if self._client is not None: # 双重检查:抢到锁后可能已被别的线程建好
|
||
return
|
||
|
||
try:
|
||
import chromadb
|
||
|
||
chroma_path = Path(settings.CHROMA_DIR)
|
||
chroma_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
self._client = chromadb.PersistentClient(path=str(chroma_path))
|
||
self._collection = self._get_or_create_collection()
|
||
logger.info("ChromaDB initialized at %s", chroma_path)
|
||
except Exception:
|
||
logger.exception("Failed to initialize ChromaDB")
|
||
self._client = None
|
||
self._collection = None
|
||
|
||
def _get_or_create_collection(self):
|
||
"""获取或创建 papers_embeddings collection。"""
|
||
try:
|
||
col = self._client.get_collection("papers_embeddings")
|
||
logger.info(
|
||
"ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count()
|
||
)
|
||
return col
|
||
except Exception:
|
||
pass
|
||
|
||
col = self._client.create_collection(
|
||
name="papers_embeddings",
|
||
metadata={"hnsw:space": "cosine"},
|
||
)
|
||
logger.info("ChromaDB collection 'papers_embeddings' created")
|
||
return col
|
||
|
||
def get_collection(self):
|
||
"""返回当前 collection,未初始化则自动初始化。"""
|
||
if not settings.CHROMA_ENABLED:
|
||
return None
|
||
if self._collection is None:
|
||
self.init()
|
||
return self._collection
|
||
|
||
def reset(self) -> None:
|
||
"""重置状态(供测试使用)。"""
|
||
self._client = None
|
||
self._collection = None
|
||
|
||
|
||
# 模块级单例
|
||
_chroma = ChromaManager()
|
||
|
||
|
||
def init_chroma() -> None:
|
||
"""初始化 ChromaDB(供 lifespan 调用)。"""
|
||
_chroma.init()
|
||
|
||
|
||
def get_collection():
|
||
"""返回当前 collection,未初始化则返回 None。"""
|
||
return _chroma.get_collection()
|
||
|
||
|
||
# ── Embedding API 调用 ──────────────────────────────────────────────────
|
||
|
||
|
||
def _get_embedding(text: str) -> list[float] | None:
|
||
"""调用 EMBED_API_BASE 的 embedding API 生成向量。
|
||
|
||
POST /v1/embeddings, model=EMBED_MODEL
|
||
校验返回向量长度 == EMBED_DIMENSIONS
|
||
失败时返回 None 并记录日志。
|
||
|
||
纯远程 HTTP 调用、线程安全 —— 留在锁外,让多 worker 并行调。
|
||
"""
|
||
if not settings.EMBED_API_BASE or not settings.EMBED_MODEL:
|
||
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
|
||
return None
|
||
|
||
from app.utils import make_http_client
|
||
|
||
url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings"
|
||
headers = {"Content-Type": "application/json"}
|
||
if settings.EMBED_API_KEY:
|
||
headers["Authorization"] = f"Bearer {settings.EMBED_API_KEY}"
|
||
|
||
payload = {
|
||
"model": settings.EMBED_MODEL,
|
||
"input": text,
|
||
}
|
||
|
||
try:
|
||
with make_http_client(sync=True) as client:
|
||
resp = client.post(url, json=payload, headers=headers)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
|
||
vec = data["data"][0]["embedding"]
|
||
if settings.EMBED_DIMENSIONS > 0 and len(vec) != settings.EMBED_DIMENSIONS:
|
||
logger.warning(
|
||
"Embedding dimension mismatch: expected=%d got=%d, skip",
|
||
settings.EMBED_DIMENSIONS,
|
||
len(vec),
|
||
)
|
||
return None
|
||
return vec
|
||
except Exception:
|
||
logger.exception("Embedding API call failed")
|
||
return None
|
||
|
||
|
||
# ── 索引文本拼接 ────────────────────────────────────────────────────────
|
||
|
||
|
||
def _build_index_text(paper: Paper, summary_dict: dict | None) -> str:
|
||
"""拼接高信号字段为一段文本用于 embedding。"""
|
||
parts: list[str] = []
|
||
|
||
if paper.title_zh:
|
||
parts.append(paper.title_zh)
|
||
if paper.title_en:
|
||
parts.append(paper.title_en)
|
||
if paper.tags:
|
||
tag_str = " ".join(t.tag for t in paper.tags)
|
||
if tag_str:
|
||
parts.append(tag_str)
|
||
|
||
if summary_dict:
|
||
for key in ("one_line", "motivation_problem", "method_key_idea"):
|
||
val = summary_dict.get(key)
|
||
if val:
|
||
parts.append(val)
|
||
|
||
return " ".join(parts)
|
||
|
||
|
||
# ── 单篇索引 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
|
||
"""将单篇论文写入 ChromaDB 语义索引。
|
||
|
||
Args:
|
||
paper_id: arxiv_id
|
||
texts_dict: 可选的预拼接字段 dict,包含 title_zh, title_en, tags, one_line 等
|
||
如果为 None,则从 DB 加载。
|
||
|
||
Returns:
|
||
True 表示成功,False 表示失败或跳过。
|
||
|
||
并发设计:远程 embedding 调用在锁外(多 worker 并行),chroma 集合访问
|
||
(含首次 init)在 ``_chroma._lock`` 内串行化。
|
||
"""
|
||
if not settings.CHROMA_ENABLED:
|
||
return False
|
||
|
||
try:
|
||
# 如果没传 texts_dict,从 DB 加载
|
||
if texts_dict is None:
|
||
from app.database import SessionLocal
|
||
|
||
db = SessionLocal()
|
||
try:
|
||
paper = (
|
||
db.execute(
|
||
select(Paper)
|
||
.where(Paper.arxiv_id == paper_id)
|
||
.options(joinedload(Paper.tags), joinedload(Paper.summary))
|
||
)
|
||
.unique()
|
||
.scalar_one_or_none()
|
||
)
|
||
if not paper:
|
||
logger.warning("Paper %s not found for indexing", paper_id)
|
||
return False
|
||
|
||
summary_dict = None
|
||
if paper.summary:
|
||
summary_dict = {
|
||
"one_line": paper.summary.one_line or "",
|
||
"motivation_problem": paper.summary.motivation_problem or "",
|
||
"method_key_idea": paper.summary.method_key_idea or "",
|
||
}
|
||
index_text = _build_index_text(paper, summary_dict)
|
||
arxiv_id = paper.arxiv_id
|
||
title_zh = paper.title_zh or ""
|
||
paper_date = paper.paper_date.isoformat() if paper.paper_date else ""
|
||
finally:
|
||
db.close()
|
||
else:
|
||
index_text = " ".join(
|
||
v for v in texts_dict.values() if isinstance(v, str) and v
|
||
)
|
||
arxiv_id = texts_dict.get("arxiv_id", paper_id)
|
||
title_zh = texts_dict.get("title_zh", "")
|
||
paper_date = texts_dict.get("paper_date", "")
|
||
|
||
if not index_text.strip():
|
||
logger.warning("Empty index text for %s, skip", paper_id)
|
||
return False
|
||
|
||
vec = _get_embedding(index_text) # 远程 HTTP,锁外并行
|
||
if vec is None:
|
||
return False
|
||
|
||
with _chroma._lock: # 串行化集合访问(首次含 init)
|
||
col = _chroma.get_collection()
|
||
if col is None:
|
||
return False
|
||
col.upsert(
|
||
ids=[arxiv_id],
|
||
embeddings=[vec],
|
||
metadatas=[
|
||
{
|
||
"arxiv_id": arxiv_id,
|
||
"title_zh": title_zh,
|
||
"paper_date": paper_date,
|
||
}
|
||
],
|
||
)
|
||
logger.info("Indexed paper %s in ChromaDB", arxiv_id)
|
||
return True
|
||
|
||
except Exception:
|
||
logger.exception("Failed to index paper %s in ChromaDB", paper_id)
|
||
return False
|
||
|
||
|
||
# ── 删除 ────────────────────────────────────────────────────────────────
|
||
|
||
|
||
def delete_paper(paper_id: str) -> bool:
|
||
"""从 ChromaDB 删除论文索引。
|
||
|
||
Args:
|
||
paper_id: arxiv_id
|
||
"""
|
||
if not settings.CHROMA_ENABLED:
|
||
return False
|
||
|
||
with _chroma._lock:
|
||
col = _chroma.get_collection()
|
||
if col is None:
|
||
return False
|
||
try:
|
||
col.delete(ids=[paper_id])
|
||
logger.info("Deleted paper %s from ChromaDB", paper_id)
|
||
return True
|
||
except Exception:
|
||
logger.exception("Failed to delete paper %s from ChromaDB", paper_id)
|
||
return False
|
||
|
||
|
||
# ── 相似查询 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
def search_similar(query_text: str, top_k: int = 20) -> list[dict]:
|
||
"""语义相似搜索。
|
||
|
||
Args:
|
||
query_text: 查询文本
|
||
top_k: 返回前 K 个结果
|
||
|
||
Returns:
|
||
[{"arxiv_id": str, "distance": float}, ...]
|
||
|
||
并发设计:远程 embedding 在锁外,集合查询在 ``_chroma._lock`` 内。
|
||
"""
|
||
if not settings.CHROMA_ENABLED:
|
||
return []
|
||
|
||
try:
|
||
vec = _get_embedding(query_text) # 远程 HTTP,锁外
|
||
if vec is None:
|
||
return []
|
||
|
||
with _chroma._lock:
|
||
col = _chroma.get_collection()
|
||
if col is None:
|
||
return []
|
||
results = col.query(
|
||
query_embeddings=[vec],
|
||
n_results=min(top_k, col.count()) if col.count() > 0 else top_k,
|
||
include=["metadatas", "distances"],
|
||
)
|
||
|
||
if not results["ids"] or not results["ids"][0]:
|
||
return []
|
||
|
||
items = []
|
||
for i, arxiv_id in enumerate(results["ids"][0]):
|
||
dist = results["distances"][0][i] if results["distances"] else 0.0
|
||
items.append({"arxiv_id": arxiv_id, "distance": dist})
|
||
return items
|
||
|
||
except Exception:
|
||
logger.exception("ChromaDB search_similar failed")
|
||
return []
|