"""搜索服务 — FTS5 关键词搜索 + ChromaDB 语义搜索,命中片段高亮,分页。""" from __future__ import annotations import logging import math import re from sqlalchemy import text from sqlalchemy.orm import Session, joinedload from app.config import settings from app.models import Paper logger = logging.getLogger(__name__) # ── 输入清洗 ────────────────────────────────────────────────────────── # FTS5 查询语法中的特殊字符,用户输入时需要移除 _FTS5_SPECIAL = re.compile(r'["{}()^+:]') def _sanitize_query(raw: str) -> str: """清洗用户输入,生成安全的 FTS5 MATCH 表达式。 - 移除 FTS5 特殊字符 - 按空白拆分为 token,用 AND 连接 - 空字符串返回 None """ cleaned = _FTS5_SPECIAL.sub("", raw.strip()) tokens = cleaned.split() if not tokens: return None return " AND ".join(tokens) # ── 核心搜索 ────────────────────────────────────────────────────────── def search_papers( db: Session, *, query: str | None = None, tag: str | None = None, sort: str = "relevance", page: int = 1, page_size: int = 20, mode: str = "keyword", ) -> dict: """搜索论文,支持 keyword (FTS5) 和 semantic (ChromaDB) 两种模式。 返回:: { "results": list[Paper], "snippets": dict[int, dict], # paper_id → {title_zh, abstract} "total": int, "page": int, "total_pages": int, "distances": dict[str, float], # arxiv_id → distance (仅 semantic) } """ # ── semantic 模式 ── if mode == "semantic" and settings.CHROMA_ENABLED and query: return _search_semantic(db, query, tag, sort, page, page_size) # ── keyword 模式(默认)── match_expr = _sanitize_query(query) if query else None # ── 无关键词 + 无标签 → 空结果 ── if not match_expr and not tag: return { "results": [], "snippets": {}, "total": 0, "page": page, "total_pages": 0, "distances": {}, } # ── 构建条件性 JOIN 和 WHERE 片段 ── tag_join = "" tag_where = "" tag_params: dict = {} if tag: tag_join = "JOIN paper_tags pt ON pt.paper_id = p.id" tag_where = "AND pt.tag = :tag" tag_params["tag"] = tag offset = (page - 1) * page_size if match_expr: return _search_with_fts( db, match_expr, tag_join, tag_where, tag_params, sort, page, page_size, offset, ) else: return _search_tag_only( db, tag, sort, page, page_size, offset, ) def _search_with_fts( db: Session, match_expr: str, tag_join: str, tag_where: str, tag_params: dict, sort: str, page: int, page_size: int, offset: int, ) -> dict: """有关键词时的 FTS5 MATCH 搜索。""" params = {"query": match_expr, "limit": page_size, "offset": offset} params.update(tag_params) order = "bm25(papers_fts)" if sort == "relevance" else "p.paper_date DESC, p.upvotes DESC" # ── 主查询:取 ID + rank + snippet ── rows_sql = text(f""" SELECT p.id, papers_fts.rank, snippet(papers_fts, 1, '', '', '...', 32) AS snippet_title_zh, snippet(papers_fts, 2, '', '', '...', 32) AS snippet_abstract FROM papers_fts JOIN papers p ON p.id = papers_fts.rowid {tag_join} WHERE papers_fts MATCH :query {tag_where} ORDER BY {order} LIMIT :limit OFFSET :offset """) fts_rows = db.execute(rows_sql, params).fetchall() # ── 计数查询 ── count_sql = text(f""" SELECT COUNT(DISTINCT papers_fts.rowid) FROM papers_fts JOIN papers p ON p.id = papers_fts.rowid {tag_join} WHERE papers_fts MATCH :query {tag_where} """) total = db.execute(count_sql, params).scalar() or 0 paper_ids = [row[0] for row in fts_rows] snippets = { row[0]: {"title_zh": row[2], "abstract": row[3]} for row in fts_rows } papers = _load_papers_by_ids(db, paper_ids, sort, {row[0]: row[1] for row in fts_rows}) return { "results": papers, "snippets": snippets, "total": total, "page": page, "total_pages": math.ceil(total / page_size) if total else 0, "distances": {}, } def _search_semantic( db: Session, query: str, tag: str | None, sort: str, page: int, page_size: int, ) -> dict: """ChromaDB 语义搜索,失败时回退到 FTS5。""" try: from app.services.embedder import search_similar top_k = page_size * 3 # 多取一些用于 tag 过滤 candidates = search_similar(query, top_k=top_k) except Exception: logger.exception("Semantic search failed, falling back to keyword") candidates = [] if not candidates: # 回退到 FTS5 return _search_with_fts( db, _sanitize_query(query) or query, "JOIN paper_tags pt ON pt.paper_id = p.id" if tag else "", "AND pt.tag = :tag" if tag else "", {"tag": tag} if tag else {}, sort, page, page_size, (page - 1) * page_size, ) # 按 arxiv_id 从 DB 加载完整数据 arxiv_ids = [c["arxiv_id"] for c in candidates] distance_map = {c["arxiv_id"]: c["distance"] for c in candidates} papers_query = ( db.query(Paper) .filter(Paper.arxiv_id.in_(arxiv_ids)) .options( joinedload(Paper.authors), joinedload(Paper.tags), joinedload(Paper.summary_status), joinedload(Paper.bookmark), joinedload(Paper.reading_status), ) ) if tag: papers_query = papers_query.filter(Paper.tags.any(tag=tag)) papers = papers_query.all() # 按语义距离排序 id_order = {aid: idx for idx, aid in enumerate(arxiv_ids)} papers.sort(key=lambda p: id_order.get(p.arxiv_id, 999)) # 分页 total = len(papers) start = (page - 1) * page_size page_papers = papers[start:start + page_size] return { "results": page_papers, "snippets": {}, "total": total, "page": page, "total_pages": math.ceil(total / page_size) if total else 0, "distances": distance_map, } def _search_tag_only( db: Session, tag: str, sort: str, page: int, page_size: int, offset: int, ) -> dict: """只有标签筛选,无关键词。""" order = "p.paper_date DESC, p.upvotes DESC" if sort == "date" else "p.paper_date DESC, p.upvotes DESC" rows_sql = text(f""" SELECT p.id FROM papers p JOIN paper_tags pt ON pt.paper_id = p.id WHERE pt.tag = :tag ORDER BY {order} LIMIT :limit OFFSET :offset """) rows = db.execute(rows_sql, {"tag": tag, "limit": page_size, "offset": offset}).fetchall() count_sql = text(""" SELECT COUNT(DISTINCT p.id) FROM papers p JOIN paper_tags pt ON pt.paper_id = p.id WHERE pt.tag = :tag """) total = db.execute(count_sql, {"tag": tag}).scalar() or 0 paper_ids = [row[0] for row in rows] papers = _load_papers_by_ids(db, paper_ids) return { "results": papers, "snippets": {}, "total": total, "page": page, "total_pages": math.ceil(total / page_size) if total else 0, "distances": {}, } def _load_papers_by_ids( db: Session, paper_ids: list[int], sort: str | None = None, rank_map: dict[int, float] | None = None, ) -> list[Paper]: """根据 ID 列表加载完整 ORM 对象,保持原始排序。""" if not paper_ids: return [] papers = ( db.query(Paper) .filter(Paper.id.in_(paper_ids)) .options( joinedload(Paper.authors), joinedload(Paper.tags), joinedload(Paper.summary_status), joinedload(Paper.bookmark), joinedload(Paper.reading_status), ) .all() ) # 按 FTS rank / tag-only 原始顺序排列 id_order = {pid: idx for idx, pid in enumerate(paper_ids)} papers.sort(key=lambda p: id_order.get(p.id, 0)) return papers # ── 辅助查询 ────────────────────────────────────────────────────────── def get_all_tags(db: Session) -> list[str]: """返回所有不重复的标签,按字母排序。""" rows = db.execute( text("SELECT DISTINCT tag FROM paper_tags ORDER BY tag") ).fetchall() return [row[0] for row in rows]