feat: add search and user data routes, services, and tests
This commit is contained in:
@@ -0,0 +1,249 @@
|
||||
"""搜索、阅读列表、RSS Feed 路由。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from zoneinfo import ZoneInfo
|
||||
from xml.sax.saxutils import escape
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi.responses import Response
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.models import Paper, PaperTag, UserReadingStatus
|
||||
from app.services.searcher import get_all_tags, search_papers
|
||||
|
||||
router = APIRouter()
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
|
||||
# ── 搜索页 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
def search_page(
|
||||
request: Request,
|
||||
q: str = Query(default=""),
|
||||
tag: str = Query(default=""),
|
||||
sort: str = Query(default="relevance"),
|
||||
page: int = Query(default=1, ge=1),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""搜索页面。"""
|
||||
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page)
|
||||
all_tags = get_all_tags(db)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"search.html",
|
||||
{
|
||||
"query": q,
|
||||
"tag": tag,
|
||||
"sort": sort,
|
||||
"results": result["results"],
|
||||
"snippets": result["snippets"],
|
||||
"total": result["total"],
|
||||
"page": result["page"],
|
||||
"total_pages": result["total_pages"],
|
||||
"all_tags": all_tags,
|
||||
"page_title": f"搜索: {q}" if q else "搜索",
|
||||
"today": _today_str(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── 搜索 JSON API ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/api/search")
|
||||
def search_api(
|
||||
q: str = Query(default=""),
|
||||
tag: str = Query(default=""),
|
||||
sort: str = Query(default="relevance"),
|
||||
page: int = Query(default=1, ge=1),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""搜索 JSON API。"""
|
||||
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page)
|
||||
|
||||
items = []
|
||||
for paper in result["results"]:
|
||||
snippet = result["snippets"].get(paper.id, {})
|
||||
items.append(
|
||||
{
|
||||
"arxiv_id": paper.arxiv_id,
|
||||
"title_en": paper.title_en,
|
||||
"title_zh": paper.title_zh,
|
||||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else None,
|
||||
"upvotes": paper.upvotes,
|
||||
"tags": [t.tag for t in paper.tags],
|
||||
"authors": [a.name for a in paper.authors],
|
||||
"snippet_title_zh": snippet.get("title_zh"),
|
||||
"snippet_abstract": snippet.get("abstract"),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"results": items,
|
||||
"total": result["total"],
|
||||
"page": result["page"],
|
||||
"total_pages": result["total_pages"],
|
||||
}
|
||||
|
||||
|
||||
# ── 阅读列表 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/reading-list")
|
||||
def reading_list_page(
|
||||
request: Request,
|
||||
filter: str = Query(default="all"),
|
||||
tag: str = Query(default=""),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""阅读列表页面。"""
|
||||
papers = _query_reading_list(db, filter, tag or None)
|
||||
all_tags = get_all_tags(db)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"reading_list.html",
|
||||
{
|
||||
"papers": papers,
|
||||
"current_filter": filter,
|
||||
"current_tag": tag,
|
||||
"all_tags": all_tags,
|
||||
"page_title": "阅读列表",
|
||||
"today": _today_str(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _query_reading_list(
|
||||
db: Session,
|
||||
filter_type: str,
|
||||
tag: str | None,
|
||||
) -> list[Paper]:
|
||||
"""根据筛选条件查询阅读列表。"""
|
||||
from sqlalchemy import or_
|
||||
|
||||
# 基础:有任意用户数据的论文
|
||||
base = db.query(Paper).filter(
|
||||
or_(
|
||||
Paper.bookmark.has(),
|
||||
Paper.reading_status.has(),
|
||||
Paper.note.has(),
|
||||
)
|
||||
)
|
||||
|
||||
# 应用筛选
|
||||
if filter_type == "has_note":
|
||||
base = base.filter(Paper.note.has())
|
||||
elif filter_type in ("unread", "skimmed", "read_summary", "read_full"):
|
||||
base = base.filter(
|
||||
Paper.reading_status.has(UserReadingStatus.status == filter_type)
|
||||
)
|
||||
|
||||
# 应用标签
|
||||
if tag:
|
||||
base = base.filter(Paper.tags.any(PaperTag.tag == tag))
|
||||
|
||||
return (
|
||||
base.options(
|
||||
joinedload(Paper.authors),
|
||||
joinedload(Paper.tags),
|
||||
joinedload(Paper.summary_status),
|
||||
joinedload(Paper.bookmark),
|
||||
joinedload(Paper.reading_status),
|
||||
joinedload(Paper.note),
|
||||
)
|
||||
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
# ── RSS Feed ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/rss.xml")
|
||||
def rss_feed(
|
||||
tag: str = Query(default=""),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""RSS 2.0 Feed — 最近 7 天论文。"""
|
||||
seven_days_ago = date.today() - timedelta(days=7)
|
||||
|
||||
query = (
|
||||
db.query(Paper)
|
||||
.filter(Paper.paper_date >= seven_days_ago)
|
||||
.options(
|
||||
joinedload(Paper.authors),
|
||||
joinedload(Paper.tags),
|
||||
joinedload(Paper.summary),
|
||||
)
|
||||
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
|
||||
)
|
||||
|
||||
if tag:
|
||||
query = query.filter(Paper.tags.any(PaperTag.tag == tag))
|
||||
|
||||
papers = query.all()
|
||||
xml = _generate_rss_xml(papers, settings.BASE_URL, tag or None)
|
||||
return Response(content=xml, media_type="application/xml")
|
||||
|
||||
|
||||
def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> str:
|
||||
"""生成 RSS 2.0 XML。"""
|
||||
lines = ['<?xml version="1.0" encoding="UTF-8"?>']
|
||||
lines.append('<rss version="2.0">')
|
||||
lines.append(" <channel>")
|
||||
|
||||
channel_title = "HF Daily Papers"
|
||||
if tag:
|
||||
channel_title += f" — {tag}"
|
||||
lines.append(f" <title>{escape(channel_title)}</title>")
|
||||
lines.append(f" <link>{escape(base_url)}</link>")
|
||||
lines.append(" <description>HuggingFace Daily Papers — 中文论文导览站</description>")
|
||||
lines.append(f" <language>zh-CN</language>")
|
||||
|
||||
for paper in papers:
|
||||
title_text = paper.title_zh or paper.title_en
|
||||
link = f"{base_url}/paper/{paper.arxiv_id}"
|
||||
|
||||
desc = ""
|
||||
if paper.summary and paper.summary.one_line:
|
||||
desc = paper.summary.one_line
|
||||
elif paper.abstract:
|
||||
desc = paper.abstract[:500]
|
||||
|
||||
pub_date = ""
|
||||
if paper.paper_date:
|
||||
# RFC 822 格式
|
||||
pub_date = paper.paper_date.strftime("%a, %d %b %Y 00:00:00 +0800")
|
||||
|
||||
lines.append(" <item>")
|
||||
lines.append(f" <title>{escape(title_text)}</title>")
|
||||
lines.append(f" <link>{escape(link)}</link>")
|
||||
lines.append(f" <description>{escape(desc)}</description>")
|
||||
if pub_date:
|
||||
lines.append(f" <pubDate>{pub_date}</pubDate>")
|
||||
lines.append(f" <guid>{escape(link)}</guid>")
|
||||
lines.append(" </item>")
|
||||
|
||||
lines.append(" </channel>")
|
||||
lines.append("</rss>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ── 工具函数 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _today_str() -> str:
|
||||
"""当前日期字符串(按 APP_TIMEZONE)。"""
|
||||
tz = ZoneInfo(settings.APP_TIMEZONE)
|
||||
return datetime.now(tz).strftime("%Y-%m-%d")
|
||||
Reference in New Issue
Block a user