feat: add compare, trends routes, embedder service, and phase5 tests

This commit is contained in:
2026-06-05 23:32:06 +08:00
parent 2cfd1a8a9f
commit ba9afa212c
17 changed files with 2122 additions and 27 deletions
+16
View File
@@ -5,13 +5,16 @@ import os
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from starlette.staticfiles import StaticFiles as StarletteStaticFiles
from app.config import settings
from app.database import engine
from app.models import init_db
from app.routes.admin import router as admin_router
from app.routes.compare import router as compare_router
from app.routes.pages import router as pages_router
from app.routes.search import router as search_router
from app.routes.trends import router as trends_router
from app.routes.user import router as user_router
logging.basicConfig(
@@ -49,11 +52,18 @@ def create_app() -> FastAPI:
# 静态文件
app.mount("/static", StaticFiles(directory="app/static"), name="static")
# Phase 5: 论文图片静态服务
papers_images_dir = os.path.join("data", "papers")
os.makedirs(papers_images_dir, exist_ok=True)
app.mount("/papers", StaticFiles(directory=papers_images_dir), name="papers")
# 路由
app.include_router(pages_router)
app.include_router(admin_router)
app.include_router(search_router)
app.include_router(user_router)
app.include_router(trends_router)
app.include_router(compare_router)
# 调度器(Phase 4
@app.on_event("startup")
@@ -61,6 +71,12 @@ def create_app() -> FastAPI:
from app.services.scheduler import start_scheduler
start_scheduler()
# Phase 5: 初始化 ChromaDB
@app.on_event("startup")
async def _init_chroma():
from app.services.embedder import init_chroma
init_chroma()
@app.on_event("shutdown")
async def _stop_scheduler():
from app.services.scheduler import stop_scheduler
+115
View File
@@ -0,0 +1,115 @@
"""论文对比页路由 — 多篇论文结构化字段并排对比。"""
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session, joinedload
from app.database import get_db
from app.models import Paper
logger = logging.getLogger(__name__)
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
@router.get("/compare")
def compare_page(
request: Request,
ids: str = Query(default="", description="逗号分隔的 arxiv_id,最多 5 篇"),
db: Session = Depends(get_db),
):
"""论文对比页面。GET /compare?ids=id1,id2,id3"""
if not ids:
return templates.TemplateResponse(
request,
"compare.html",
{
"page_title": "论文对比",
"papers": [],
"error": None,
},
)
arxiv_ids = [i.strip() for i in ids.split(",") if i.strip()]
# 最多 5 篇
if len(arxiv_ids) > 5:
arxiv_ids = arxiv_ids[:5]
if not arxiv_ids:
return templates.TemplateResponse(
request,
"compare.html",
{
"page_title": "论文对比",
"papers": [],
"error": "请提供有效的论文 ID",
},
)
papers = (
db.query(Paper)
.filter(Paper.arxiv_id.in_(arxiv_ids))
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary),
joinedload(Paper.summary_status),
)
.all()
)
# 按请求顺序排列
id_order = {aid: idx for idx, aid in enumerate(arxiv_ids)}
papers.sort(key=lambda p: id_order.get(p.arxiv_id, 999))
# 构建对比数据
compare_fields = [
("title_zh", "中文标题"),
("title_en", "英文标题"),
("one_line", "一句话摘要"),
("difficulty", "难度"),
("motivation_problem", "研究问题"),
("motivation_goal", "研究目标"),
("motivation_gap", "研究差距"),
("method_overview", "方法概述"),
("method_key_idea", "关键思路"),
("method_novelty", "新颖性"),
("results", "实验结果"),
("limitations", "局限与改进"),
]
rows = []
for field_key, field_label in compare_fields:
cells = []
for paper in papers:
if field_key in ("title_zh", "title_en"):
val = getattr(paper, field_key, None) or ""
elif paper.summary:
val = getattr(paper.summary, field_key, None) or ""
# JSON 字段直接展示
if field_key == "results" and not val:
val = paper.summary.results_main_json or ""
if field_key == "limitations" and not val:
val = paper.summary.limitations_json or ""
else:
val = ""
cells.append(val)
rows.append({"key": field_key, "label": field_label, "cells": cells})
return templates.TemplateResponse(
request,
"compare.html",
{
"page_title": "论文对比",
"papers": papers,
"rows": rows,
"ids_param": ids,
"error": None,
},
)
+128 -1
View File
@@ -1,9 +1,13 @@
"""页面路由 — 首页、日期页、论文详情。"""
from __future__ import annotations
import logging
from datetime import date, datetime, timedelta
from pathlib import Path
from zoneinfo import ZoneInfo
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import RedirectResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session, joinedload
@@ -12,6 +16,8 @@ from app.config import settings
from app.database import get_db
from app.models import Paper
logger = logging.getLogger(__name__)
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
@@ -99,11 +105,132 @@ def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db))
if paper.summary_status:
summary_state = paper.summary_status.status
# Phase 5: 相似论文推荐
similar_papers = _get_similar_papers(db, arxiv_id, top_k=6)
# Phase 5: 图片画廊
images = _get_paper_images(arxiv_id)
return templates.TemplateResponse(
request, "detail.html",
{
"paper": paper,
"summary_state": summary_state,
"similar_papers": similar_papers,
"paper_images": images,
"chroma_enabled": settings.CHROMA_ENABLED,
"page_title": paper.title_zh or paper.title_en,
},
)
# ── 相似论文 API (Phase 5) ────────────────────────────────────────────
@router.get("/api/similar/{arxiv_id}")
def similar_api(
arxiv_id: str,
top_k: int = Query(default=5, ge=1, le=20),
db: Session = Depends(get_db),
):
"""返回与指定论文相似的论文列表(JSON)。"""
similar = _get_similar_papers(db, arxiv_id, top_k=top_k + 1)
# 排除自身
items = [s for s in similar if s["arxiv_id"] != arxiv_id][:top_k]
return {"results": items}
def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict]:
"""从 ChromaDB 获取相似论文,返回 [{arxiv_id, title_zh, distance, paper_date}, ...]。"""
if not settings.CHROMA_ENABLED:
return []
try:
from app.services.embedder import search_similar
# 用论文的 arxiv_id 从 ChromaDB 查询
col = None
try:
from app.services.embedder import get_collection
col = get_collection()
except Exception:
return []
if col is None:
return []
# 获取当前论文的 embedding
result = col.get(ids=[arxiv_id], include=["embeddings"])
if not result["embeddings"] or not result["embeddings"][0]:
return []
vec = result["embeddings"][0]
count = col.count()
if count == 0:
return []
results = col.query(
query_embeddings=[vec],
n_results=min(top_k, count),
include=["metadatas", "distances"],
)
if not results["ids"] or not results["ids"][0]:
return []
# 从 DB 加载论文信息
similar_ids = results["ids"][0]
distances = results["distances"][0] if results["distances"] else [0.0] * len(similar_ids)
# 排除自身
papers_info = {}
for i, sid in enumerate(similar_ids):
if sid != arxiv_id:
papers_info[sid] = distances[i]
if not papers_info:
return []
papers = (
db.query(Paper)
.filter(Paper.arxiv_id.in_(list(papers_info.keys())))
.options(joinedload(Paper.tags))
.all()
)
items = []
for p in papers:
items.append({
"arxiv_id": p.arxiv_id,
"title_zh": p.title_zh or p.title_en,
"distance": papers_info.get(p.arxiv_id, 0.0),
"paper_date": p.paper_date.isoformat() if p.paper_date else "",
"tags": [t.tag for t in p.tags[:3]],
})
# 按距离排序
items.sort(key=lambda x: x["distance"])
return items
except Exception:
logger.exception("Failed to get similar papers for %s", arxiv_id)
return []
# ── 图片画廊 (Phase 5) ────────────────────────────────────────────────
def _get_paper_images(arxiv_id: str) -> list[dict]:
"""获取论文提取的图片列表。"""
images_dir = Path("data/papers") / arxiv_id / "images"
if not images_dir.exists():
return []
images = []
for img_file in sorted(images_dir.iterdir()):
if img_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg"):
images.append({
"url": f"/papers/{arxiv_id}/images/{img_file.name}",
"name": img_file.name,
})
return images
+24 -17
View File
@@ -31,11 +31,12 @@ def search_page(
q: str = Query(default=""),
tag: str = Query(default=""),
sort: str = Query(default="relevance"),
mode: str = Query(default="keyword"),
page: int = Query(default=1, ge=1),
db: Session = Depends(get_db),
):
"""搜索页面。"""
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page)
"""搜索页面,支持 keyword 和 semantic 模式"""
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode)
all_tags = get_all_tags(db)
return templates.TemplateResponse(
@@ -45,8 +46,11 @@ def search_page(
"query": q,
"tag": tag,
"sort": sort,
"mode": mode,
"chroma_enabled": settings.CHROMA_ENABLED,
"results": result["results"],
"snippets": result["snippets"],
"distances": result.get("distances", {}),
"total": result["total"],
"page": result["page"],
"total_pages": result["total_pages"],
@@ -65,28 +69,31 @@ def search_api(
q: str = Query(default=""),
tag: str = Query(default=""),
sort: str = Query(default="relevance"),
mode: str = Query(default="keyword"),
page: int = Query(default=1, ge=1),
db: Session = Depends(get_db),
):
"""搜索 JSON API。"""
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page)
"""搜索 JSON API,支持 keyword 和 semantic 模式"""
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode)
distances = result.get("distances", {})
items = []
for paper in result["results"]:
snippet = result["snippets"].get(paper.id, {})
items.append(
{
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"title_zh": paper.title_zh,
"paper_date": paper.paper_date.isoformat() if paper.paper_date else None,
"upvotes": paper.upvotes,
"tags": [t.tag for t in paper.tags],
"authors": [a.name for a in paper.authors],
"snippet_title_zh": snippet.get("title_zh"),
"snippet_abstract": snippet.get("abstract"),
}
)
item = {
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"title_zh": paper.title_zh,
"paper_date": paper.paper_date.isoformat() if paper.paper_date else None,
"upvotes": paper.upvotes,
"tags": [t.tag for t in paper.tags],
"authors": [a.name for a in paper.authors],
"snippet_title_zh": snippet.get("title_zh"),
"snippet_abstract": snippet.get("abstract"),
}
if paper.arxiv_id in distances:
item["distance"] = distances[paper.arxiv_id]
items.append(item)
return {
"results": items,
+120
View File
@@ -0,0 +1,120 @@
"""趋势看板路由 — 论文统计图表页面和数据 API。"""
from __future__ import annotations
import logging
from datetime import date, timedelta
from fastapi import APIRouter, Depends, Request
from fastapi.templating import Jinja2Templates
from sqlalchemy import func, text
from sqlalchemy.orm import Session
from app.config import settings
from app.database import get_db
logger = logging.getLogger(__name__)
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
@router.get("/trends")
def trends_page(request: Request, db: Session = Depends(get_db)):
"""趋势看板页面。"""
stats = _get_trends_data(db)
return templates.TemplateResponse(
request,
"trends.html",
{
"page_title": "趋势看板",
"stats": stats,
"today": _today_str(),
},
)
@router.get("/api/stats/trends")
def trends_api(db: Session = Depends(get_db)):
"""趋势数据 JSON API。"""
return _get_trends_data(db)
def _get_trends_data(db: Session) -> dict:
"""从 DB 聚合趋势数据。"""
thirty_days_ago = (date.today() - timedelta(days=30)).isoformat()
# 1. 按日论文数量(近 30 天)
daily_rows = db.execute(text("""
SELECT paper_date, COUNT(*) as cnt
FROM papers
WHERE paper_date >= :start_date
GROUP BY paper_date
ORDER BY paper_date ASC
"""), {"start_date": thirty_days_ago}).fetchall()
daily_counts = [
{"date": str(row[0]), "count": row[1]}
for row in daily_rows
]
# 2. 热门标签 Top 20
tag_rows = db.execute(text("""
SELECT tag, COUNT(*) as cnt
FROM paper_tags
GROUP BY tag
ORDER BY cnt DESC
LIMIT 20
""")).fetchall()
top_tags = [
{"tag": row[0], "count": row[1]}
for row in tag_rows
]
# 3. Upvotes 分布
upvote_rows = db.execute(text("""
SELECT
CASE
WHEN upvotes >= 100 THEN '100+'
WHEN upvotes >= 50 THEN '50-99'
WHEN upvotes >= 20 THEN '20-49'
WHEN upvotes >= 10 THEN '10-19'
WHEN upvotes >= 5 THEN '5-9'
ELSE '0-4'
END as bucket,
COUNT(*) as cnt
FROM papers
GROUP BY bucket
ORDER BY MIN(upvotes) DESC
""")).fetchall()
upvotes_dist = [
{"range": row[0], "count": row[1]}
for row in upvote_rows
]
# 4. 总结完成率
summary_rows = db.execute(text("""
SELECT
COALESCE(ss.status, 'none') as status,
COUNT(*) as cnt
FROM papers p
LEFT JOIN summary_status ss ON ss.paper_id = p.id
GROUP BY status
""")).fetchall()
summary_completion = [
{"status": row[0], "count": row[1]}
for row in summary_rows
]
return {
"daily_counts": daily_counts,
"top_tags": top_tags,
"upvotes_dist": upvotes_dist,
"summary_completion": summary_completion,
}
def _today_str() -> str:
from datetime import datetime
from zoneinfo import ZoneInfo
tz = ZoneInfo(settings.APP_TIMEZONE)
return datetime.now(tz).strftime("%Y-%m-%d")
+7
View File
@@ -139,6 +139,13 @@ async def delete_papers_by_date_range(
{"paper_id": paper_id},
)
# 1.5 Phase 5: 从 ChromaDB 删除语义索引
try:
from app.services.embedder import delete_paper
delete_paper(arxiv_id)
except Exception:
logger.warning("Failed to delete %s from ChromaDB", arxiv_id, exc_info=True)
# 2. 删除本地文件 data/papers/{arxiv_id}/
paper_dir = _PAPERS_DIR / arxiv_id
if paper_dir.exists():
+314
View File
@@ -0,0 +1,314 @@
"""ChromaDB 语义搜索服务 — embedding 生成、索引管理、相似查询。"""
from __future__ import annotations
import logging
from pathlib import Path
import httpx
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.models import Paper
logger = logging.getLogger(__name__)
# ── 单例客户端和 collection ─────────────────────────────────────────────
_client = None
_collection = None
def _chroma_dir() -> Path:
return Path(settings.CHROMA_DIR)
def init_chroma() -> None:
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
global _client, _collection
if not settings.CHROMA_ENABLED:
logger.debug("ChromaDB disabled, skip init")
return
if _client is not None:
return
try:
import chromadb
chroma_path = _chroma_dir()
chroma_path.mkdir(parents=True, exist_ok=True)
_client = chromadb.PersistentClient(path=str(chroma_path))
_collection = _get_or_create_collection()
logger.info("ChromaDB initialized at %s", chroma_path)
except Exception:
logger.exception("Failed to initialize ChromaDB")
_client = None
_collection = None
def _get_or_create_collection():
"""获取或创建 papers_embeddings collection,维度不匹配时记录日志并跳过。"""
import chromadb
try:
col = _client.get_collection("papers_embeddings")
logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count())
return col
except Exception:
pass
col = _client.create_collection(
name="papers_embeddings",
metadata={"hnsw:space": "cosine"},
)
logger.info("ChromaDB collection 'papers_embeddings' created")
return col
def get_collection():
"""返回当前 collection,未初始化则返回 None。"""
if not settings.CHROMA_ENABLED:
return None
if _collection is None:
init_chroma()
return _collection
# ── Embedding API 调用 ──────────────────────────────────────────────────
def _get_embedding(text: str) -> list[float] | None:
"""调用 EMBED_API_BASE 的 embedding API 生成向量。
POST /v1/embeddings, model=EMBED_MODEL
校验返回向量长度 == EMBED_DIMENSIONS
失败时返回 None 并记录日志。
"""
if not settings.EMBED_API_BASE or not settings.EMBED_MODEL:
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
return None
url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings"
headers = {"Content-Type": "application/json"}
if settings.EMBED_API_KEY:
headers["Authorization"] = f"Bearer {settings.EMBED_API_KEY}"
payload = {
"model": settings.EMBED_MODEL,
"input": text,
}
try:
with httpx.Client(timeout=settings.HTTP_TIMEOUT_SECONDS) as client:
resp = client.post(url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
vec = data["data"][0]["embedding"]
if settings.EMBED_DIMENSIONS > 0 and len(vec) != settings.EMBED_DIMENSIONS:
logger.warning(
"Embedding dimension mismatch: expected=%d got=%d, skip",
settings.EMBED_DIMENSIONS,
len(vec),
)
return None
return vec
except Exception:
logger.exception("Embedding API call failed")
return None
# ── 索引文本拼接 ────────────────────────────────────────────────────────
def _build_index_text(paper: Paper, summary_dict: dict | None) -> str:
"""拼接高信号字段为一段文本用于 embedding。"""
parts: list[str] = []
if paper.title_zh:
parts.append(paper.title_zh)
if paper.title_en:
parts.append(paper.title_en)
if paper.tags:
tag_str = " ".join(t.tag for t in paper.tags)
if tag_str:
parts.append(tag_str)
if summary_dict:
for key in ("one_line", "motivation_problem", "method_key_idea"):
val = summary_dict.get(key)
if val:
parts.append(val)
return " ".join(parts)
# ── 单篇索引 ────────────────────────────────────────────────────────────
def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
"""将单篇论文写入 ChromaDB 语义索引。
Args:
paper_id: arxiv_id
texts_dict: 可选的预拼接字段 dict,包含 title_zh, title_en, tags, one_line 等
如果为 None,则从 DB 加载。
Returns:
True 表示成功,False 表示失败或跳过。
"""
col = get_collection()
if col is None:
return False
try:
# 如果没传 texts_dict,从 DB 加载
if texts_dict is None:
from app.database import SessionLocal
db = SessionLocal()
try:
paper = (
db.query(Paper)
.filter(Paper.arxiv_id == paper_id)
.options(joinedload(Paper.tags), joinedload(Paper.summary))
.first()
)
if not paper:
logger.warning("Paper %s not found for indexing", paper_id)
return False
summary_dict = None
if paper.summary:
summary_dict = {
"one_line": paper.summary.one_line or "",
"motivation_problem": paper.summary.motivation_problem or "",
"method_key_idea": paper.summary.method_key_idea or "",
}
index_text = _build_index_text(paper, summary_dict)
arxiv_id = paper.arxiv_id
title_zh = paper.title_zh or ""
paper_date = paper.paper_date.isoformat() if paper.paper_date else ""
finally:
db.close()
else:
index_text = " ".join(
v for v in texts_dict.values() if isinstance(v, str) and v
)
arxiv_id = texts_dict.get("arxiv_id", paper_id)
title_zh = texts_dict.get("title_zh", "")
paper_date = texts_dict.get("paper_date", "")
if not index_text.strip():
logger.warning("Empty index text for %s, skip", paper_id)
return False
vec = _get_embedding(index_text)
if vec is None:
return False
col.upsert(
ids=[arxiv_id],
embeddings=[vec],
metadatas=[{"arxiv_id": arxiv_id, "title_zh": title_zh, "paper_date": paper_date}],
)
logger.info("Indexed paper %s in ChromaDB", arxiv_id)
return True
except Exception:
logger.exception("Failed to index paper %s in ChromaDB", paper_id)
return False
# ── 批量索引 ────────────────────────────────────────────────────────────
def index_batch(paper_ids: list[str]) -> dict:
"""批量索引论文,单篇失败不影响其他。
Returns:
{"total": int, "success": int, "failed": int}
"""
if not paper_ids:
return {"total": 0, "success": 0, "failed": 0}
col = get_collection()
if col is None:
return {"total": len(paper_ids), "success": 0, "failed": len(paper_ids)}
success = 0
failed = 0
for pid in paper_ids:
if index_paper(pid):
success += 1
else:
failed += 1
logger.info("Batch index: total=%d success=%d failed=%d", len(paper_ids), success, failed)
return {"total": len(paper_ids), "success": success, "failed": failed}
# ── 删除 ────────────────────────────────────────────────────────────────
def delete_paper(paper_id: str) -> bool:
"""从 ChromaDB 删除论文索引。
Args:
paper_id: arxiv_id
"""
col = get_collection()
if col is None:
return False
try:
col.delete(ids=[paper_id])
logger.info("Deleted paper %s from ChromaDB", paper_id)
return True
except Exception:
logger.exception("Failed to delete paper %s from ChromaDB", paper_id)
return False
# ── 相似查询 ────────────────────────────────────────────────────────────
def search_similar(query_text: str, top_k: int = 20) -> list[dict]:
"""语义相似搜索。
Args:
query_text: 查询文本
top_k: 返回前 K 个结果
Returns:
[{"arxiv_id": str, "distance": float}, ...]
"""
col = get_collection()
if col is None:
return []
try:
vec = _get_embedding(query_text)
if vec is None:
return []
results = col.query(
query_embeddings=[vec],
n_results=min(top_k, col.count()) if col.count() > 0 else top_k,
include=["metadatas", "distances"],
)
if not results["ids"] or not results["ids"][0]:
return []
items = []
for i, arxiv_id in enumerate(results["ids"][0]):
dist = results["distances"][0][i] if results["distances"] else 0.0
items.append({"arxiv_id": arxiv_id, "distance": dist})
return items
except Exception:
logger.exception("ChromaDB search_similar failed")
return []
+84 -2
View File
@@ -1,15 +1,19 @@
"""FTS5 全文搜索服务 — 关键词 + 标签筛选,命中片段高亮,分页。"""
"""搜索服务 — FTS5 关键词搜索 + ChromaDB 语义搜索,命中片段高亮,分页。"""
from __future__ import annotations
import logging
import math
import re
from sqlalchemy import text
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.models import Paper
logger = logging.getLogger(__name__)
# ── 输入清洗 ──────────────────────────────────────────────────────────
# FTS5 查询语法中的特殊字符,用户输入时需要移除
@@ -41,8 +45,9 @@ def search_papers(
sort: str = "relevance",
page: int = 1,
page_size: int = 20,
mode: str = "keyword",
) -> dict:
"""FTS5 搜索论文。
"""搜索论文,支持 keyword (FTS5) 和 semantic (ChromaDB) 两种模式
返回::
{
@@ -51,8 +56,14 @@ def search_papers(
"total": int,
"page": int,
"total_pages": int,
"distances": dict[str, float], # arxiv_id → distance (仅 semantic)
}
"""
# ── semantic 模式 ──
if mode == "semantic" and settings.CHROMA_ENABLED and query:
return _search_semantic(db, query, tag, sort, page, page_size)
# ── keyword 模式(默认)──
match_expr = _sanitize_query(query) if query else None
# ── 无关键词 + 无标签 → 空结果 ──
@@ -63,6 +74,7 @@ def search_papers(
"total": 0,
"page": page,
"total_pages": 0,
"distances": {},
}
# ── 构建条件性 JOIN 和 WHERE 片段 ──
@@ -146,6 +158,75 @@ def _search_with_fts(
"total": total,
"page": page,
"total_pages": math.ceil(total / page_size) if total else 0,
"distances": {},
}
def _search_semantic(
db: Session,
query: str,
tag: str | None,
sort: str,
page: int,
page_size: int,
) -> dict:
"""ChromaDB 语义搜索,失败时回退到 FTS5。"""
try:
from app.services.embedder import search_similar
top_k = page_size * 3 # 多取一些用于 tag 过滤
candidates = search_similar(query, top_k=top_k)
except Exception:
logger.exception("Semantic search failed, falling back to keyword")
candidates = []
if not candidates:
# 回退到 FTS5
return _search_with_fts(
db,
_sanitize_query(query) or query,
"JOIN paper_tags pt ON pt.paper_id = p.id" if tag else "",
"AND pt.tag = :tag" if tag else "",
{"tag": tag} if tag else {},
sort, page, page_size, (page - 1) * page_size,
)
# 按 arxiv_id 从 DB 加载完整数据
arxiv_ids = [c["arxiv_id"] for c in candidates]
distance_map = {c["arxiv_id"]: c["distance"] for c in candidates}
papers_query = (
db.query(Paper)
.filter(Paper.arxiv_id.in_(arxiv_ids))
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
joinedload(Paper.bookmark),
joinedload(Paper.reading_status),
)
)
if tag:
papers_query = papers_query.filter(Paper.tags.any(tag=tag))
papers = papers_query.all()
# 按语义距离排序
id_order = {aid: idx for idx, aid in enumerate(arxiv_ids)}
papers.sort(key=lambda p: id_order.get(p.arxiv_id, 999))
# 分页
total = len(papers)
start = (page - 1) * page_size
page_papers = papers[start:start + page_size]
return {
"results": page_papers,
"snippets": {},
"total": total,
"page": page,
"total_pages": math.ceil(total / page_size) if total else 0,
"distances": distance_map,
}
@@ -187,6 +268,7 @@ def _search_tag_only(
"total": total,
"page": page,
"total_pages": math.ceil(total / page_size) if total else 0,
"distances": {},
}
+145
View File
@@ -359,6 +359,127 @@ def _cleanup_tmp(arxiv_id: str) -> None:
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
# ── LaTeX 图片提取(Phase 5)───────────────────────────────────────────
_INCLUDEGRAPHICS_RE = re.compile(
r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
)
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
async def _extract_images_from_source(arxiv_id: str, tmp_source: Path | None = None) -> int:
"""从 LaTeX 源码中提取图片文件。
流程:
1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
2. 扫描 .tex 文件中的 \\includegraphics
3. 复制图片到 data/papers/{arxiv_id}/images/
4. 清理源码临时文件
Returns:
提取的图片数量
"""
tmp_source = _tmp_dir(arxiv_id) / "source"
images_dest = _paper_dir(arxiv_id) / "images"
try:
# 下载源码 zip(如果还没下载)
if not tmp_source.exists():
source_url = f"https://arxiv.org/e-print/{arxiv_id}"
await _download_source_zip(arxiv_id, source_url, tmp_source)
if not tmp_source.exists():
return 0
# 扫描 .tex 文件,收集图片路径
image_paths: set[str] = set()
for tex_file in tmp_source.rglob("*.tex"):
try:
content = tex_file.read_text(encoding="utf-8", errors="replace")
for match in _INCLUDEGRAPHICS_RE.finditer(content):
img_path = match.group(1).strip()
image_paths.add(img_path)
except Exception:
continue
if not image_paths:
return 0
# 查找并复制图片
images_dest.mkdir(parents=True, exist_ok=True)
copied = 0
for img_rel in image_paths:
# 尝试在源码目录中找到文件
for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
candidate = tmp_source / (img_rel + ext)
if candidate.is_file():
dest_name = candidate.name
# 避免文件名冲突
dest = images_dest / dest_name
if dest.exists():
stem = dest.stem
suffix = dest.suffix
dest = images_dest / f"{stem}_{copied}{suffix}"
shutil.copy2(candidate, dest)
copied += 1
break
if copied > 0:
logger.info("Extracted %d images from source for %s", copied, arxiv_id)
return copied
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
return 0
async def _download_source_zip(
arxiv_id: str, source_url: str, dest_dir: Path
) -> None:
"""下载 arXiv 源码并解压。"""
import zipfile
dest_dir.mkdir(parents=True, exist_ok=True)
zip_path = _tmp_dir(arxiv_id) / "source.zip"
transport = None
if settings.http_proxy:
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
try:
async with httpx.AsyncClient(
timeout=settings.HTTP_TIMEOUT_SECONDS,
headers={"User-Agent": settings.HTTP_USER_AGENT},
transport=transport,
follow_redirects=True,
) as client:
resp = await client.get(source_url)
resp.raise_for_status()
zip_path.write_bytes(resp.content)
except Exception as exc:
logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
return
try:
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(dest_dir)
logger.debug("Extracted source for %s", arxiv_id)
except zipfile.BadZipFile:
# 可能是 tar.gz
import tarfile
try:
with tarfile.open(zip_path, "r:*") as tf:
tf.extractall(dest_dir)
logger.debug("Extracted source (tar) for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
finally:
if zip_path.exists():
zip_path.unlink()
# ── 单篇总结 ────────────────────────────────────────────────────────────
@@ -441,6 +562,30 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
status.raw_output_saved = True
db.commit()
# Phase 5: LaTeX 图片提取(可选增强,失败不影响总结)
try:
await _extract_images_from_source(arxiv_id)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
# Phase 5: 同步写入语义索引(失败仅 log)
try:
from app.services.embedder import index_paper
texts_dict = {
"arxiv_id": arxiv_id,
"title_zh": schema.title_zh or "",
"title_en": paper.title_en or "",
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
"one_line": schema.one_line or "",
"motivation_problem": schema.motivation_problem or "",
"method_key_idea": schema.method_key_idea or "",
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
}
index_paper(arxiv_id, texts_dict)
except Exception:
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
logger.info("Summarize done: %s quality=%s", arxiv_id, quality)
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
+182
View File
@@ -556,3 +556,185 @@ mark {
.reading-list-filters { gap: 4px; }
.filter-chip { padding: 4px 10px; font-size: 0.8rem; }
}
/* ── Search Mode Toggle (Phase 5) ─────────────────────────────── */
.search-mode-toggle {
display: flex;
gap: 0;
border: 1px solid var(--border);
border-radius: var(--radius);
overflow: hidden;
}
.mode-option {
padding: 8px 14px;
font-size: 0.85rem;
color: var(--ink-light);
cursor: pointer;
transition: all 0.2s;
display: flex;
align-items: center;
gap: 4px;
}
.mode-option input[type="radio"] { display: none; }
.mode-option:hover { background: var(--bg); }
.mode-option.active {
background: var(--accent);
color: #fff;
}
/* ── Similarity Score (Phase 5) ────────────────────────────────── */
.similarity-score {
font-size: 0.8rem;
color: var(--accent);
white-space: nowrap;
padding: 2px 8px;
background: #eef3f8;
border-radius: 4px;
}
/* ── Similar Papers (Phase 5) ──────────────────────────────────── */
.similar-papers {
margin-top: 32px;
padding-top: 24px;
border-top: 1px solid var(--border);
}
.similar-papers h2 {
font-family: var(--font-body);
font-size: 1.1rem;
font-weight: 600;
margin-bottom: 12px;
color: var(--accent);
}
.similar-paper-item {
display: flex;
justify-content: space-between;
align-items: center;
padding: 10px 0;
border-bottom: 1px solid var(--border);
}
.similar-paper-item:last-child { border-bottom: none; }
.similar-paper-title a {
font-size: 0.92rem;
color: var(--ink);
}
.similar-paper-title a:hover { color: var(--accent); }
.similar-paper-dist {
font-size: 0.8rem;
color: var(--ink-light);
}
/* ── Trends Dashboard (Phase 5) ────────────────────────────────── */
.trends-page h1 {
font-family: var(--font-body);
font-size: 1.5rem;
font-weight: 700;
margin-bottom: 24px;
}
.charts-grid {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 24px;
margin-bottom: 32px;
}
.chart-card {
background: var(--surface);
border: 1px solid var(--border);
border-radius: var(--radius);
padding: 20px;
}
.chart-card h2 {
font-family: var(--font-body);
font-size: 1rem;
font-weight: 600;
margin-bottom: 12px;
color: var(--accent);
}
.chart-card canvas {
width: 100% !important;
max-height: 300px;
}
@media (max-width: 768px) {
.charts-grid { grid-template-columns: 1fr; }
}
/* ── Compare Page (Phase 5) ────────────────────────────────────── */
.compare-page h1 {
font-family: var(--font-body);
font-size: 1.5rem;
font-weight: 700;
margin-bottom: 24px;
}
.compare-table-wrapper {
overflow-x: auto;
margin-bottom: 24px;
}
.compare-table {
width: 100%;
border-collapse: collapse;
background: var(--surface);
border: 1px solid var(--border);
border-radius: var(--radius);
font-size: 0.88rem;
}
.compare-table th,
.compare-table td {
padding: 12px 16px;
border: 1px solid var(--border);
vertical-align: top;
text-align: left;
}
.compare-table th {
background: var(--bg);
font-weight: 600;
color: var(--ink-light);
white-space: nowrap;
min-width: 100px;
}
.compare-table td.field-label {
background: var(--bg);
font-weight: 500;
color: var(--ink);
white-space: nowrap;
min-width: 90px;
}
.compare-table td.paper-col {
min-width: 200px;
}
.compare-table td .no-summary {
color: var(--ink-light);
font-style: italic;
}
/* ── Image Gallery (Phase 5) ───────────────────────────────────── */
.image-gallery {
margin-top: 24px;
}
.image-gallery h2 {
font-family: var(--font-body);
font-size: 1.05rem;
font-weight: 600;
margin-bottom: 12px;
color: var(--accent);
}
.gallery-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(260px, 1fr));
gap: 16px;
}
.gallery-item {
background: var(--surface);
border: 1px solid var(--border);
border-radius: var(--radius);
overflow: hidden;
}
.gallery-item img {
width: 100%;
height: auto;
display: block;
}
.gallery-item .gallery-caption {
padding: 8px 12px;
font-size: 0.8rem;
color: var(--ink-light);
text-align: center;
}
+1
View File
@@ -16,6 +16,7 @@
<div class="nav-links">
<a href="/day/{{ today if today else '' }}">今日</a>
<a href="/search">搜索</a>
<a href="/trends">趋势</a>
<a href="/reading-list">阅读列表</a>
<a href="/admin/logs">管理</a>
</div>
+86
View File
@@ -0,0 +1,86 @@
{% extends "base.html" %}
{% block title %}{{ page_title }} — HF Daily Papers{% endblock %}
{% block content %}
<section class="compare-page">
<h1>论文对比</h1>
{# ID 输入表单 #}
<form class="search-form" method="get" action="/compare">
<input type="text" name="ids" value="{{ ids_param }}"
placeholder="输入 arXiv ID,逗号分隔(最多 5 篇),如 2401.12345,2401.67890"
class="search-input">
<button type="submit" class="search-btn">对比</button>
</form>
{% if error %}
<div class="empty-state">
<p>{{ error }}</p>
</div>
{% endif %}
{% if papers %}
<div class="compare-table-wrapper">
<table class="compare-table">
<thead>
<tr>
<th>字段</th>
{% for paper in papers %}
<th>
<a href="/paper/{{ paper.arxiv_id }}">{{ paper.arxiv_id }}</a>
<br>
<small style="color: var(--ink-light);">
{{ paper.upvotes }} 👍 · {{ paper.paper_date }}
</small>
</th>
{% endfor %}
</tr>
</thead>
<tbody>
{# 作者行 #}
<tr>
<td class="field-label">作者</td>
{% for paper in papers %}
<td class="paper-col">{{ paper.authors|map(attribute='name')|join(', ') }}</td>
{% endfor %}
</tr>
{# 标签行 #}
<tr>
<td class="field-label">标签</td>
{% for paper in papers %}
<td class="paper-col">
{% for t in paper.tags[:5] %}
<span class="tag">{{ t.tag }}</span>
{% endfor %}
</td>
{% endfor %}
</tr>
{# 结构化对比字段 #}
{% for row in rows %}
<tr>
<td class="field-label">{{ row.label }}</td>
{% for cell in row.cells %}
<td class="paper-col">
{% if cell %}
{{ cell }}
{% else %}
<span class="no-summary">暂无总结</span>
{% endif %}
</td>
{% endfor %}
</tr>
{% endfor %}
</tbody>
</table>
</div>
{% elif ids_param and not error %}
<div class="empty-state">
<p>未找到匹配的论文</p>
<p class="hint">请检查 arXiv ID 是否正确</p>
</div>
{% endif %}
</section>
{% endblock %}
+30
View File
@@ -117,5 +117,35 @@
<p class="abstract-en">{{ paper.abstract }}</p>
</section>
{% endif %}
{# Phase 5: 图片画廊 #}
{% if paper_images %}
<section class="image-gallery">
<h2>论文图片</h2>
<div class="gallery-grid">
{% for img in paper_images %}
<div class="gallery-item">
<img src="{{ img.url }}" alt="{{ img.name }}" loading="lazy">
<div class="gallery-caption">{{ img.name }}</div>
</div>
{% endfor %}
</div>
</section>
{% endif %}
{# Phase 5: 相似论文推荐 #}
{% if similar_papers %}
<section class="similar-papers">
<h2>相似论文推荐</h2>
{% for sp in similar_papers %}
<div class="similar-paper-item">
<span class="similar-paper-title">
<a href="/paper/{{ sp.arxiv_id }}">{{ sp.title_zh }}</a>
</span>
<span class="similar-paper-dist">🎯 {{ "%.3f"|format(sp.distance) }}</span>
</div>
{% endfor %}
</section>
{% endif %}
</article>
{% endblock %}
+27 -7
View File
@@ -11,6 +11,21 @@
{% if tag %}
<input type="hidden" name="tag" value="{{ tag }}">
{% endif %}
{# 模式切换 #}
{% if chroma_enabled %}
<div class="search-mode-toggle">
<label class="mode-option {% if mode == 'keyword' or not mode %}active{% endif %}">
<input type="radio" name="mode" value="keyword" {% if mode == 'keyword' or not mode %}checked{% endif %}>
关键词
</label>
<label class="mode-option {% if mode == 'semantic' %}active{% endif %}">
<input type="radio" name="mode" value="semantic" {% if mode == 'semantic' %}checked{% endif %}>
语义搜索
</label>
</div>
{% endif %}
<button type="submit" class="search-btn">搜索</button>
</form>
@@ -18,10 +33,10 @@
{% if all_tags %}
<div class="tag-filter">
<span class="tag-filter-label">标签:</span>
<a href="/search{% if query %}?q={{ query }}{% endif %}"
<a href="/search?q={{ query }}&mode={{ mode }}{% if tag %}&tag={{ tag }}{% endif %}"
class="tag-chip {% if not tag %}active{% endif %}">全部</a>
{% for t in all_tags %}
<a href="/search?q={{ query }}&tag={{ t }}"
<a href="/search?q={{ query }}&tag={{ t }}&mode={{ mode }}"
class="tag-chip {% if t == tag %}active{% endif %}">{{ t }}</a>
{% endfor %}
</div>
@@ -30,12 +45,12 @@
{% if query or tag %}
{# 搜索结果元信息 #}
<div class="search-meta">
<span>找到 {{ total }} 条结果</span>
<span>找到 {{ total }} 条结果{% if mode == 'semantic' %}(语义模式){% endif %}</span>
<div class="sort-toggle">
<a href="/search?q={{ query }}&tag={{ tag }}&sort=relevance"
<a href="/search?q={{ query }}&tag={{ tag }}&mode={{ mode }}&sort=relevance"
class="{% if sort == 'relevance' %}active{% endif %}">相关性</a>
<span class="sort-divider">|</span>
<a href="/search?q={{ query }}&tag={{ tag }}&sort=date"
<a href="/search?q={{ query }}&tag={{ tag }}&mode={{ mode }}&sort=date"
class="{% if sort == 'date' %}active{% endif %}">日期</a>
</div>
</div>
@@ -58,6 +73,11 @@
</a>
</h2>
<span class="paper-upvotes">👍 {{ paper.upvotes }}</span>
{% if distances and paper.arxiv_id in distances %}
<span class="similarity-score" title="语义相似度距离">
🎯 {{ "%.3f"|format(distances[paper.arxiv_id]) }}
</span>
{% endif %}
</div>
{% if snippet and snippet.abstract %}
@@ -103,11 +123,11 @@
{% if total_pages > 1 %}
<nav class="pagination">
{% if page > 1 %}
<a href="/search?q={{ query }}&tag={{ tag }}&sort={{ sort }}&page={{ page - 1 }}" class="page-btn">← 上一页</a>
<a href="/search?q={{ query }}&tag={{ tag }}&sort={{ sort }}&mode={{ mode }}&page={{ page - 1 }}" class="page-btn">← 上一页</a>
{% endif %}
<span class="page-info">{{ page }} / {{ total_pages }}</span>
{% if page < total_pages %}
<a href="/search?q={{ query }}&tag={{ tag }}&sort={{ sort }}&page={{ page + 1 }}" class="page-btn">下一页 →</a>
<a href="/search?q={{ query }}&tag={{ tag }}&sort={{ sort }}&mode={{ mode }}&page={{ page + 1 }}" class="page-btn">下一页 →</a>
{% endif %}
</nav>
{% endif %}
+185
View File
@@ -0,0 +1,185 @@
{% extends "base.html" %}
{% block title %}{{ page_title }} — HF Daily Papers{% endblock %}
{% block content %}
<section class="trends-page">
<h1>趋势看板</h1>
<div class="charts-grid">
{# 按日论文数量折线图 #}
<div class="chart-card">
<h2>📅 每日论文数量(近 30 天)</h2>
<canvas id="dailyChart"></canvas>
</div>
{# 热门标签 Top 20 #}
<div class="chart-card">
<h2>🏷️ 热门标签 Top 20</h2>
<canvas id="tagsChart"></canvas>
</div>
{# Upvotes 分布 #}
<div class="chart-card">
<h2>👍 Upvotes 分布</h2>
<canvas id="upvotesChart"></canvas>
</div>
{# 总结完成率 #}
<div class="chart-card">
<h2>📝 总结完成率</h2>
<canvas id="summaryChart"></canvas>
</div>
</div>
</section>
{% endblock %}
{% block scripts %}
<script src="https://cdn.jsdelivr.net/npm/chart.js@4.4.7/dist/chart.umd.min.js"></script>
<script>
// 颜色配置(kami 风格墨蓝色系)
const COLORS = {
primary: '#2d5f8a',
primaryLight: 'rgba(45, 95, 138, 0.2)',
accent: '#5a9bc7',
success: '#388e3c',
warning: '#f57f17',
danger: '#c62828',
muted: '#4a4a6a',
palette: [
'#2d5f8a', '#5a9bc7', '#388e3c', '#f57f17', '#c62828',
'#7b1fa2', '#00838f', '#ef6c00', '#455a64', '#827717',
'#1565c0', '#ad1457', '#00695c', '#e65100', '#283593',
'#9e9d24', '#6a1b9a', '#00838f', '#4e342e', '#37474f',
],
};
const statsData = {{ stats | tojson }};
// 每日论文数量折线图
(function() {
const ctx = document.getElementById('dailyChart').getContext('2d');
const labels = statsData.daily_counts.map(d => d.date);
const data = statsData.daily_counts.map(d => d.count);
new Chart(ctx, {
type: 'line',
data: {
labels: labels,
datasets: [{
label: '论文数',
data: data,
borderColor: COLORS.primary,
backgroundColor: COLORS.primaryLight,
fill: true,
tension: 0.3,
pointRadius: 3,
pointHoverRadius: 6,
}]
},
options: {
responsive: true,
plugins: { legend: { display: false } },
scales: {
x: { ticks: { maxTicksLimit: 10, font: { size: 11 } } },
y: { beginAtZero: true, ticks: { stepSize: 1 } },
}
}
});
})();
// 热门标签柱状图
(function() {
const ctx = document.getElementById('tagsChart').getContext('2d');
const labels = statsData.top_tags.map(d => d.tag);
const data = statsData.top_tags.map(d => d.count);
new Chart(ctx, {
type: 'bar',
data: {
labels: labels,
datasets: [{
label: '论文数',
data: data,
backgroundColor: COLORS.palette.slice(0, data.length),
borderRadius: 4,
}]
},
options: {
responsive: true,
indexAxis: 'y',
plugins: { legend: { display: false } },
scales: {
x: { beginAtZero: true, ticks: { stepSize: 1 } },
}
}
});
})();
// Upvotes 分布
(function() {
const ctx = document.getElementById('upvotesChart').getContext('2d');
const labels = statsData.upvotes_dist.map(d => d.range);
const data = statsData.upvotes_dist.map(d => d.count);
new Chart(ctx, {
type: 'bar',
data: {
labels: labels,
datasets: [{
label: '论文数',
data: data,
backgroundColor: COLORS.accent,
borderRadius: 4,
}]
},
options: {
responsive: true,
plugins: { legend: { display: false } },
scales: {
y: { beginAtZero: true, ticks: { stepSize: 1 } },
}
}
});
})();
// 总结完成率环形图
(function() {
const ctx = document.getElementById('summaryChart').getContext('2d');
const statusLabels = {
'done': '已完成',
'pending': '待总结',
'processing': '总结中',
'failed': '失败',
'permanent_failure': '永久失败',
'none': '未开始',
};
const statusColors = {
'done': COLORS.success,
'pending': COLORS.warning,
'processing': COLORS.primary,
'failed': COLORS.danger,
'permanent_failure': '#b71c1c',
'none': '#bdbdbd',
};
const labels = statsData.summary_completion.map(d => statusLabels[d.status] || d.status);
const data = statsData.summary_completion.map(d => d.count);
const colors = statsData.summary_completion.map(d => statusColors[d.status] || COLORS.muted);
new Chart(ctx, {
type: 'doughnut',
data: {
labels: labels,
datasets: [{
data: data,
backgroundColor: colors,
borderWidth: 2,
borderColor: '#fff',
}]
},
options: {
responsive: true,
plugins: {
legend: { position: 'bottom', labels: { padding: 12 } },
}
}
});
})();
</script>
{% endblock %}