feat: add compare, trends routes, embedder service, and phase5 tests

This commit is contained in:
2026-06-05 23:32:06 +08:00
parent 2cfd1a8a9f
commit ba9afa212c
17 changed files with 2122 additions and 27 deletions
+7
View File
@@ -139,6 +139,13 @@ async def delete_papers_by_date_range(
{"paper_id": paper_id},
)
# 1.5 Phase 5: 从 ChromaDB 删除语义索引
try:
from app.services.embedder import delete_paper
delete_paper(arxiv_id)
except Exception:
logger.warning("Failed to delete %s from ChromaDB", arxiv_id, exc_info=True)
# 2. 删除本地文件 data/papers/{arxiv_id}/
paper_dir = _PAPERS_DIR / arxiv_id
if paper_dir.exists():
+314
View File
@@ -0,0 +1,314 @@
"""ChromaDB 语义搜索服务 — embedding 生成、索引管理、相似查询。"""
from __future__ import annotations
import logging
from pathlib import Path
import httpx
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.models import Paper
logger = logging.getLogger(__name__)
# ── 单例客户端和 collection ─────────────────────────────────────────────
_client = None
_collection = None
def _chroma_dir() -> Path:
return Path(settings.CHROMA_DIR)
def init_chroma() -> None:
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
global _client, _collection
if not settings.CHROMA_ENABLED:
logger.debug("ChromaDB disabled, skip init")
return
if _client is not None:
return
try:
import chromadb
chroma_path = _chroma_dir()
chroma_path.mkdir(parents=True, exist_ok=True)
_client = chromadb.PersistentClient(path=str(chroma_path))
_collection = _get_or_create_collection()
logger.info("ChromaDB initialized at %s", chroma_path)
except Exception:
logger.exception("Failed to initialize ChromaDB")
_client = None
_collection = None
def _get_or_create_collection():
"""获取或创建 papers_embeddings collection,维度不匹配时记录日志并跳过。"""
import chromadb
try:
col = _client.get_collection("papers_embeddings")
logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count())
return col
except Exception:
pass
col = _client.create_collection(
name="papers_embeddings",
metadata={"hnsw:space": "cosine"},
)
logger.info("ChromaDB collection 'papers_embeddings' created")
return col
def get_collection():
"""返回当前 collection,未初始化则返回 None。"""
if not settings.CHROMA_ENABLED:
return None
if _collection is None:
init_chroma()
return _collection
# ── Embedding API 调用 ──────────────────────────────────────────────────
def _get_embedding(text: str) -> list[float] | None:
"""调用 EMBED_API_BASE 的 embedding API 生成向量。
POST /v1/embeddings, model=EMBED_MODEL
校验返回向量长度 == EMBED_DIMENSIONS
失败时返回 None 并记录日志。
"""
if not settings.EMBED_API_BASE or not settings.EMBED_MODEL:
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
return None
url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings"
headers = {"Content-Type": "application/json"}
if settings.EMBED_API_KEY:
headers["Authorization"] = f"Bearer {settings.EMBED_API_KEY}"
payload = {
"model": settings.EMBED_MODEL,
"input": text,
}
try:
with httpx.Client(timeout=settings.HTTP_TIMEOUT_SECONDS) as client:
resp = client.post(url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
vec = data["data"][0]["embedding"]
if settings.EMBED_DIMENSIONS > 0 and len(vec) != settings.EMBED_DIMENSIONS:
logger.warning(
"Embedding dimension mismatch: expected=%d got=%d, skip",
settings.EMBED_DIMENSIONS,
len(vec),
)
return None
return vec
except Exception:
logger.exception("Embedding API call failed")
return None
# ── 索引文本拼接 ────────────────────────────────────────────────────────
def _build_index_text(paper: Paper, summary_dict: dict | None) -> str:
"""拼接高信号字段为一段文本用于 embedding。"""
parts: list[str] = []
if paper.title_zh:
parts.append(paper.title_zh)
if paper.title_en:
parts.append(paper.title_en)
if paper.tags:
tag_str = " ".join(t.tag for t in paper.tags)
if tag_str:
parts.append(tag_str)
if summary_dict:
for key in ("one_line", "motivation_problem", "method_key_idea"):
val = summary_dict.get(key)
if val:
parts.append(val)
return " ".join(parts)
# ── 单篇索引 ────────────────────────────────────────────────────────────
def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
"""将单篇论文写入 ChromaDB 语义索引。
Args:
paper_id: arxiv_id
texts_dict: 可选的预拼接字段 dict,包含 title_zh, title_en, tags, one_line 等
如果为 None,则从 DB 加载。
Returns:
True 表示成功,False 表示失败或跳过。
"""
col = get_collection()
if col is None:
return False
try:
# 如果没传 texts_dict,从 DB 加载
if texts_dict is None:
from app.database import SessionLocal
db = SessionLocal()
try:
paper = (
db.query(Paper)
.filter(Paper.arxiv_id == paper_id)
.options(joinedload(Paper.tags), joinedload(Paper.summary))
.first()
)
if not paper:
logger.warning("Paper %s not found for indexing", paper_id)
return False
summary_dict = None
if paper.summary:
summary_dict = {
"one_line": paper.summary.one_line or "",
"motivation_problem": paper.summary.motivation_problem or "",
"method_key_idea": paper.summary.method_key_idea or "",
}
index_text = _build_index_text(paper, summary_dict)
arxiv_id = paper.arxiv_id
title_zh = paper.title_zh or ""
paper_date = paper.paper_date.isoformat() if paper.paper_date else ""
finally:
db.close()
else:
index_text = " ".join(
v for v in texts_dict.values() if isinstance(v, str) and v
)
arxiv_id = texts_dict.get("arxiv_id", paper_id)
title_zh = texts_dict.get("title_zh", "")
paper_date = texts_dict.get("paper_date", "")
if not index_text.strip():
logger.warning("Empty index text for %s, skip", paper_id)
return False
vec = _get_embedding(index_text)
if vec is None:
return False
col.upsert(
ids=[arxiv_id],
embeddings=[vec],
metadatas=[{"arxiv_id": arxiv_id, "title_zh": title_zh, "paper_date": paper_date}],
)
logger.info("Indexed paper %s in ChromaDB", arxiv_id)
return True
except Exception:
logger.exception("Failed to index paper %s in ChromaDB", paper_id)
return False
# ── 批量索引 ────────────────────────────────────────────────────────────
def index_batch(paper_ids: list[str]) -> dict:
"""批量索引论文,单篇失败不影响其他。
Returns:
{"total": int, "success": int, "failed": int}
"""
if not paper_ids:
return {"total": 0, "success": 0, "failed": 0}
col = get_collection()
if col is None:
return {"total": len(paper_ids), "success": 0, "failed": len(paper_ids)}
success = 0
failed = 0
for pid in paper_ids:
if index_paper(pid):
success += 1
else:
failed += 1
logger.info("Batch index: total=%d success=%d failed=%d", len(paper_ids), success, failed)
return {"total": len(paper_ids), "success": success, "failed": failed}
# ── 删除 ────────────────────────────────────────────────────────────────
def delete_paper(paper_id: str) -> bool:
"""从 ChromaDB 删除论文索引。
Args:
paper_id: arxiv_id
"""
col = get_collection()
if col is None:
return False
try:
col.delete(ids=[paper_id])
logger.info("Deleted paper %s from ChromaDB", paper_id)
return True
except Exception:
logger.exception("Failed to delete paper %s from ChromaDB", paper_id)
return False
# ── 相似查询 ────────────────────────────────────────────────────────────
def search_similar(query_text: str, top_k: int = 20) -> list[dict]:
"""语义相似搜索。
Args:
query_text: 查询文本
top_k: 返回前 K 个结果
Returns:
[{"arxiv_id": str, "distance": float}, ...]
"""
col = get_collection()
if col is None:
return []
try:
vec = _get_embedding(query_text)
if vec is None:
return []
results = col.query(
query_embeddings=[vec],
n_results=min(top_k, col.count()) if col.count() > 0 else top_k,
include=["metadatas", "distances"],
)
if not results["ids"] or not results["ids"][0]:
return []
items = []
for i, arxiv_id in enumerate(results["ids"][0]):
dist = results["distances"][0][i] if results["distances"] else 0.0
items.append({"arxiv_id": arxiv_id, "distance": dist})
return items
except Exception:
logger.exception("ChromaDB search_similar failed")
return []
+84 -2
View File
@@ -1,15 +1,19 @@
"""FTS5 全文搜索服务 — 关键词 + 标签筛选,命中片段高亮,分页。"""
"""搜索服务 — FTS5 关键词搜索 + ChromaDB 语义搜索,命中片段高亮,分页。"""
from __future__ import annotations
import logging
import math
import re
from sqlalchemy import text
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.models import Paper
logger = logging.getLogger(__name__)
# ── 输入清洗 ──────────────────────────────────────────────────────────
# FTS5 查询语法中的特殊字符,用户输入时需要移除
@@ -41,8 +45,9 @@ def search_papers(
sort: str = "relevance",
page: int = 1,
page_size: int = 20,
mode: str = "keyword",
) -> dict:
"""FTS5 搜索论文。
"""搜索论文,支持 keyword (FTS5) 和 semantic (ChromaDB) 两种模式
返回::
{
@@ -51,8 +56,14 @@ def search_papers(
"total": int,
"page": int,
"total_pages": int,
"distances": dict[str, float], # arxiv_id → distance (仅 semantic)
}
"""
# ── semantic 模式 ──
if mode == "semantic" and settings.CHROMA_ENABLED and query:
return _search_semantic(db, query, tag, sort, page, page_size)
# ── keyword 模式(默认)──
match_expr = _sanitize_query(query) if query else None
# ── 无关键词 + 无标签 → 空结果 ──
@@ -63,6 +74,7 @@ def search_papers(
"total": 0,
"page": page,
"total_pages": 0,
"distances": {},
}
# ── 构建条件性 JOIN 和 WHERE 片段 ──
@@ -146,6 +158,75 @@ def _search_with_fts(
"total": total,
"page": page,
"total_pages": math.ceil(total / page_size) if total else 0,
"distances": {},
}
def _search_semantic(
db: Session,
query: str,
tag: str | None,
sort: str,
page: int,
page_size: int,
) -> dict:
"""ChromaDB 语义搜索,失败时回退到 FTS5。"""
try:
from app.services.embedder import search_similar
top_k = page_size * 3 # 多取一些用于 tag 过滤
candidates = search_similar(query, top_k=top_k)
except Exception:
logger.exception("Semantic search failed, falling back to keyword")
candidates = []
if not candidates:
# 回退到 FTS5
return _search_with_fts(
db,
_sanitize_query(query) or query,
"JOIN paper_tags pt ON pt.paper_id = p.id" if tag else "",
"AND pt.tag = :tag" if tag else "",
{"tag": tag} if tag else {},
sort, page, page_size, (page - 1) * page_size,
)
# 按 arxiv_id 从 DB 加载完整数据
arxiv_ids = [c["arxiv_id"] for c in candidates]
distance_map = {c["arxiv_id"]: c["distance"] for c in candidates}
papers_query = (
db.query(Paper)
.filter(Paper.arxiv_id.in_(arxiv_ids))
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
joinedload(Paper.bookmark),
joinedload(Paper.reading_status),
)
)
if tag:
papers_query = papers_query.filter(Paper.tags.any(tag=tag))
papers = papers_query.all()
# 按语义距离排序
id_order = {aid: idx for idx, aid in enumerate(arxiv_ids)}
papers.sort(key=lambda p: id_order.get(p.arxiv_id, 999))
# 分页
total = len(papers)
start = (page - 1) * page_size
page_papers = papers[start:start + page_size]
return {
"results": page_papers,
"snippets": {},
"total": total,
"page": page,
"total_pages": math.ceil(total / page_size) if total else 0,
"distances": distance_map,
}
@@ -187,6 +268,7 @@ def _search_tag_only(
"total": total,
"page": page,
"total_pages": math.ceil(total / page_size) if total else 0,
"distances": {},
}
+145
View File
@@ -359,6 +359,127 @@ def _cleanup_tmp(arxiv_id: str) -> None:
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
# ── LaTeX 图片提取(Phase 5)───────────────────────────────────────────
_INCLUDEGRAPHICS_RE = re.compile(
r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
)
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
async def _extract_images_from_source(arxiv_id: str, tmp_source: Path | None = None) -> int:
"""从 LaTeX 源码中提取图片文件。
流程:
1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
2. 扫描 .tex 文件中的 \\includegraphics
3. 复制图片到 data/papers/{arxiv_id}/images/
4. 清理源码临时文件
Returns:
提取的图片数量
"""
tmp_source = _tmp_dir(arxiv_id) / "source"
images_dest = _paper_dir(arxiv_id) / "images"
try:
# 下载源码 zip(如果还没下载)
if not tmp_source.exists():
source_url = f"https://arxiv.org/e-print/{arxiv_id}"
await _download_source_zip(arxiv_id, source_url, tmp_source)
if not tmp_source.exists():
return 0
# 扫描 .tex 文件,收集图片路径
image_paths: set[str] = set()
for tex_file in tmp_source.rglob("*.tex"):
try:
content = tex_file.read_text(encoding="utf-8", errors="replace")
for match in _INCLUDEGRAPHICS_RE.finditer(content):
img_path = match.group(1).strip()
image_paths.add(img_path)
except Exception:
continue
if not image_paths:
return 0
# 查找并复制图片
images_dest.mkdir(parents=True, exist_ok=True)
copied = 0
for img_rel in image_paths:
# 尝试在源码目录中找到文件
for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
candidate = tmp_source / (img_rel + ext)
if candidate.is_file():
dest_name = candidate.name
# 避免文件名冲突
dest = images_dest / dest_name
if dest.exists():
stem = dest.stem
suffix = dest.suffix
dest = images_dest / f"{stem}_{copied}{suffix}"
shutil.copy2(candidate, dest)
copied += 1
break
if copied > 0:
logger.info("Extracted %d images from source for %s", copied, arxiv_id)
return copied
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
return 0
async def _download_source_zip(
arxiv_id: str, source_url: str, dest_dir: Path
) -> None:
"""下载 arXiv 源码并解压。"""
import zipfile
dest_dir.mkdir(parents=True, exist_ok=True)
zip_path = _tmp_dir(arxiv_id) / "source.zip"
transport = None
if settings.http_proxy:
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
try:
async with httpx.AsyncClient(
timeout=settings.HTTP_TIMEOUT_SECONDS,
headers={"User-Agent": settings.HTTP_USER_AGENT},
transport=transport,
follow_redirects=True,
) as client:
resp = await client.get(source_url)
resp.raise_for_status()
zip_path.write_bytes(resp.content)
except Exception as exc:
logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
return
try:
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(dest_dir)
logger.debug("Extracted source for %s", arxiv_id)
except zipfile.BadZipFile:
# 可能是 tar.gz
import tarfile
try:
with tarfile.open(zip_path, "r:*") as tf:
tf.extractall(dest_dir)
logger.debug("Extracted source (tar) for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
finally:
if zip_path.exists():
zip_path.unlink()
# ── 单篇总结 ────────────────────────────────────────────────────────────
@@ -441,6 +562,30 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
status.raw_output_saved = True
db.commit()
# Phase 5: LaTeX 图片提取(可选增强,失败不影响总结)
try:
await _extract_images_from_source(arxiv_id)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
# Phase 5: 同步写入语义索引(失败仅 log)
try:
from app.services.embedder import index_paper
texts_dict = {
"arxiv_id": arxiv_id,
"title_zh": schema.title_zh or "",
"title_en": paper.title_en or "",
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
"one_line": schema.one_line or "",
"motivation_problem": schema.motivation_problem or "",
"method_key_idea": schema.method_key_idea or "",
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
}
index_paper(arxiv_id, texts_dict)
except Exception:
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
logger.info("Summarize done: %s quality=%s", arxiv_id, quality)
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}