209 lines
6.8 KiB
Python
209 lines
6.8 KiB
Python
"""搜索、阅读列表、RSS Feed 路由。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import date, timedelta
|
|
from xml.sax.saxutils import escape
|
|
|
|
from fastapi import APIRouter, Depends, Query, Request
|
|
from fastapi.responses import Response
|
|
from sqlalchemy import text
|
|
from sqlalchemy.orm import Session, joinedload
|
|
|
|
from app.config import settings
|
|
from app.database import get_db
|
|
from app.models import Paper, PaperTag, UserReadingStatus
|
|
from app.services.searcher import get_all_tags, search_papers
|
|
from app.services.user_data import query_reading_list
|
|
from app.utils import templates, today_str
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# ── 搜索页 ────────────────────────────────────────────────────────────
|
|
|
|
|
|
@router.get("/search")
|
|
def search_page(
|
|
request: Request,
|
|
q: str = Query(default=""),
|
|
tag: str = Query(default=""),
|
|
sort: str = Query(default="relevance"),
|
|
mode: str = Query(default="keyword"),
|
|
page: int = Query(default=1, ge=1),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""搜索页面,支持 keyword 和 semantic 模式。"""
|
|
result = search_papers(
|
|
db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode
|
|
)
|
|
all_tags = get_all_tags(db)
|
|
|
|
return templates.TemplateResponse(
|
|
request,
|
|
"search.html",
|
|
{
|
|
"query": q,
|
|
"tag": tag,
|
|
"sort": sort,
|
|
"mode": mode,
|
|
"chroma_enabled": settings.CHROMA_ENABLED,
|
|
"results": result["results"],
|
|
"snippets": result["snippets"],
|
|
"distances": result.get("distances", {}),
|
|
"total": result["total"],
|
|
"page": result["page"],
|
|
"total_pages": result["total_pages"],
|
|
"all_tags": all_tags,
|
|
"page_title": f"搜索: {q}" if q else "搜索",
|
|
"today": today_str(),
|
|
},
|
|
)
|
|
|
|
|
|
# ── 搜索 JSON API ─────────────────────────────────────────────────────
|
|
|
|
|
|
@router.get("/api/search")
|
|
def search_api(
|
|
q: str = Query(default=""),
|
|
tag: str = Query(default=""),
|
|
sort: str = Query(default="relevance"),
|
|
mode: str = Query(default="keyword"),
|
|
page: int = Query(default=1, ge=1),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""搜索 JSON API,支持 keyword 和 semantic 模式。"""
|
|
result = search_papers(
|
|
db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode
|
|
)
|
|
distances = result.get("distances", {})
|
|
|
|
items = []
|
|
for paper in result["results"]:
|
|
snippet = result["snippets"].get(paper.id, {})
|
|
item = {
|
|
"arxiv_id": paper.arxiv_id,
|
|
"title_en": paper.title_en,
|
|
"title_zh": paper.title_zh,
|
|
"paper_date": paper.paper_date.isoformat() if paper.paper_date else None,
|
|
"upvotes": paper.upvotes,
|
|
"tags": [t.tag for t in paper.tags],
|
|
"authors": [a.name for a in paper.authors],
|
|
"snippet_title_zh": snippet.get("title_zh"),
|
|
"snippet_abstract": snippet.get("abstract"),
|
|
}
|
|
if paper.arxiv_id in distances:
|
|
item["distance"] = distances[paper.arxiv_id]
|
|
items.append(item)
|
|
|
|
return {
|
|
"results": items,
|
|
"total": result["total"],
|
|
"page": result["page"],
|
|
"total_pages": result["total_pages"],
|
|
}
|
|
|
|
|
|
# ── 阅读列表 ──────────────────────────────────────────────────────────
|
|
|
|
|
|
@router.get("/reading-list")
|
|
def reading_list_page(
|
|
request: Request,
|
|
filter: str = Query(default="all"),
|
|
tag: str = Query(default=""),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""阅读列表页面。"""
|
|
papers = query_reading_list(db, filter, tag or None)
|
|
all_tags = get_all_tags(db)
|
|
|
|
return templates.TemplateResponse(
|
|
request,
|
|
"reading_list.html",
|
|
{
|
|
"papers": papers,
|
|
"current_filter": filter,
|
|
"current_tag": tag,
|
|
"all_tags": all_tags,
|
|
"page_title": "阅读列表",
|
|
"today": today_str(),
|
|
},
|
|
)
|
|
|
|
|
|
# ── RSS Feed ──────────────────────────────────────────────────────────
|
|
|
|
|
|
@router.get("/rss.xml")
|
|
def rss_feed(
|
|
tag: str = Query(default=""),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""RSS 2.0 Feed — 最近 7 天论文。"""
|
|
seven_days_ago = date.today() - timedelta(days=7)
|
|
|
|
query = (
|
|
db.query(Paper)
|
|
.filter(Paper.paper_date >= seven_days_ago)
|
|
.options(
|
|
joinedload(Paper.authors),
|
|
joinedload(Paper.tags),
|
|
joinedload(Paper.summary),
|
|
)
|
|
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
|
|
)
|
|
|
|
if tag:
|
|
query = query.filter(Paper.tags.any(PaperTag.tag == tag))
|
|
|
|
papers = query.all()
|
|
xml = _generate_rss_xml(papers, settings.BASE_URL, tag or None)
|
|
return Response(content=xml, media_type="application/xml")
|
|
|
|
|
|
def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> str:
|
|
"""生成 RSS 2.0 XML。"""
|
|
lines = ['<?xml version="1.0" encoding="UTF-8"?>']
|
|
lines.append('<rss version="2.0">')
|
|
lines.append(" <channel>")
|
|
|
|
channel_title = "HF Daily Papers"
|
|
if tag:
|
|
channel_title += f" — {tag}"
|
|
lines.append(f" <title>{escape(channel_title)}</title>")
|
|
lines.append(f" <link>{escape(base_url)}</link>")
|
|
lines.append(
|
|
" <description>HuggingFace Daily Papers — 中文论文导览站</description>"
|
|
)
|
|
lines.append(" <language>zh-CN</language>")
|
|
|
|
for paper in papers:
|
|
title_text = paper.title_zh or paper.title_en
|
|
link = f"{base_url}/paper/{paper.arxiv_id}"
|
|
|
|
desc = ""
|
|
if paper.summary and paper.summary.one_line:
|
|
desc = paper.summary.one_line
|
|
elif paper.abstract:
|
|
desc = paper.abstract[:500]
|
|
|
|
pub_date = ""
|
|
if paper.paper_date:
|
|
# RFC 822 格式
|
|
pub_date = paper.paper_date.strftime("%a, %d %b %Y 00:00:00 +0800")
|
|
|
|
lines.append(" <item>")
|
|
lines.append(f" <title>{escape(title_text)}</title>")
|
|
lines.append(f" <link>{escape(link)}</link>")
|
|
lines.append(f" <description>{escape(desc)}</description>")
|
|
if pub_date:
|
|
lines.append(f" <pubDate>{pub_date}</pubDate>")
|
|
lines.append(f" <guid>{escape(link)}</guid>")
|
|
lines.append(" </item>")
|
|
|
|
lines.append(" </channel>")
|
|
lines.append("</rss>")
|
|
return "\n".join(lines)
|