feat: add search and user data routes, services, and tests
This commit is contained in:
@@ -0,0 +1,230 @@
|
||||
"""FTS5 全文搜索服务 — 关键词 + 标签筛选,命中片段高亮,分页。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.models import Paper
|
||||
|
||||
# ── 输入清洗 ──────────────────────────────────────────────────────────
|
||||
|
||||
# FTS5 查询语法中的特殊字符,用户输入时需要移除
|
||||
_FTS5_SPECIAL = re.compile(r'["{}()^+:]')
|
||||
|
||||
|
||||
def _sanitize_query(raw: str) -> str:
|
||||
"""清洗用户输入,生成安全的 FTS5 MATCH 表达式。
|
||||
|
||||
- 移除 FTS5 特殊字符
|
||||
- 按空白拆分为 token,用 AND 连接
|
||||
- 空字符串返回 None
|
||||
"""
|
||||
cleaned = _FTS5_SPECIAL.sub("", raw.strip())
|
||||
tokens = cleaned.split()
|
||||
if not tokens:
|
||||
return None
|
||||
return " AND ".join(tokens)
|
||||
|
||||
|
||||
# ── 核心搜索 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def search_papers(
|
||||
db: Session,
|
||||
*,
|
||||
query: str | None = None,
|
||||
tag: str | None = None,
|
||||
sort: str = "relevance",
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> dict:
|
||||
"""FTS5 搜索论文。
|
||||
|
||||
返回::
|
||||
{
|
||||
"results": list[Paper],
|
||||
"snippets": dict[int, dict], # paper_id → {title_zh, abstract}
|
||||
"total": int,
|
||||
"page": int,
|
||||
"total_pages": int,
|
||||
}
|
||||
"""
|
||||
match_expr = _sanitize_query(query) if query else None
|
||||
|
||||
# ── 无关键词 + 无标签 → 空结果 ──
|
||||
if not match_expr and not tag:
|
||||
return {
|
||||
"results": [],
|
||||
"snippets": {},
|
||||
"total": 0,
|
||||
"page": page,
|
||||
"total_pages": 0,
|
||||
}
|
||||
|
||||
# ── 构建条件性 JOIN 和 WHERE 片段 ──
|
||||
tag_join = ""
|
||||
tag_where = ""
|
||||
tag_params: dict = {}
|
||||
if tag:
|
||||
tag_join = "JOIN paper_tags pt ON pt.paper_id = p.id"
|
||||
tag_where = "AND pt.tag = :tag"
|
||||
tag_params["tag"] = tag
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
if match_expr:
|
||||
return _search_with_fts(
|
||||
db, match_expr, tag_join, tag_where, tag_params,
|
||||
sort, page, page_size, offset,
|
||||
)
|
||||
else:
|
||||
return _search_tag_only(
|
||||
db, tag, sort, page, page_size, offset,
|
||||
)
|
||||
|
||||
|
||||
def _search_with_fts(
|
||||
db: Session,
|
||||
match_expr: str,
|
||||
tag_join: str,
|
||||
tag_where: str,
|
||||
tag_params: dict,
|
||||
sort: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
offset: int,
|
||||
) -> dict:
|
||||
"""有关键词时的 FTS5 MATCH 搜索。"""
|
||||
params = {"query": match_expr, "limit": page_size, "offset": offset}
|
||||
params.update(tag_params)
|
||||
|
||||
order = "bm25(papers_fts)" if sort == "relevance" else "p.paper_date DESC, p.upvotes DESC"
|
||||
|
||||
# ── 主查询:取 ID + rank + snippet ──
|
||||
rows_sql = text(f"""
|
||||
SELECT
|
||||
p.id,
|
||||
papers_fts.rank,
|
||||
snippet(papers_fts, 1, '<mark>', '</mark>', '...', 32) AS snippet_title_zh,
|
||||
snippet(papers_fts, 2, '<mark>', '</mark>', '...', 32) AS snippet_abstract
|
||||
FROM papers_fts
|
||||
JOIN papers p ON p.id = papers_fts.rowid
|
||||
{tag_join}
|
||||
WHERE papers_fts MATCH :query
|
||||
{tag_where}
|
||||
ORDER BY {order}
|
||||
LIMIT :limit OFFSET :offset
|
||||
""")
|
||||
fts_rows = db.execute(rows_sql, params).fetchall()
|
||||
|
||||
# ── 计数查询 ──
|
||||
count_sql = text(f"""
|
||||
SELECT COUNT(DISTINCT papers_fts.rowid)
|
||||
FROM papers_fts
|
||||
JOIN papers p ON p.id = papers_fts.rowid
|
||||
{tag_join}
|
||||
WHERE papers_fts MATCH :query
|
||||
{tag_where}
|
||||
""")
|
||||
total = db.execute(count_sql, params).scalar() or 0
|
||||
|
||||
paper_ids = [row[0] for row in fts_rows]
|
||||
snippets = {
|
||||
row[0]: {"title_zh": row[2], "abstract": row[3]}
|
||||
for row in fts_rows
|
||||
}
|
||||
|
||||
papers = _load_papers_by_ids(db, paper_ids, sort, {row[0]: row[1] for row in fts_rows})
|
||||
|
||||
return {
|
||||
"results": papers,
|
||||
"snippets": snippets,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"total_pages": math.ceil(total / page_size) if total else 0,
|
||||
}
|
||||
|
||||
|
||||
def _search_tag_only(
|
||||
db: Session,
|
||||
tag: str,
|
||||
sort: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
offset: int,
|
||||
) -> dict:
|
||||
"""只有标签筛选,无关键词。"""
|
||||
order = "p.paper_date DESC, p.upvotes DESC" if sort == "date" else "p.paper_date DESC, p.upvotes DESC"
|
||||
|
||||
rows_sql = text(f"""
|
||||
SELECT p.id
|
||||
FROM papers p
|
||||
JOIN paper_tags pt ON pt.paper_id = p.id
|
||||
WHERE pt.tag = :tag
|
||||
ORDER BY {order}
|
||||
LIMIT :limit OFFSET :offset
|
||||
""")
|
||||
rows = db.execute(rows_sql, {"tag": tag, "limit": page_size, "offset": offset}).fetchall()
|
||||
|
||||
count_sql = text("""
|
||||
SELECT COUNT(DISTINCT p.id)
|
||||
FROM papers p
|
||||
JOIN paper_tags pt ON pt.paper_id = p.id
|
||||
WHERE pt.tag = :tag
|
||||
""")
|
||||
total = db.execute(count_sql, {"tag": tag}).scalar() or 0
|
||||
|
||||
paper_ids = [row[0] for row in rows]
|
||||
papers = _load_papers_by_ids(db, paper_ids)
|
||||
|
||||
return {
|
||||
"results": papers,
|
||||
"snippets": {},
|
||||
"total": total,
|
||||
"page": page,
|
||||
"total_pages": math.ceil(total / page_size) if total else 0,
|
||||
}
|
||||
|
||||
|
||||
def _load_papers_by_ids(
|
||||
db: Session,
|
||||
paper_ids: list[int],
|
||||
sort: str | None = None,
|
||||
rank_map: dict[int, float] | None = None,
|
||||
) -> list[Paper]:
|
||||
"""根据 ID 列表加载完整 ORM 对象,保持原始排序。"""
|
||||
if not paper_ids:
|
||||
return []
|
||||
|
||||
papers = (
|
||||
db.query(Paper)
|
||||
.filter(Paper.id.in_(paper_ids))
|
||||
.options(
|
||||
joinedload(Paper.authors),
|
||||
joinedload(Paper.tags),
|
||||
joinedload(Paper.summary_status),
|
||||
joinedload(Paper.bookmark),
|
||||
joinedload(Paper.reading_status),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 按 FTS rank / tag-only 原始顺序排列
|
||||
id_order = {pid: idx for idx, pid in enumerate(paper_ids)}
|
||||
papers.sort(key=lambda p: id_order.get(p.id, 0))
|
||||
return papers
|
||||
|
||||
|
||||
# ── 辅助查询 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_all_tags(db: Session) -> list[str]:
|
||||
"""返回所有不重复的标签,按字母排序。"""
|
||||
rows = db.execute(
|
||||
text("SELECT DISTINCT tag FROM paper_tags ORDER BY tag")
|
||||
).fetchall()
|
||||
return [row[0] for row in rows]
|
||||
@@ -0,0 +1,115 @@
|
||||
"""用户数据服务 — 收藏、阅读状态、个人笔记。无账号体系,数据写入本地 SQLite。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import Paper, UserBookmark, UserNote, UserReadingStatus
|
||||
|
||||
# ── 收藏 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def toggle_bookmark(db: Session, arxiv_id: str) -> dict:
|
||||
"""切换收藏状态。返回 {"bookmarked": bool, "arxiv_id": str}。"""
|
||||
paper = db.query(Paper).filter(Paper.arxiv_id == arxiv_id).first()
|
||||
if not paper:
|
||||
return {"error": "not_found"}
|
||||
|
||||
existing = db.query(UserBookmark).filter(UserBookmark.paper_id == paper.id).first()
|
||||
if existing:
|
||||
db.delete(existing)
|
||||
db.commit()
|
||||
return {"bookmarked": False, "arxiv_id": arxiv_id}
|
||||
else:
|
||||
bookmark = UserBookmark(
|
||||
paper_id=paper.id,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(bookmark)
|
||||
db.commit()
|
||||
return {"bookmarked": True, "arxiv_id": arxiv_id}
|
||||
|
||||
|
||||
# ── 阅读状态 ──────────────────────────────────────────────────────────
|
||||
|
||||
VALID_STATUSES = {"unread", "skimmed", "read_summary", "read_full"}
|
||||
|
||||
|
||||
def set_reading_status(db: Session, arxiv_id: str, status: str) -> dict:
|
||||
"""设置阅读状态。status 必须是 unread/skimmed/read_summary/read_full。"""
|
||||
if status not in VALID_STATUSES:
|
||||
return {"error": "invalid_status", "valid": sorted(VALID_STATUSES)}
|
||||
|
||||
paper = db.query(Paper).filter(Paper.arxiv_id == arxiv_id).first()
|
||||
if not paper:
|
||||
return {"error": "not_found"}
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
existing = (
|
||||
db.query(UserReadingStatus)
|
||||
.filter(UserReadingStatus.paper_id == paper.id)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
existing.status = status
|
||||
existing.updated_at = now
|
||||
else:
|
||||
db.add(
|
||||
UserReadingStatus(
|
||||
paper_id=paper.id,
|
||||
status=status,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
return {"arxiv_id": arxiv_id, "status": status}
|
||||
|
||||
|
||||
# ── 笔记 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_note(db: Session, arxiv_id: str) -> dict | None:
|
||||
"""获取笔记。返回 {"arxiv_id", "content", "updated_at"} 或 None(论文不存在时)。"""
|
||||
paper = db.query(Paper).filter(Paper.arxiv_id == arxiv_id).first()
|
||||
if not paper:
|
||||
return None
|
||||
|
||||
note = db.query(UserNote).filter(UserNote.paper_id == paper.id).first()
|
||||
if not note:
|
||||
return {"arxiv_id": arxiv_id, "content": "", "updated_at": None}
|
||||
|
||||
return {
|
||||
"arxiv_id": arxiv_id,
|
||||
"content": note.content,
|
||||
"updated_at": note.updated_at.isoformat() if note.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def save_note(db: Session, arxiv_id: str, content: str) -> dict:
|
||||
"""创建或更新笔记。返回 {"arxiv_id", "content", "updated_at"}。"""
|
||||
paper = db.query(Paper).filter(Paper.arxiv_id == arxiv_id).first()
|
||||
if not paper:
|
||||
return {"error": "not_found"}
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
existing = db.query(UserNote).filter(UserNote.paper_id == paper.id).first()
|
||||
if existing:
|
||||
existing.content = content
|
||||
existing.updated_at = now
|
||||
else:
|
||||
db.add(
|
||||
UserNote(
|
||||
paper_id=paper.id,
|
||||
content=content,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
return {
|
||||
"arxiv_id": arxiv_id,
|
||||
"content": content,
|
||||
"updated_at": now.isoformat(),
|
||||
}
|
||||
Reference in New Issue
Block a user