"""页面路由 — 首页、日期页、论文详情。""" from __future__ import annotations import logging from datetime import date, timedelta from pathlib import Path from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.responses import RedirectResponse from sqlalchemy.orm import Session, joinedload from app.config import settings from app.database import get_db from app.models import Paper from app.utils import templates, today_str logger = logging.getLogger(__name__) router = APIRouter() @router.get("/") def index(request: Request): """重定向到 /day/{today}。""" return RedirectResponse(url=f"/day/{today_str()}") @router.get("/day/{date_str}") def day_page(date_str: str, request: Request, db: Session = Depends(get_db)): """指定日期论文列表。""" try: target = date.fromisoformat(date_str) except ValueError: raise HTTPException(status_code=404, detail="Invalid date format") prev_day = (target - timedelta(days=1)).isoformat() next_day = (target + timedelta(days=1)).isoformat() today = today_str() papers = ( db.query(Paper) .filter(Paper.paper_date == date_str) .options( joinedload(Paper.authors), joinedload(Paper.tags), joinedload(Paper.summary_status), joinedload(Paper.bookmark), ) .order_by(Paper.upvotes.desc()) .all() ) dates_raw = ( db.query(Paper.paper_date) .distinct() .order_by(Paper.paper_date.desc()) .limit(30) .all() ) available_dates = [ d[0].isoformat() if isinstance(d[0], date) else str(d[0]) for d in dates_raw ] return templates.TemplateResponse( request, "index.html", { "papers": papers, "current_date": date_str, "prev_day": prev_day, "next_day": next_day, "today": today, "available_dates": available_dates, "page_title": f"{date_str} 论文列表", }, ) @router.get("/paper/{arxiv_id}") def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db)): """论文详情页。""" paper = ( db.query(Paper) .filter(Paper.arxiv_id == arxiv_id) .options( joinedload(Paper.authors), joinedload(Paper.tags), joinedload(Paper.summary), joinedload(Paper.summary_status), joinedload(Paper.bookmark), joinedload(Paper.reading_status), joinedload(Paper.note), ) .first() ) if not paper: raise HTTPException(status_code=404, detail="Paper not found") summary_state = "none" if paper.summary_status: summary_state = paper.summary_status.status # Phase 5: 相似论文推荐 similar_papers = _get_similar_papers(db, arxiv_id, top_k=6) # Phase 5: 图片画廊 images = _get_paper_images(arxiv_id) return templates.TemplateResponse( request, "detail.html", { "paper": paper, "summary_state": summary_state, "similar_papers": similar_papers, "paper_images": images, "chroma_enabled": settings.CHROMA_ENABLED, "page_title": paper.title_zh or paper.title_en, }, ) # ── 相似论文 API (Phase 5) ──────────────────────────────────────────── @router.get("/api/similar/{arxiv_id}") def similar_api( arxiv_id: str, top_k: int = Query(default=5, ge=1, le=20), db: Session = Depends(get_db), ): """返回与指定论文相似的论文列表(JSON)。""" similar = _get_similar_papers(db, arxiv_id, top_k=top_k + 1) # 排除自身 items = [s for s in similar if s["arxiv_id"] != arxiv_id][:top_k] return {"results": items} def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict]: """从 ChromaDB 获取相似论文,返回 [{arxiv_id, title_zh, distance, paper_date}, ...]。""" if not settings.CHROMA_ENABLED: return [] try: from app.services.embedder import get_collection col = get_collection() if col is None: return [] # 获取当前论文的 embedding result = col.get(ids=[arxiv_id], include=["embeddings"]) if not result["embeddings"] or not result["embeddings"][0]: return [] vec = result["embeddings"][0] count = col.count() if count == 0: return [] results = col.query( query_embeddings=[vec], n_results=min(top_k, count), include=["metadatas", "distances"], ) if not results["ids"] or not results["ids"][0]: return [] # 从 DB 加载论文信息 similar_ids = results["ids"][0] distances = ( results["distances"][0] if results["distances"] else [0.0] * len(similar_ids) ) # 排除自身 papers_info = {} for i, sid in enumerate(similar_ids): if sid != arxiv_id: papers_info[sid] = distances[i] if not papers_info: return [] papers = ( db.query(Paper) .filter(Paper.arxiv_id.in_(list(papers_info.keys()))) .options(joinedload(Paper.tags)) .all() ) items = [] for p in papers: items.append( { "arxiv_id": p.arxiv_id, "title_zh": p.title_zh or p.title_en, "distance": papers_info.get(p.arxiv_id, 0.0), "paper_date": p.paper_date.isoformat() if p.paper_date else "", "tags": [t.tag for t in p.tags[:3]], } ) # 按距离排序 items.sort(key=lambda x: x["distance"]) return items except Exception: logger.exception("Failed to get similar papers for %s", arxiv_id) return [] # ── 图片画廊 (Phase 5) ──────────────────────────────────────────────── def _get_paper_images(arxiv_id: str) -> list[dict]: """获取论文提取的图片列表。""" images_dir = Path("data/papers") / arxiv_id / "images" if not images_dir.exists(): return [] images = [] for img_file in sorted(images_dir.iterdir()): if img_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg"): images.append( { "url": f"/papers/{arxiv_id}/images/{img_file.name}", "name": img_file.name, } ) return images