feat: add concurrency safety, caption detection, admin enhancements, and performance improvements
This commit is contained in:
+78
-40
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select
|
||||
@@ -18,14 +19,27 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChromaManager:
|
||||
"""封装 ChromaDB 客户端和 collection 的生命周期。"""
|
||||
"""封装 ChromaDB 客户端和 collection 的生命周期。
|
||||
|
||||
所有客户端/集合访问经 ``self._lock`` 串行化:后处理经 ``asyncio.to_thread``
|
||||
多 worker 并发调 ``index_paper``,若不串行化会并发建连,触发 chromadb 1.5.x
|
||||
``SharedSystemClient`` 类级缓存的并发竞争(``_create_system_if_not_exists``
|
||||
无锁 + refcount release 弹 key)→ ``KeyError: '<persist_dir>'``。
|
||||
锁用 RLock:``index_paper`` 持锁后经 ``get_collection()`` 间接再调 ``init()``,
|
||||
同线程可重入,不死锁。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.RLock()
|
||||
self._client = None
|
||||
self._collection = None
|
||||
|
||||
def init(self) -> None:
|
||||
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
|
||||
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。
|
||||
|
||||
双重检查锁串行化首次建连 —— 外层快判已建好就直接返回(快路径),抢到锁后再
|
||||
判一次(防并发下另一线程已建好),确保 ``PersistentClient`` 全进程只调一次。
|
||||
"""
|
||||
if not settings.CHROMA_ENABLED:
|
||||
logger.debug("ChromaDB disabled, skip init")
|
||||
return
|
||||
@@ -33,19 +47,23 @@ class ChromaManager:
|
||||
if self._client is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
with self._lock:
|
||||
if self._client is not None: # 双重检查:抢到锁后可能已被别的线程建好
|
||||
return
|
||||
|
||||
chroma_path = Path(settings.CHROMA_DIR)
|
||||
chroma_path.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
import chromadb
|
||||
|
||||
self._client = chromadb.PersistentClient(path=str(chroma_path))
|
||||
self._collection = self._get_or_create_collection()
|
||||
logger.info("ChromaDB initialized at %s", chroma_path)
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize ChromaDB")
|
||||
self._client = None
|
||||
self._collection = None
|
||||
chroma_path = Path(settings.CHROMA_DIR)
|
||||
chroma_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._client = chromadb.PersistentClient(path=str(chroma_path))
|
||||
self._collection = self._get_or_create_collection()
|
||||
logger.info("ChromaDB initialized at %s", chroma_path)
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize ChromaDB")
|
||||
self._client = None
|
||||
self._collection = None
|
||||
|
||||
def _get_or_create_collection(self):
|
||||
"""获取或创建 papers_embeddings collection。"""
|
||||
@@ -102,6 +120,8 @@ def _get_embedding(text: str) -> list[float] | None:
|
||||
POST /v1/embeddings, model=EMBED_MODEL
|
||||
校验返回向量长度 == EMBED_DIMENSIONS
|
||||
失败时返回 None 并记录日志。
|
||||
|
||||
纯远程 HTTP 调用、线程安全 —— 留在锁外,让多 worker 并行调。
|
||||
"""
|
||||
if not settings.EMBED_API_BASE or not settings.EMBED_MODEL:
|
||||
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
|
||||
@@ -177,9 +197,11 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
|
||||
|
||||
Returns:
|
||||
True 表示成功,False 表示失败或跳过。
|
||||
|
||||
并发设计:远程 embedding 调用在锁外(多 worker 并行),chroma 集合访问
|
||||
(含首次 init)在 ``_chroma._lock`` 内串行化。
|
||||
"""
|
||||
col = get_collection()
|
||||
if col is None:
|
||||
if not settings.CHROMA_ENABLED:
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -227,17 +249,25 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
|
||||
logger.warning("Empty index text for %s, skip", paper_id)
|
||||
return False
|
||||
|
||||
vec = _get_embedding(index_text)
|
||||
vec = _get_embedding(index_text) # 远程 HTTP,锁外并行
|
||||
if vec is None:
|
||||
return False
|
||||
|
||||
col.upsert(
|
||||
ids=[arxiv_id],
|
||||
embeddings=[vec],
|
||||
metadatas=[
|
||||
{"arxiv_id": arxiv_id, "title_zh": title_zh, "paper_date": paper_date}
|
||||
],
|
||||
)
|
||||
with _chroma._lock: # 串行化集合访问(首次含 init)
|
||||
col = _chroma.get_collection()
|
||||
if col is None:
|
||||
return False
|
||||
col.upsert(
|
||||
ids=[arxiv_id],
|
||||
embeddings=[vec],
|
||||
metadatas=[
|
||||
{
|
||||
"arxiv_id": arxiv_id,
|
||||
"title_zh": title_zh,
|
||||
"paper_date": paper_date,
|
||||
}
|
||||
],
|
||||
)
|
||||
logger.info("Indexed paper %s in ChromaDB", arxiv_id)
|
||||
return True
|
||||
|
||||
@@ -255,17 +285,20 @@ def delete_paper(paper_id: str) -> bool:
|
||||
Args:
|
||||
paper_id: arxiv_id
|
||||
"""
|
||||
col = get_collection()
|
||||
if col is None:
|
||||
if not settings.CHROMA_ENABLED:
|
||||
return False
|
||||
|
||||
try:
|
||||
col.delete(ids=[paper_id])
|
||||
logger.info("Deleted paper %s from ChromaDB", paper_id)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to delete paper %s from ChromaDB", paper_id)
|
||||
return False
|
||||
with _chroma._lock:
|
||||
col = _chroma.get_collection()
|
||||
if col is None:
|
||||
return False
|
||||
try:
|
||||
col.delete(ids=[paper_id])
|
||||
logger.info("Deleted paper %s from ChromaDB", paper_id)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to delete paper %s from ChromaDB", paper_id)
|
||||
return False
|
||||
|
||||
|
||||
# ── 相似查询 ────────────────────────────────────────────────────────────
|
||||
@@ -280,21 +313,26 @@ def search_similar(query_text: str, top_k: int = 20) -> list[dict]:
|
||||
|
||||
Returns:
|
||||
[{"arxiv_id": str, "distance": float}, ...]
|
||||
|
||||
并发设计:远程 embedding 在锁外,集合查询在 ``_chroma._lock`` 内。
|
||||
"""
|
||||
col = get_collection()
|
||||
if col is None:
|
||||
if not settings.CHROMA_ENABLED:
|
||||
return []
|
||||
|
||||
try:
|
||||
vec = _get_embedding(query_text)
|
||||
vec = _get_embedding(query_text) # 远程 HTTP,锁外
|
||||
if vec is None:
|
||||
return []
|
||||
|
||||
results = col.query(
|
||||
query_embeddings=[vec],
|
||||
n_results=min(top_k, col.count()) if col.count() > 0 else top_k,
|
||||
include=["metadatas", "distances"],
|
||||
)
|
||||
with _chroma._lock:
|
||||
col = _chroma.get_collection()
|
||||
if col is None:
|
||||
return []
|
||||
results = col.query(
|
||||
query_embeddings=[vec],
|
||||
n_results=min(top_k, col.count()) if col.count() > 0 else top_k,
|
||||
include=["metadatas", "distances"],
|
||||
)
|
||||
|
||||
if not results["ids"] or not results["ids"][0]:
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user