feat: add compare, trends routes, embedder service, and phase5 tests
This commit is contained in:
@@ -0,0 +1,115 @@
|
||||
"""论文对比页路由 — 多篇论文结构化字段并排对比。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Paper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
|
||||
@router.get("/compare")
|
||||
def compare_page(
|
||||
request: Request,
|
||||
ids: str = Query(default="", description="逗号分隔的 arxiv_id,最多 5 篇"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""论文对比页面。GET /compare?ids=id1,id2,id3"""
|
||||
if not ids:
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"compare.html",
|
||||
{
|
||||
"page_title": "论文对比",
|
||||
"papers": [],
|
||||
"error": None,
|
||||
},
|
||||
)
|
||||
|
||||
arxiv_ids = [i.strip() for i in ids.split(",") if i.strip()]
|
||||
|
||||
# 最多 5 篇
|
||||
if len(arxiv_ids) > 5:
|
||||
arxiv_ids = arxiv_ids[:5]
|
||||
|
||||
if not arxiv_ids:
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"compare.html",
|
||||
{
|
||||
"page_title": "论文对比",
|
||||
"papers": [],
|
||||
"error": "请提供有效的论文 ID",
|
||||
},
|
||||
)
|
||||
|
||||
papers = (
|
||||
db.query(Paper)
|
||||
.filter(Paper.arxiv_id.in_(arxiv_ids))
|
||||
.options(
|
||||
joinedload(Paper.authors),
|
||||
joinedload(Paper.tags),
|
||||
joinedload(Paper.summary),
|
||||
joinedload(Paper.summary_status),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 按请求顺序排列
|
||||
id_order = {aid: idx for idx, aid in enumerate(arxiv_ids)}
|
||||
papers.sort(key=lambda p: id_order.get(p.arxiv_id, 999))
|
||||
|
||||
# 构建对比数据
|
||||
compare_fields = [
|
||||
("title_zh", "中文标题"),
|
||||
("title_en", "英文标题"),
|
||||
("one_line", "一句话摘要"),
|
||||
("difficulty", "难度"),
|
||||
("motivation_problem", "研究问题"),
|
||||
("motivation_goal", "研究目标"),
|
||||
("motivation_gap", "研究差距"),
|
||||
("method_overview", "方法概述"),
|
||||
("method_key_idea", "关键思路"),
|
||||
("method_novelty", "新颖性"),
|
||||
("results", "实验结果"),
|
||||
("limitations", "局限与改进"),
|
||||
]
|
||||
|
||||
rows = []
|
||||
for field_key, field_label in compare_fields:
|
||||
cells = []
|
||||
for paper in papers:
|
||||
if field_key in ("title_zh", "title_en"):
|
||||
val = getattr(paper, field_key, None) or ""
|
||||
elif paper.summary:
|
||||
val = getattr(paper.summary, field_key, None) or ""
|
||||
# JSON 字段直接展示
|
||||
if field_key == "results" and not val:
|
||||
val = paper.summary.results_main_json or ""
|
||||
if field_key == "limitations" and not val:
|
||||
val = paper.summary.limitations_json or ""
|
||||
else:
|
||||
val = ""
|
||||
cells.append(val)
|
||||
rows.append({"key": field_key, "label": field_label, "cells": cells})
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"compare.html",
|
||||
{
|
||||
"page_title": "论文对比",
|
||||
"papers": papers,
|
||||
"rows": rows,
|
||||
"ids_param": ids,
|
||||
"error": None,
|
||||
},
|
||||
)
|
||||
+128
-1
@@ -1,9 +1,13 @@
|
||||
"""页面路由 — 首页、日期页、论文详情。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import date, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
@@ -12,6 +16,8 @@ from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.models import Paper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
@@ -99,11 +105,132 @@ def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db))
|
||||
if paper.summary_status:
|
||||
summary_state = paper.summary_status.status
|
||||
|
||||
# Phase 5: 相似论文推荐
|
||||
similar_papers = _get_similar_papers(db, arxiv_id, top_k=6)
|
||||
|
||||
# Phase 5: 图片画廊
|
||||
images = _get_paper_images(arxiv_id)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request, "detail.html",
|
||||
{
|
||||
"paper": paper,
|
||||
"summary_state": summary_state,
|
||||
"similar_papers": similar_papers,
|
||||
"paper_images": images,
|
||||
"chroma_enabled": settings.CHROMA_ENABLED,
|
||||
"page_title": paper.title_zh or paper.title_en,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── 相似论文 API (Phase 5) ────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/api/similar/{arxiv_id}")
|
||||
def similar_api(
|
||||
arxiv_id: str,
|
||||
top_k: int = Query(default=5, ge=1, le=20),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""返回与指定论文相似的论文列表(JSON)。"""
|
||||
similar = _get_similar_papers(db, arxiv_id, top_k=top_k + 1)
|
||||
# 排除自身
|
||||
items = [s for s in similar if s["arxiv_id"] != arxiv_id][:top_k]
|
||||
return {"results": items}
|
||||
|
||||
|
||||
def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict]:
|
||||
"""从 ChromaDB 获取相似论文,返回 [{arxiv_id, title_zh, distance, paper_date}, ...]。"""
|
||||
if not settings.CHROMA_ENABLED:
|
||||
return []
|
||||
|
||||
try:
|
||||
from app.services.embedder import search_similar
|
||||
|
||||
# 用论文的 arxiv_id 从 ChromaDB 查询
|
||||
col = None
|
||||
try:
|
||||
from app.services.embedder import get_collection
|
||||
col = get_collection()
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
if col is None:
|
||||
return []
|
||||
|
||||
# 获取当前论文的 embedding
|
||||
result = col.get(ids=[arxiv_id], include=["embeddings"])
|
||||
if not result["embeddings"] or not result["embeddings"][0]:
|
||||
return []
|
||||
|
||||
vec = result["embeddings"][0]
|
||||
count = col.count()
|
||||
if count == 0:
|
||||
return []
|
||||
|
||||
results = col.query(
|
||||
query_embeddings=[vec],
|
||||
n_results=min(top_k, count),
|
||||
include=["metadatas", "distances"],
|
||||
)
|
||||
|
||||
if not results["ids"] or not results["ids"][0]:
|
||||
return []
|
||||
|
||||
# 从 DB 加载论文信息
|
||||
similar_ids = results["ids"][0]
|
||||
distances = results["distances"][0] if results["distances"] else [0.0] * len(similar_ids)
|
||||
|
||||
# 排除自身
|
||||
papers_info = {}
|
||||
for i, sid in enumerate(similar_ids):
|
||||
if sid != arxiv_id:
|
||||
papers_info[sid] = distances[i]
|
||||
|
||||
if not papers_info:
|
||||
return []
|
||||
|
||||
papers = (
|
||||
db.query(Paper)
|
||||
.filter(Paper.arxiv_id.in_(list(papers_info.keys())))
|
||||
.options(joinedload(Paper.tags))
|
||||
.all()
|
||||
)
|
||||
|
||||
items = []
|
||||
for p in papers:
|
||||
items.append({
|
||||
"arxiv_id": p.arxiv_id,
|
||||
"title_zh": p.title_zh or p.title_en,
|
||||
"distance": papers_info.get(p.arxiv_id, 0.0),
|
||||
"paper_date": p.paper_date.isoformat() if p.paper_date else "",
|
||||
"tags": [t.tag for t in p.tags[:3]],
|
||||
})
|
||||
|
||||
# 按距离排序
|
||||
items.sort(key=lambda x: x["distance"])
|
||||
return items
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to get similar papers for %s", arxiv_id)
|
||||
return []
|
||||
|
||||
|
||||
# ── 图片画廊 (Phase 5) ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_paper_images(arxiv_id: str) -> list[dict]:
|
||||
"""获取论文提取的图片列表。"""
|
||||
images_dir = Path("data/papers") / arxiv_id / "images"
|
||||
if not images_dir.exists():
|
||||
return []
|
||||
|
||||
images = []
|
||||
for img_file in sorted(images_dir.iterdir()):
|
||||
if img_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg"):
|
||||
images.append({
|
||||
"url": f"/papers/{arxiv_id}/images/{img_file.name}",
|
||||
"name": img_file.name,
|
||||
})
|
||||
return images
|
||||
|
||||
+24
-17
@@ -31,11 +31,12 @@ def search_page(
|
||||
q: str = Query(default=""),
|
||||
tag: str = Query(default=""),
|
||||
sort: str = Query(default="relevance"),
|
||||
mode: str = Query(default="keyword"),
|
||||
page: int = Query(default=1, ge=1),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""搜索页面。"""
|
||||
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page)
|
||||
"""搜索页面,支持 keyword 和 semantic 模式。"""
|
||||
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode)
|
||||
all_tags = get_all_tags(db)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
@@ -45,8 +46,11 @@ def search_page(
|
||||
"query": q,
|
||||
"tag": tag,
|
||||
"sort": sort,
|
||||
"mode": mode,
|
||||
"chroma_enabled": settings.CHROMA_ENABLED,
|
||||
"results": result["results"],
|
||||
"snippets": result["snippets"],
|
||||
"distances": result.get("distances", {}),
|
||||
"total": result["total"],
|
||||
"page": result["page"],
|
||||
"total_pages": result["total_pages"],
|
||||
@@ -65,28 +69,31 @@ def search_api(
|
||||
q: str = Query(default=""),
|
||||
tag: str = Query(default=""),
|
||||
sort: str = Query(default="relevance"),
|
||||
mode: str = Query(default="keyword"),
|
||||
page: int = Query(default=1, ge=1),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""搜索 JSON API。"""
|
||||
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page)
|
||||
"""搜索 JSON API,支持 keyword 和 semantic 模式。"""
|
||||
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode)
|
||||
distances = result.get("distances", {})
|
||||
|
||||
items = []
|
||||
for paper in result["results"]:
|
||||
snippet = result["snippets"].get(paper.id, {})
|
||||
items.append(
|
||||
{
|
||||
"arxiv_id": paper.arxiv_id,
|
||||
"title_en": paper.title_en,
|
||||
"title_zh": paper.title_zh,
|
||||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else None,
|
||||
"upvotes": paper.upvotes,
|
||||
"tags": [t.tag for t in paper.tags],
|
||||
"authors": [a.name for a in paper.authors],
|
||||
"snippet_title_zh": snippet.get("title_zh"),
|
||||
"snippet_abstract": snippet.get("abstract"),
|
||||
}
|
||||
)
|
||||
item = {
|
||||
"arxiv_id": paper.arxiv_id,
|
||||
"title_en": paper.title_en,
|
||||
"title_zh": paper.title_zh,
|
||||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else None,
|
||||
"upvotes": paper.upvotes,
|
||||
"tags": [t.tag for t in paper.tags],
|
||||
"authors": [a.name for a in paper.authors],
|
||||
"snippet_title_zh": snippet.get("title_zh"),
|
||||
"snippet_abstract": snippet.get("abstract"),
|
||||
}
|
||||
if paper.arxiv_id in distances:
|
||||
item["distance"] = distances[paper.arxiv_id]
|
||||
items.append(item)
|
||||
|
||||
return {
|
||||
"results": items,
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
"""趋势看板路由 — 论文统计图表页面和数据 API。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import date, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy import func, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
|
||||
@router.get("/trends")
|
||||
def trends_page(request: Request, db: Session = Depends(get_db)):
|
||||
"""趋势看板页面。"""
|
||||
stats = _get_trends_data(db)
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"trends.html",
|
||||
{
|
||||
"page_title": "趋势看板",
|
||||
"stats": stats,
|
||||
"today": _today_str(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/stats/trends")
|
||||
def trends_api(db: Session = Depends(get_db)):
|
||||
"""趋势数据 JSON API。"""
|
||||
return _get_trends_data(db)
|
||||
|
||||
|
||||
def _get_trends_data(db: Session) -> dict:
|
||||
"""从 DB 聚合趋势数据。"""
|
||||
thirty_days_ago = (date.today() - timedelta(days=30)).isoformat()
|
||||
|
||||
# 1. 按日论文数量(近 30 天)
|
||||
daily_rows = db.execute(text("""
|
||||
SELECT paper_date, COUNT(*) as cnt
|
||||
FROM papers
|
||||
WHERE paper_date >= :start_date
|
||||
GROUP BY paper_date
|
||||
ORDER BY paper_date ASC
|
||||
"""), {"start_date": thirty_days_ago}).fetchall()
|
||||
daily_counts = [
|
||||
{"date": str(row[0]), "count": row[1]}
|
||||
for row in daily_rows
|
||||
]
|
||||
|
||||
# 2. 热门标签 Top 20
|
||||
tag_rows = db.execute(text("""
|
||||
SELECT tag, COUNT(*) as cnt
|
||||
FROM paper_tags
|
||||
GROUP BY tag
|
||||
ORDER BY cnt DESC
|
||||
LIMIT 20
|
||||
""")).fetchall()
|
||||
top_tags = [
|
||||
{"tag": row[0], "count": row[1]}
|
||||
for row in tag_rows
|
||||
]
|
||||
|
||||
# 3. Upvotes 分布
|
||||
upvote_rows = db.execute(text("""
|
||||
SELECT
|
||||
CASE
|
||||
WHEN upvotes >= 100 THEN '100+'
|
||||
WHEN upvotes >= 50 THEN '50-99'
|
||||
WHEN upvotes >= 20 THEN '20-49'
|
||||
WHEN upvotes >= 10 THEN '10-19'
|
||||
WHEN upvotes >= 5 THEN '5-9'
|
||||
ELSE '0-4'
|
||||
END as bucket,
|
||||
COUNT(*) as cnt
|
||||
FROM papers
|
||||
GROUP BY bucket
|
||||
ORDER BY MIN(upvotes) DESC
|
||||
""")).fetchall()
|
||||
upvotes_dist = [
|
||||
{"range": row[0], "count": row[1]}
|
||||
for row in upvote_rows
|
||||
]
|
||||
|
||||
# 4. 总结完成率
|
||||
summary_rows = db.execute(text("""
|
||||
SELECT
|
||||
COALESCE(ss.status, 'none') as status,
|
||||
COUNT(*) as cnt
|
||||
FROM papers p
|
||||
LEFT JOIN summary_status ss ON ss.paper_id = p.id
|
||||
GROUP BY status
|
||||
""")).fetchall()
|
||||
summary_completion = [
|
||||
{"status": row[0], "count": row[1]}
|
||||
for row in summary_rows
|
||||
]
|
||||
|
||||
return {
|
||||
"daily_counts": daily_counts,
|
||||
"top_tags": top_tags,
|
||||
"upvotes_dist": upvotes_dist,
|
||||
"summary_completion": summary_completion,
|
||||
}
|
||||
|
||||
|
||||
def _today_str() -> str:
|
||||
from datetime import datetime
|
||||
from zoneinfo import ZoneInfo
|
||||
tz = ZoneInfo(settings.APP_TIMEZONE)
|
||||
return datetime.now(tz).strftime("%Y-%m-%d")
|
||||
Reference in New Issue
Block a user