feat: add search and user data routes, services, and tests

This commit is contained in:
2026-06-05 22:53:27 +08:00
parent 29e6797c12
commit 1538d564f6
14 changed files with 1633 additions and 13 deletions
+249
View File
@@ -0,0 +1,249 @@
"""搜索、阅读列表、RSS Feed 路由。"""
from __future__ import annotations
import math
from datetime import date, datetime, timedelta, timezone
from zoneinfo import ZoneInfo
from xml.sax.saxutils import escape
from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import Response
from fastapi.templating import Jinja2Templates
from sqlalchemy import text
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.database import get_db
from app.models import Paper, PaperTag, UserReadingStatus
from app.services.searcher import get_all_tags, search_papers
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
# ── 搜索页 ────────────────────────────────────────────────────────────
@router.get("/search")
def search_page(
request: Request,
q: str = Query(default=""),
tag: str = Query(default=""),
sort: str = Query(default="relevance"),
page: int = Query(default=1, ge=1),
db: Session = Depends(get_db),
):
"""搜索页面。"""
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page)
all_tags = get_all_tags(db)
return templates.TemplateResponse(
request,
"search.html",
{
"query": q,
"tag": tag,
"sort": sort,
"results": result["results"],
"snippets": result["snippets"],
"total": result["total"],
"page": result["page"],
"total_pages": result["total_pages"],
"all_tags": all_tags,
"page_title": f"搜索: {q}" if q else "搜索",
"today": _today_str(),
},
)
# ── 搜索 JSON API ─────────────────────────────────────────────────────
@router.get("/api/search")
def search_api(
q: str = Query(default=""),
tag: str = Query(default=""),
sort: str = Query(default="relevance"),
page: int = Query(default=1, ge=1),
db: Session = Depends(get_db),
):
"""搜索 JSON API。"""
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page)
items = []
for paper in result["results"]:
snippet = result["snippets"].get(paper.id, {})
items.append(
{
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"title_zh": paper.title_zh,
"paper_date": paper.paper_date.isoformat() if paper.paper_date else None,
"upvotes": paper.upvotes,
"tags": [t.tag for t in paper.tags],
"authors": [a.name for a in paper.authors],
"snippet_title_zh": snippet.get("title_zh"),
"snippet_abstract": snippet.get("abstract"),
}
)
return {
"results": items,
"total": result["total"],
"page": result["page"],
"total_pages": result["total_pages"],
}
# ── 阅读列表 ──────────────────────────────────────────────────────────
@router.get("/reading-list")
def reading_list_page(
request: Request,
filter: str = Query(default="all"),
tag: str = Query(default=""),
db: Session = Depends(get_db),
):
"""阅读列表页面。"""
papers = _query_reading_list(db, filter, tag or None)
all_tags = get_all_tags(db)
return templates.TemplateResponse(
request,
"reading_list.html",
{
"papers": papers,
"current_filter": filter,
"current_tag": tag,
"all_tags": all_tags,
"page_title": "阅读列表",
"today": _today_str(),
},
)
def _query_reading_list(
db: Session,
filter_type: str,
tag: str | None,
) -> list[Paper]:
"""根据筛选条件查询阅读列表。"""
from sqlalchemy import or_
# 基础:有任意用户数据的论文
base = db.query(Paper).filter(
or_(
Paper.bookmark.has(),
Paper.reading_status.has(),
Paper.note.has(),
)
)
# 应用筛选
if filter_type == "has_note":
base = base.filter(Paper.note.has())
elif filter_type in ("unread", "skimmed", "read_summary", "read_full"):
base = base.filter(
Paper.reading_status.has(UserReadingStatus.status == filter_type)
)
# 应用标签
if tag:
base = base.filter(Paper.tags.any(PaperTag.tag == tag))
return (
base.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
joinedload(Paper.bookmark),
joinedload(Paper.reading_status),
joinedload(Paper.note),
)
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
.all()
)
# ── RSS Feed ──────────────────────────────────────────────────────────
@router.get("/rss.xml")
def rss_feed(
tag: str = Query(default=""),
db: Session = Depends(get_db),
):
"""RSS 2.0 Feed — 最近 7 天论文。"""
seven_days_ago = date.today() - timedelta(days=7)
query = (
db.query(Paper)
.filter(Paper.paper_date >= seven_days_ago)
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary),
)
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
)
if tag:
query = query.filter(Paper.tags.any(PaperTag.tag == tag))
papers = query.all()
xml = _generate_rss_xml(papers, settings.BASE_URL, tag or None)
return Response(content=xml, media_type="application/xml")
def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> str:
"""生成 RSS 2.0 XML。"""
lines = ['<?xml version="1.0" encoding="UTF-8"?>']
lines.append('<rss version="2.0">')
lines.append(" <channel>")
channel_title = "HF Daily Papers"
if tag:
channel_title += f"{tag}"
lines.append(f" <title>{escape(channel_title)}</title>")
lines.append(f" <link>{escape(base_url)}</link>")
lines.append(" <description>HuggingFace Daily Papers — 中文论文导览站</description>")
lines.append(f" <language>zh-CN</language>")
for paper in papers:
title_text = paper.title_zh or paper.title_en
link = f"{base_url}/paper/{paper.arxiv_id}"
desc = ""
if paper.summary and paper.summary.one_line:
desc = paper.summary.one_line
elif paper.abstract:
desc = paper.abstract[:500]
pub_date = ""
if paper.paper_date:
# RFC 822 格式
pub_date = paper.paper_date.strftime("%a, %d %b %Y 00:00:00 +0800")
lines.append(" <item>")
lines.append(f" <title>{escape(title_text)}</title>")
lines.append(f" <link>{escape(link)}</link>")
lines.append(f" <description>{escape(desc)}</description>")
if pub_date:
lines.append(f" <pubDate>{pub_date}</pubDate>")
lines.append(f" <guid>{escape(link)}</guid>")
lines.append(" </item>")
lines.append(" </channel>")
lines.append("</rss>")
return "\n".join(lines)
# ── 工具函数 ──────────────────────────────────────────────────────────
def _today_str() -> str:
"""当前日期字符串(按 APP_TIMEZONE)。"""
tz = ZoneInfo(settings.APP_TIMEZONE)
return datetime.now(tz).strftime("%Y-%m-%d")
+103
View File
@@ -0,0 +1,103 @@
"""用户数据 JSON API — 收藏、阅读状态、笔记。"""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
from app.services.user_data import (
get_note,
save_note,
set_reading_status,
toggle_bookmark,
)
router = APIRouter(prefix="/api", tags=["user-data"])
# ── 请求模型 ──────────────────────────────────────────────────────────
class ReadingStatusRequest(BaseModel):
status: str
class NoteRequest(BaseModel):
content: str
# ── 收藏 ──────────────────────────────────────────────────────────────
@router.post("/bookmark/{arxiv_id}")
def bookmark_toggle(arxiv_id: str, request: Request, db: Session = Depends(get_db)):
"""切换收藏状态。支持 HTMX 局部刷新和 JSON 响应。"""
result = toggle_bookmark(db, arxiv_id)
if "error" in result:
raise HTTPException(status_code=404, detail=result["error"])
# HTMX 请求 → 返回 HTML 片段
if request.headers.get("HX-Request"):
star = "" if result["bookmarked"] else ""
active_class = " active" if result["bookmarked"] else ""
html = (
f'<button class="btn-bookmark{active_class}" '
f'hx-post="/api/bookmark/{arxiv_id}" '
f'hx-target="#user-data-{arxiv_id}" '
f'hx-swap="outerHTML">'
f"{star}</button>"
)
return HTMLResponse(content=html)
return result
# ── 阅读状态 ──────────────────────────────────────────────────────────
@router.post("/reading-status/{arxiv_id}")
def reading_status_update(
arxiv_id: str,
body: ReadingStatusRequest,
db: Session = Depends(get_db),
):
"""更新阅读状态。"""
result = set_reading_status(db, arxiv_id, body.status)
if "error" in result:
if result["error"] == "not_found":
raise HTTPException(status_code=404, detail="Paper not found")
elif result["error"] == "invalid_status":
raise HTTPException(
status_code=422,
detail=f"Invalid status. Valid: {result['valid']}",
)
return result
# ── 笔记 ──────────────────────────────────────────────────────────────
@router.get("/note/{arxiv_id}")
def note_get(arxiv_id: str, db: Session = Depends(get_db)):
"""获取笔记。"""
result = get_note(db, arxiv_id)
if result is None:
raise HTTPException(status_code=404, detail="Paper not found")
return result
@router.post("/note/{arxiv_id}")
def note_save(arxiv_id: str, body: NoteRequest, db: Session = Depends(get_db)):
"""保存笔记。"""
result = save_note(db, arxiv_id, body.content)
if "error" in result:
raise HTTPException(status_code=404, detail=result["error"])
return result