feat: add compare, trends routes, embedder service, and phase5 tests
This commit is contained in:
+128
-1
@@ -1,9 +1,13 @@
|
||||
"""页面路由 — 首页、日期页、论文详情。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import date, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
@@ -12,6 +16,8 @@ from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.models import Paper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
@@ -99,11 +105,132 @@ def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db))
|
||||
if paper.summary_status:
|
||||
summary_state = paper.summary_status.status
|
||||
|
||||
# Phase 5: 相似论文推荐
|
||||
similar_papers = _get_similar_papers(db, arxiv_id, top_k=6)
|
||||
|
||||
# Phase 5: 图片画廊
|
||||
images = _get_paper_images(arxiv_id)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request, "detail.html",
|
||||
{
|
||||
"paper": paper,
|
||||
"summary_state": summary_state,
|
||||
"similar_papers": similar_papers,
|
||||
"paper_images": images,
|
||||
"chroma_enabled": settings.CHROMA_ENABLED,
|
||||
"page_title": paper.title_zh or paper.title_en,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── 相似论文 API (Phase 5) ────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/api/similar/{arxiv_id}")
|
||||
def similar_api(
|
||||
arxiv_id: str,
|
||||
top_k: int = Query(default=5, ge=1, le=20),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""返回与指定论文相似的论文列表(JSON)。"""
|
||||
similar = _get_similar_papers(db, arxiv_id, top_k=top_k + 1)
|
||||
# 排除自身
|
||||
items = [s for s in similar if s["arxiv_id"] != arxiv_id][:top_k]
|
||||
return {"results": items}
|
||||
|
||||
|
||||
def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict]:
|
||||
"""从 ChromaDB 获取相似论文,返回 [{arxiv_id, title_zh, distance, paper_date}, ...]。"""
|
||||
if not settings.CHROMA_ENABLED:
|
||||
return []
|
||||
|
||||
try:
|
||||
from app.services.embedder import search_similar
|
||||
|
||||
# 用论文的 arxiv_id 从 ChromaDB 查询
|
||||
col = None
|
||||
try:
|
||||
from app.services.embedder import get_collection
|
||||
col = get_collection()
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
if col is None:
|
||||
return []
|
||||
|
||||
# 获取当前论文的 embedding
|
||||
result = col.get(ids=[arxiv_id], include=["embeddings"])
|
||||
if not result["embeddings"] or not result["embeddings"][0]:
|
||||
return []
|
||||
|
||||
vec = result["embeddings"][0]
|
||||
count = col.count()
|
||||
if count == 0:
|
||||
return []
|
||||
|
||||
results = col.query(
|
||||
query_embeddings=[vec],
|
||||
n_results=min(top_k, count),
|
||||
include=["metadatas", "distances"],
|
||||
)
|
||||
|
||||
if not results["ids"] or not results["ids"][0]:
|
||||
return []
|
||||
|
||||
# 从 DB 加载论文信息
|
||||
similar_ids = results["ids"][0]
|
||||
distances = results["distances"][0] if results["distances"] else [0.0] * len(similar_ids)
|
||||
|
||||
# 排除自身
|
||||
papers_info = {}
|
||||
for i, sid in enumerate(similar_ids):
|
||||
if sid != arxiv_id:
|
||||
papers_info[sid] = distances[i]
|
||||
|
||||
if not papers_info:
|
||||
return []
|
||||
|
||||
papers = (
|
||||
db.query(Paper)
|
||||
.filter(Paper.arxiv_id.in_(list(papers_info.keys())))
|
||||
.options(joinedload(Paper.tags))
|
||||
.all()
|
||||
)
|
||||
|
||||
items = []
|
||||
for p in papers:
|
||||
items.append({
|
||||
"arxiv_id": p.arxiv_id,
|
||||
"title_zh": p.title_zh or p.title_en,
|
||||
"distance": papers_info.get(p.arxiv_id, 0.0),
|
||||
"paper_date": p.paper_date.isoformat() if p.paper_date else "",
|
||||
"tags": [t.tag for t in p.tags[:3]],
|
||||
})
|
||||
|
||||
# 按距离排序
|
||||
items.sort(key=lambda x: x["distance"])
|
||||
return items
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to get similar papers for %s", arxiv_id)
|
||||
return []
|
||||
|
||||
|
||||
# ── 图片画廊 (Phase 5) ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_paper_images(arxiv_id: str) -> list[dict]:
|
||||
"""获取论文提取的图片列表。"""
|
||||
images_dir = Path("data/papers") / arxiv_id / "images"
|
||||
if not images_dir.exists():
|
||||
return []
|
||||
|
||||
images = []
|
||||
for img_file in sorted(images_dir.iterdir()):
|
||||
if img_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg"):
|
||||
images.append({
|
||||
"url": f"/papers/{arxiv_id}/images/{img_file.name}",
|
||||
"name": img_file.name,
|
||||
})
|
||||
return images
|
||||
|
||||
Reference in New Issue
Block a user