240 lines
10 KiB
Python
240 lines
10 KiB
Python
"""Embedder / Chroma 服务测试 — 初始化、索引、embedding API。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import threading
|
||
import time
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
|
||
from app.config import settings
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 初始化
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestEmbedderInit:
|
||
"""embedder.py 初始化测试。"""
|
||
|
||
def test_chroma_disabled_skip_init(self, monkeypatch):
|
||
"""CHROMA_ENABLED=false 时不初始化。"""
|
||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||
import app.services.embedder as emb
|
||
|
||
emb._chroma.reset()
|
||
emb.init_chroma()
|
||
assert emb._chroma._client is None
|
||
|
||
def test_chroma_init_success(self, monkeypatch, tmp_path):
|
||
"""CHROMA_ENABLED=true 时初始化成功。"""
|
||
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
|
||
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||
|
||
import app.services.embedder as emb
|
||
|
||
emb._chroma.reset()
|
||
emb.init_chroma()
|
||
|
||
assert emb._chroma._client is not None
|
||
assert emb._chroma._collection is not None
|
||
|
||
# 清理
|
||
emb._chroma.reset()
|
||
|
||
def test_get_collection_returns_none_when_disabled(self, monkeypatch):
|
||
"""CHROMA_ENABLED=false 时 get_collection 返回 None。"""
|
||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||
import app.services.embedder as emb
|
||
|
||
emb._chroma.reset()
|
||
assert emb.get_collection() is None
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 索引
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestEmbedderIndexing:
|
||
"""embedder.py 索引测试。"""
|
||
|
||
def test_index_paper_disabled(self, monkeypatch):
|
||
"""CHROMA_ENABLED=false 时 index_paper 返回 False。"""
|
||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||
import app.services.embedder as emb
|
||
|
||
emb._chroma.reset()
|
||
assert emb.index_paper("test-id") is False
|
||
|
||
def test_index_paper_no_api_config(self, monkeypatch, tmp_path):
|
||
"""没有 EMBED_API_BASE 时返回 False。"""
|
||
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
|
||
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
|
||
monkeypatch.setattr(settings, "EMBED_MODEL", "")
|
||
|
||
import app.services.embedder as emb
|
||
|
||
emb._chroma.reset()
|
||
emb.init_chroma()
|
||
|
||
result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"})
|
||
assert result is False
|
||
|
||
emb._chroma.reset()
|
||
|
||
def test_delete_paper_disabled(self, monkeypatch):
|
||
"""CHROMA_ENABLED=false 时 delete_paper 返回 False。"""
|
||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||
import app.services.embedder as emb
|
||
|
||
emb._chroma.reset()
|
||
assert emb.delete_paper("test-id") is False
|
||
|
||
def test_search_similar_disabled(self, monkeypatch):
|
||
"""CHROMA_ENABLED=false 时 search_similar 返回空列表。"""
|
||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||
import app.services.embedder as emb
|
||
|
||
emb._chroma.reset()
|
||
assert emb.search_similar("test query") == []
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# Embedding API
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestEmbeddingApi:
|
||
"""_get_embedding 测试。"""
|
||
|
||
def test_no_api_base_returns_none(self, monkeypatch):
|
||
"""EMBED_API_BASE 为空时返回 None。"""
|
||
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
|
||
monkeypatch.setattr(settings, "EMBED_MODEL", "")
|
||
import app.services.embedder as emb
|
||
|
||
assert emb._get_embedding("test") is None
|
||
|
||
def test_dimension_mismatch_returns_none(self, monkeypatch):
|
||
"""维度不匹配时返回 None。"""
|
||
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
|
||
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
|
||
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
|
||
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 128)
|
||
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
|
||
|
||
import app.services.embedder as emb
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.json.return_value = {"data": [{"embedding": [0.1] * 64}]}
|
||
mock_resp.raise_for_status = MagicMock()
|
||
|
||
with patch("httpx.Client") as mock_client:
|
||
mock_client.return_value.__enter__ = MagicMock(return_value=mock_resp)
|
||
mock_client.return_value.__exit__ = MagicMock(return_value=False)
|
||
result = emb._get_embedding("test")
|
||
assert result is None
|
||
|
||
def test_api_failure_returns_none(self, monkeypatch):
|
||
"""API 调用失败时返回 None。"""
|
||
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
|
||
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
|
||
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
|
||
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 0)
|
||
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
|
||
|
||
import app.services.embedder as emb
|
||
|
||
with patch("httpx.Client") as mock_client:
|
||
mock_client.return_value.__enter__ = MagicMock()
|
||
mock_client.return_value.__exit__ = MagicMock(return_value=False)
|
||
mock_client.return_value.__enter__.return_value.post.side_effect = (
|
||
Exception("timeout")
|
||
)
|
||
result = emb._get_embedding("test")
|
||
assert result is None
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
# 并发安全:init() 双重检查锁 + 集合访问串行化
|
||
# ═══════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestEmbedderConcurrency:
|
||
"""后处理经 asyncio.to_thread 多 worker 并发调 index_paper 的安全性。"""
|
||
|
||
def test_init_serialized_under_concurrency(self, monkeypatch, tmp_path):
|
||
"""并发 init() 只调一次 PersistentClient(chromadb SharedSystemClient 缓存竞争修复)。
|
||
|
||
复现崩坏条件:10 线程同时 init(),fake PersistentClient 故意 sleep 拉长建连窗口。
|
||
修复前会有多线程同时进入 _create_system_if_not_exists → 并发 mutate 类级缓存;
|
||
修复后(双重检查锁)只有抢到锁的那个线程建连。
|
||
"""
|
||
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
|
||
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||
import app.services.embedder as emb
|
||
|
||
emb._chroma.reset()
|
||
|
||
counter = {"n": 0}
|
||
counter_lock = threading.Lock()
|
||
|
||
def fake_persistent_client(path):
|
||
with counter_lock:
|
||
counter["n"] += 1
|
||
time.sleep(0.05) # 拉长建连窗口,放大并发竞争
|
||
client = MagicMock()
|
||
client.get_collection.side_effect = Exception(
|
||
"not exist"
|
||
) # 触发 create 路径
|
||
client.create_collection.return_value = MagicMock()
|
||
return client
|
||
|
||
with patch("chromadb.PersistentClient", side_effect=fake_persistent_client):
|
||
threads = [threading.Thread(target=emb._chroma.init) for _ in range(10)]
|
||
for t in threads:
|
||
t.start()
|
||
for t in threads:
|
||
t.join()
|
||
|
||
assert counter["n"] == 1, f"PersistentClient 应只调一次,实际 {counter['n']}"
|
||
assert emb._chroma._client is not None
|
||
emb._chroma.reset()
|
||
|
||
def test_index_paper_concurrent_no_error(self, monkeypatch, tmp_path):
|
||
"""并发 index_paper:embedding 锁外并行,集合写入串行化,全部成功。"""
|
||
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
|
||
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||
import app.services.embedder as emb
|
||
|
||
emb._chroma.reset()
|
||
# 跳过 init,直接注入 mock collection
|
||
emb._chroma._client = MagicMock()
|
||
col = MagicMock()
|
||
col.count.return_value = 0
|
||
emb._chroma._collection = col
|
||
|
||
with patch.object(emb, "_get_embedding", return_value=[0.1, 0.2, 0.3]):
|
||
errors: list[BaseException] = []
|
||
|
||
def worker(i: int) -> None:
|
||
try:
|
||
emb.index_paper(
|
||
f"id-{i}", {"arxiv_id": f"id-{i}", "title_zh": f"标题{i}"}
|
||
)
|
||
except BaseException as exc: # noqa: BLE001 — 收集所有错误
|
||
errors.append(exc)
|
||
|
||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)]
|
||
for t in threads:
|
||
t.start()
|
||
for t in threads:
|
||
t.join()
|
||
|
||
assert errors == []
|
||
assert col.upsert.call_count == 10
|
||
emb._chroma.reset()
|