feat: add compare, trends routes, embedder service, and phase5 tests

This commit is contained in:
2026-06-05 23:32:06 +08:00
parent 2cfd1a8a9f
commit ba9afa212c
17 changed files with 2122 additions and 27 deletions
+16
View File
@@ -5,13 +5,16 @@ import os
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from starlette.staticfiles import StaticFiles as StarletteStaticFiles
from app.config import settings
from app.database import engine
from app.models import init_db
from app.routes.admin import router as admin_router
from app.routes.compare import router as compare_router
from app.routes.pages import router as pages_router
from app.routes.search import router as search_router
from app.routes.trends import router as trends_router
from app.routes.user import router as user_router
logging.basicConfig(
@@ -49,11 +52,18 @@ def create_app() -> FastAPI:
# 静态文件
app.mount("/static", StaticFiles(directory="app/static"), name="static")
# Phase 5: 论文图片静态服务
papers_images_dir = os.path.join("data", "papers")
os.makedirs(papers_images_dir, exist_ok=True)
app.mount("/papers", StaticFiles(directory=papers_images_dir), name="papers")
# 路由
app.include_router(pages_router)
app.include_router(admin_router)
app.include_router(search_router)
app.include_router(user_router)
app.include_router(trends_router)
app.include_router(compare_router)
# 调度器(Phase 4
@app.on_event("startup")
@@ -61,6 +71,12 @@ def create_app() -> FastAPI:
from app.services.scheduler import start_scheduler
start_scheduler()
# Phase 5: 初始化 ChromaDB
@app.on_event("startup")
async def _init_chroma():
from app.services.embedder import init_chroma
init_chroma()
@app.on_event("shutdown")
async def _stop_scheduler():
from app.services.scheduler import stop_scheduler
+115
View File
@@ -0,0 +1,115 @@
"""论文对比页路由 — 多篇论文结构化字段并排对比。"""
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session, joinedload
from app.database import get_db
from app.models import Paper
logger = logging.getLogger(__name__)
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
@router.get("/compare")
def compare_page(
request: Request,
ids: str = Query(default="", description="逗号分隔的 arxiv_id,最多 5 篇"),
db: Session = Depends(get_db),
):
"""论文对比页面。GET /compare?ids=id1,id2,id3"""
if not ids:
return templates.TemplateResponse(
request,
"compare.html",
{
"page_title": "论文对比",
"papers": [],
"error": None,
},
)
arxiv_ids = [i.strip() for i in ids.split(",") if i.strip()]
# 最多 5 篇
if len(arxiv_ids) > 5:
arxiv_ids = arxiv_ids[:5]
if not arxiv_ids:
return templates.TemplateResponse(
request,
"compare.html",
{
"page_title": "论文对比",
"papers": [],
"error": "请提供有效的论文 ID",
},
)
papers = (
db.query(Paper)
.filter(Paper.arxiv_id.in_(arxiv_ids))
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary),
joinedload(Paper.summary_status),
)
.all()
)
# 按请求顺序排列
id_order = {aid: idx for idx, aid in enumerate(arxiv_ids)}
papers.sort(key=lambda p: id_order.get(p.arxiv_id, 999))
# 构建对比数据
compare_fields = [
("title_zh", "中文标题"),
("title_en", "英文标题"),
("one_line", "一句话摘要"),
("difficulty", "难度"),
("motivation_problem", "研究问题"),
("motivation_goal", "研究目标"),
("motivation_gap", "研究差距"),
("method_overview", "方法概述"),
("method_key_idea", "关键思路"),
("method_novelty", "新颖性"),
("results", "实验结果"),
("limitations", "局限与改进"),
]
rows = []
for field_key, field_label in compare_fields:
cells = []
for paper in papers:
if field_key in ("title_zh", "title_en"):
val = getattr(paper, field_key, None) or ""
elif paper.summary:
val = getattr(paper.summary, field_key, None) or ""
# JSON 字段直接展示
if field_key == "results" and not val:
val = paper.summary.results_main_json or ""
if field_key == "limitations" and not val:
val = paper.summary.limitations_json or ""
else:
val = ""
cells.append(val)
rows.append({"key": field_key, "label": field_label, "cells": cells})
return templates.TemplateResponse(
request,
"compare.html",
{
"page_title": "论文对比",
"papers": papers,
"rows": rows,
"ids_param": ids,
"error": None,
},
)
+128 -1
View File
@@ -1,9 +1,13 @@
"""页面路由 — 首页、日期页、论文详情。"""
from __future__ import annotations
import logging
from datetime import date, datetime, timedelta
from pathlib import Path
from zoneinfo import ZoneInfo
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import RedirectResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session, joinedload
@@ -12,6 +16,8 @@ from app.config import settings
from app.database import get_db
from app.models import Paper
logger = logging.getLogger(__name__)
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
@@ -99,11 +105,132 @@ def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db))
if paper.summary_status:
summary_state = paper.summary_status.status
# Phase 5: 相似论文推荐
similar_papers = _get_similar_papers(db, arxiv_id, top_k=6)
# Phase 5: 图片画廊
images = _get_paper_images(arxiv_id)
return templates.TemplateResponse(
request, "detail.html",
{
"paper": paper,
"summary_state": summary_state,
"similar_papers": similar_papers,
"paper_images": images,
"chroma_enabled": settings.CHROMA_ENABLED,
"page_title": paper.title_zh or paper.title_en,
},
)
# ── 相似论文 API (Phase 5) ────────────────────────────────────────────
@router.get("/api/similar/{arxiv_id}")
def similar_api(
arxiv_id: str,
top_k: int = Query(default=5, ge=1, le=20),
db: Session = Depends(get_db),
):
"""返回与指定论文相似的论文列表(JSON)。"""
similar = _get_similar_papers(db, arxiv_id, top_k=top_k + 1)
# 排除自身
items = [s for s in similar if s["arxiv_id"] != arxiv_id][:top_k]
return {"results": items}
def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict]:
"""从 ChromaDB 获取相似论文,返回 [{arxiv_id, title_zh, distance, paper_date}, ...]。"""
if not settings.CHROMA_ENABLED:
return []
try:
from app.services.embedder import search_similar
# 用论文的 arxiv_id 从 ChromaDB 查询
col = None
try:
from app.services.embedder import get_collection
col = get_collection()
except Exception:
return []
if col is None:
return []
# 获取当前论文的 embedding
result = col.get(ids=[arxiv_id], include=["embeddings"])
if not result["embeddings"] or not result["embeddings"][0]:
return []
vec = result["embeddings"][0]
count = col.count()
if count == 0:
return []
results = col.query(
query_embeddings=[vec],
n_results=min(top_k, count),
include=["metadatas", "distances"],
)
if not results["ids"] or not results["ids"][0]:
return []
# 从 DB 加载论文信息
similar_ids = results["ids"][0]
distances = results["distances"][0] if results["distances"] else [0.0] * len(similar_ids)
# 排除自身
papers_info = {}
for i, sid in enumerate(similar_ids):
if sid != arxiv_id:
papers_info[sid] = distances[i]
if not papers_info:
return []
papers = (
db.query(Paper)
.filter(Paper.arxiv_id.in_(list(papers_info.keys())))
.options(joinedload(Paper.tags))
.all()
)
items = []
for p in papers:
items.append({
"arxiv_id": p.arxiv_id,
"title_zh": p.title_zh or p.title_en,
"distance": papers_info.get(p.arxiv_id, 0.0),
"paper_date": p.paper_date.isoformat() if p.paper_date else "",
"tags": [t.tag for t in p.tags[:3]],
})
# 按距离排序
items.sort(key=lambda x: x["distance"])
return items
except Exception:
logger.exception("Failed to get similar papers for %s", arxiv_id)
return []
# ── 图片画廊 (Phase 5) ────────────────────────────────────────────────
def _get_paper_images(arxiv_id: str) -> list[dict]:
"""获取论文提取的图片列表。"""
images_dir = Path("data/papers") / arxiv_id / "images"
if not images_dir.exists():
return []
images = []
for img_file in sorted(images_dir.iterdir()):
if img_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg"):
images.append({
"url": f"/papers/{arxiv_id}/images/{img_file.name}",
"name": img_file.name,
})
return images
+14 -7
View File
@@ -31,11 +31,12 @@ def search_page(
q: str = Query(default=""),
tag: str = Query(default=""),
sort: str = Query(default="relevance"),
mode: str = Query(default="keyword"),
page: int = Query(default=1, ge=1),
db: Session = Depends(get_db),
):
"""搜索页面。"""
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page)
"""搜索页面,支持 keyword 和 semantic 模式"""
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode)
all_tags = get_all_tags(db)
return templates.TemplateResponse(
@@ -45,8 +46,11 @@ def search_page(
"query": q,
"tag": tag,
"sort": sort,
"mode": mode,
"chroma_enabled": settings.CHROMA_ENABLED,
"results": result["results"],
"snippets": result["snippets"],
"distances": result.get("distances", {}),
"total": result["total"],
"page": result["page"],
"total_pages": result["total_pages"],
@@ -65,17 +69,18 @@ def search_api(
q: str = Query(default=""),
tag: str = Query(default=""),
sort: str = Query(default="relevance"),
mode: str = Query(default="keyword"),
page: int = Query(default=1, ge=1),
db: Session = Depends(get_db),
):
"""搜索 JSON API。"""
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page)
"""搜索 JSON API,支持 keyword 和 semantic 模式"""
result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode)
distances = result.get("distances", {})
items = []
for paper in result["results"]:
snippet = result["snippets"].get(paper.id, {})
items.append(
{
item = {
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"title_zh": paper.title_zh,
@@ -86,7 +91,9 @@ def search_api(
"snippet_title_zh": snippet.get("title_zh"),
"snippet_abstract": snippet.get("abstract"),
}
)
if paper.arxiv_id in distances:
item["distance"] = distances[paper.arxiv_id]
items.append(item)
return {
"results": items,
+120
View File
@@ -0,0 +1,120 @@
"""趋势看板路由 — 论文统计图表页面和数据 API。"""
from __future__ import annotations
import logging
from datetime import date, timedelta
from fastapi import APIRouter, Depends, Request
from fastapi.templating import Jinja2Templates
from sqlalchemy import func, text
from sqlalchemy.orm import Session
from app.config import settings
from app.database import get_db
logger = logging.getLogger(__name__)
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
@router.get("/trends")
def trends_page(request: Request, db: Session = Depends(get_db)):
"""趋势看板页面。"""
stats = _get_trends_data(db)
return templates.TemplateResponse(
request,
"trends.html",
{
"page_title": "趋势看板",
"stats": stats,
"today": _today_str(),
},
)
@router.get("/api/stats/trends")
def trends_api(db: Session = Depends(get_db)):
"""趋势数据 JSON API。"""
return _get_trends_data(db)
def _get_trends_data(db: Session) -> dict:
"""从 DB 聚合趋势数据。"""
thirty_days_ago = (date.today() - timedelta(days=30)).isoformat()
# 1. 按日论文数量(近 30 天)
daily_rows = db.execute(text("""
SELECT paper_date, COUNT(*) as cnt
FROM papers
WHERE paper_date >= :start_date
GROUP BY paper_date
ORDER BY paper_date ASC
"""), {"start_date": thirty_days_ago}).fetchall()
daily_counts = [
{"date": str(row[0]), "count": row[1]}
for row in daily_rows
]
# 2. 热门标签 Top 20
tag_rows = db.execute(text("""
SELECT tag, COUNT(*) as cnt
FROM paper_tags
GROUP BY tag
ORDER BY cnt DESC
LIMIT 20
""")).fetchall()
top_tags = [
{"tag": row[0], "count": row[1]}
for row in tag_rows
]
# 3. Upvotes 分布
upvote_rows = db.execute(text("""
SELECT
CASE
WHEN upvotes >= 100 THEN '100+'
WHEN upvotes >= 50 THEN '50-99'
WHEN upvotes >= 20 THEN '20-49'
WHEN upvotes >= 10 THEN '10-19'
WHEN upvotes >= 5 THEN '5-9'
ELSE '0-4'
END as bucket,
COUNT(*) as cnt
FROM papers
GROUP BY bucket
ORDER BY MIN(upvotes) DESC
""")).fetchall()
upvotes_dist = [
{"range": row[0], "count": row[1]}
for row in upvote_rows
]
# 4. 总结完成率
summary_rows = db.execute(text("""
SELECT
COALESCE(ss.status, 'none') as status,
COUNT(*) as cnt
FROM papers p
LEFT JOIN summary_status ss ON ss.paper_id = p.id
GROUP BY status
""")).fetchall()
summary_completion = [
{"status": row[0], "count": row[1]}
for row in summary_rows
]
return {
"daily_counts": daily_counts,
"top_tags": top_tags,
"upvotes_dist": upvotes_dist,
"summary_completion": summary_completion,
}
def _today_str() -> str:
from datetime import datetime
from zoneinfo import ZoneInfo
tz = ZoneInfo(settings.APP_TIMEZONE)
return datetime.now(tz).strftime("%Y-%m-%d")
+7
View File
@@ -139,6 +139,13 @@ async def delete_papers_by_date_range(
{"paper_id": paper_id},
)
# 1.5 Phase 5: 从 ChromaDB 删除语义索引
try:
from app.services.embedder import delete_paper
delete_paper(arxiv_id)
except Exception:
logger.warning("Failed to delete %s from ChromaDB", arxiv_id, exc_info=True)
# 2. 删除本地文件 data/papers/{arxiv_id}/
paper_dir = _PAPERS_DIR / arxiv_id
if paper_dir.exists():
+314
View File
@@ -0,0 +1,314 @@
"""ChromaDB 语义搜索服务 — embedding 生成、索引管理、相似查询。"""
from __future__ import annotations
import logging
from pathlib import Path
import httpx
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.models import Paper
logger = logging.getLogger(__name__)
# ── 单例客户端和 collection ─────────────────────────────────────────────
_client = None
_collection = None
def _chroma_dir() -> Path:
return Path(settings.CHROMA_DIR)
def init_chroma() -> None:
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
global _client, _collection
if not settings.CHROMA_ENABLED:
logger.debug("ChromaDB disabled, skip init")
return
if _client is not None:
return
try:
import chromadb
chroma_path = _chroma_dir()
chroma_path.mkdir(parents=True, exist_ok=True)
_client = chromadb.PersistentClient(path=str(chroma_path))
_collection = _get_or_create_collection()
logger.info("ChromaDB initialized at %s", chroma_path)
except Exception:
logger.exception("Failed to initialize ChromaDB")
_client = None
_collection = None
def _get_or_create_collection():
"""获取或创建 papers_embeddings collection,维度不匹配时记录日志并跳过。"""
import chromadb
try:
col = _client.get_collection("papers_embeddings")
logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count())
return col
except Exception:
pass
col = _client.create_collection(
name="papers_embeddings",
metadata={"hnsw:space": "cosine"},
)
logger.info("ChromaDB collection 'papers_embeddings' created")
return col
def get_collection():
"""返回当前 collection,未初始化则返回 None。"""
if not settings.CHROMA_ENABLED:
return None
if _collection is None:
init_chroma()
return _collection
# ── Embedding API 调用 ──────────────────────────────────────────────────
def _get_embedding(text: str) -> list[float] | None:
"""调用 EMBED_API_BASE 的 embedding API 生成向量。
POST /v1/embeddings, model=EMBED_MODEL
校验返回向量长度 == EMBED_DIMENSIONS
失败时返回 None 并记录日志。
"""
if not settings.EMBED_API_BASE or not settings.EMBED_MODEL:
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
return None
url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings"
headers = {"Content-Type": "application/json"}
if settings.EMBED_API_KEY:
headers["Authorization"] = f"Bearer {settings.EMBED_API_KEY}"
payload = {
"model": settings.EMBED_MODEL,
"input": text,
}
try:
with httpx.Client(timeout=settings.HTTP_TIMEOUT_SECONDS) as client:
resp = client.post(url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
vec = data["data"][0]["embedding"]
if settings.EMBED_DIMENSIONS > 0 and len(vec) != settings.EMBED_DIMENSIONS:
logger.warning(
"Embedding dimension mismatch: expected=%d got=%d, skip",
settings.EMBED_DIMENSIONS,
len(vec),
)
return None
return vec
except Exception:
logger.exception("Embedding API call failed")
return None
# ── 索引文本拼接 ────────────────────────────────────────────────────────
def _build_index_text(paper: Paper, summary_dict: dict | None) -> str:
"""拼接高信号字段为一段文本用于 embedding。"""
parts: list[str] = []
if paper.title_zh:
parts.append(paper.title_zh)
if paper.title_en:
parts.append(paper.title_en)
if paper.tags:
tag_str = " ".join(t.tag for t in paper.tags)
if tag_str:
parts.append(tag_str)
if summary_dict:
for key in ("one_line", "motivation_problem", "method_key_idea"):
val = summary_dict.get(key)
if val:
parts.append(val)
return " ".join(parts)
# ── 单篇索引 ────────────────────────────────────────────────────────────
def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
"""将单篇论文写入 ChromaDB 语义索引。
Args:
paper_id: arxiv_id
texts_dict: 可选的预拼接字段 dict,包含 title_zh, title_en, tags, one_line 等
如果为 None,则从 DB 加载。
Returns:
True 表示成功,False 表示失败或跳过。
"""
col = get_collection()
if col is None:
return False
try:
# 如果没传 texts_dict,从 DB 加载
if texts_dict is None:
from app.database import SessionLocal
db = SessionLocal()
try:
paper = (
db.query(Paper)
.filter(Paper.arxiv_id == paper_id)
.options(joinedload(Paper.tags), joinedload(Paper.summary))
.first()
)
if not paper:
logger.warning("Paper %s not found for indexing", paper_id)
return False
summary_dict = None
if paper.summary:
summary_dict = {
"one_line": paper.summary.one_line or "",
"motivation_problem": paper.summary.motivation_problem or "",
"method_key_idea": paper.summary.method_key_idea or "",
}
index_text = _build_index_text(paper, summary_dict)
arxiv_id = paper.arxiv_id
title_zh = paper.title_zh or ""
paper_date = paper.paper_date.isoformat() if paper.paper_date else ""
finally:
db.close()
else:
index_text = " ".join(
v for v in texts_dict.values() if isinstance(v, str) and v
)
arxiv_id = texts_dict.get("arxiv_id", paper_id)
title_zh = texts_dict.get("title_zh", "")
paper_date = texts_dict.get("paper_date", "")
if not index_text.strip():
logger.warning("Empty index text for %s, skip", paper_id)
return False
vec = _get_embedding(index_text)
if vec is None:
return False
col.upsert(
ids=[arxiv_id],
embeddings=[vec],
metadatas=[{"arxiv_id": arxiv_id, "title_zh": title_zh, "paper_date": paper_date}],
)
logger.info("Indexed paper %s in ChromaDB", arxiv_id)
return True
except Exception:
logger.exception("Failed to index paper %s in ChromaDB", paper_id)
return False
# ── 批量索引 ────────────────────────────────────────────────────────────
def index_batch(paper_ids: list[str]) -> dict:
"""批量索引论文,单篇失败不影响其他。
Returns:
{"total": int, "success": int, "failed": int}
"""
if not paper_ids:
return {"total": 0, "success": 0, "failed": 0}
col = get_collection()
if col is None:
return {"total": len(paper_ids), "success": 0, "failed": len(paper_ids)}
success = 0
failed = 0
for pid in paper_ids:
if index_paper(pid):
success += 1
else:
failed += 1
logger.info("Batch index: total=%d success=%d failed=%d", len(paper_ids), success, failed)
return {"total": len(paper_ids), "success": success, "failed": failed}
# ── 删除 ────────────────────────────────────────────────────────────────
def delete_paper(paper_id: str) -> bool:
"""从 ChromaDB 删除论文索引。
Args:
paper_id: arxiv_id
"""
col = get_collection()
if col is None:
return False
try:
col.delete(ids=[paper_id])
logger.info("Deleted paper %s from ChromaDB", paper_id)
return True
except Exception:
logger.exception("Failed to delete paper %s from ChromaDB", paper_id)
return False
# ── 相似查询 ────────────────────────────────────────────────────────────
def search_similar(query_text: str, top_k: int = 20) -> list[dict]:
"""语义相似搜索。
Args:
query_text: 查询文本
top_k: 返回前 K 个结果
Returns:
[{"arxiv_id": str, "distance": float}, ...]
"""
col = get_collection()
if col is None:
return []
try:
vec = _get_embedding(query_text)
if vec is None:
return []
results = col.query(
query_embeddings=[vec],
n_results=min(top_k, col.count()) if col.count() > 0 else top_k,
include=["metadatas", "distances"],
)
if not results["ids"] or not results["ids"][0]:
return []
items = []
for i, arxiv_id in enumerate(results["ids"][0]):
dist = results["distances"][0][i] if results["distances"] else 0.0
items.append({"arxiv_id": arxiv_id, "distance": dist})
return items
except Exception:
logger.exception("ChromaDB search_similar failed")
return []
+84 -2
View File
@@ -1,15 +1,19 @@
"""FTS5 全文搜索服务 — 关键词 + 标签筛选,命中片段高亮,分页。"""
"""搜索服务 — FTS5 关键词搜索 + ChromaDB 语义搜索,命中片段高亮,分页。"""
from __future__ import annotations
import logging
import math
import re
from sqlalchemy import text
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.models import Paper
logger = logging.getLogger(__name__)
# ── 输入清洗 ──────────────────────────────────────────────────────────
# FTS5 查询语法中的特殊字符,用户输入时需要移除
@@ -41,8 +45,9 @@ def search_papers(
sort: str = "relevance",
page: int = 1,
page_size: int = 20,
mode: str = "keyword",
) -> dict:
"""FTS5 搜索论文。
"""搜索论文,支持 keyword (FTS5) 和 semantic (ChromaDB) 两种模式
返回::
{
@@ -51,8 +56,14 @@ def search_papers(
"total": int,
"page": int,
"total_pages": int,
"distances": dict[str, float], # arxiv_id → distance (仅 semantic)
}
"""
# ── semantic 模式 ──
if mode == "semantic" and settings.CHROMA_ENABLED and query:
return _search_semantic(db, query, tag, sort, page, page_size)
# ── keyword 模式(默认)──
match_expr = _sanitize_query(query) if query else None
# ── 无关键词 + 无标签 → 空结果 ──
@@ -63,6 +74,7 @@ def search_papers(
"total": 0,
"page": page,
"total_pages": 0,
"distances": {},
}
# ── 构建条件性 JOIN 和 WHERE 片段 ──
@@ -146,6 +158,75 @@ def _search_with_fts(
"total": total,
"page": page,
"total_pages": math.ceil(total / page_size) if total else 0,
"distances": {},
}
def _search_semantic(
db: Session,
query: str,
tag: str | None,
sort: str,
page: int,
page_size: int,
) -> dict:
"""ChromaDB 语义搜索,失败时回退到 FTS5。"""
try:
from app.services.embedder import search_similar
top_k = page_size * 3 # 多取一些用于 tag 过滤
candidates = search_similar(query, top_k=top_k)
except Exception:
logger.exception("Semantic search failed, falling back to keyword")
candidates = []
if not candidates:
# 回退到 FTS5
return _search_with_fts(
db,
_sanitize_query(query) or query,
"JOIN paper_tags pt ON pt.paper_id = p.id" if tag else "",
"AND pt.tag = :tag" if tag else "",
{"tag": tag} if tag else {},
sort, page, page_size, (page - 1) * page_size,
)
# 按 arxiv_id 从 DB 加载完整数据
arxiv_ids = [c["arxiv_id"] for c in candidates]
distance_map = {c["arxiv_id"]: c["distance"] for c in candidates}
papers_query = (
db.query(Paper)
.filter(Paper.arxiv_id.in_(arxiv_ids))
.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
joinedload(Paper.bookmark),
joinedload(Paper.reading_status),
)
)
if tag:
papers_query = papers_query.filter(Paper.tags.any(tag=tag))
papers = papers_query.all()
# 按语义距离排序
id_order = {aid: idx for idx, aid in enumerate(arxiv_ids)}
papers.sort(key=lambda p: id_order.get(p.arxiv_id, 999))
# 分页
total = len(papers)
start = (page - 1) * page_size
page_papers = papers[start:start + page_size]
return {
"results": page_papers,
"snippets": {},
"total": total,
"page": page,
"total_pages": math.ceil(total / page_size) if total else 0,
"distances": distance_map,
}
@@ -187,6 +268,7 @@ def _search_tag_only(
"total": total,
"page": page,
"total_pages": math.ceil(total / page_size) if total else 0,
"distances": {},
}
+145
View File
@@ -359,6 +359,127 @@ def _cleanup_tmp(arxiv_id: str) -> None:
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
# ── LaTeX 图片提取(Phase 5)───────────────────────────────────────────
_INCLUDEGRAPHICS_RE = re.compile(
r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
)
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
async def _extract_images_from_source(arxiv_id: str, tmp_source: Path | None = None) -> int:
"""从 LaTeX 源码中提取图片文件。
流程:
1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
2. 扫描 .tex 文件中的 \\includegraphics
3. 复制图片到 data/papers/{arxiv_id}/images/
4. 清理源码临时文件
Returns:
提取的图片数量
"""
tmp_source = _tmp_dir(arxiv_id) / "source"
images_dest = _paper_dir(arxiv_id) / "images"
try:
# 下载源码 zip(如果还没下载)
if not tmp_source.exists():
source_url = f"https://arxiv.org/e-print/{arxiv_id}"
await _download_source_zip(arxiv_id, source_url, tmp_source)
if not tmp_source.exists():
return 0
# 扫描 .tex 文件,收集图片路径
image_paths: set[str] = set()
for tex_file in tmp_source.rglob("*.tex"):
try:
content = tex_file.read_text(encoding="utf-8", errors="replace")
for match in _INCLUDEGRAPHICS_RE.finditer(content):
img_path = match.group(1).strip()
image_paths.add(img_path)
except Exception:
continue
if not image_paths:
return 0
# 查找并复制图片
images_dest.mkdir(parents=True, exist_ok=True)
copied = 0
for img_rel in image_paths:
# 尝试在源码目录中找到文件
for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
candidate = tmp_source / (img_rel + ext)
if candidate.is_file():
dest_name = candidate.name
# 避免文件名冲突
dest = images_dest / dest_name
if dest.exists():
stem = dest.stem
suffix = dest.suffix
dest = images_dest / f"{stem}_{copied}{suffix}"
shutil.copy2(candidate, dest)
copied += 1
break
if copied > 0:
logger.info("Extracted %d images from source for %s", copied, arxiv_id)
return copied
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
return 0
async def _download_source_zip(
arxiv_id: str, source_url: str, dest_dir: Path
) -> None:
"""下载 arXiv 源码并解压。"""
import zipfile
dest_dir.mkdir(parents=True, exist_ok=True)
zip_path = _tmp_dir(arxiv_id) / "source.zip"
transport = None
if settings.http_proxy:
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
try:
async with httpx.AsyncClient(
timeout=settings.HTTP_TIMEOUT_SECONDS,
headers={"User-Agent": settings.HTTP_USER_AGENT},
transport=transport,
follow_redirects=True,
) as client:
resp = await client.get(source_url)
resp.raise_for_status()
zip_path.write_bytes(resp.content)
except Exception as exc:
logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
return
try:
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(dest_dir)
logger.debug("Extracted source for %s", arxiv_id)
except zipfile.BadZipFile:
# 可能是 tar.gz
import tarfile
try:
with tarfile.open(zip_path, "r:*") as tf:
tf.extractall(dest_dir)
logger.debug("Extracted source (tar) for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
finally:
if zip_path.exists():
zip_path.unlink()
# ── 单篇总结 ────────────────────────────────────────────────────────────
@@ -441,6 +562,30 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
status.raw_output_saved = True
db.commit()
# Phase 5: LaTeX 图片提取(可选增强,失败不影响总结)
try:
await _extract_images_from_source(arxiv_id)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
# Phase 5: 同步写入语义索引(失败仅 log)
try:
from app.services.embedder import index_paper
texts_dict = {
"arxiv_id": arxiv_id,
"title_zh": schema.title_zh or "",
"title_en": paper.title_en or "",
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
"one_line": schema.one_line or "",
"motivation_problem": schema.motivation_problem or "",
"method_key_idea": schema.method_key_idea or "",
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
}
index_paper(arxiv_id, texts_dict)
except Exception:
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
logger.info("Summarize done: %s quality=%s", arxiv_id, quality)
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
+182
View File
@@ -556,3 +556,185 @@ mark {
.reading-list-filters { gap: 4px; }
.filter-chip { padding: 4px 10px; font-size: 0.8rem; }
}
/* ── Search Mode Toggle (Phase 5) ─────────────────────────────── */
.search-mode-toggle {
display: flex;
gap: 0;
border: 1px solid var(--border);
border-radius: var(--radius);
overflow: hidden;
}
.mode-option {
padding: 8px 14px;
font-size: 0.85rem;
color: var(--ink-light);
cursor: pointer;
transition: all 0.2s;
display: flex;
align-items: center;
gap: 4px;
}
.mode-option input[type="radio"] { display: none; }
.mode-option:hover { background: var(--bg); }
.mode-option.active {
background: var(--accent);
color: #fff;
}
/* ── Similarity Score (Phase 5) ────────────────────────────────── */
.similarity-score {
font-size: 0.8rem;
color: var(--accent);
white-space: nowrap;
padding: 2px 8px;
background: #eef3f8;
border-radius: 4px;
}
/* ── Similar Papers (Phase 5) ──────────────────────────────────── */
.similar-papers {
margin-top: 32px;
padding-top: 24px;
border-top: 1px solid var(--border);
}
.similar-papers h2 {
font-family: var(--font-body);
font-size: 1.1rem;
font-weight: 600;
margin-bottom: 12px;
color: var(--accent);
}
.similar-paper-item {
display: flex;
justify-content: space-between;
align-items: center;
padding: 10px 0;
border-bottom: 1px solid var(--border);
}
.similar-paper-item:last-child { border-bottom: none; }
.similar-paper-title a {
font-size: 0.92rem;
color: var(--ink);
}
.similar-paper-title a:hover { color: var(--accent); }
.similar-paper-dist {
font-size: 0.8rem;
color: var(--ink-light);
}
/* ── Trends Dashboard (Phase 5) ────────────────────────────────── */
.trends-page h1 {
font-family: var(--font-body);
font-size: 1.5rem;
font-weight: 700;
margin-bottom: 24px;
}
.charts-grid {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 24px;
margin-bottom: 32px;
}
.chart-card {
background: var(--surface);
border: 1px solid var(--border);
border-radius: var(--radius);
padding: 20px;
}
.chart-card h2 {
font-family: var(--font-body);
font-size: 1rem;
font-weight: 600;
margin-bottom: 12px;
color: var(--accent);
}
.chart-card canvas {
width: 100% !important;
max-height: 300px;
}
@media (max-width: 768px) {
.charts-grid { grid-template-columns: 1fr; }
}
/* ── Compare Page (Phase 5) ────────────────────────────────────── */
.compare-page h1 {
font-family: var(--font-body);
font-size: 1.5rem;
font-weight: 700;
margin-bottom: 24px;
}
.compare-table-wrapper {
overflow-x: auto;
margin-bottom: 24px;
}
.compare-table {
width: 100%;
border-collapse: collapse;
background: var(--surface);
border: 1px solid var(--border);
border-radius: var(--radius);
font-size: 0.88rem;
}
.compare-table th,
.compare-table td {
padding: 12px 16px;
border: 1px solid var(--border);
vertical-align: top;
text-align: left;
}
.compare-table th {
background: var(--bg);
font-weight: 600;
color: var(--ink-light);
white-space: nowrap;
min-width: 100px;
}
.compare-table td.field-label {
background: var(--bg);
font-weight: 500;
color: var(--ink);
white-space: nowrap;
min-width: 90px;
}
.compare-table td.paper-col {
min-width: 200px;
}
.compare-table td .no-summary {
color: var(--ink-light);
font-style: italic;
}
/* ── Image Gallery (Phase 5) ───────────────────────────────────── */
.image-gallery {
margin-top: 24px;
}
.image-gallery h2 {
font-family: var(--font-body);
font-size: 1.05rem;
font-weight: 600;
margin-bottom: 12px;
color: var(--accent);
}
.gallery-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(260px, 1fr));
gap: 16px;
}
.gallery-item {
background: var(--surface);
border: 1px solid var(--border);
border-radius: var(--radius);
overflow: hidden;
}
.gallery-item img {
width: 100%;
height: auto;
display: block;
}
.gallery-item .gallery-caption {
padding: 8px 12px;
font-size: 0.8rem;
color: var(--ink-light);
text-align: center;
}
+1
View File
@@ -16,6 +16,7 @@
<div class="nav-links">
<a href="/day/{{ today if today else '' }}">今日</a>
<a href="/search">搜索</a>
<a href="/trends">趋势</a>
<a href="/reading-list">阅读列表</a>
<a href="/admin/logs">管理</a>
</div>
+86
View File
@@ -0,0 +1,86 @@
{% extends "base.html" %}
{% block title %}{{ page_title }} — HF Daily Papers{% endblock %}
{% block content %}
<section class="compare-page">
<h1>论文对比</h1>
{# ID 输入表单 #}
<form class="search-form" method="get" action="/compare">
<input type="text" name="ids" value="{{ ids_param }}"
placeholder="输入 arXiv ID,逗号分隔(最多 5 篇),如 2401.12345,2401.67890"
class="search-input">
<button type="submit" class="search-btn">对比</button>
</form>
{% if error %}
<div class="empty-state">
<p>{{ error }}</p>
</div>
{% endif %}
{% if papers %}
<div class="compare-table-wrapper">
<table class="compare-table">
<thead>
<tr>
<th>字段</th>
{% for paper in papers %}
<th>
<a href="/paper/{{ paper.arxiv_id }}">{{ paper.arxiv_id }}</a>
<br>
<small style="color: var(--ink-light);">
{{ paper.upvotes }} 👍 · {{ paper.paper_date }}
</small>
</th>
{% endfor %}
</tr>
</thead>
<tbody>
{# 作者行 #}
<tr>
<td class="field-label">作者</td>
{% for paper in papers %}
<td class="paper-col">{{ paper.authors|map(attribute='name')|join(', ') }}</td>
{% endfor %}
</tr>
{# 标签行 #}
<tr>
<td class="field-label">标签</td>
{% for paper in papers %}
<td class="paper-col">
{% for t in paper.tags[:5] %}
<span class="tag">{{ t.tag }}</span>
{% endfor %}
</td>
{% endfor %}
</tr>
{# 结构化对比字段 #}
{% for row in rows %}
<tr>
<td class="field-label">{{ row.label }}</td>
{% for cell in row.cells %}
<td class="paper-col">
{% if cell %}
{{ cell }}
{% else %}
<span class="no-summary">暂无总结</span>
{% endif %}
</td>
{% endfor %}
</tr>
{% endfor %}
</tbody>
</table>
</div>
{% elif ids_param and not error %}
<div class="empty-state">
<p>未找到匹配的论文</p>
<p class="hint">请检查 arXiv ID 是否正确</p>
</div>
{% endif %}
</section>
{% endblock %}
+30
View File
@@ -117,5 +117,35 @@
<p class="abstract-en">{{ paper.abstract }}</p>
</section>
{% endif %}
{# Phase 5: 图片画廊 #}
{% if paper_images %}
<section class="image-gallery">
<h2>论文图片</h2>
<div class="gallery-grid">
{% for img in paper_images %}
<div class="gallery-item">
<img src="{{ img.url }}" alt="{{ img.name }}" loading="lazy">
<div class="gallery-caption">{{ img.name }}</div>
</div>
{% endfor %}
</div>
</section>
{% endif %}
{# Phase 5: 相似论文推荐 #}
{% if similar_papers %}
<section class="similar-papers">
<h2>相似论文推荐</h2>
{% for sp in similar_papers %}
<div class="similar-paper-item">
<span class="similar-paper-title">
<a href="/paper/{{ sp.arxiv_id }}">{{ sp.title_zh }}</a>
</span>
<span class="similar-paper-dist">🎯 {{ "%.3f"|format(sp.distance) }}</span>
</div>
{% endfor %}
</section>
{% endif %}
</article>
{% endblock %}
+27 -7
View File
@@ -11,6 +11,21 @@
{% if tag %}
<input type="hidden" name="tag" value="{{ tag }}">
{% endif %}
{# 模式切换 #}
{% if chroma_enabled %}
<div class="search-mode-toggle">
<label class="mode-option {% if mode == 'keyword' or not mode %}active{% endif %}">
<input type="radio" name="mode" value="keyword" {% if mode == 'keyword' or not mode %}checked{% endif %}>
关键词
</label>
<label class="mode-option {% if mode == 'semantic' %}active{% endif %}">
<input type="radio" name="mode" value="semantic" {% if mode == 'semantic' %}checked{% endif %}>
语义搜索
</label>
</div>
{% endif %}
<button type="submit" class="search-btn">搜索</button>
</form>
@@ -18,10 +33,10 @@
{% if all_tags %}
<div class="tag-filter">
<span class="tag-filter-label">标签:</span>
<a href="/search{% if query %}?q={{ query }}{% endif %}"
<a href="/search?q={{ query }}&mode={{ mode }}{% if tag %}&tag={{ tag }}{% endif %}"
class="tag-chip {% if not tag %}active{% endif %}">全部</a>
{% for t in all_tags %}
<a href="/search?q={{ query }}&tag={{ t }}"
<a href="/search?q={{ query }}&tag={{ t }}&mode={{ mode }}"
class="tag-chip {% if t == tag %}active{% endif %}">{{ t }}</a>
{% endfor %}
</div>
@@ -30,12 +45,12 @@
{% if query or tag %}
{# 搜索结果元信息 #}
<div class="search-meta">
<span>找到 {{ total }} 条结果</span>
<span>找到 {{ total }} 条结果{% if mode == 'semantic' %}(语义模式){% endif %}</span>
<div class="sort-toggle">
<a href="/search?q={{ query }}&tag={{ tag }}&sort=relevance"
<a href="/search?q={{ query }}&tag={{ tag }}&mode={{ mode }}&sort=relevance"
class="{% if sort == 'relevance' %}active{% endif %}">相关性</a>
<span class="sort-divider">|</span>
<a href="/search?q={{ query }}&tag={{ tag }}&sort=date"
<a href="/search?q={{ query }}&tag={{ tag }}&mode={{ mode }}&sort=date"
class="{% if sort == 'date' %}active{% endif %}">日期</a>
</div>
</div>
@@ -58,6 +73,11 @@
</a>
</h2>
<span class="paper-upvotes">👍 {{ paper.upvotes }}</span>
{% if distances and paper.arxiv_id in distances %}
<span class="similarity-score" title="语义相似度距离">
🎯 {{ "%.3f"|format(distances[paper.arxiv_id]) }}
</span>
{% endif %}
</div>
{% if snippet and snippet.abstract %}
@@ -103,11 +123,11 @@
{% if total_pages > 1 %}
<nav class="pagination">
{% if page > 1 %}
<a href="/search?q={{ query }}&tag={{ tag }}&sort={{ sort }}&page={{ page - 1 }}" class="page-btn">← 上一页</a>
<a href="/search?q={{ query }}&tag={{ tag }}&sort={{ sort }}&mode={{ mode }}&page={{ page - 1 }}" class="page-btn">← 上一页</a>
{% endif %}
<span class="page-info">{{ page }} / {{ total_pages }}</span>
{% if page < total_pages %}
<a href="/search?q={{ query }}&tag={{ tag }}&sort={{ sort }}&page={{ page + 1 }}" class="page-btn">下一页 →</a>
<a href="/search?q={{ query }}&tag={{ tag }}&sort={{ sort }}&mode={{ mode }}&page={{ page + 1 }}" class="page-btn">下一页 →</a>
{% endif %}
</nav>
{% endif %}
+185
View File
@@ -0,0 +1,185 @@
{% extends "base.html" %}
{% block title %}{{ page_title }} — HF Daily Papers{% endblock %}
{% block content %}
<section class="trends-page">
<h1>趋势看板</h1>
<div class="charts-grid">
{# 按日论文数量折线图 #}
<div class="chart-card">
<h2>📅 每日论文数量(近 30 天)</h2>
<canvas id="dailyChart"></canvas>
</div>
{# 热门标签 Top 20 #}
<div class="chart-card">
<h2>🏷️ 热门标签 Top 20</h2>
<canvas id="tagsChart"></canvas>
</div>
{# Upvotes 分布 #}
<div class="chart-card">
<h2>👍 Upvotes 分布</h2>
<canvas id="upvotesChart"></canvas>
</div>
{# 总结完成率 #}
<div class="chart-card">
<h2>📝 总结完成率</h2>
<canvas id="summaryChart"></canvas>
</div>
</div>
</section>
{% endblock %}
{% block scripts %}
<script src="https://cdn.jsdelivr.net/npm/chart.js@4.4.7/dist/chart.umd.min.js"></script>
<script>
// 颜色配置(kami 风格墨蓝色系)
const COLORS = {
primary: '#2d5f8a',
primaryLight: 'rgba(45, 95, 138, 0.2)',
accent: '#5a9bc7',
success: '#388e3c',
warning: '#f57f17',
danger: '#c62828',
muted: '#4a4a6a',
palette: [
'#2d5f8a', '#5a9bc7', '#388e3c', '#f57f17', '#c62828',
'#7b1fa2', '#00838f', '#ef6c00', '#455a64', '#827717',
'#1565c0', '#ad1457', '#00695c', '#e65100', '#283593',
'#9e9d24', '#6a1b9a', '#00838f', '#4e342e', '#37474f',
],
};
const statsData = {{ stats | tojson }};
// 每日论文数量折线图
(function() {
const ctx = document.getElementById('dailyChart').getContext('2d');
const labels = statsData.daily_counts.map(d => d.date);
const data = statsData.daily_counts.map(d => d.count);
new Chart(ctx, {
type: 'line',
data: {
labels: labels,
datasets: [{
label: '论文数',
data: data,
borderColor: COLORS.primary,
backgroundColor: COLORS.primaryLight,
fill: true,
tension: 0.3,
pointRadius: 3,
pointHoverRadius: 6,
}]
},
options: {
responsive: true,
plugins: { legend: { display: false } },
scales: {
x: { ticks: { maxTicksLimit: 10, font: { size: 11 } } },
y: { beginAtZero: true, ticks: { stepSize: 1 } },
}
}
});
})();
// 热门标签柱状图
(function() {
const ctx = document.getElementById('tagsChart').getContext('2d');
const labels = statsData.top_tags.map(d => d.tag);
const data = statsData.top_tags.map(d => d.count);
new Chart(ctx, {
type: 'bar',
data: {
labels: labels,
datasets: [{
label: '论文数',
data: data,
backgroundColor: COLORS.palette.slice(0, data.length),
borderRadius: 4,
}]
},
options: {
responsive: true,
indexAxis: 'y',
plugins: { legend: { display: false } },
scales: {
x: { beginAtZero: true, ticks: { stepSize: 1 } },
}
}
});
})();
// Upvotes 分布
(function() {
const ctx = document.getElementById('upvotesChart').getContext('2d');
const labels = statsData.upvotes_dist.map(d => d.range);
const data = statsData.upvotes_dist.map(d => d.count);
new Chart(ctx, {
type: 'bar',
data: {
labels: labels,
datasets: [{
label: '论文数',
data: data,
backgroundColor: COLORS.accent,
borderRadius: 4,
}]
},
options: {
responsive: true,
plugins: { legend: { display: false } },
scales: {
y: { beginAtZero: true, ticks: { stepSize: 1 } },
}
}
});
})();
// 总结完成率环形图
(function() {
const ctx = document.getElementById('summaryChart').getContext('2d');
const statusLabels = {
'done': '已完成',
'pending': '待总结',
'processing': '总结中',
'failed': '失败',
'permanent_failure': '永久失败',
'none': '未开始',
};
const statusColors = {
'done': COLORS.success,
'pending': COLORS.warning,
'processing': COLORS.primary,
'failed': COLORS.danger,
'permanent_failure': '#b71c1c',
'none': '#bdbdbd',
};
const labels = statsData.summary_completion.map(d => statusLabels[d.status] || d.status);
const data = statsData.summary_completion.map(d => d.count);
const colors = statsData.summary_completion.map(d => statusColors[d.status] || COLORS.muted);
new Chart(ctx, {
type: 'doughnut',
data: {
labels: labels,
datasets: [{
data: data,
backgroundColor: colors,
borderWidth: 2,
borderColor: '#fff',
}]
},
options: {
responsive: true,
plugins: {
legend: { position: 'bottom', labels: { padding: 12 } },
}
}
});
})();
</script>
{% endblock %}
+1
View File
@@ -15,6 +15,7 @@ dependencies = [
"typer>=0.15",
"python-dotenv>=1.0",
"apscheduler>=3.10",
"chromadb>=1.0",
]
[project.optional-dependencies]
+657
View File
@@ -0,0 +1,657 @@
"""Phase 5 后续增强测试 — embedder、semantic search、trends、compare、image extraction。"""
from __future__ import annotations
import json
import shutil
import time
from datetime import date, datetime, timezone
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import select
from app.config import settings
from app.database import get_db
from app.models import (
Paper,
PaperAuthor,
PaperSummary,
PaperTag,
SummaryStatus,
)
# ── Fixtures ────────────────────────────────────────────────────────────
ADMIN_TOKEN = "test-admin-token-12345"
@pytest.fixture
def admin_headers():
return {"Authorization": f"Bearer " + ADMIN_TOKEN}
@pytest.fixture
def auth_client(client, monkeypatch):
monkeypatch.setattr(settings, "ADMIN_TOKEN", ADMIN_TOKEN)
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
return client
@pytest.fixture
def sample_papers_with_summary(db_session):
"""插入多篇带总结的论文。"""
now = datetime.now(timezone.utc)
papers = []
for i, (arxiv_id, paper_date_str) in enumerate([
("2401.20001", "2024-01-10"),
("2401.20002", "2024-01-11"),
("2401.20003", "2024-01-12"),
("2401.20004", "2024-01-13"),
("2401.20005", "2024-01-14"),
]):
paper_date = date.fromisoformat(paper_date_str)
p = Paper(
arxiv_id=arxiv_id,
title_en=f"Test Paper {i+1}",
title_zh=f"测试论文 {i+1}",
abstract=f"Abstract for paper {i+1}.",
paper_date=paper_date,
crawled_at=now,
upvotes=i * 10 + 5,
)
db_session.add(p)
db_session.flush()
db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0))
db_session.add(PaperTag(paper_id=p.id, tag="NLP", source="hf"))
db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf"))
db_session.add(SummaryStatus(
paper_id=p.id,
status="done" if i < 4 else "pending",
quality="normal",
))
# 添加总结(前 4 篇)
if i < 4:
from app.services.schemas import SummarySchema
summary = PaperSummary(
paper_id=p.id,
one_line=f"这是论文{i+1}的一句话摘要",
difficulty="中级",
motivation_problem=f"论文{i+1}的研究问题",
motivation_goal=f"论文{i+1}的研究目标",
method_key_idea=f"论文{i+1}的关键思路",
method_overview=f"论文{i+1}的方法概述",
updated_at=now,
full_json=json.dumps({"title_zh": f"测试论文 {i+1}"}),
)
db_session.add(summary)
# FTS5
import sqlalchemy
db_session.execute(
sqlalchemy.text(
"INSERT INTO papers_fts(rowid, title_en, title_zh, abstract, authors, tags) "
"VALUES (:id, :title_en, :title_zh, :abstract, :authors, :tags)"
),
{
"id": p.id,
"title_en": p.title_en,
"title_zh": p.title_zh or "",
"abstract": p.abstract or "",
"authors": f"Author {i+1}",
"tags": f"NLP, Tag{i+1}",
},
)
papers.append(p)
db_session.commit()
return papers
# ═══════════════════════════════════════════════════════════════════════
# Embedder 服务测试
# ═══════════════════════════════════════════════════════════════════════
class TestEmbedderInit:
"""embedder.py 初始化测试。"""
def test_chroma_disabled_skip_init(self, monkeypatch):
"""CHROMA_ENABLED=false 时不初始化。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
emb.init_chroma()
assert emb._client is None
def test_chroma_init_success(self, monkeypatch, tmp_path):
"""CHROMA_ENABLED=true 时初始化成功。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
import app.services.embedder as emb
emb._client = None
emb._collection = None
emb.init_chroma()
assert emb._client is not None
assert emb._collection is not None
# 清理
emb._client = None
emb._collection = None
def test_get_collection_returns_none_when_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 get_collection 返回 None。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
assert emb.get_collection() is None
class TestEmbedderIndexing:
"""embedder.py 索引测试。"""
def test_index_paper_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 index_paper 返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
assert emb.index_paper("test-id") is False
def test_index_paper_no_api_config(self, monkeypatch, tmp_path):
"""没有 EMBED_API_BASE 时返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", True)
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
monkeypatch.setattr(settings, "EMBED_MODEL", "")
import app.services.embedder as emb
emb._client = None
emb._collection = None
emb.init_chroma()
result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"})
assert result is False
emb._client = None
emb._collection = None
def test_index_batch_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 index_batch 返回全失败。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
result = emb.index_batch(["a", "b"])
assert result["success"] == 0
assert result["failed"] == 2
def test_index_batch_empty(self, monkeypatch):
"""空列表时返回 0。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
result = emb.index_batch([])
assert result["total"] == 0
def test_delete_paper_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 delete_paper 返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
assert emb.delete_paper("test-id") is False
def test_search_similar_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 search_similar 返回空列表。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
assert emb.search_similar("test query") == []
class TestEmbeddingApi:
"""_get_embedding 测试。"""
def test_no_api_base_returns_none(self, monkeypatch):
"""EMBED_API_BASE 为空时返回 None。"""
monkeypatch.setattr(settings, "EMBED_API_BASE", "")
monkeypatch.setattr(settings, "EMBED_MODEL", "")
import app.services.embedder as emb
assert emb._get_embedding("test") is None
def test_dimension_mismatch_returns_none(self, monkeypatch):
"""维度不匹配时返回 None。"""
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 128)
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
import app.services.embedder as emb
mock_resp = MagicMock()
mock_resp.json.return_value = {"data": [{"embedding": [0.1] * 64}]}
mock_resp.raise_for_status = MagicMock()
with patch("httpx.Client") as mock_client:
mock_client.return_value.__enter__ = MagicMock(return_value=mock_resp)
mock_client.return_value.__exit__ = MagicMock(return_value=False)
result = emb._get_embedding("test")
assert result is None
def test_api_failure_returns_none(self, monkeypatch):
"""API 调用失败时返回 None。"""
monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake")
monkeypatch.setattr(settings, "EMBED_MODEL", "test-model")
monkeypatch.setattr(settings, "EMBED_API_KEY", "")
monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 0)
monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5)
import app.services.embedder as emb
with patch("httpx.Client") as mock_client:
mock_client.return_value.__enter__ = MagicMock()
mock_client.return_value.__exit__ = MagicMock(return_value=False)
mock_client.return_value.__enter__.return_value.post.side_effect = Exception("timeout")
result = emb._get_embedding("test")
assert result is None
# ═══════════════════════════════════════════════════════════════════════
# Searcher 语义模式测试
# ═══════════════════════════════════════════════════════════════════════
class TestSearchSemanticMode:
"""searcher.py 语义搜索模式测试。"""
def test_keyword_mode_default(self, db_session, sample_papers_with_summary):
"""默认 keyword 模式走 FTS5。"""
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test Paper", mode="keyword")
assert result["total"] >= 1
assert result["distances"] == {}
def test_semantic_mode_disabled_fallback(self, db_session, monkeypatch, sample_papers_with_summary):
"""CHROMA_ENABLED=false + semantic 模式走 FTS5。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test", mode="semantic")
# 应回退到 FTS5
assert result["total"] >= 1
def test_search_returns_distances_dict(self, db_session, sample_papers_with_summary):
"""搜索结果应包含 distances 字段。"""
from app.services.searcher import search_papers
result = search_papers(db_session, query="Test Paper")
assert "distances" in result
assert isinstance(result["distances"], dict)
def test_empty_query_returns_empty(self, db_session):
"""空查询无标签时返回空。"""
from app.services.searcher import search_papers
result = search_papers(db_session)
assert result["total"] == 0
assert result["results"] == []
def test_tag_only_search(self, db_session, sample_papers_with_summary):
"""仅标签搜索。"""
from app.services.searcher import search_papers
result = search_papers(db_session, tag="NLP")
assert result["total"] >= 1
# ═══════════════════════════════════════════════════════════════════════
# Search Routes 测试
# ═══════════════════════════════════════════════════════════════════════
class TestSearchRoutes:
"""搜索路由测试。"""
def test_search_page_keyword(self, auth_client, sample_papers_with_summary):
"""搜索页 keyword 模式。"""
resp = auth_client.get("/search?q=Test&mode=keyword")
assert resp.status_code == 200
assert "Test" in resp.text or "测试" in resp.text
def test_search_page_semantic_disabled(self, auth_client, monkeypatch, sample_papers_with_summary):
"""语义模式 CHROMA_ENABLED=false 时仍能工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/search?q=Test&mode=semantic")
assert resp.status_code == 200
def test_search_api_with_mode(self, auth_client, sample_papers_with_summary):
"""搜索 API 支持 mode 参数。"""
resp = auth_client.get("/api/search?q=Test&mode=keyword")
assert resp.status_code == 200
data = resp.json()
assert "results" in data
assert "total" in data
# ═══════════════════════════════════════════════════════════════════════
# Similar Paper API 测试
# ═══════════════════════════════════════════════════════════════════════
class TestSimilarAPI:
"""相似论文 API 测试。"""
def test_similar_api_disabled(self, auth_client, monkeypatch, sample_papers_with_summary):
"""CHROMA_ENABLED=false 时返回空列表。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/api/similar/2401.20001")
assert resp.status_code == 200
data = resp.json()
assert data["results"] == []
def test_similar_api_paper_not_found(self, auth_client, monkeypatch):
"""不存在的论文返回空。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/api/similar/nonexistent.99999")
assert resp.status_code == 200
assert resp.json()["results"] == []
def test_similar_api_with_top_k(self, auth_client, monkeypatch, sample_papers_with_summary):
"""top_k 参数控制返回数量。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/api/similar/2401.20001?top_k=3")
assert resp.status_code == 200
# ═══════════════════════════════════════════════════════════════════════
# Detail Page 相似论文测试
# ═══════════════════════════════════════════════════════════════════════
class TestDetailSimilarPapers:
"""详情页相似论文模块测试。"""
def test_detail_page_renders(self, auth_client, sample_papers_with_summary):
"""详情页正常渲染。"""
resp = auth_client.get("/paper/2401.20001")
assert resp.status_code == 200
assert "测试论文" in resp.text or "Test Paper" in resp.text
def test_detail_page_not_found(self, auth_client):
"""不存在的论文返回 404。"""
resp = auth_client.get("/paper/nonexistent.99999")
assert resp.status_code == 404
# ═══════════════════════════════════════════════════════════════════════
# Trends Dashboard 测试
# ═══════════════════════════════════════════════════════════════════════
class TestTrendsDashboard:
"""趋势看板测试。"""
def test_trends_page_renders(self, auth_client, sample_papers_with_summary):
"""趋势看板页面正常渲染。"""
resp = auth_client.get("/trends")
assert resp.status_code == 200
assert "趋势看板" in resp.text
assert "chart" in resp.text.lower() or "Chart" in resp.text
def test_trends_api_returns_data(self, auth_client, sample_papers_with_summary):
"""趋势 API 返回正确数据结构。"""
resp = auth_client.get("/api/stats/trends")
assert resp.status_code == 200
data = resp.json()
assert "daily_counts" in data
assert "top_tags" in data
assert "upvotes_dist" in data
assert "summary_completion" in data
assert isinstance(data["daily_counts"], list)
assert isinstance(data["top_tags"], list)
assert isinstance(data["upvotes_dist"], list)
assert isinstance(data["summary_completion"], list)
def test_trends_api_daily_counts(self, auth_client, sample_papers_with_summary, monkeypatch):
"""每日论文数量数据正确。"""
# 使用测试数据的日期范围
from unittest.mock import patch as upatch
import app.routes.trends as trends_mod
# monkeypatch _get_trends_data 中的 date.today
with upatch("app.routes.trends.date") as mock_date:
mock_date.today.return_value = date(2024, 1, 20)
mock_date.side_effect = lambda *a, **kw: date(*a, **kw)
resp = auth_client.get("/api/stats/trends")
data = resp.json()
assert len(data["daily_counts"]) == 5
for item in data["daily_counts"]:
assert "date" in item
assert "count" in item
assert item["count"] == 1
def test_trends_api_top_tags(self, auth_client, sample_papers_with_summary):
"""热门标签数据正确。"""
resp = auth_client.get("/api/stats/trends")
data = resp.json()
tags = {t["tag"]: t["count"] for t in data["top_tags"]}
assert "NLP" in tags
assert tags["NLP"] == 5 # 所有论文都有 NLP
def test_trends_api_summary_completion(self, auth_client, sample_papers_with_summary):
"""总结完成率数据正确。"""
resp = auth_client.get("/api/stats/trends")
data = resp.json()
statuses = {s["status"]: s["count"] for s in data["summary_completion"]}
assert "done" in statuses
assert statuses["done"] == 4 # 4 篇已完成
def test_trends_empty_db(self, auth_client):
"""无数据时不崩溃。"""
resp = auth_client.get("/api/stats/trends")
assert resp.status_code == 200
data = resp.json()
assert data["daily_counts"] == []
assert data["top_tags"] == []
# ═══════════════════════════════════════════════════════════════════════
# Compare Page 测试
# ═══════════════════════════════════════════════════════════════════════
class TestComparePage:
"""论文对比页测试。"""
def test_compare_page_no_ids(self, auth_client):
"""无 ID 时显示输入表单。"""
resp = auth_client.get("/compare")
assert resp.status_code == 200
assert "对比" in resp.text
def test_compare_page_with_ids(self, auth_client, sample_papers_with_summary):
"""对比多篇论文正常渲染。"""
resp = auth_client.get("/compare?ids=2401.20001,2401.20002")
assert resp.status_code == 200
assert "2401.20001" in resp.text
assert "2401.20002" in resp.text
# 应包含对比字段
assert "一句话摘要" in resp.text
assert "研究问题" in resp.text
def test_compare_page_max_5(self, auth_client, sample_papers_with_summary):
"""最多 5 篇。"""
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005"
resp = auth_client.get(f"/compare?ids={ids}")
assert resp.status_code == 200
def test_compare_page_over_5_truncates(self, auth_client, sample_papers_with_summary):
"""超过 5 篇截断。"""
ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005,2401.20006"
resp = auth_client.get(f"/compare?ids={ids}")
assert resp.status_code == 200
# 不应包含第 6 篇(不存在)
def test_compare_page_invalid_ids(self, auth_client):
"""无效 ID 时显示空结果。"""
resp = auth_client.get("/compare?ids=nonexistent.99999")
assert resp.status_code == 200
# 不存在的论文
assert "未找到" in resp.text or "暂无" in resp.text or resp.status_code == 200
def test_compare_page_shows_no_summary_placeholder(self, auth_client, sample_papers_with_summary):
"""无总结的论文显示占位文本。"""
# 2401.20005 没有 summarystatus=pending
resp = auth_client.get("/compare?ids=2401.20005")
assert resp.status_code == 200
assert "暂无总结" in resp.text
# ═══════════════════════════════════════════════════════════════════════
# Image Extraction 测试
# ═══════════════════════════════════════════════════════════════════════
class TestImageExtraction:
"""LaTeX 图片提取测试。"""
@pytest.mark.asyncio
async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path):
"""源码目录不存在时返回 0。"""
monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x)
from app.services.summarizer import _extract_images_from_source
result = await _extract_images_from_source("2401.99999")
assert result == 0
@pytest.mark.asyncio
async def test_extract_images_from_tex(self, monkeypatch, tmp_path):
"""从 .tex 文件中提取图片。"""
tmp_source = tmp_path / "tmp" / "2401.00001" / "source"
tmp_source.mkdir(parents=True)
images_dir = tmp_source / "figs"
images_dir.mkdir()
(images_dir / "figure1.png").write_bytes(b"\x89PNG\r\n")
(images_dir / "figure2.jpg").write_bytes(b"\xff\xd8\xff\xe0")
# 创建 .tex 文件
tex_content = r"""
\documentclass{article}
\begin{document}
\begin{figure}
\includegraphics[width=0.8\textwidth]{figs/figure1.png}
\includegraphics{figs/figure2.jpg}
\includegraphics[angle=90]{figs/nonexistent.pdf}
\end{figure}
\end{document}
"""
(tmp_source / "main.tex").write_text(tex_content)
papers_dir = tmp_path / "papers" / "2401.00001"
monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x)
from app.services.summarizer import _extract_images_from_source
result = await _extract_images_from_source("2401.00001")
assert result == 2
dest_images = papers_dir / "images"
assert dest_images.exists()
assert (dest_images / "figure1.png").exists()
assert (dest_images / "figure2.jpg").exists()
@pytest.mark.asyncio
async def test_extract_images_empty_tex(self, monkeypatch, tmp_path):
""".tex 文件无图片时返回 0。"""
tmp_source = tmp_path / "tmp" / "2401.00002" / "source"
tmp_source.mkdir(parents=True)
(tmp_source / "main.tex").write_text(r"\documentclass{article}\begin{document}Hello\end{document}")
monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x)
from app.services.summarizer import _extract_images_from_source
result = await _extract_images_from_source("2401.00002")
assert result == 0
# ═══════════════════════════════════════════════════════════════════════
# Nav Bar 测试
# ═══════════════════════════════════════════════════════════════════════
class TestNavBar:
"""导航栏测试。"""
def test_nav_includes_trends_link(self, auth_client):
"""导航栏应包含趋势链接。"""
resp = auth_client.get("/search")
assert resp.status_code == 200
assert "/trends" in resp.text
def test_nav_includes_compare_implicitly(self, auth_client):
"""compare 页面可访问。"""
resp = auth_client.get("/compare")
assert resp.status_code == 200
# ═══════════════════════════════════════════════════════════════════════
# Graceful Degradation 测试
# ═══════════════════════════════════════════════════════════════════════
class TestGracefulDegradation:
"""CHROMA_ENABLED=false 时优雅降级测试。"""
def test_search_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时 FTS5 搜索正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/search?q=Test")
assert resp.status_code == 200
assert "Test Paper" in resp.text or "测试论文" in resp.text
def test_detail_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时详情页正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/paper/2401.20001")
assert resp.status_code == 200
def test_trends_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时趋势看板正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/trends")
assert resp.status_code == 200
def test_compare_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary):
"""CHROMA 关闭时对比页正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
resp = auth_client.get("/compare?ids=2401.20001,2401.20002")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_cleaner_works_without_chroma(self, db_session, sample_papers_with_summary, monkeypatch):
"""CHROMA 关闭时删除论文正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
from app.services.cleaner import delete_papers_by_date_range
result = await delete_papers_by_date_range(
db_session,
date(2024, 1, 10),
date(2024, 1, 10),
)
assert result["status"] == "success"
assert result["deleted"] == 1