Files

209 lines
6.8 KiB
Python

"""搜索、阅读列表、RSS Feed 路由。"""
from __future__ import annotations
from datetime import date, timedelta
from xml.sax.saxutils import escape
from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import Response
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.database import get_db
from app.models import Paper, PaperTag
from app.services.searcher import get_all_tags, search_papers
from app.services.user_data import query_reading_list
from app.utils import templates, today_str
router = APIRouter()
# ── 搜索页 ────────────────────────────────────────────────────────────
@router.get("/search")
def search_page(
request: Request,
q: str = Query(default=""),
tag: str = Query(default=""),
sort: str = Query(default="relevance"),
mode: str = Query(default="keyword"),
page: int = Query(default=1, ge=1),
db: Session = Depends(get_db),
):
"""搜索页面,支持 keyword 和 semantic 模式。"""
result = search_papers(
db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode
)
all_tags = get_all_tags(db)
return templates.TemplateResponse(
request,
"search.html",
{
"query": q,
"tag": tag,
"sort": sort,
"mode": mode,
"chroma_enabled": settings.CHROMA_ENABLED,
"results": result["results"],
"snippets": result["snippets"],
"distances": result.get("distances", {}),
"total": result["total"],
"page": result["page"],
"total_pages": result["total_pages"],
"all_tags": all_tags,
"page_title": f"搜索: {q}" if q else "搜索",
"today": today_str(),
},
)
# ── 搜索 JSON API ─────────────────────────────────────────────────────
@router.get("/api/search")
def search_api(
q: str = Query(default=""),
tag: str = Query(default=""),
sort: str = Query(default="relevance"),
mode: str = Query(default="keyword"),
page: int = Query(default=1, ge=1),
db: Session = Depends(get_db),
):
"""搜索 JSON API,支持 keyword 和 semantic 模式。"""
result = search_papers(
db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode
)
distances = result.get("distances", {})
items = []
for paper in result["results"]:
snippet = result["snippets"].get(paper.id, {})
item = {
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"title_zh": paper.title_zh,
"paper_date": paper.paper_date.isoformat() if paper.paper_date else None,
"upvotes": paper.upvotes,
"tags": [t.tag for t in paper.tags],
"authors": [a.name for a in paper.authors],
"snippet_title_zh": snippet.get("title_zh"),
"snippet_abstract": snippet.get("abstract"),
}
if paper.arxiv_id in distances:
item["distance"] = distances[paper.arxiv_id]
items.append(item)
return {
"results": items,
"total": result["total"],
"page": result["page"],
"total_pages": result["total_pages"],
}
# ── 阅读列表 ──────────────────────────────────────────────────────────
@router.get("/reading-list")
def reading_list_page(
request: Request,
filter: str = Query(default="all"),
tag: str = Query(default=""),
db: Session = Depends(get_db),
):
"""阅读列表页面。"""
papers = query_reading_list(db, filter, tag or None)
all_tags = get_all_tags(db)
return templates.TemplateResponse(
request,
"reading_list.html",
{
"papers": papers,
"current_filter": filter,
"current_tag": tag,
"all_tags": all_tags,
"page_title": "阅读列表",
"today": today_str(),
},
)
# ── RSS Feed ──────────────────────────────────────────────────────────
@router.get("/rss.xml")
def rss_feed(
tag: str = Query(default=""),
db: Session = Depends(get_db),
):
"""RSS 2.0 Feed — 最近 7 天论文。"""
seven_days_ago = date.today() - timedelta(days=7)
stmt = (
select(Paper)
.where(Paper.paper_date >= seven_days_ago)
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary),
)
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
)
if tag:
stmt = stmt.where(Paper.tags.any(PaperTag.tag == tag))
papers = db.execute(stmt).unique().scalars().all()
xml = _generate_rss_xml(papers, settings.BASE_URL, tag or None)
return Response(content=xml, media_type="application/xml")
def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> str:
"""生成 RSS 2.0 XML。"""
lines = ['<?xml version="1.0" encoding="UTF-8"?>']
lines.append('<rss version="2.0">')
lines.append(" <channel>")
channel_title = "HF Daily Papers"
if tag:
channel_title += f"{tag}"
lines.append(f" <title>{escape(channel_title)}</title>")
lines.append(f" <link>{escape(base_url)}</link>")
lines.append(
" <description>HuggingFace Daily Papers — 中文论文导览站</description>"
)
lines.append(" <language>zh-CN</language>")
for paper in papers:
title_text = paper.title_zh or paper.title_en
link = f"{base_url}/paper/{paper.arxiv_id}"
desc = ""
if paper.summary and paper.summary.one_line:
desc = paper.summary.one_line
elif paper.abstract:
desc = paper.abstract[:500]
pub_date = ""
if paper.paper_date:
# RFC 822 格式
pub_date = paper.paper_date.strftime("%a, %d %b %Y 00:00:00 +0800")
lines.append(" <item>")
lines.append(f" <title>{escape(title_text)}</title>")
lines.append(f" <link>{escape(link)}</link>")
lines.append(f" <description>{escape(desc)}</description>")
if pub_date:
lines.append(f" <pubDate>{pub_date}</pubDate>")
lines.append(f" <guid>{escape(link)}</guid>")
lines.append(" </item>")
lines.append(" </channel>")
lines.append("</rss>")
return "\n".join(lines)