Files

349 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""ChromaDB 语义搜索服务 — embedding 生成、索引管理、相似查询。"""
from __future__ import annotations
import logging
import threading
from pathlib import Path
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from app.config import settings
from app.models import Paper
logger = logging.getLogger(__name__)
# ── ChromaDB 管理器(替代全局可变状态)──────────────────────────────────
class ChromaManager:
"""封装 ChromaDB 客户端和 collection 的生命周期。
所有客户端/集合访问经 ``self._lock`` 串行化:后处理经 ``asyncio.to_thread``
多 worker 并发调 ``index_paper``,若不串行化会并发建连,触发 chromadb 1.5.x
``SharedSystemClient`` 类级缓存的并发竞争(``_create_system_if_not_exists``
无锁 + refcount release 弹 key)→ ``KeyError: '<persist_dir>'``。
锁用 RLock``index_paper`` 持锁后经 ``get_collection()`` 间接再调 ``init()``
同线程可重入,不死锁。
"""
def __init__(self) -> None:
self._lock = threading.RLock()
self._client = None
self._collection = None
def init(self) -> None:
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。
双重检查锁串行化首次建连 —— 外层快判已建好就直接返回(快路径),抢到锁后再
判一次(防并发下另一线程已建好),确保 ``PersistentClient`` 全进程只调一次。
"""
if not settings.CHROMA_ENABLED:
logger.debug("ChromaDB disabled, skip init")
return
if self._client is not None:
return
with self._lock:
if self._client is not None: # 双重检查:抢到锁后可能已被别的线程建好
return
try:
import chromadb
chroma_path = Path(settings.CHROMA_DIR)
chroma_path.mkdir(parents=True, exist_ok=True)
self._client = chromadb.PersistentClient(path=str(chroma_path))
self._collection = self._get_or_create_collection()
logger.info("ChromaDB initialized at %s", chroma_path)
except Exception:
logger.exception("Failed to initialize ChromaDB")
self._client = None
self._collection = None
def _get_or_create_collection(self):
"""获取或创建 papers_embeddings collection。"""
try:
col = self._client.get_collection("papers_embeddings")
logger.info(
"ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count()
)
return col
except Exception:
pass
col = self._client.create_collection(
name="papers_embeddings",
metadata={"hnsw:space": "cosine"},
)
logger.info("ChromaDB collection 'papers_embeddings' created")
return col
def get_collection(self):
"""返回当前 collection,未初始化则自动初始化。"""
if not settings.CHROMA_ENABLED:
return None
if self._collection is None:
self.init()
return self._collection
def reset(self) -> None:
"""重置状态(供测试使用)。"""
self._client = None
self._collection = None
# 模块级单例
_chroma = ChromaManager()
def init_chroma() -> None:
"""初始化 ChromaDB(供 lifespan 调用)。"""
_chroma.init()
def get_collection():
"""返回当前 collection,未初始化则返回 None。"""
return _chroma.get_collection()
# ── Embedding API 调用 ──────────────────────────────────────────────────
def _get_embedding(text: str) -> list[float] | None:
"""调用 EMBED_API_BASE 的 embedding API 生成向量。
POST /v1/embeddings, model=EMBED_MODEL
校验返回向量长度 == EMBED_DIMENSIONS
失败时返回 None 并记录日志。
纯远程 HTTP 调用、线程安全 —— 留在锁外,让多 worker 并行调。
"""
if not settings.EMBED_API_BASE or not settings.EMBED_MODEL:
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
return None
from app.utils import make_http_client
url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings"
headers = {"Content-Type": "application/json"}
if settings.EMBED_API_KEY:
headers["Authorization"] = f"Bearer {settings.EMBED_API_KEY}"
payload = {
"model": settings.EMBED_MODEL,
"input": text,
}
try:
with make_http_client(sync=True) as client:
resp = client.post(url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
vec = data["data"][0]["embedding"]
if settings.EMBED_DIMENSIONS > 0 and len(vec) != settings.EMBED_DIMENSIONS:
logger.warning(
"Embedding dimension mismatch: expected=%d got=%d, skip",
settings.EMBED_DIMENSIONS,
len(vec),
)
return None
return vec
except Exception:
logger.exception("Embedding API call failed")
return None
# ── 索引文本拼接 ────────────────────────────────────────────────────────
def _build_index_text(paper: Paper, summary_dict: dict | None) -> str:
"""拼接高信号字段为一段文本用于 embedding。"""
parts: list[str] = []
if paper.title_zh:
parts.append(paper.title_zh)
if paper.title_en:
parts.append(paper.title_en)
if paper.tags:
tag_str = " ".join(t.tag for t in paper.tags)
if tag_str:
parts.append(tag_str)
if summary_dict:
for key in ("one_line", "motivation_problem", "method_key_idea"):
val = summary_dict.get(key)
if val:
parts.append(val)
return " ".join(parts)
# ── 单篇索引 ────────────────────────────────────────────────────────────
def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
"""将单篇论文写入 ChromaDB 语义索引。
Args:
paper_id: arxiv_id
texts_dict: 可选的预拼接字段 dict,包含 title_zh, title_en, tags, one_line 等
如果为 None,则从 DB 加载。
Returns:
True 表示成功,False 表示失败或跳过。
并发设计:远程 embedding 调用在锁外(多 worker 并行),chroma 集合访问
(含首次 init)在 ``_chroma._lock`` 内串行化。
"""
if not settings.CHROMA_ENABLED:
return False
try:
# 如果没传 texts_dict,从 DB 加载
if texts_dict is None:
from app.database import SessionLocal
db = SessionLocal()
try:
paper = (
db.execute(
select(Paper)
.where(Paper.arxiv_id == paper_id)
.options(joinedload(Paper.tags), joinedload(Paper.summary))
)
.unique()
.scalar_one_or_none()
)
if not paper:
logger.warning("Paper %s not found for indexing", paper_id)
return False
summary_dict = None
if paper.summary:
summary_dict = {
"one_line": paper.summary.one_line or "",
"motivation_problem": paper.summary.motivation_problem or "",
"method_key_idea": paper.summary.method_key_idea or "",
}
index_text = _build_index_text(paper, summary_dict)
arxiv_id = paper.arxiv_id
title_zh = paper.title_zh or ""
paper_date = paper.paper_date.isoformat() if paper.paper_date else ""
finally:
db.close()
else:
index_text = " ".join(
v for v in texts_dict.values() if isinstance(v, str) and v
)
arxiv_id = texts_dict.get("arxiv_id", paper_id)
title_zh = texts_dict.get("title_zh", "")
paper_date = texts_dict.get("paper_date", "")
if not index_text.strip():
logger.warning("Empty index text for %s, skip", paper_id)
return False
vec = _get_embedding(index_text) # 远程 HTTP,锁外并行
if vec is None:
return False
with _chroma._lock: # 串行化集合访问(首次含 init
col = _chroma.get_collection()
if col is None:
return False
col.upsert(
ids=[arxiv_id],
embeddings=[vec],
metadatas=[
{
"arxiv_id": arxiv_id,
"title_zh": title_zh,
"paper_date": paper_date,
}
],
)
logger.info("Indexed paper %s in ChromaDB", arxiv_id)
return True
except Exception:
logger.exception("Failed to index paper %s in ChromaDB", paper_id)
return False
# ── 删除 ────────────────────────────────────────────────────────────────
def delete_paper(paper_id: str) -> bool:
"""从 ChromaDB 删除论文索引。
Args:
paper_id: arxiv_id
"""
if not settings.CHROMA_ENABLED:
return False
with _chroma._lock:
col = _chroma.get_collection()
if col is None:
return False
try:
col.delete(ids=[paper_id])
logger.info("Deleted paper %s from ChromaDB", paper_id)
return True
except Exception:
logger.exception("Failed to delete paper %s from ChromaDB", paper_id)
return False
# ── 相似查询 ────────────────────────────────────────────────────────────
def search_similar(query_text: str, top_k: int = 20) -> list[dict]:
"""语义相似搜索。
Args:
query_text: 查询文本
top_k: 返回前 K 个结果
Returns:
[{"arxiv_id": str, "distance": float}, ...]
并发设计:远程 embedding 在锁外,集合查询在 ``_chroma._lock`` 内。
"""
if not settings.CHROMA_ENABLED:
return []
try:
vec = _get_embedding(query_text) # 远程 HTTP,锁外
if vec is None:
return []
with _chroma._lock:
col = _chroma.get_collection()
if col is None:
return []
results = col.query(
query_embeddings=[vec],
n_results=min(top_k, col.count()) if col.count() > 0 else top_k,
include=["metadatas", "distances"],
)
if not results["ids"] or not results["ids"][0]:
return []
items = []
for i, arxiv_id in enumerate(results["ids"][0]):
dist = results["distances"][0][i] if results["distances"] else 0.0
items.append({"arxiv_id": arxiv_id, "distance": dist})
return items
except Exception:
logger.exception("ChromaDB search_similar failed")
return []