Files

240 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Embedder / Chroma 服务测试 — 初始化、索引、embedding API。"""
from __future__ import annotations
import threading
import time
from unittest.mock import MagicMock, patch
from app.config import settings
# ═══════════════════════════════════════════════════════════════════════
# 初始化
# ═══════════════════════════════════════════════════════════════════════
class TestEmbedderInit:
"""embedder.py 初始化测试。"""
def test_chroma_disabled_skip_init(self, monkeypatch):
"""CHROMA_ENABLED=false 时不初始化。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
emb.init_chroma()
assert emb._chroma._client is None
def test_chroma_init_success(self, monkeypatch, tmp_path):
"""CHROMA_ENABLED=true 时初始化成功。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
import app.services.embedder as emb
emb._chroma.reset()
emb.init_chroma()
assert emb._chroma._client is not None
assert emb._chroma._collection is not None
# 清理
emb._chroma.reset()
def test_get_collection_returns_none_when_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 get_collection 返回 None。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
assert emb.get_collection() is None
# ═══════════════════════════════════════════════════════════════════════
# 索引
# ═══════════════════════════════════════════════════════════════════════
class TestEmbedderIndexing:
"""embedder.py 索引测试。"""
def test_index_paper_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 index_paper 返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
assert emb.index_paper("test-id") is False
def test_index_paper_no_api_config(self, monkeypatch, tmp_path):
"""没有 EMBED_API_BASE 时返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
monkeypatch.setattr(settings, "EMBED_MODEL", "")
import app.services.embedder as emb
emb._chroma.reset()
emb.init_chroma()
result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"})
assert result is False
emb._chroma.reset()
def test_delete_paper_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 delete_paper 返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
assert emb.delete_paper("test-id") is False
def test_search_similar_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 search_similar 返回空列表。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._chroma.reset()
assert emb.search_similar("test query") == []
# ═══════════════════════════════════════════════════════════════════════
# Embedding API
# ═══════════════════════════════════════════════════════════════════════
class TestEmbeddingApi:
"""_get_embedding 测试。"""
def test_no_api_base_returns_none(self, monkeypatch):
"""EMBED_API_BASE 为空时返回 None。"""
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
monkeypatch.setattr(settings, "EMBED_MODEL", "")
import app.services.embedder as emb
assert emb._get_embedding("test") is None
def test_dimension_mismatch_returns_none(self, monkeypatch):
"""维度不匹配时返回 None。"""
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 128)
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
import app.services.embedder as emb
mock_resp = MagicMock()
mock_resp.json.return_value = {"data": [{"embedding": [0.1] * 64}]}
mock_resp.raise_for_status = MagicMock()
with patch("httpx.Client") as mock_client:
mock_client.return_value.__enter__ = MagicMock(return_value=mock_resp)
mock_client.return_value.__exit__ = MagicMock(return_value=False)
result = emb._get_embedding("test")
assert result is None
def test_api_failure_returns_none(self, monkeypatch):
"""API 调用失败时返回 None。"""
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 0)
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
import app.services.embedder as emb
with patch("httpx.Client") as mock_client:
mock_client.return_value.__enter__ = MagicMock()
mock_client.return_value.__exit__ = MagicMock(return_value=False)
mock_client.return_value.__enter__.return_value.post.side_effect = (
Exception("timeout")
)
result = emb._get_embedding("test")
assert result is None
# ═══════════════════════════════════════════════════════════════════════
# 并发安全:init() 双重检查锁 + 集合访问串行化
# ═══════════════════════════════════════════════════════════════════════
class TestEmbedderConcurrency:
"""后处理经 asyncio.to_thread 多 worker 并发调 index_paper 的安全性。"""
def test_init_serialized_under_concurrency(self, monkeypatch, tmp_path):
"""并发 init() 只调一次 PersistentClientchromadb SharedSystemClient 缓存竞争修复)。
复现崩坏条件:10 线程同时 init()fake PersistentClient 故意 sleep 拉长建连窗口。
修复前会有多线程同时进入 _create_system_if_not_exists → 并发 mutate 类级缓存;
修复后(双重检查锁)只有抢到锁的那个线程建连。
"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
import app.services.embedder as emb
emb._chroma.reset()
counter = {"n": 0}
counter_lock = threading.Lock()
def fake_persistent_client(path):
with counter_lock:
counter["n"] += 1
time.sleep(0.05) # 拉长建连窗口,放大并发竞争
client = MagicMock()
client.get_collection.side_effect = Exception(
"not exist"
) # 触发 create 路径
client.create_collection.return_value = MagicMock()
return client
with patch("chromadb.PersistentClient", side_effect=fake_persistent_client):
threads = [threading.Thread(target=emb._chroma.init) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert counter["n"] == 1, f"PersistentClient 应只调一次,实际 {counter['n']}"
assert emb._chroma._client is not None
emb._chroma.reset()
def test_index_paper_concurrent_no_error(self, monkeypatch, tmp_path):
"""并发 index_paperembedding 锁外并行,集合写入串行化,全部成功。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
import app.services.embedder as emb
emb._chroma.reset()
# 跳过 init,直接注入 mock collection
emb._chroma._client = MagicMock()
col = MagicMock()
col.count.return_value = 0
emb._chroma._collection = col
with patch.object(emb, "_get_embedding", return_value=[0.1, 0.2, 0.3]):
errors: list[BaseException] = []
def worker(i: int) -> None:
try:
emb.index_paper(
f"id-{i}", {"arxiv_id": f"id-{i}", "title_zh": f"标题{i}"}
)
except BaseException as exc: # noqa: BLE001 — 收集所有错误
errors.append(exc)
threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert errors == []
assert col.upsert.call_count == 10
emb._chroma.reset()