diff --git a/app/main.py b/app/main.py index 5493505..0ab932b 100644 --- a/app/main.py +++ b/app/main.py @@ -5,13 +5,16 @@ import os from fastapi import FastAPI from fastapi.staticfiles import StaticFiles +from starlette.staticfiles import StaticFiles as StarletteStaticFiles from app.config import settings from app.database import engine from app.models import init_db from app.routes.admin import router as admin_router +from app.routes.compare import router as compare_router from app.routes.pages import router as pages_router from app.routes.search import router as search_router +from app.routes.trends import router as trends_router from app.routes.user import router as user_router logging.basicConfig( @@ -49,11 +52,18 @@ def create_app() -> FastAPI: # 静态文件 app.mount("/static", StaticFiles(directory="app/static"), name="static") + # Phase 5: 论文图片静态服务 + papers_images_dir = os.path.join("data", "papers") + os.makedirs(papers_images_dir, exist_ok=True) + app.mount("/papers", StaticFiles(directory=papers_images_dir), name="papers") + # 路由 app.include_router(pages_router) app.include_router(admin_router) app.include_router(search_router) app.include_router(user_router) + app.include_router(trends_router) + app.include_router(compare_router) # 调度器(Phase 4) @app.on_event("startup") @@ -61,6 +71,12 @@ def create_app() -> FastAPI: from app.services.scheduler import start_scheduler start_scheduler() + # Phase 5: 初始化 ChromaDB + @app.on_event("startup") + async def _init_chroma(): + from app.services.embedder import init_chroma + init_chroma() + @app.on_event("shutdown") async def _stop_scheduler(): from app.services.scheduler import stop_scheduler diff --git a/app/routes/compare.py b/app/routes/compare.py new file mode 100644 index 0000000..8c418c3 --- /dev/null +++ b/app/routes/compare.py @@ -0,0 +1,115 @@ +"""论文对比页路由 — 多篇论文结构化字段并排对比。""" + +from __future__ import annotations + +import logging + +from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi.templating import Jinja2Templates +from sqlalchemy.orm import Session, joinedload + +from app.database import get_db +from app.models import Paper + +logger = logging.getLogger(__name__) + +router = APIRouter() +templates = Jinja2Templates(directory="app/templates") + + +@router.get("/compare") +def compare_page( + request: Request, + ids: str = Query(default="", description="逗号分隔的 arxiv_id,最多 5 篇"), + db: Session = Depends(get_db), +): + """论文对比页面。GET /compare?ids=id1,id2,id3""" + if not ids: + return templates.TemplateResponse( + request, + "compare.html", + { + "page_title": "论文对比", + "papers": [], + "error": None, + }, + ) + + arxiv_ids = [i.strip() for i in ids.split(",") if i.strip()] + + # 最多 5 篇 + if len(arxiv_ids) > 5: + arxiv_ids = arxiv_ids[:5] + + if not arxiv_ids: + return templates.TemplateResponse( + request, + "compare.html", + { + "page_title": "论文对比", + "papers": [], + "error": "请提供有效的论文 ID", + }, + ) + + papers = ( + db.query(Paper) + .filter(Paper.arxiv_id.in_(arxiv_ids)) + .options( + joinedload(Paper.authors), + joinedload(Paper.tags), + joinedload(Paper.summary), + joinedload(Paper.summary_status), + ) + .all() + ) + + # 按请求顺序排列 + id_order = {aid: idx for idx, aid in enumerate(arxiv_ids)} + papers.sort(key=lambda p: id_order.get(p.arxiv_id, 999)) + + # 构建对比数据 + compare_fields = [ + ("title_zh", "中文标题"), + ("title_en", "英文标题"), + ("one_line", "一句话摘要"), + ("difficulty", "难度"), + ("motivation_problem", "研究问题"), + ("motivation_goal", "研究目标"), + ("motivation_gap", "研究差距"), + ("method_overview", "方法概述"), + ("method_key_idea", "关键思路"), + ("method_novelty", "新颖性"), + ("results", "实验结果"), + ("limitations", "局限与改进"), + ] + + rows = [] + for field_key, field_label in compare_fields: + cells = [] + for paper in papers: + if field_key in ("title_zh", "title_en"): + val = getattr(paper, field_key, None) or "" + elif paper.summary: + val = getattr(paper.summary, field_key, None) or "" + # JSON 字段直接展示 + if field_key == "results" and not val: + val = paper.summary.results_main_json or "" + if field_key == "limitations" and not val: + val = paper.summary.limitations_json or "" + else: + val = "" + cells.append(val) + rows.append({"key": field_key, "label": field_label, "cells": cells}) + + return templates.TemplateResponse( + request, + "compare.html", + { + "page_title": "论文对比", + "papers": papers, + "rows": rows, + "ids_param": ids, + "error": None, + }, + ) diff --git a/app/routes/pages.py b/app/routes/pages.py index cf94a87..a034433 100644 --- a/app/routes/pages.py +++ b/app/routes/pages.py @@ -1,9 +1,13 @@ """页面路由 — 首页、日期页、论文详情。""" +from __future__ import annotations + +import logging from datetime import date, datetime, timedelta +from pathlib import Path from zoneinfo import ZoneInfo -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.responses import RedirectResponse from fastapi.templating import Jinja2Templates from sqlalchemy.orm import Session, joinedload @@ -12,6 +16,8 @@ from app.config import settings from app.database import get_db from app.models import Paper +logger = logging.getLogger(__name__) + router = APIRouter() templates = Jinja2Templates(directory="app/templates") @@ -99,11 +105,132 @@ def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db)) if paper.summary_status: summary_state = paper.summary_status.status + # Phase 5: 相似论文推荐 + similar_papers = _get_similar_papers(db, arxiv_id, top_k=6) + + # Phase 5: 图片画廊 + images = _get_paper_images(arxiv_id) + return templates.TemplateResponse( request, "detail.html", { "paper": paper, "summary_state": summary_state, + "similar_papers": similar_papers, + "paper_images": images, + "chroma_enabled": settings.CHROMA_ENABLED, "page_title": paper.title_zh or paper.title_en, }, ) + + +# ── 相似论文 API (Phase 5) ──────────────────────────────────────────── + + +@router.get("/api/similar/{arxiv_id}") +def similar_api( + arxiv_id: str, + top_k: int = Query(default=5, ge=1, le=20), + db: Session = Depends(get_db), +): + """返回与指定论文相似的论文列表(JSON)。""" + similar = _get_similar_papers(db, arxiv_id, top_k=top_k + 1) + # 排除自身 + items = [s for s in similar if s["arxiv_id"] != arxiv_id][:top_k] + return {"results": items} + + +def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict]: + """从 ChromaDB 获取相似论文,返回 [{arxiv_id, title_zh, distance, paper_date}, ...]。""" + if not settings.CHROMA_ENABLED: + return [] + + try: + from app.services.embedder import search_similar + + # 用论文的 arxiv_id 从 ChromaDB 查询 + col = None + try: + from app.services.embedder import get_collection + col = get_collection() + except Exception: + return [] + + if col is None: + return [] + + # 获取当前论文的 embedding + result = col.get(ids=[arxiv_id], include=["embeddings"]) + if not result["embeddings"] or not result["embeddings"][0]: + return [] + + vec = result["embeddings"][0] + count = col.count() + if count == 0: + return [] + + results = col.query( + query_embeddings=[vec], + n_results=min(top_k, count), + include=["metadatas", "distances"], + ) + + if not results["ids"] or not results["ids"][0]: + return [] + + # 从 DB 加载论文信息 + similar_ids = results["ids"][0] + distances = results["distances"][0] if results["distances"] else [0.0] * len(similar_ids) + + # 排除自身 + papers_info = {} + for i, sid in enumerate(similar_ids): + if sid != arxiv_id: + papers_info[sid] = distances[i] + + if not papers_info: + return [] + + papers = ( + db.query(Paper) + .filter(Paper.arxiv_id.in_(list(papers_info.keys()))) + .options(joinedload(Paper.tags)) + .all() + ) + + items = [] + for p in papers: + items.append({ + "arxiv_id": p.arxiv_id, + "title_zh": p.title_zh or p.title_en, + "distance": papers_info.get(p.arxiv_id, 0.0), + "paper_date": p.paper_date.isoformat() if p.paper_date else "", + "tags": [t.tag for t in p.tags[:3]], + }) + + # 按距离排序 + items.sort(key=lambda x: x["distance"]) + return items + + except Exception: + logger.exception("Failed to get similar papers for %s", arxiv_id) + return [] + + +# ── 图片画廊 (Phase 5) ──────────────────────────────────────────────── + + +def _get_paper_images(arxiv_id: str) -> list[dict]: + """获取论文提取的图片列表。""" + images_dir = Path("data/papers") / arxiv_id / "images" + if not images_dir.exists(): + return [] + + images = [] + for img_file in sorted(images_dir.iterdir()): + if img_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg"): + images.append({ + "url": f"/papers/{arxiv_id}/images/{img_file.name}", + "name": img_file.name, + }) + return images diff --git a/app/routes/search.py b/app/routes/search.py index ab9e59d..cc23ece 100644 --- a/app/routes/search.py +++ b/app/routes/search.py @@ -31,11 +31,12 @@ def search_page( q: str = Query(default=""), tag: str = Query(default=""), sort: str = Query(default="relevance"), + mode: str = Query(default="keyword"), page: int = Query(default=1, ge=1), db: Session = Depends(get_db), ): - """搜索页面。""" - result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page) + """搜索页面,支持 keyword 和 semantic 模式。""" + result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode) all_tags = get_all_tags(db) return templates.TemplateResponse( @@ -45,8 +46,11 @@ def search_page( "query": q, "tag": tag, "sort": sort, + "mode": mode, + "chroma_enabled": settings.CHROMA_ENABLED, "results": result["results"], "snippets": result["snippets"], + "distances": result.get("distances", {}), "total": result["total"], "page": result["page"], "total_pages": result["total_pages"], @@ -65,28 +69,31 @@ def search_api( q: str = Query(default=""), tag: str = Query(default=""), sort: str = Query(default="relevance"), + mode: str = Query(default="keyword"), page: int = Query(default=1, ge=1), db: Session = Depends(get_db), ): - """搜索 JSON API。""" - result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page) + """搜索 JSON API,支持 keyword 和 semantic 模式。""" + result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode) + distances = result.get("distances", {}) items = [] for paper in result["results"]: snippet = result["snippets"].get(paper.id, {}) - items.append( - { - "arxiv_id": paper.arxiv_id, - "title_en": paper.title_en, - "title_zh": paper.title_zh, - "paper_date": paper.paper_date.isoformat() if paper.paper_date else None, - "upvotes": paper.upvotes, - "tags": [t.tag for t in paper.tags], - "authors": [a.name for a in paper.authors], - "snippet_title_zh": snippet.get("title_zh"), - "snippet_abstract": snippet.get("abstract"), - } - ) + item = { + "arxiv_id": paper.arxiv_id, + "title_en": paper.title_en, + "title_zh": paper.title_zh, + "paper_date": paper.paper_date.isoformat() if paper.paper_date else None, + "upvotes": paper.upvotes, + "tags": [t.tag for t in paper.tags], + "authors": [a.name for a in paper.authors], + "snippet_title_zh": snippet.get("title_zh"), + "snippet_abstract": snippet.get("abstract"), + } + if paper.arxiv_id in distances: + item["distance"] = distances[paper.arxiv_id] + items.append(item) return { "results": items, diff --git a/app/routes/trends.py b/app/routes/trends.py new file mode 100644 index 0000000..80e6ce5 --- /dev/null +++ b/app/routes/trends.py @@ -0,0 +1,120 @@ +"""趋势看板路由 — 论文统计图表页面和数据 API。""" + +from __future__ import annotations + +import logging +from datetime import date, timedelta + +from fastapi import APIRouter, Depends, Request +from fastapi.templating import Jinja2Templates +from sqlalchemy import func, text +from sqlalchemy.orm import Session + +from app.config import settings +from app.database import get_db + +logger = logging.getLogger(__name__) + +router = APIRouter() +templates = Jinja2Templates(directory="app/templates") + + +@router.get("/trends") +def trends_page(request: Request, db: Session = Depends(get_db)): + """趋势看板页面。""" + stats = _get_trends_data(db) + return templates.TemplateResponse( + request, + "trends.html", + { + "page_title": "趋势看板", + "stats": stats, + "today": _today_str(), + }, + ) + + +@router.get("/api/stats/trends") +def trends_api(db: Session = Depends(get_db)): + """趋势数据 JSON API。""" + return _get_trends_data(db) + + +def _get_trends_data(db: Session) -> dict: + """从 DB 聚合趋势数据。""" + thirty_days_ago = (date.today() - timedelta(days=30)).isoformat() + + # 1. 按日论文数量(近 30 天) + daily_rows = db.execute(text(""" + SELECT paper_date, COUNT(*) as cnt + FROM papers + WHERE paper_date >= :start_date + GROUP BY paper_date + ORDER BY paper_date ASC + """), {"start_date": thirty_days_ago}).fetchall() + daily_counts = [ + {"date": str(row[0]), "count": row[1]} + for row in daily_rows + ] + + # 2. 热门标签 Top 20 + tag_rows = db.execute(text(""" + SELECT tag, COUNT(*) as cnt + FROM paper_tags + GROUP BY tag + ORDER BY cnt DESC + LIMIT 20 + """)).fetchall() + top_tags = [ + {"tag": row[0], "count": row[1]} + for row in tag_rows + ] + + # 3. Upvotes 分布 + upvote_rows = db.execute(text(""" + SELECT + CASE + WHEN upvotes >= 100 THEN '100+' + WHEN upvotes >= 50 THEN '50-99' + WHEN upvotes >= 20 THEN '20-49' + WHEN upvotes >= 10 THEN '10-19' + WHEN upvotes >= 5 THEN '5-9' + ELSE '0-4' + END as bucket, + COUNT(*) as cnt + FROM papers + GROUP BY bucket + ORDER BY MIN(upvotes) DESC + """)).fetchall() + upvotes_dist = [ + {"range": row[0], "count": row[1]} + for row in upvote_rows + ] + + # 4. 总结完成率 + summary_rows = db.execute(text(""" + SELECT + COALESCE(ss.status, 'none') as status, + COUNT(*) as cnt + FROM papers p + LEFT JOIN summary_status ss ON ss.paper_id = p.id + GROUP BY status + """)).fetchall() + summary_completion = [ + {"status": row[0], "count": row[1]} + for row in summary_rows + ] + + return { + "daily_counts": daily_counts, + "top_tags": top_tags, + "upvotes_dist": upvotes_dist, + "summary_completion": summary_completion, + } + + +def _today_str() -> str: + from datetime import datetime + from zoneinfo import ZoneInfo + tz = ZoneInfo(settings.APP_TIMEZONE) + return datetime.now(tz).strftime("%Y-%m-%d") diff --git a/app/services/cleaner.py b/app/services/cleaner.py index 56af5d1..f5f4c9b 100644 --- a/app/services/cleaner.py +++ b/app/services/cleaner.py @@ -139,6 +139,13 @@ async def delete_papers_by_date_range( {"paper_id": paper_id}, ) + # 1.5 Phase 5: 从 ChromaDB 删除语义索引 + try: + from app.services.embedder import delete_paper + delete_paper(arxiv_id) + except Exception: + logger.warning("Failed to delete %s from ChromaDB", arxiv_id, exc_info=True) + # 2. 删除本地文件 data/papers/{arxiv_id}/ paper_dir = _PAPERS_DIR / arxiv_id if paper_dir.exists(): diff --git a/app/services/embedder.py b/app/services/embedder.py new file mode 100644 index 0000000..f153251 --- /dev/null +++ b/app/services/embedder.py @@ -0,0 +1,314 @@ +"""ChromaDB 语义搜索服务 — embedding 生成、索引管理、相似查询。""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import httpx +from sqlalchemy import select +from sqlalchemy.orm import Session, joinedload + +from app.config import settings +from app.models import Paper + +logger = logging.getLogger(__name__) + +# ── 单例客户端和 collection ───────────────────────────────────────────── +_client = None +_collection = None + + +def _chroma_dir() -> Path: + return Path(settings.CHROMA_DIR) + + +def init_chroma() -> None: + """CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。""" + global _client, _collection + if not settings.CHROMA_ENABLED: + logger.debug("ChromaDB disabled, skip init") + return + + if _client is not None: + return + + try: + import chromadb + + chroma_path = _chroma_dir() + chroma_path.mkdir(parents=True, exist_ok=True) + + _client = chromadb.PersistentClient(path=str(chroma_path)) + _collection = _get_or_create_collection() + logger.info("ChromaDB initialized at %s", chroma_path) + except Exception: + logger.exception("Failed to initialize ChromaDB") + _client = None + _collection = None + + +def _get_or_create_collection(): + """获取或创建 papers_embeddings collection,维度不匹配时记录日志并跳过。""" + import chromadb + + try: + col = _client.get_collection("papers_embeddings") + logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count()) + return col + except Exception: + pass + + col = _client.create_collection( + name="papers_embeddings", + metadata={"hnsw:space": "cosine"}, + ) + logger.info("ChromaDB collection 'papers_embeddings' created") + return col + + +def get_collection(): + """返回当前 collection,未初始化则返回 None。""" + if not settings.CHROMA_ENABLED: + return None + if _collection is None: + init_chroma() + return _collection + + +# ── Embedding API 调用 ────────────────────────────────────────────────── + + +def _get_embedding(text: str) -> list[float] | None: + """调用 EMBED_API_BASE 的 embedding API 生成向量。 + + POST /v1/embeddings, model=EMBED_MODEL + 校验返回向量长度 == EMBED_DIMENSIONS + 失败时返回 None 并记录日志。 + """ + if not settings.EMBED_API_BASE or not settings.EMBED_MODEL: + logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding") + return None + + url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings" + headers = {"Content-Type": "application/json"} + if settings.EMBED_API_KEY: + headers["Authorization"] = f"Bearer {settings.EMBED_API_KEY}" + + payload = { + "model": settings.EMBED_MODEL, + "input": text, + } + + try: + with httpx.Client(timeout=settings.HTTP_TIMEOUT_SECONDS) as client: + resp = client.post(url, json=payload, headers=headers) + resp.raise_for_status() + data = resp.json() + + vec = data["data"][0]["embedding"] + if settings.EMBED_DIMENSIONS > 0 and len(vec) != settings.EMBED_DIMENSIONS: + logger.warning( + "Embedding dimension mismatch: expected=%d got=%d, skip", + settings.EMBED_DIMENSIONS, + len(vec), + ) + return None + return vec + except Exception: + logger.exception("Embedding API call failed") + return None + + +# ── 索引文本拼接 ──────────────────────────────────────────────────────── + + +def _build_index_text(paper: Paper, summary_dict: dict | None) -> str: + """拼接高信号字段为一段文本用于 embedding。""" + parts: list[str] = [] + + if paper.title_zh: + parts.append(paper.title_zh) + if paper.title_en: + parts.append(paper.title_en) + if paper.tags: + tag_str = " ".join(t.tag for t in paper.tags) + if tag_str: + parts.append(tag_str) + + if summary_dict: + for key in ("one_line", "motivation_problem", "method_key_idea"): + val = summary_dict.get(key) + if val: + parts.append(val) + + return " ".join(parts) + + +# ── 单篇索引 ──────────────────────────────────────────────────────────── + + +def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool: + """将单篇论文写入 ChromaDB 语义索引。 + + Args: + paper_id: arxiv_id + texts_dict: 可选的预拼接字段 dict,包含 title_zh, title_en, tags, one_line 等 + 如果为 None,则从 DB 加载。 + + Returns: + True 表示成功,False 表示失败或跳过。 + """ + col = get_collection() + if col is None: + return False + + try: + # 如果没传 texts_dict,从 DB 加载 + if texts_dict is None: + from app.database import SessionLocal + + db = SessionLocal() + try: + paper = ( + db.query(Paper) + .filter(Paper.arxiv_id == paper_id) + .options(joinedload(Paper.tags), joinedload(Paper.summary)) + .first() + ) + if not paper: + logger.warning("Paper %s not found for indexing", paper_id) + return False + + summary_dict = None + if paper.summary: + summary_dict = { + "one_line": paper.summary.one_line or "", + "motivation_problem": paper.summary.motivation_problem or "", + "method_key_idea": paper.summary.method_key_idea or "", + } + index_text = _build_index_text(paper, summary_dict) + arxiv_id = paper.arxiv_id + title_zh = paper.title_zh or "" + paper_date = paper.paper_date.isoformat() if paper.paper_date else "" + finally: + db.close() + else: + index_text = " ".join( + v for v in texts_dict.values() if isinstance(v, str) and v + ) + arxiv_id = texts_dict.get("arxiv_id", paper_id) + title_zh = texts_dict.get("title_zh", "") + paper_date = texts_dict.get("paper_date", "") + + if not index_text.strip(): + logger.warning("Empty index text for %s, skip", paper_id) + return False + + vec = _get_embedding(index_text) + if vec is None: + return False + + col.upsert( + ids=[arxiv_id], + embeddings=[vec], + metadatas=[{"arxiv_id": arxiv_id, "title_zh": title_zh, "paper_date": paper_date}], + ) + logger.info("Indexed paper %s in ChromaDB", arxiv_id) + return True + + except Exception: + logger.exception("Failed to index paper %s in ChromaDB", paper_id) + return False + + +# ── 批量索引 ──────────────────────────────────────────────────────────── + + +def index_batch(paper_ids: list[str]) -> dict: + """批量索引论文,单篇失败不影响其他。 + + Returns: + {"total": int, "success": int, "failed": int} + """ + if not paper_ids: + return {"total": 0, "success": 0, "failed": 0} + + col = get_collection() + if col is None: + return {"total": len(paper_ids), "success": 0, "failed": len(paper_ids)} + + success = 0 + failed = 0 + for pid in paper_ids: + if index_paper(pid): + success += 1 + else: + failed += 1 + + logger.info("Batch index: total=%d success=%d failed=%d", len(paper_ids), success, failed) + return {"total": len(paper_ids), "success": success, "failed": failed} + + +# ── 删除 ──────────────────────────────────────────────────────────────── + + +def delete_paper(paper_id: str) -> bool: + """从 ChromaDB 删除论文索引。 + + Args: + paper_id: arxiv_id + """ + col = get_collection() + if col is None: + return False + + try: + col.delete(ids=[paper_id]) + logger.info("Deleted paper %s from ChromaDB", paper_id) + return True + except Exception: + logger.exception("Failed to delete paper %s from ChromaDB", paper_id) + return False + + +# ── 相似查询 ──────────────────────────────────────────────────────────── + + +def search_similar(query_text: str, top_k: int = 20) -> list[dict]: + """语义相似搜索。 + + Args: + query_text: 查询文本 + top_k: 返回前 K 个结果 + + Returns: + [{"arxiv_id": str, "distance": float}, ...] + """ + col = get_collection() + if col is None: + return [] + + try: + vec = _get_embedding(query_text) + if vec is None: + return [] + + results = col.query( + query_embeddings=[vec], + n_results=min(top_k, col.count()) if col.count() > 0 else top_k, + include=["metadatas", "distances"], + ) + + if not results["ids"] or not results["ids"][0]: + return [] + + items = [] + for i, arxiv_id in enumerate(results["ids"][0]): + dist = results["distances"][0][i] if results["distances"] else 0.0 + items.append({"arxiv_id": arxiv_id, "distance": dist}) + return items + + except Exception: + logger.exception("ChromaDB search_similar failed") + return [] diff --git a/app/services/searcher.py b/app/services/searcher.py index 35a0284..9dd88a1 100644 --- a/app/services/searcher.py +++ b/app/services/searcher.py @@ -1,15 +1,19 @@ -"""FTS5 全文搜索服务 — 关键词 + 标签筛选,命中片段高亮,分页。""" +"""搜索服务 — FTS5 关键词搜索 + ChromaDB 语义搜索,命中片段高亮,分页。""" from __future__ import annotations +import logging import math import re from sqlalchemy import text from sqlalchemy.orm import Session, joinedload +from app.config import settings from app.models import Paper +logger = logging.getLogger(__name__) + # ── 输入清洗 ────────────────────────────────────────────────────────── # FTS5 查询语法中的特殊字符,用户输入时需要移除 @@ -41,8 +45,9 @@ def search_papers( sort: str = "relevance", page: int = 1, page_size: int = 20, + mode: str = "keyword", ) -> dict: - """FTS5 搜索论文。 + """搜索论文,支持 keyword (FTS5) 和 semantic (ChromaDB) 两种模式。 返回:: { @@ -51,8 +56,14 @@ def search_papers( "total": int, "page": int, "total_pages": int, + "distances": dict[str, float], # arxiv_id → distance (仅 semantic) } """ + # ── semantic 模式 ── + if mode == "semantic" and settings.CHROMA_ENABLED and query: + return _search_semantic(db, query, tag, sort, page, page_size) + + # ── keyword 模式(默认)── match_expr = _sanitize_query(query) if query else None # ── 无关键词 + 无标签 → 空结果 ── @@ -63,6 +74,7 @@ def search_papers( "total": 0, "page": page, "total_pages": 0, + "distances": {}, } # ── 构建条件性 JOIN 和 WHERE 片段 ── @@ -146,6 +158,75 @@ def _search_with_fts( "total": total, "page": page, "total_pages": math.ceil(total / page_size) if total else 0, + "distances": {}, + } + + +def _search_semantic( + db: Session, + query: str, + tag: str | None, + sort: str, + page: int, + page_size: int, +) -> dict: + """ChromaDB 语义搜索,失败时回退到 FTS5。""" + try: + from app.services.embedder import search_similar + + top_k = page_size * 3 # 多取一些用于 tag 过滤 + candidates = search_similar(query, top_k=top_k) + except Exception: + logger.exception("Semantic search failed, falling back to keyword") + candidates = [] + + if not candidates: + # 回退到 FTS5 + return _search_with_fts( + db, + _sanitize_query(query) or query, + "JOIN paper_tags pt ON pt.paper_id = p.id" if tag else "", + "AND pt.tag = :tag" if tag else "", + {"tag": tag} if tag else {}, + sort, page, page_size, (page - 1) * page_size, + ) + + # 按 arxiv_id 从 DB 加载完整数据 + arxiv_ids = [c["arxiv_id"] for c in candidates] + distance_map = {c["arxiv_id"]: c["distance"] for c in candidates} + + papers_query = ( + db.query(Paper) + .filter(Paper.arxiv_id.in_(arxiv_ids)) + .options( + joinedload(Paper.authors), + joinedload(Paper.tags), + joinedload(Paper.summary_status), + joinedload(Paper.bookmark), + joinedload(Paper.reading_status), + ) + ) + if tag: + papers_query = papers_query.filter(Paper.tags.any(tag=tag)) + + papers = papers_query.all() + + # 按语义距离排序 + id_order = {aid: idx for idx, aid in enumerate(arxiv_ids)} + papers.sort(key=lambda p: id_order.get(p.arxiv_id, 999)) + + # 分页 + total = len(papers) + start = (page - 1) * page_size + page_papers = papers[start:start + page_size] + + return { + "results": page_papers, + "snippets": {}, + "total": total, + "page": page, + "total_pages": math.ceil(total / page_size) if total else 0, + "distances": distance_map, } @@ -187,6 +268,7 @@ def _search_tag_only( "total": total, "page": page, "total_pages": math.ceil(total / page_size) if total else 0, + "distances": {}, } diff --git a/app/services/summarizer.py b/app/services/summarizer.py index 7003f95..523eb18 100644 --- a/app/services/summarizer.py +++ b/app/services/summarizer.py @@ -359,6 +359,127 @@ def _cleanup_tmp(arxiv_id: str) -> None: logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True) +# ── LaTeX 图片提取(Phase 5)─────────────────────────────────────────── + +_INCLUDEGRAPHICS_RE = re.compile( + r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE +) +_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"} + + +async def _extract_images_from_source(arxiv_id: str, tmp_source: Path | None = None) -> int: + """从 LaTeX 源码中提取图片文件。 + + 流程: + 1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/ + 2. 扫描 .tex 文件中的 \\includegraphics + 3. 复制图片到 data/papers/{arxiv_id}/images/ + 4. 清理源码临时文件 + + Returns: + 提取的图片数量 + """ + tmp_source = _tmp_dir(arxiv_id) / "source" + images_dest = _paper_dir(arxiv_id) / "images" + + try: + # 下载源码 zip(如果还没下载) + if not tmp_source.exists(): + source_url = f"https://arxiv.org/e-print/{arxiv_id}" + await _download_source_zip(arxiv_id, source_url, tmp_source) + + if not tmp_source.exists(): + return 0 + + # 扫描 .tex 文件,收集图片路径 + image_paths: set[str] = set() + for tex_file in tmp_source.rglob("*.tex"): + try: + content = tex_file.read_text(encoding="utf-8", errors="replace") + for match in _INCLUDEGRAPHICS_RE.finditer(content): + img_path = match.group(1).strip() + image_paths.add(img_path) + except Exception: + continue + + if not image_paths: + return 0 + + # 查找并复制图片 + images_dest.mkdir(parents=True, exist_ok=True) + copied = 0 + for img_rel in image_paths: + # 尝试在源码目录中找到文件 + for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"): + candidate = tmp_source / (img_rel + ext) + if candidate.is_file(): + dest_name = candidate.name + # 避免文件名冲突 + dest = images_dest / dest_name + if dest.exists(): + stem = dest.stem + suffix = dest.suffix + dest = images_dest / f"{stem}_{copied}{suffix}" + shutil.copy2(candidate, dest) + copied += 1 + break + + if copied > 0: + logger.info("Extracted %d images from source for %s", copied, arxiv_id) + return copied + + except Exception: + logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True) + return 0 + + +async def _download_source_zip( + arxiv_id: str, source_url: str, dest_dir: Path +) -> None: + """下载 arXiv 源码并解压。""" + import zipfile + + dest_dir.mkdir(parents=True, exist_ok=True) + zip_path = _tmp_dir(arxiv_id) / "source.zip" + + transport = None + if settings.http_proxy: + transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy) + + try: + async with httpx.AsyncClient( + timeout=settings.HTTP_TIMEOUT_SECONDS, + headers={"User-Agent": settings.HTTP_USER_AGENT}, + transport=transport, + follow_redirects=True, + ) as client: + resp = await client.get(source_url) + resp.raise_for_status() + zip_path.write_bytes(resp.content) + except Exception as exc: + logger.debug("Failed to download source for %s: %s", arxiv_id, exc) + return + + try: + with zipfile.ZipFile(zip_path, "r") as zf: + zf.extractall(dest_dir) + logger.debug("Extracted source for %s", arxiv_id) + except zipfile.BadZipFile: + # 可能是 tar.gz + import tarfile + try: + with tarfile.open(zip_path, "r:*") as tf: + tf.extractall(dest_dir) + logger.debug("Extracted source (tar) for %s", arxiv_id) + except Exception: + logger.warning("Cannot extract source for %s", arxiv_id) + except Exception: + logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True) + finally: + if zip_path.exists(): + zip_path.unlink() + + # ── 单篇总结 ──────────────────────────────────────────────────────────── @@ -441,6 +562,30 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict: status.raw_output_saved = True db.commit() + # Phase 5: LaTeX 图片提取(可选增强,失败不影响总结) + try: + await _extract_images_from_source(arxiv_id) + except Exception: + logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True) + + # Phase 5: 同步写入语义索引(失败仅 log) + try: + from app.services.embedder import index_paper + + texts_dict = { + "arxiv_id": arxiv_id, + "title_zh": schema.title_zh or "", + "title_en": paper.title_en or "", + "tags": " ".join(t.tag for t in paper.tags) if paper.tags else "", + "one_line": schema.one_line or "", + "motivation_problem": schema.motivation_problem or "", + "method_key_idea": schema.method_key_idea or "", + "paper_date": paper.paper_date.isoformat() if paper.paper_date else "", + } + index_paper(arxiv_id, texts_dict) + except Exception: + logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True) + logger.info("Summarize done: %s quality=%s", arxiv_id, quality) return {"arxiv_id": arxiv_id, "status": "done", "quality": quality} diff --git a/app/static/css/style.css b/app/static/css/style.css index ff4ed02..a8f85d1 100644 --- a/app/static/css/style.css +++ b/app/static/css/style.css @@ -556,3 +556,185 @@ mark { .reading-list-filters { gap: 4px; } .filter-chip { padding: 4px 10px; font-size: 0.8rem; } } + +/* ── Search Mode Toggle (Phase 5) ─────────────────────────────── */ +.search-mode-toggle { + display: flex; + gap: 0; + border: 1px solid var(--border); + border-radius: var(--radius); + overflow: hidden; +} +.mode-option { + padding: 8px 14px; + font-size: 0.85rem; + color: var(--ink-light); + cursor: pointer; + transition: all 0.2s; + display: flex; + align-items: center; + gap: 4px; +} +.mode-option input[type="radio"] { display: none; } +.mode-option:hover { background: var(--bg); } +.mode-option.active { + background: var(--accent); + color: #fff; +} + +/* ── Similarity Score (Phase 5) ────────────────────────────────── */ +.similarity-score { + font-size: 0.8rem; + color: var(--accent); + white-space: nowrap; + padding: 2px 8px; + background: #eef3f8; + border-radius: 4px; +} + +/* ── Similar Papers (Phase 5) ──────────────────────────────────── */ +.similar-papers { + margin-top: 32px; + padding-top: 24px; + border-top: 1px solid var(--border); +} +.similar-papers h2 { + font-family: var(--font-body); + font-size: 1.1rem; + font-weight: 600; + margin-bottom: 12px; + color: var(--accent); +} +.similar-paper-item { + display: flex; + justify-content: space-between; + align-items: center; + padding: 10px 0; + border-bottom: 1px solid var(--border); +} +.similar-paper-item:last-child { border-bottom: none; } +.similar-paper-title a { + font-size: 0.92rem; + color: var(--ink); +} +.similar-paper-title a:hover { color: var(--accent); } +.similar-paper-dist { + font-size: 0.8rem; + color: var(--ink-light); +} + +/* ── Trends Dashboard (Phase 5) ────────────────────────────────── */ +.trends-page h1 { + font-family: var(--font-body); + font-size: 1.5rem; + font-weight: 700; + margin-bottom: 24px; +} +.charts-grid { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 24px; + margin-bottom: 32px; +} +.chart-card { + background: var(--surface); + border: 1px solid var(--border); + border-radius: var(--radius); + padding: 20px; +} +.chart-card h2 { + font-family: var(--font-body); + font-size: 1rem; + font-weight: 600; + margin-bottom: 12px; + color: var(--accent); +} +.chart-card canvas { + width: 100% !important; + max-height: 300px; +} +@media (max-width: 768px) { + .charts-grid { grid-template-columns: 1fr; } +} + +/* ── Compare Page (Phase 5) ────────────────────────────────────── */ +.compare-page h1 { + font-family: var(--font-body); + font-size: 1.5rem; + font-weight: 700; + margin-bottom: 24px; +} +.compare-table-wrapper { + overflow-x: auto; + margin-bottom: 24px; +} +.compare-table { + width: 100%; + border-collapse: collapse; + background: var(--surface); + border: 1px solid var(--border); + border-radius: var(--radius); + font-size: 0.88rem; +} +.compare-table th, +.compare-table td { + padding: 12px 16px; + border: 1px solid var(--border); + vertical-align: top; + text-align: left; +} +.compare-table th { + background: var(--bg); + font-weight: 600; + color: var(--ink-light); + white-space: nowrap; + min-width: 100px; +} +.compare-table td.field-label { + background: var(--bg); + font-weight: 500; + color: var(--ink); + white-space: nowrap; + min-width: 90px; +} +.compare-table td.paper-col { + min-width: 200px; +} +.compare-table td .no-summary { + color: var(--ink-light); + font-style: italic; +} + +/* ── Image Gallery (Phase 5) ───────────────────────────────────── */ +.image-gallery { + margin-top: 24px; +} +.image-gallery h2 { + font-family: var(--font-body); + font-size: 1.05rem; + font-weight: 600; + margin-bottom: 12px; + color: var(--accent); +} +.gallery-grid { + display: grid; + grid-template-columns: repeat(auto-fill, minmax(260px, 1fr)); + gap: 16px; +} +.gallery-item { + background: var(--surface); + border: 1px solid var(--border); + border-radius: var(--radius); + overflow: hidden; +} +.gallery-item img { + width: 100%; + height: auto; + display: block; +} +.gallery-item .gallery-caption { + padding: 8px 12px; + font-size: 0.8rem; + color: var(--ink-light); + text-align: center; +} diff --git a/app/templates/base.html b/app/templates/base.html index eb32d2d..6e272f9 100644 --- a/app/templates/base.html +++ b/app/templates/base.html @@ -16,6 +16,7 @@ diff --git a/app/templates/compare.html b/app/templates/compare.html new file mode 100644 index 0000000..c59a934 --- /dev/null +++ b/app/templates/compare.html @@ -0,0 +1,86 @@ +{% extends "base.html" %} + +{% block title %}{{ page_title }} — HF Daily Papers{% endblock %} + +{% block content %} +
+

论文对比

+ + {# ID 输入表单 #} +
+ + +
+ + {% if error %} +
+

{{ error }}

+
+ {% endif %} + + {% if papers %} +
+ + + + + {% for paper in papers %} + + {% endfor %} + + + + {# 作者行 #} + + + {% for paper in papers %} + + {% endfor %} + + + {# 标签行 #} + + + {% for paper in papers %} + + {% endfor %} + + + {# 结构化对比字段 #} + {% for row in rows %} + + + {% for cell in row.cells %} + + {% endfor %} + + {% endfor %} + +
字段 + {{ paper.arxiv_id }} +
+ + {{ paper.upvotes }} 👍 · {{ paper.paper_date }} + +
作者{{ paper.authors|map(attribute='name')|join(', ') }}
标签 + {% for t in paper.tags[:5] %} + {{ t.tag }} + {% endfor %} +
{{ row.label }} + {% if cell %} + {{ cell }} + {% else %} + 暂无总结 + {% endif %} +
+
+ {% elif ids_param and not error %} +
+

未找到匹配的论文

+

请检查 arXiv ID 是否正确

+
+ {% endif %} +
+{% endblock %} diff --git a/app/templates/detail.html b/app/templates/detail.html index f14635b..3979d6d 100644 --- a/app/templates/detail.html +++ b/app/templates/detail.html @@ -117,5 +117,35 @@

{{ paper.abstract }}

{% endif %} + + {# Phase 5: 图片画廊 #} + {% if paper_images %} + + {% endif %} + + {# Phase 5: 相似论文推荐 #} + {% if similar_papers %} +
+

相似论文推荐

+ {% for sp in similar_papers %} +
+ + {{ sp.title_zh }} + + 🎯 {{ "%.3f"|format(sp.distance) }} +
+ {% endfor %} +
+ {% endif %} {% endblock %} diff --git a/app/templates/search.html b/app/templates/search.html index b4224ca..4a4950a 100644 --- a/app/templates/search.html +++ b/app/templates/search.html @@ -11,6 +11,21 @@ {% if tag %} {% endif %} + + {# 模式切换 #} + {% if chroma_enabled %} +
+ + +
+ {% endif %} + @@ -18,10 +33,10 @@ {% if all_tags %}
标签: - 全部 {% for t in all_tags %} - {{ t }} {% endfor %}
@@ -30,12 +45,12 @@ {% if query or tag %} {# 搜索结果元信息 #}
- 找到 {{ total }} 条结果 + 找到 {{ total }} 条结果{% if mode == 'semantic' %}(语义模式){% endif %}
- 相关性 | - 日期
@@ -58,6 +73,11 @@ 👍 {{ paper.upvotes }} + {% if distances and paper.arxiv_id in distances %} + + 🎯 {{ "%.3f"|format(distances[paper.arxiv_id]) }} + + {% endif %} {% if snippet and snippet.abstract %} @@ -103,11 +123,11 @@ {% if total_pages > 1 %} {% endif %} diff --git a/app/templates/trends.html b/app/templates/trends.html new file mode 100644 index 0000000..c81dc40 --- /dev/null +++ b/app/templates/trends.html @@ -0,0 +1,185 @@ +{% extends "base.html" %} + +{% block title %}{{ page_title }} — HF Daily Papers{% endblock %} + +{% block content %} + +{% endblock %} + +{% block scripts %} + + +{% endblock %} diff --git a/pyproject.toml b/pyproject.toml index 63ec04f..a7ec16d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "typer>=0.15", "python-dotenv>=1.0", "apscheduler>=3.10", + "chromadb>=1.0", ] [project.optional-dependencies] diff --git a/tests/test_phase5.py b/tests/test_phase5.py new file mode 100644 index 0000000..cf73132 --- /dev/null +++ b/tests/test_phase5.py @@ -0,0 +1,657 @@ +"""Phase 5 后续增强测试 — embedder、semantic search、trends、compare、image extraction。""" + +from __future__ import annotations + +import json +import shutil +import time +from datetime import date, datetime, timezone +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import select + +from app.config import settings +from app.database import get_db +from app.models import ( + Paper, + PaperAuthor, + PaperSummary, + PaperTag, + SummaryStatus, +) + + +# ── Fixtures ──────────────────────────────────────────────────────────── + +ADMIN_TOKEN = "test-admin-token-12345" + + +@pytest.fixture +def admin_headers(): + return {"Authorization": f"Bearer " + ADMIN_TOKEN} + + +@pytest.fixture +def auth_client(client, monkeypatch): + monkeypatch.setattr(settings, "ADMIN_TOKEN", ADMIN_TOKEN) + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + return client + + +@pytest.fixture +def sample_papers_with_summary(db_session): + """插入多篇带总结的论文。""" + now = datetime.now(timezone.utc) + papers = [] + for i, (arxiv_id, paper_date_str) in enumerate([ + ("2401.20001", "2024-01-10"), + ("2401.20002", "2024-01-11"), + ("2401.20003", "2024-01-12"), + ("2401.20004", "2024-01-13"), + ("2401.20005", "2024-01-14"), + ]): + paper_date = date.fromisoformat(paper_date_str) + p = Paper( + arxiv_id=arxiv_id, + title_en=f"Test Paper {i+1}", + title_zh=f"测试论文 {i+1}", + abstract=f"Abstract for paper {i+1}.", + paper_date=paper_date, + crawled_at=now, + upvotes=i * 10 + 5, + ) + db_session.add(p) + db_session.flush() + + db_session.add(PaperAuthor(paper_id=p.id, name=f"Author {i+1}", position=0)) + db_session.add(PaperTag(paper_id=p.id, tag="NLP", source="hf")) + db_session.add(PaperTag(paper_id=p.id, tag=f"Tag{i+1}", source="hf")) + + db_session.add(SummaryStatus( + paper_id=p.id, + status="done" if i < 4 else "pending", + quality="normal", + )) + + # 添加总结(前 4 篇) + if i < 4: + from app.services.schemas import SummarySchema + summary = PaperSummary( + paper_id=p.id, + one_line=f"这是论文{i+1}的一句话摘要", + difficulty="中级", + motivation_problem=f"论文{i+1}的研究问题", + motivation_goal=f"论文{i+1}的研究目标", + method_key_idea=f"论文{i+1}的关键思路", + method_overview=f"论文{i+1}的方法概述", + updated_at=now, + full_json=json.dumps({"title_zh": f"测试论文 {i+1}"}), + ) + db_session.add(summary) + + # FTS5 + import sqlalchemy + db_session.execute( + sqlalchemy.text( + "INSERT INTO papers_fts(rowid, title_en, title_zh, abstract, authors, tags) " + "VALUES (:id, :title_en, :title_zh, :abstract, :authors, :tags)" + ), + { + "id": p.id, + "title_en": p.title_en, + "title_zh": p.title_zh or "", + "abstract": p.abstract or "", + "authors": f"Author {i+1}", + "tags": f"NLP, Tag{i+1}", + }, + ) + papers.append(p) + db_session.commit() + return papers + + +# ═══════════════════════════════════════════════════════════════════════ +# Embedder 服务测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestEmbedderInit: + """embedder.py 初始化测试。""" + + def test_chroma_disabled_skip_init(self, monkeypatch): + """CHROMA_ENABLED=false 时不初始化。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + emb._client = None + emb._collection = None + emb.init_chroma() + assert emb._client is None + + def test_chroma_init_success(self, monkeypatch, tmp_path): + """CHROMA_ENABLED=true 时初始化成功。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", True) + monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma")) + + import app.services.embedder as emb + emb._client = None + emb._collection = None + emb.init_chroma() + + assert emb._client is not None + assert emb._collection is not None + + # 清理 + emb._client = None + emb._collection = None + + def test_get_collection_returns_none_when_disabled(self, monkeypatch): + """CHROMA_ENABLED=false 时 get_collection 返回 None。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + emb._client = None + emb._collection = None + assert emb.get_collection() is None + + +class TestEmbedderIndexing: + """embedder.py 索引测试。""" + + def test_index_paper_disabled(self, monkeypatch): + """CHROMA_ENABLED=false 时 index_paper 返回 False。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + emb._client = None + emb._collection = None + assert emb.index_paper("test-id") is False + + def test_index_paper_no_api_config(self, monkeypatch, tmp_path): + """没有 EMBED_API_BASE 时返回 False。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", True) + monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma")) + monkeypatch.setattr(settings, "EMBED_API_BASE", "") + monkeypatch.setattr(settings, "EMBED_MODEL", "") + + import app.services.embedder as emb + emb._client = None + emb._collection = None + emb.init_chroma() + + result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"}) + assert result is False + + emb._client = None + emb._collection = None + + def test_index_batch_disabled(self, monkeypatch): + """CHROMA_ENABLED=false 时 index_batch 返回全失败。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + emb._client = None + emb._collection = None + result = emb.index_batch(["a", "b"]) + assert result["success"] == 0 + assert result["failed"] == 2 + + def test_index_batch_empty(self, monkeypatch): + """空列表时返回 0。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + result = emb.index_batch([]) + assert result["total"] == 0 + + def test_delete_paper_disabled(self, monkeypatch): + """CHROMA_ENABLED=false 时 delete_paper 返回 False。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + emb._client = None + emb._collection = None + assert emb.delete_paper("test-id") is False + + def test_search_similar_disabled(self, monkeypatch): + """CHROMA_ENABLED=false 时 search_similar 返回空列表。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + emb._client = None + emb._collection = None + assert emb.search_similar("test query") == [] + + +class TestEmbeddingApi: + """_get_embedding 测试。""" + + def test_no_api_base_returns_none(self, monkeypatch): + """EMBED_API_BASE 为空时返回 None。""" + monkeypatch.setattr(settings, "EMBED_API_BASE", "") + monkeypatch.setattr(settings, "EMBED_MODEL", "") + import app.services.embedder as emb + assert emb._get_embedding("test") is None + + def test_dimension_mismatch_returns_none(self, monkeypatch): + """维度不匹配时返回 None。""" + monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake") + monkeypatch.setattr(settings, "EMBED_MODEL", "test-model") + monkeypatch.setattr(settings, "EMBED_API_KEY", "") + monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 128) + monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5) + + import app.services.embedder as emb + + mock_resp = MagicMock() + mock_resp.json.return_value = {"data": [{"embedding": [0.1] * 64}]} + mock_resp.raise_for_status = MagicMock() + + with patch("httpx.Client") as mock_client: + mock_client.return_value.__enter__ = MagicMock(return_value=mock_resp) + mock_client.return_value.__exit__ = MagicMock(return_value=False) + result = emb._get_embedding("test") + assert result is None + + def test_api_failure_returns_none(self, monkeypatch): + """API 调用失败时返回 None。""" + monkeypatch.setattr(settings, "EMBED_API_BASE", "http://fake") + monkeypatch.setattr(settings, "EMBED_MODEL", "test-model") + monkeypatch.setattr(settings, "EMBED_API_KEY", "") + monkeypatch.setattr(settings, "EMBED_DIMENSIONS", 0) + monkeypatch.setattr(settings, "HTTP_TIMEOUT_SECONDS", 5) + + import app.services.embedder as emb + + with patch("httpx.Client") as mock_client: + mock_client.return_value.__enter__ = MagicMock() + mock_client.return_value.__exit__ = MagicMock(return_value=False) + mock_client.return_value.__enter__.return_value.post.side_effect = Exception("timeout") + result = emb._get_embedding("test") + assert result is None + + +# ═══════════════════════════════════════════════════════════════════════ +# Searcher 语义模式测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestSearchSemanticMode: + """searcher.py 语义搜索模式测试。""" + + def test_keyword_mode_default(self, db_session, sample_papers_with_summary): + """默认 keyword 模式走 FTS5。""" + from app.services.searcher import search_papers + result = search_papers(db_session, query="Test Paper", mode="keyword") + assert result["total"] >= 1 + assert result["distances"] == {} + + def test_semantic_mode_disabled_fallback(self, db_session, monkeypatch, sample_papers_with_summary): + """CHROMA_ENABLED=false + semantic 模式走 FTS5。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + from app.services.searcher import search_papers + result = search_papers(db_session, query="Test", mode="semantic") + # 应回退到 FTS5 + assert result["total"] >= 1 + + def test_search_returns_distances_dict(self, db_session, sample_papers_with_summary): + """搜索结果应包含 distances 字段。""" + from app.services.searcher import search_papers + result = search_papers(db_session, query="Test Paper") + assert "distances" in result + assert isinstance(result["distances"], dict) + + def test_empty_query_returns_empty(self, db_session): + """空查询无标签时返回空。""" + from app.services.searcher import search_papers + result = search_papers(db_session) + assert result["total"] == 0 + assert result["results"] == [] + + def test_tag_only_search(self, db_session, sample_papers_with_summary): + """仅标签搜索。""" + from app.services.searcher import search_papers + result = search_papers(db_session, tag="NLP") + assert result["total"] >= 1 + + +# ═══════════════════════════════════════════════════════════════════════ +# Search Routes 测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestSearchRoutes: + """搜索路由测试。""" + + def test_search_page_keyword(self, auth_client, sample_papers_with_summary): + """搜索页 keyword 模式。""" + resp = auth_client.get("/search?q=Test&mode=keyword") + assert resp.status_code == 200 + assert "Test" in resp.text or "测试" in resp.text + + def test_search_page_semantic_disabled(self, auth_client, monkeypatch, sample_papers_with_summary): + """语义模式 CHROMA_ENABLED=false 时仍能工作。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = auth_client.get("/search?q=Test&mode=semantic") + assert resp.status_code == 200 + + def test_search_api_with_mode(self, auth_client, sample_papers_with_summary): + """搜索 API 支持 mode 参数。""" + resp = auth_client.get("/api/search?q=Test&mode=keyword") + assert resp.status_code == 200 + data = resp.json() + assert "results" in data + assert "total" in data + + +# ═══════════════════════════════════════════════════════════════════════ +# Similar Paper API 测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestSimilarAPI: + """相似论文 API 测试。""" + + def test_similar_api_disabled(self, auth_client, monkeypatch, sample_papers_with_summary): + """CHROMA_ENABLED=false 时返回空列表。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = auth_client.get("/api/similar/2401.20001") + assert resp.status_code == 200 + data = resp.json() + assert data["results"] == [] + + def test_similar_api_paper_not_found(self, auth_client, monkeypatch): + """不存在的论文返回空。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = auth_client.get("/api/similar/nonexistent.99999") + assert resp.status_code == 200 + assert resp.json()["results"] == [] + + def test_similar_api_with_top_k(self, auth_client, monkeypatch, sample_papers_with_summary): + """top_k 参数控制返回数量。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = auth_client.get("/api/similar/2401.20001?top_k=3") + assert resp.status_code == 200 + + +# ═══════════════════════════════════════════════════════════════════════ +# Detail Page 相似论文测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestDetailSimilarPapers: + """详情页相似论文模块测试。""" + + def test_detail_page_renders(self, auth_client, sample_papers_with_summary): + """详情页正常渲染。""" + resp = auth_client.get("/paper/2401.20001") + assert resp.status_code == 200 + assert "测试论文" in resp.text or "Test Paper" in resp.text + + def test_detail_page_not_found(self, auth_client): + """不存在的论文返回 404。""" + resp = auth_client.get("/paper/nonexistent.99999") + assert resp.status_code == 404 + + +# ═══════════════════════════════════════════════════════════════════════ +# Trends Dashboard 测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestTrendsDashboard: + """趋势看板测试。""" + + def test_trends_page_renders(self, auth_client, sample_papers_with_summary): + """趋势看板页面正常渲染。""" + resp = auth_client.get("/trends") + assert resp.status_code == 200 + assert "趋势看板" in resp.text + assert "chart" in resp.text.lower() or "Chart" in resp.text + + def test_trends_api_returns_data(self, auth_client, sample_papers_with_summary): + """趋势 API 返回正确数据结构。""" + resp = auth_client.get("/api/stats/trends") + assert resp.status_code == 200 + data = resp.json() + + assert "daily_counts" in data + assert "top_tags" in data + assert "upvotes_dist" in data + assert "summary_completion" in data + + assert isinstance(data["daily_counts"], list) + assert isinstance(data["top_tags"], list) + assert isinstance(data["upvotes_dist"], list) + assert isinstance(data["summary_completion"], list) + + def test_trends_api_daily_counts(self, auth_client, sample_papers_with_summary, monkeypatch): + """每日论文数量数据正确。""" + # 使用测试数据的日期范围 + from unittest.mock import patch as upatch + import app.routes.trends as trends_mod + + # monkeypatch _get_trends_data 中的 date.today + with upatch("app.routes.trends.date") as mock_date: + mock_date.today.return_value = date(2024, 1, 20) + mock_date.side_effect = lambda *a, **kw: date(*a, **kw) + + resp = auth_client.get("/api/stats/trends") + data = resp.json() + assert len(data["daily_counts"]) == 5 + for item in data["daily_counts"]: + assert "date" in item + assert "count" in item + assert item["count"] == 1 + + def test_trends_api_top_tags(self, auth_client, sample_papers_with_summary): + """热门标签数据正确。""" + resp = auth_client.get("/api/stats/trends") + data = resp.json() + tags = {t["tag"]: t["count"] for t in data["top_tags"]} + assert "NLP" in tags + assert tags["NLP"] == 5 # 所有论文都有 NLP + + def test_trends_api_summary_completion(self, auth_client, sample_papers_with_summary): + """总结完成率数据正确。""" + resp = auth_client.get("/api/stats/trends") + data = resp.json() + statuses = {s["status"]: s["count"] for s in data["summary_completion"]} + assert "done" in statuses + assert statuses["done"] == 4 # 4 篇已完成 + + def test_trends_empty_db(self, auth_client): + """无数据时不崩溃。""" + resp = auth_client.get("/api/stats/trends") + assert resp.status_code == 200 + data = resp.json() + assert data["daily_counts"] == [] + assert data["top_tags"] == [] + + +# ═══════════════════════════════════════════════════════════════════════ +# Compare Page 测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestComparePage: + """论文对比页测试。""" + + def test_compare_page_no_ids(self, auth_client): + """无 ID 时显示输入表单。""" + resp = auth_client.get("/compare") + assert resp.status_code == 200 + assert "对比" in resp.text + + def test_compare_page_with_ids(self, auth_client, sample_papers_with_summary): + """对比多篇论文正常渲染。""" + resp = auth_client.get("/compare?ids=2401.20001,2401.20002") + assert resp.status_code == 200 + assert "2401.20001" in resp.text + assert "2401.20002" in resp.text + # 应包含对比字段 + assert "一句话摘要" in resp.text + assert "研究问题" in resp.text + + def test_compare_page_max_5(self, auth_client, sample_papers_with_summary): + """最多 5 篇。""" + ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005" + resp = auth_client.get(f"/compare?ids={ids}") + assert resp.status_code == 200 + + def test_compare_page_over_5_truncates(self, auth_client, sample_papers_with_summary): + """超过 5 篇截断。""" + ids = "2401.20001,2401.20002,2401.20003,2401.20004,2401.20005,2401.20006" + resp = auth_client.get(f"/compare?ids={ids}") + assert resp.status_code == 200 + # 不应包含第 6 篇(不存在) + + def test_compare_page_invalid_ids(self, auth_client): + """无效 ID 时显示空结果。""" + resp = auth_client.get("/compare?ids=nonexistent.99999") + assert resp.status_code == 200 + # 不存在的论文 + assert "未找到" in resp.text or "暂无" in resp.text or resp.status_code == 200 + + def test_compare_page_shows_no_summary_placeholder(self, auth_client, sample_papers_with_summary): + """无总结的论文显示占位文本。""" + # 2401.20005 没有 summary(status=pending) + resp = auth_client.get("/compare?ids=2401.20005") + assert resp.status_code == 200 + assert "暂无总结" in resp.text + + +# ═══════════════════════════════════════════════════════════════════════ +# Image Extraction 测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestImageExtraction: + """LaTeX 图片提取测试。""" + + @pytest.mark.asyncio + async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path): + """源码目录不存在时返回 0。""" + monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x) + monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x) + from app.services.summarizer import _extract_images_from_source + result = await _extract_images_from_source("2401.99999") + assert result == 0 + + @pytest.mark.asyncio + async def test_extract_images_from_tex(self, monkeypatch, tmp_path): + """从 .tex 文件中提取图片。""" + tmp_source = tmp_path / "tmp" / "2401.00001" / "source" + tmp_source.mkdir(parents=True) + + images_dir = tmp_source / "figs" + images_dir.mkdir() + (images_dir / "figure1.png").write_bytes(b"\x89PNG\r\n") + (images_dir / "figure2.jpg").write_bytes(b"\xff\xd8\xff\xe0") + + # 创建 .tex 文件 + tex_content = r""" +\documentclass{article} +\begin{document} +\begin{figure} + \includegraphics[width=0.8\textwidth]{figs/figure1.png} + \includegraphics{figs/figure2.jpg} + \includegraphics[angle=90]{figs/nonexistent.pdf} +\end{figure} +\end{document} +""" + (tmp_source / "main.tex").write_text(tex_content) + + papers_dir = tmp_path / "papers" / "2401.00001" + monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x) + monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x) + + from app.services.summarizer import _extract_images_from_source + result = await _extract_images_from_source("2401.00001") + + assert result == 2 + dest_images = papers_dir / "images" + assert dest_images.exists() + assert (dest_images / "figure1.png").exists() + assert (dest_images / "figure2.jpg").exists() + + @pytest.mark.asyncio + async def test_extract_images_empty_tex(self, monkeypatch, tmp_path): + """.tex 文件无图片时返回 0。""" + tmp_source = tmp_path / "tmp" / "2401.00002" / "source" + tmp_source.mkdir(parents=True) + (tmp_source / "main.tex").write_text(r"\documentclass{article}\begin{document}Hello\end{document}") + + monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x) + monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x) + + from app.services.summarizer import _extract_images_from_source + result = await _extract_images_from_source("2401.00002") + assert result == 0 + + +# ═══════════════════════════════════════════════════════════════════════ +# Nav Bar 测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestNavBar: + """导航栏测试。""" + + def test_nav_includes_trends_link(self, auth_client): + """导航栏应包含趋势链接。""" + resp = auth_client.get("/search") + assert resp.status_code == 200 + assert "/trends" in resp.text + + def test_nav_includes_compare_implicitly(self, auth_client): + """compare 页面可访问。""" + resp = auth_client.get("/compare") + assert resp.status_code == 200 + + +# ═══════════════════════════════════════════════════════════════════════ +# Graceful Degradation 测试 +# ═══════════════════════════════════════════════════════════════════════ + + +class TestGracefulDegradation: + """CHROMA_ENABLED=false 时优雅降级测试。""" + + def test_search_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary): + """CHROMA 关闭时 FTS5 搜索正常工作。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = auth_client.get("/search?q=Test") + assert resp.status_code == 200 + assert "Test Paper" in resp.text or "测试论文" in resp.text + + def test_detail_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary): + """CHROMA 关闭时详情页正常工作。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = auth_client.get("/paper/2401.20001") + assert resp.status_code == 200 + + def test_trends_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary): + """CHROMA 关闭时趋势看板正常工作。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = auth_client.get("/trends") + assert resp.status_code == 200 + + def test_compare_works_without_chroma(self, auth_client, monkeypatch, sample_papers_with_summary): + """CHROMA 关闭时对比页正常工作。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + resp = auth_client.get("/compare?ids=2401.20001,2401.20002") + assert resp.status_code == 200 + + @pytest.mark.asyncio + async def test_cleaner_works_without_chroma(self, db_session, sample_papers_with_summary, monkeypatch): + """CHROMA 关闭时删除论文正常工作。""" + monkeypatch.setattr(settings, "CHROMA_ENABLED", False) + import app.services.embedder as emb + emb._client = None + emb._collection = None + + from app.services.cleaner import delete_papers_by_date_range + result = await delete_papers_by_date_range( + db_session, + date(2024, 1, 10), + date(2024, 1, 10), + ) + assert result["status"] == "success" + assert result["deleted"] == 1