Files
daily-paper/app/services/searcher.py
T
Rain-Bus 21f16e6756 feat: refactor summarizer and PDF extraction pipeline
- Split summarizer into summary_generator and summary_persister modules
- Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection
- Add layout_detector service for PicoDet-S_layout_3cls integration
- Add exceptions module with ConflictError and NotFoundError
- Improve admin dashboard with better statistics and task management
- Add design review document with system optimization suggestions
- Add new tests for crawler, pdf_downloader, pipeline, and summary_utils
- Update dependencies and configuration
- Clean up dead code and improve error handling
2026-06-13 13:16:47 +08:00

319 lines
8.8 KiB
Python

"""搜索服务 — FTS5 关键词搜索 + ChromaDB 语义搜索,命中片段高亮,分页。"""
from __future__ import annotations
import logging
import math
import re
from sqlalchemy import select, text
from sqlalchemy.orm import Session
from app.config import settings
from app.models import PAPER_FULL_LOAD, Paper
logger = logging.getLogger(__name__)
# ── 输入清洗 ──────────────────────────────────────────────────────────
# FTS5 查询语法中的特殊字符,用户输入时需要移除
_FTS5_SPECIAL = re.compile(r'["{}()^+:]')
def _sanitize_query(raw: str) -> str:
"""清洗用户输入,生成安全的 FTS5 MATCH 表达式。
- 移除 FTS5 特殊字符
- 按空白拆分为 token,用 AND 连接
- 空字符串返回 None
"""
cleaned = _FTS5_SPECIAL.sub("", raw.strip())
tokens = cleaned.split()
if not tokens:
return None
return " AND ".join(tokens)
# ── 核心搜索 ──────────────────────────────────────────────────────────
def search_papers(
db: Session,
*,
query: str | None = None,
tag: str | None = None,
sort: str = "relevance",
page: int = 1,
page_size: int = 20,
mode: str = "keyword",
) -> dict:
"""搜索论文,支持 keyword (FTS5) 和 semantic (ChromaDB) 两种模式。
返回::
{
"results": list[Paper],
"snippets": dict[int, dict], # paper_id → {title_zh, abstract}
"total": int,
"page": int,
"total_pages": int,
"distances": dict[str, float], # arxiv_id → distance (仅 semantic)
}
"""
# ── semantic 模式 ──
if mode == "semantic" and settings.CHROMA_ENABLED and query:
return _search_semantic(db, query, tag, sort, page, page_size)
# ── keyword 模式(默认)──
match_expr = _sanitize_query(query) if query else None
# ── 无关键词 + 无标签 → 空结果 ──
if not match_expr and not tag:
return {
"results": [],
"snippets": {},
"total": 0,
"page": page,
"total_pages": 0,
"distances": {},
}
# ── 构建条件性 JOIN 和 WHERE 片段 ──
tag_join = ""
tag_where = ""
tag_params: dict = {}
if tag:
tag_join = "JOIN paper_tags pt ON pt.paper_id = p.id"
tag_where = "AND pt.tag = :tag"
tag_params["tag"] = tag
offset = (page - 1) * page_size
if match_expr:
return _search_with_fts(
db,
match_expr,
tag_join,
tag_where,
tag_params,
sort,
page,
page_size,
offset,
)
else:
return _search_tag_only(
db,
tag,
sort,
page,
page_size,
offset,
)
def _search_with_fts(
db: Session,
match_expr: str,
tag_join: str,
tag_where: str,
tag_params: dict,
sort: str,
page: int,
page_size: int,
offset: int,
) -> dict:
"""有关键词时的 FTS5 MATCH 搜索。"""
params = {"query": match_expr, "limit": page_size, "offset": offset}
params.update(tag_params)
order = (
"bm25(papers_fts)"
if sort == "relevance"
else "p.paper_date DESC, p.upvotes DESC"
)
# ── 主查询:取 ID + rank + snippet ──
rows_sql = text(f"""
SELECT
p.id,
papers_fts.rank,
snippet(papers_fts, 1, '<mark>', '</mark>', '...', 32) AS snippet_title_zh,
snippet(papers_fts, 2, '<mark>', '</mark>', '...', 32) AS snippet_abstract
FROM papers_fts
JOIN papers p ON p.id = papers_fts.rowid
{tag_join}
WHERE papers_fts MATCH :query
{tag_where}
ORDER BY {order}
LIMIT :limit OFFSET :offset
""")
fts_rows = db.execute(rows_sql, params).fetchall()
# ── 计数查询 ──
count_sql = text(f"""
SELECT COUNT(DISTINCT papers_fts.rowid)
FROM papers_fts
JOIN papers p ON p.id = papers_fts.rowid
{tag_join}
WHERE papers_fts MATCH :query
{tag_where}
""")
total = db.execute(count_sql, params).scalar() or 0
paper_ids = [row[0] for row in fts_rows]
snippets = {row[0]: {"title_zh": row[2], "abstract": row[3]} for row in fts_rows}
papers = _load_papers_by_ids(
db, paper_ids, sort, {row[0]: row[1] for row in fts_rows}
)
return {
"results": papers,
"snippets": snippets,
"total": total,
"page": page,
"total_pages": math.ceil(total / page_size) if total else 0,
"distances": {},
}
def _search_semantic(
db: Session,
query: str,
tag: str | None,
sort: str,
page: int,
page_size: int,
) -> dict:
"""ChromaDB 语义搜索,失败时回退到 FTS5。"""
try:
from app.services.embedder import search_similar
top_k = page_size * 3 # 多取一些用于 tag 过滤
candidates = search_similar(query, top_k=top_k)
except Exception:
logger.exception("Semantic search failed, falling back to keyword")
candidates = []
if not candidates:
# 回退到 FTS5
return _search_with_fts(
db,
_sanitize_query(query) or query,
"JOIN paper_tags pt ON pt.paper_id = p.id" if tag else "",
"AND pt.tag = :tag" if tag else "",
{"tag": tag} if tag else {},
sort,
page,
page_size,
(page - 1) * page_size,
)
# 按 arxiv_id 从 DB 加载完整数据
arxiv_ids = [c["arxiv_id"] for c in candidates]
distance_map = {c["arxiv_id"]: c["distance"] for c in candidates}
stmt = select(Paper).where(Paper.arxiv_id.in_(arxiv_ids)).options(*PAPER_FULL_LOAD)
if tag:
stmt = stmt.where(Paper.tags.any(tag=tag))
papers = db.execute(stmt).unique().scalars().all()
# 按语义距离排序
id_order = {aid: idx for idx, aid in enumerate(arxiv_ids)}
papers.sort(key=lambda p: id_order.get(p.arxiv_id, 999))
# 分页
total = len(papers)
start = (page - 1) * page_size
page_papers = papers[start : start + page_size]
return {
"results": page_papers,
"snippets": {},
"total": total,
"page": page,
"total_pages": math.ceil(total / page_size) if total else 0,
"distances": distance_map,
}
def _search_tag_only(
db: Session,
tag: str,
sort: str,
page: int,
page_size: int,
offset: int,
) -> dict:
"""只有标签筛选,无关键词。"""
order = "p.paper_date DESC, p.upvotes DESC"
rows_sql = text(f"""
SELECT p.id
FROM papers p
JOIN paper_tags pt ON pt.paper_id = p.id
WHERE pt.tag = :tag
ORDER BY {order}
LIMIT :limit OFFSET :offset
""")
rows = db.execute(
rows_sql, {"tag": tag, "limit": page_size, "offset": offset}
).fetchall()
count_sql = text("""
SELECT COUNT(DISTINCT p.id)
FROM papers p
JOIN paper_tags pt ON pt.paper_id = p.id
WHERE pt.tag = :tag
""")
total = db.execute(count_sql, {"tag": tag}).scalar() or 0
paper_ids = [row[0] for row in rows]
papers = _load_papers_by_ids(db, paper_ids)
return {
"results": papers,
"snippets": {},
"total": total,
"page": page,
"total_pages": math.ceil(total / page_size) if total else 0,
"distances": {},
}
def _load_papers_by_ids(
db: Session,
paper_ids: list[int],
sort: str | None = None,
rank_map: dict[int, float] | None = None,
) -> list[Paper]:
"""根据 ID 列表加载完整 ORM 对象,保持原始排序。"""
if not paper_ids:
return []
papers = (
db.execute(
select(Paper).where(Paper.id.in_(paper_ids)).options(*PAPER_FULL_LOAD)
)
.unique()
.scalars()
.all()
)
# 按 FTS rank / tag-only 原始顺序排列
id_order = {pid: idx for idx, pid in enumerate(paper_ids)}
papers.sort(key=lambda p: id_order.get(p.id, 0))
return papers
# ── 辅助查询 ──────────────────────────────────────────────────────────
def get_all_tags(db: Session) -> list[str]:
"""返回所有不重复的标签,按字母排序。"""
rows = db.execute(
text("SELECT DISTINCT tag FROM paper_tags ORDER BY tag")
).fetchall()
return [row[0] for row in rows]