85c4cfb9e8
- Add image_extractor, pdf_downloader, pi_client, trends services - Add shared utils module - Refactor summarizer, embedder, routes for cleaner separation - Update tests to match new service structure
331 lines
11 KiB
Python
331 lines
11 KiB
Python
"""ChromaDB 语义搜索服务 — embedding 生成、索引管理、相似查询。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
from sqlalchemy.orm import Session, joinedload
|
|
|
|
from app.config import settings
|
|
from app.models import Paper
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ── ChromaDB 管理器(替代全局可变状态)──────────────────────────────────
|
|
|
|
|
|
class ChromaManager:
|
|
"""封装 ChromaDB 客户端和 collection 的生命周期。"""
|
|
|
|
def __init__(self) -> None:
|
|
self._client = None
|
|
self._collection = None
|
|
|
|
def init(self) -> None:
|
|
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
|
|
if not settings.CHROMA_ENABLED:
|
|
logger.debug("ChromaDB disabled, skip init")
|
|
return
|
|
|
|
if self._client is not None:
|
|
return
|
|
|
|
try:
|
|
import chromadb
|
|
|
|
chroma_path = Path(settings.CHROMA_DIR)
|
|
chroma_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
self._client = chromadb.PersistentClient(path=str(chroma_path))
|
|
self._collection = self._get_or_create_collection()
|
|
logger.info("ChromaDB initialized at %s", chroma_path)
|
|
except Exception:
|
|
logger.exception("Failed to initialize ChromaDB")
|
|
self._client = None
|
|
self._collection = None
|
|
|
|
def _get_or_create_collection(self):
|
|
"""获取或创建 papers_embeddings collection。"""
|
|
try:
|
|
col = self._client.get_collection("papers_embeddings")
|
|
logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count())
|
|
return col
|
|
except Exception:
|
|
pass
|
|
|
|
col = self._client.create_collection(
|
|
name="papers_embeddings",
|
|
metadata={"hnsw:space": "cosine"},
|
|
)
|
|
logger.info("ChromaDB collection 'papers_embeddings' created")
|
|
return col
|
|
|
|
def get_collection(self):
|
|
"""返回当前 collection,未初始化则自动初始化。"""
|
|
if not settings.CHROMA_ENABLED:
|
|
return None
|
|
if self._collection is None:
|
|
self.init()
|
|
return self._collection
|
|
|
|
def reset(self) -> None:
|
|
"""重置状态(供测试使用)。"""
|
|
self._client = None
|
|
self._collection = None
|
|
|
|
|
|
# 模块级单例
|
|
_chroma = ChromaManager()
|
|
|
|
|
|
def init_chroma() -> None:
|
|
"""初始化 ChromaDB(供 lifespan 调用)。"""
|
|
_chroma.init()
|
|
|
|
|
|
def get_collection():
|
|
"""返回当前 collection,未初始化则返回 None。"""
|
|
return _chroma.get_collection()
|
|
|
|
|
|
# ── Embedding API 调用 ──────────────────────────────────────────────────
|
|
|
|
|
|
def _get_embedding(text: str) -> list[float] | None:
|
|
"""调用 EMBED_API_BASE 的 embedding API 生成向量。
|
|
|
|
POST /v1/embeddings, model=EMBED_MODEL
|
|
校验返回向量长度 == EMBED_DIMENSIONS
|
|
失败时返回 None 并记录日志。
|
|
"""
|
|
if not settings.EMBED_API_BASE or not settings.EMBED_MODEL:
|
|
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
|
|
return None
|
|
|
|
from app.utils import make_http_client
|
|
|
|
url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings"
|
|
headers = {"Content-Type": "application/json"}
|
|
if settings.EMBED_API_KEY:
|
|
headers["Authorization"] = f"Bearer {settings.EMBED_API_KEY}"
|
|
|
|
payload = {
|
|
"model": settings.EMBED_MODEL,
|
|
"input": text,
|
|
}
|
|
|
|
try:
|
|
with make_http_client(sync=True) as client:
|
|
resp = client.post(url, json=payload, headers=headers)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
|
|
vec = data["data"][0]["embedding"]
|
|
if settings.EMBED_DIMENSIONS > 0 and len(vec) != settings.EMBED_DIMENSIONS:
|
|
logger.warning(
|
|
"Embedding dimension mismatch: expected=%d got=%d, skip",
|
|
settings.EMBED_DIMENSIONS,
|
|
len(vec),
|
|
)
|
|
return None
|
|
return vec
|
|
except Exception:
|
|
logger.exception("Embedding API call failed")
|
|
return None
|
|
|
|
|
|
# ── 索引文本拼接 ────────────────────────────────────────────────────────
|
|
|
|
|
|
def _build_index_text(paper: Paper, summary_dict: dict | None) -> str:
|
|
"""拼接高信号字段为一段文本用于 embedding。"""
|
|
parts: list[str] = []
|
|
|
|
if paper.title_zh:
|
|
parts.append(paper.title_zh)
|
|
if paper.title_en:
|
|
parts.append(paper.title_en)
|
|
if paper.tags:
|
|
tag_str = " ".join(t.tag for t in paper.tags)
|
|
if tag_str:
|
|
parts.append(tag_str)
|
|
|
|
if summary_dict:
|
|
for key in ("one_line", "motivation_problem", "method_key_idea"):
|
|
val = summary_dict.get(key)
|
|
if val:
|
|
parts.append(val)
|
|
|
|
return " ".join(parts)
|
|
|
|
|
|
# ── 单篇索引 ────────────────────────────────────────────────────────────
|
|
|
|
|
|
def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
|
|
"""将单篇论文写入 ChromaDB 语义索引。
|
|
|
|
Args:
|
|
paper_id: arxiv_id
|
|
texts_dict: 可选的预拼接字段 dict,包含 title_zh, title_en, tags, one_line 等
|
|
如果为 None,则从 DB 加载。
|
|
|
|
Returns:
|
|
True 表示成功,False 表示失败或跳过。
|
|
"""
|
|
col = get_collection()
|
|
if col is None:
|
|
return False
|
|
|
|
try:
|
|
# 如果没传 texts_dict,从 DB 加载
|
|
if texts_dict is None:
|
|
from app.database import SessionLocal
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
paper = (
|
|
db.query(Paper)
|
|
.filter(Paper.arxiv_id == paper_id)
|
|
.options(joinedload(Paper.tags), joinedload(Paper.summary))
|
|
.first()
|
|
)
|
|
if not paper:
|
|
logger.warning("Paper %s not found for indexing", paper_id)
|
|
return False
|
|
|
|
summary_dict = None
|
|
if paper.summary:
|
|
summary_dict = {
|
|
"one_line": paper.summary.one_line or "",
|
|
"motivation_problem": paper.summary.motivation_problem or "",
|
|
"method_key_idea": paper.summary.method_key_idea or "",
|
|
}
|
|
index_text = _build_index_text(paper, summary_dict)
|
|
arxiv_id = paper.arxiv_id
|
|
title_zh = paper.title_zh or ""
|
|
paper_date = paper.paper_date.isoformat() if paper.paper_date else ""
|
|
finally:
|
|
db.close()
|
|
else:
|
|
index_text = " ".join(
|
|
v for v in texts_dict.values() if isinstance(v, str) and v
|
|
)
|
|
arxiv_id = texts_dict.get("arxiv_id", paper_id)
|
|
title_zh = texts_dict.get("title_zh", "")
|
|
paper_date = texts_dict.get("paper_date", "")
|
|
|
|
if not index_text.strip():
|
|
logger.warning("Empty index text for %s, skip", paper_id)
|
|
return False
|
|
|
|
vec = _get_embedding(index_text)
|
|
if vec is None:
|
|
return False
|
|
|
|
col.upsert(
|
|
ids=[arxiv_id],
|
|
embeddings=[vec],
|
|
metadatas=[{"arxiv_id": arxiv_id, "title_zh": title_zh, "paper_date": paper_date}],
|
|
)
|
|
logger.info("Indexed paper %s in ChromaDB", arxiv_id)
|
|
return True
|
|
|
|
except Exception:
|
|
logger.exception("Failed to index paper %s in ChromaDB", paper_id)
|
|
return False
|
|
|
|
|
|
# ── 批量索引 ────────────────────────────────────────────────────────────
|
|
|
|
|
|
def index_batch(paper_ids: list[str]) -> dict:
|
|
"""批量索引论文,单篇失败不影响其他。
|
|
|
|
Returns:
|
|
{"total": int, "success": int, "failed": int}
|
|
"""
|
|
if not paper_ids:
|
|
return {"total": 0, "success": 0, "failed": 0}
|
|
|
|
col = get_collection()
|
|
if col is None:
|
|
return {"total": len(paper_ids), "success": 0, "failed": len(paper_ids)}
|
|
|
|
success = 0
|
|
failed = 0
|
|
for pid in paper_ids:
|
|
if index_paper(pid):
|
|
success += 1
|
|
else:
|
|
failed += 1
|
|
|
|
logger.info("Batch index: total=%d success=%d failed=%d", len(paper_ids), success, failed)
|
|
return {"total": len(paper_ids), "success": success, "failed": failed}
|
|
|
|
|
|
# ── 删除 ────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def delete_paper(paper_id: str) -> bool:
|
|
"""从 ChromaDB 删除论文索引。
|
|
|
|
Args:
|
|
paper_id: arxiv_id
|
|
"""
|
|
col = get_collection()
|
|
if col is None:
|
|
return False
|
|
|
|
try:
|
|
col.delete(ids=[paper_id])
|
|
logger.info("Deleted paper %s from ChromaDB", paper_id)
|
|
return True
|
|
except Exception:
|
|
logger.exception("Failed to delete paper %s from ChromaDB", paper_id)
|
|
return False
|
|
|
|
|
|
# ── 相似查询 ────────────────────────────────────────────────────────────
|
|
|
|
|
|
def search_similar(query_text: str, top_k: int = 20) -> list[dict]:
|
|
"""语义相似搜索。
|
|
|
|
Args:
|
|
query_text: 查询文本
|
|
top_k: 返回前 K 个结果
|
|
|
|
Returns:
|
|
[{"arxiv_id": str, "distance": float}, ...]
|
|
"""
|
|
col = get_collection()
|
|
if col is None:
|
|
return []
|
|
|
|
try:
|
|
vec = _get_embedding(query_text)
|
|
if vec is None:
|
|
return []
|
|
|
|
results = col.query(
|
|
query_embeddings=[vec],
|
|
n_results=min(top_k, col.count()) if col.count() > 0 else top_k,
|
|
include=["metadatas", "distances"],
|
|
)
|
|
|
|
if not results["ids"] or not results["ids"][0]:
|
|
return []
|
|
|
|
items = []
|
|
for i, arxiv_id in enumerate(results["ids"][0]):
|
|
dist = results["distances"][0][i] if results["distances"] else 0.0
|
|
items.append({"arxiv_id": arxiv_id, "distance": dist})
|
|
return items
|
|
|
|
except Exception:
|
|
logger.exception("ChromaDB search_similar failed")
|
|
return []
|