feat: add compare, trends routes, embedder service, and phase5 tests

This commit is contained in:
2026-06-05 23:32:06 +08:00
parent 2cfd1a8a9f
commit ba9afa212c
17 changed files with 2122 additions and 27 deletions
+115
View File
@@ -0,0 +1,115 @@
"""论文对比页路由 — 多篇论文结构化字段并排对比。"""
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session, joinedload
from app.database import get_db
from app.models import Paper
logger = logging.getLogger(__name__)
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
@router.get("/compare")
def compare_page(
request: Request,
ids: str = Query(default="", description="逗号分隔的 arxiv_id,最多 5 篇"),
db: Session = Depends(get_db),
):
"""论文对比页面。GET /compare?ids=id1,id2,id3"""
if not ids:
return templates.TemplateResponse(
request,
"compare.html",
{
"page_title": "论文对比",
"papers": [],
"error": None,
},
)
arxiv_ids = [i.strip() for i in ids.split(",") if i.strip()]
# 最多 5 篇
if len(arxiv_ids) > 5:
arxiv_ids = arxiv_ids[:5]
if not arxiv_ids:
return templates.TemplateResponse(
request,
"compare.html",
{
"page_title": "论文对比",
"papers": [],
"error": "请提供有效的论文 ID",
},
)
papers = (
db.query(Paper)
.filter(Paper.arxiv_id.in_(arxiv_ids))
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary),
joinedload(Paper.summary_status),
)
.all()
)
# 按请求顺序排列
id_order = {aid: idx for idx, aid in enumerate(arxiv_ids)}
papers.sort(key=lambda p: id_order.get(p.arxiv_id, 999))
# 构建对比数据
compare_fields = [
("title_zh", "中文标题"),
("title_en", "英文标题"),
("one_line", "一句话摘要"),
("difficulty", "难度"),
("motivation_problem", "研究问题"),
("motivation_goal", "研究目标"),
("motivation_gap", "研究差距"),
("method_overview", "方法概述"),
("method_key_idea", "关键思路"),
("method_novelty", "新颖性"),
("results", "实验结果"),
("limitations", "局限与改进"),
]
rows = []
for field_key, field_label in compare_fields:
cells = []
for paper in papers:
if field_key in ("title_zh", "title_en"):
val = getattr(paper, field_key, None) or ""
elif paper.summary:
val = getattr(paper.summary, field_key, None) or ""
# JSON 字段直接展示
if field_key == "results" and not val:
val = paper.summary.results_main_json or ""
if field_key == "limitations" and not val:
val = paper.summary.limitations_json or ""
else:
val = ""
cells.append(val)
rows.append({"key": field_key, "label": field_label, "cells": cells})
return templates.TemplateResponse(
request,
"compare.html",
{
"page_title": "论文对比",
"papers": papers,
"rows": rows,
"ids_param": ids,
"error": None,
},
)
+128 -1
View File
@@ -1,9 +1,13 @@
"""页面路由 — 首页、日期页、论文详情。"""
from __future__ import annotations
import logging
from datetime import date, datetime, timedelta
from pathlib import Path
from zoneinfo import ZoneInfo
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import RedirectResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session, joinedload
@@ -12,6 +16,8 @@ from app.config import settings
from app.database import get_db
from app.models import Paper
logger = logging.getLogger(__name__)
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
@@ -99,11 +105,132 @@ def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db))
if paper.summary_status:
summary_state = paper.summary_status.status
# Phase 5: 相似论文推荐
similar_papers = _get_similar_papers(db, arxiv_id, top_k=6)
# Phase 5: 图片画廊
images = _get_paper_images(arxiv_id)
return templates.TemplateResponse(
request, "detail.html",
{
"paper": paper,
"summary_state": summary_state,
"similar_papers": similar_papers,
"paper_images": images,
"chroma_enabled": settings.CHROMA_ENABLED,
"page_title": paper.title_zh or paper.title_en,
},
)
# ── 相似论文 API (Phase 5) ────────────────────────────────────────────
@router.get("/api/similar/{arxiv_id}")
def similar_api(
arxiv_id: str,
top_k: int = Query(default=5, ge=1, le=20),
db: Session = Depends(get_db),
):
"""返回与指定论文相似的论文列表(JSON)。"""
similar = _get_similar_papers(db, arxiv_id, top_k=top_k + 1)
# 排除自身
items = [s for s in similar if s["arxiv_id"] != arxiv_id][:top_k]
return {"results": items}
def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict]:
"""从 ChromaDB 获取相似论文,返回 [{arxiv_id, title_zh, distance, paper_date}, ...]。"""
if not settings.CHROMA_ENABLED:
return []
try:
from app.services.embedder import search_similar
# 用论文的 arxiv_id 从 ChromaDB 查询
col = None
try:
from app.services.embedder import get_collection
col = get_collection()
except Exception:
return []
if col is None:
return []
# 获取当前论文的 embedding
result = col.get(ids=[arxiv_id], include=["embeddings"])
if not result["embeddings"] or not result["embeddings"][0]:
return []
vec = result["embeddings"][0]
count = col.count()
if count == 0:
return []
results = col.query(
query_embeddings=[vec],
n_results=min(top_k, count),
include=["metadatas", "distances"],
)
if not results["ids"] or not results["ids"][0]:
return []
# 从 DB 加载论文信息
similar_ids = results["ids"][0]
distances = results["distances"][0] if results["distances"] else [0.0] * len(similar_ids)
# 排除自身
papers_info = {}
for i, sid in enumerate(similar_ids):
if sid != arxiv_id:
papers_info[sid] = distances[i]
if not papers_info:
return []
papers = (
db.query(Paper)
.filter(Paper.arxiv_id.in_(list(papers_info.keys())))
.options(joinedload(Paper.tags))
.all()
)
items = []
for p in papers:
items.append({
"arxiv_id": p.arxiv_id,
"title_zh": p.title_zh or p.title_en,
"distance": papers_info.get(p.arxiv_id, 0.0),
"paper_date": p.paper_date.isoformat() if p.paper_date else "",
"tags": [t.tag for t in p.tags[:3]],
})
# 按距离排序
items.sort(key=lambda x: x["distance"])
return items
except Exception:
logger.exception("Failed to get similar papers for %s", arxiv_id)
return []
# ── 图片画廊 (Phase 5) ────────────────────────────────────────────────
def _get_paper_images(arxiv_id: str) -> list[dict]:
"""获取论文提取的图片列表。"""
images_dir = Path("data/papers") / arxiv_id / "images"
if not images_dir.exists():
return []
images = []
for img_file in sorted(images_dir.iterdir()):
if img_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg"):
images.append({
"url": f"/papers/{arxiv_id}/images/{img_file.name}",
"name": img_file.name,
})
return images
+24 -17
View File
@@ -31,11 +31,12 @@ def search_page(
q: str = Query(default=""),
tag: str = Query(default=""),
sort: str = Query(default="relevance"),
mode: str = Query(default="keyword"),
page: int = Query(default=1, ge=1),
db: Session = Depends(get_db),
):
"""搜索页面。"""
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page)
"""搜索页面,支持 keyword 和 semantic 模式"""
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode)
all_tags = get_all_tags(db)
return templates.TemplateResponse(
@@ -45,8 +46,11 @@ def search_page(
"query": q,
"tag": tag,
"sort": sort,
"mode": mode,
"chroma_enabled": settings.CHROMA_ENABLED,
"results": result["results"],
"snippets": result["snippets"],
"distances": result.get("distances", {}),
"total": result["total"],
"page": result["page"],
"total_pages": result["total_pages"],
@@ -65,28 +69,31 @@ def search_api(
q: str = Query(default=""),
tag: str = Query(default=""),
sort: str = Query(default="relevance"),
mode: str = Query(default="keyword"),
page: int = Query(default=1, ge=1),
db: Session = Depends(get_db),
):
"""搜索 JSON API。"""
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page)
"""搜索 JSON API,支持 keyword 和 semantic 模式"""
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode)
distances = result.get("distances", {})
items = []
for paper in result["results"]:
snippet = result["snippets"].get(paper.id, {})
items.append(
{
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"title_zh": paper.title_zh,
"paper_date": paper.paper_date.isoformat() if paper.paper_date else None,
"upvotes": paper.upvotes,
"tags": [t.tag for t in paper.tags],
"authors": [a.name for a in paper.authors],
"snippet_title_zh": snippet.get("title_zh"),
"snippet_abstract": snippet.get("abstract"),
}
)
item = {
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"title_zh": paper.title_zh,
"paper_date": paper.paper_date.isoformat() if paper.paper_date else None,
"upvotes": paper.upvotes,
"tags": [t.tag for t in paper.tags],
"authors": [a.name for a in paper.authors],
"snippet_title_zh": snippet.get("title_zh"),
"snippet_abstract": snippet.get("abstract"),
}
if paper.arxiv_id in distances:
item["distance"] = distances[paper.arxiv_id]
items.append(item)
return {
"results": items,
+120
View File
@@ -0,0 +1,120 @@
"""趋势看板路由 — 论文统计图表页面和数据 API。"""
from __future__ import annotations
import logging
from datetime import date, timedelta
from fastapi import APIRouter, Depends, Request
from fastapi.templating import Jinja2Templates
from sqlalchemy import func, text
from sqlalchemy.orm import Session
from app.config import settings
from app.database import get_db
logger = logging.getLogger(__name__)
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
@router.get("/trends")
def trends_page(request: Request, db: Session = Depends(get_db)):
"""趋势看板页面。"""
stats = _get_trends_data(db)
return templates.TemplateResponse(
request,
"trends.html",
{
"page_title": "趋势看板",
"stats": stats,
"today": _today_str(),
},
)
@router.get("/api/stats/trends")
def trends_api(db: Session = Depends(get_db)):
"""趋势数据 JSON API。"""
return _get_trends_data(db)
def _get_trends_data(db: Session) -> dict:
"""从 DB 聚合趋势数据。"""
thirty_days_ago = (date.today() - timedelta(days=30)).isoformat()
# 1. 按日论文数量(近 30 天)
daily_rows = db.execute(text("""
SELECT paper_date, COUNT(*) as cnt
FROM papers
WHERE paper_date >= :start_date
GROUP BY paper_date
ORDER BY paper_date ASC
"""), {"start_date": thirty_days_ago}).fetchall()
daily_counts = [
{"date": str(row[0]), "count": row[1]}
for row in daily_rows
]
# 2. 热门标签 Top 20
tag_rows = db.execute(text("""
SELECT tag, COUNT(*) as cnt
FROM paper_tags
GROUP BY tag
ORDER BY cnt DESC
LIMIT 20
""")).fetchall()
top_tags = [
{"tag": row[0], "count": row[1]}
for row in tag_rows
]
# 3. Upvotes 分布
upvote_rows = db.execute(text("""
SELECT
CASE
WHEN upvotes >= 100 THEN '100+'
WHEN upvotes >= 50 THEN '50-99'
WHEN upvotes >= 20 THEN '20-49'
WHEN upvotes >= 10 THEN '10-19'
WHEN upvotes >= 5 THEN '5-9'
ELSE '0-4'
END as bucket,
COUNT(*) as cnt
FROM papers
GROUP BY bucket
ORDER BY MIN(upvotes) DESC
""")).fetchall()
upvotes_dist = [
{"range": row[0], "count": row[1]}
for row in upvote_rows
]
# 4. 总结完成率
summary_rows = db.execute(text("""
SELECT
COALESCE(ss.status, 'none') as status,
COUNT(*) as cnt
FROM papers p
LEFT JOIN summary_status ss ON ss.paper_id = p.id
GROUP BY status
""")).fetchall()
summary_completion = [
{"status": row[0], "count": row[1]}
for row in summary_rows
]
return {
"daily_counts": daily_counts,
"top_tags": top_tags,
"upvotes_dist": upvotes_dist,
"summary_completion": summary_completion,
}
def _today_str() -> str:
from datetime import datetime
from zoneinfo import ZoneInfo
tz = ZoneInfo(settings.APP_TIMEZONE)
return datetime.now(tz).strftime("%Y-%m-%d")