diff --git a/app/main.py b/app/main.py index 5493505..0ab932b 100644 --- a/app/main.py +++ b/app/main.py @@ -5,13 +5,16 @@ import os from fastapi import FastAPI from fastapi.staticfiles import StaticFiles +from starlette.staticfiles import StaticFiles as StarletteStaticFiles from app.config import settings from app.database import engine from app.models import init_db from app.routes.admin import router as admin_router +from app.routes.compare import router as compare_router from app.routes.pages import router as pages_router from app.routes.search import router as search_router +from app.routes.trends import router as trends_router from app.routes.user import router as user_router logging.basicConfig( @@ -49,11 +52,18 @@ def create_app() -> FastAPI: # 静态文件 app.mount("/static", StaticFiles(directory="app/static"), name="static") + # Phase 5: 论文图片静态服务 + papers_images_dir = os.path.join("data", "papers") + os.makedirs(papers_images_dir, exist_ok=True) + app.mount("/papers", StaticFiles(directory=papers_images_dir), name="papers") + # 路由 app.include_router(pages_router) app.include_router(admin_router) app.include_router(search_router) app.include_router(user_router) + app.include_router(trends_router) + app.include_router(compare_router) # 调度器(Phase 4) @app.on_event("startup") @@ -61,6 +71,12 @@ def create_app() -> FastAPI: from app.services.scheduler import start_scheduler start_scheduler() + # Phase 5: 初始化 ChromaDB + @app.on_event("startup") + async def _init_chroma(): + from app.services.embedder import init_chroma + init_chroma() + @app.on_event("shutdown") async def _stop_scheduler(): from app.services.scheduler import stop_scheduler diff --git a/app/routes/compare.py b/app/routes/compare.py new file mode 100644 index 0000000..8c418c3 --- /dev/null +++ b/app/routes/compare.py @@ -0,0 +1,115 @@ +"""论文对比页路由 — 多篇论文结构化字段并排对比。""" + +from __future__ import annotations + +import logging + +from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi.templating import Jinja2Templates +from sqlalchemy.orm import Session, joinedload + +from app.database import get_db +from app.models import Paper + +logger = logging.getLogger(__name__) + +router = APIRouter() +templates = Jinja2Templates(directory="app/templates") + + +@router.get("/compare") +def compare_page( + request: Request, + ids: str = Query(default="", description="逗号分隔的 arxiv_id,最多 5 篇"), + db: Session = Depends(get_db), +): + """论文对比页面。GET /compare?ids=id1,id2,id3""" + if not ids: + return templates.TemplateResponse( + request, + "compare.html", + { + "page_title": "论文对比", + "papers": [], + "error": None, + }, + ) + + arxiv_ids = [i.strip() for i in ids.split(",") if i.strip()] + + # 最多 5 篇 + if len(arxiv_ids) > 5: + arxiv_ids = arxiv_ids[:5] + + if not arxiv_ids: + return templates.TemplateResponse( + request, + "compare.html", + { + "page_title": "论文对比", + "papers": [], + "error": "请提供有效的论文 ID", + }, + ) + + papers = ( + db.query(Paper) + .filter(Paper.arxiv_id.in_(arxiv_ids)) + .options( + joinedload(Paper.authors), + joinedload(Paper.tags), + joinedload(Paper.summary), + joinedload(Paper.summary_status), + ) + .all() + ) + + # 按请求顺序排列 + id_order = {aid: idx for idx, aid in enumerate(arxiv_ids)} + papers.sort(key=lambda p: id_order.get(p.arxiv_id, 999)) + + # 构建对比数据 + compare_fields = [ + ("title_zh", "中文标题"), + ("title_en", "英文标题"), + ("one_line", "一句话摘要"), + ("difficulty", "难度"), + ("motivation_problem", "研究问题"), + ("motivation_goal", "研究目标"), + ("motivation_gap", "研究差距"), + ("method_overview", "方法概述"), + ("method_key_idea", "关键思路"), + ("method_novelty", "新颖性"), + ("results", "实验结果"), + ("limitations", "局限与改进"), + ] + + rows = [] + for field_key, field_label in compare_fields: + cells = [] + for paper in papers: + if field_key in ("title_zh", "title_en"): + val = getattr(paper, field_key, None) or "" + elif paper.summary: + val = getattr(paper.summary, field_key, None) or "" + # JSON 字段直接展示 + if field_key == "results" and not val: + val = paper.summary.results_main_json or "" + if field_key == "limitations" and not val: + val = paper.summary.limitations_json or "" + else: + val = "" + cells.append(val) + rows.append({"key": field_key, "label": field_label, "cells": cells}) + + return templates.TemplateResponse( + request, + "compare.html", + { + "page_title": "论文对比", + "papers": papers, + "rows": rows, + "ids_param": ids, + "error": None, + }, + ) diff --git a/app/routes/pages.py b/app/routes/pages.py index cf94a87..a034433 100644 --- a/app/routes/pages.py +++ b/app/routes/pages.py @@ -1,9 +1,13 @@ """页面路由 — 首页、日期页、论文详情。""" +from __future__ import annotations + +import logging from datetime import date, datetime, timedelta +from pathlib import Path from zoneinfo import ZoneInfo -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.responses import RedirectResponse from fastapi.templating import Jinja2Templates from sqlalchemy.orm import Session, joinedload @@ -12,6 +16,8 @@ from app.config import settings from app.database import get_db from app.models import Paper +logger = logging.getLogger(__name__) + router = APIRouter() templates = Jinja2Templates(directory="app/templates") @@ -99,11 +105,132 @@ def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db)) if paper.summary_status: summary_state = paper.summary_status.status + # Phase 5: 相似论文推荐 + similar_papers = _get_similar_papers(db, arxiv_id, top_k=6) + + # Phase 5: 图片画廊 + images = _get_paper_images(arxiv_id) + return templates.TemplateResponse( request, "detail.html", { "paper": paper, "summary_state": summary_state, + "similar_papers": similar_papers, + "paper_images": images, + "chroma_enabled": settings.CHROMA_ENABLED, "page_title": paper.title_zh or paper.title_en, }, ) + + +# ── 相似论文 API (Phase 5) ──────────────────────────────────────────── + + +@router.get("/api/similar/{arxiv_id}") +def similar_api( + arxiv_id: str, + top_k: int = Query(default=5, ge=1, le=20), + db: Session = Depends(get_db), +): + """返回与指定论文相似的论文列表(JSON)。""" + similar = _get_similar_papers(db, arxiv_id, top_k=top_k + 1) + # 排除自身 + items = [s for s in similar if s["arxiv_id"] != arxiv_id][:top_k] + return {"results": items} + + +def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict]: + """从 ChromaDB 获取相似论文,返回 [{arxiv_id, title_zh, distance, paper_date}, ...]。""" + if not settings.CHROMA_ENABLED: + return [] + + try: + from app.services.embedder import search_similar + + # 用论文的 arxiv_id 从 ChromaDB 查询 + col = None + try: + from app.services.embedder import get_collection + col = get_collection() + except Exception: + return [] + + if col is None: + return [] + + # 获取当前论文的 embedding + result = col.get(ids=[arxiv_id], include=["embeddings"]) + if not result["embeddings"] or not result["embeddings"][0]: + return [] + + vec = result["embeddings"][0] + count = col.count() + if count == 0: + return [] + + results = col.query( + query_embeddings=[vec], + n_results=min(top_k, count), + include=["metadatas", "distances"], + ) + + if not results["ids"] or not results["ids"][0]: + return [] + + # 从 DB 加载论文信息 + similar_ids = results["ids"][0] + distances = results["distances"][0] if results["distances"] else [0.0] * len(similar_ids) + + # 排除自身 + papers_info = {} + for i, sid in enumerate(similar_ids): + if sid != arxiv_id: + papers_info[sid] = distances[i] + + if not papers_info: + return [] + + papers = ( + db.query(Paper) + .filter(Paper.arxiv_id.in_(list(papers_info.keys()))) + .options(joinedload(Paper.tags)) + .all() + ) + + items = [] + for p in papers: + items.append({ + "arxiv_id": p.arxiv_id, + "title_zh": p.title_zh or p.title_en, + "distance": papers_info.get(p.arxiv_id, 0.0), + "paper_date": p.paper_date.isoformat() if p.paper_date else "", + "tags": [t.tag for t in p.tags[:3]], + }) + + # 按距离排序 + items.sort(key=lambda x: x["distance"]) + return items + + except Exception: + logger.exception("Failed to get similar papers for %s", arxiv_id) + return [] + + +# ── 图片画廊 (Phase 5) ──────────────────────────────────────────────── + + +def _get_paper_images(arxiv_id: str) -> list[dict]: + """获取论文提取的图片列表。""" + images_dir = Path("data/papers") / arxiv_id / "images" + if not images_dir.exists(): + return [] + + images = [] + for img_file in sorted(images_dir.iterdir()): + if img_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg"): + images.append({ + "url": f"/papers/{arxiv_id}/images/{img_file.name}", + "name": img_file.name, + }) + return images diff --git a/app/routes/search.py b/app/routes/search.py index ab9e59d..cc23ece 100644 --- a/app/routes/search.py +++ b/app/routes/search.py @@ -31,11 +31,12 @@ def search_page( q: str = Query(default=""), tag: str = Query(default=""), sort: str = Query(default="relevance"), + mode: str = Query(default="keyword"), page: int = Query(default=1, ge=1), db: Session = Depends(get_db), ): - """搜索页面。""" - result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page) + """搜索页面,支持 keyword 和 semantic 模式。""" + result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode) all_tags = get_all_tags(db) return templates.TemplateResponse( @@ -45,8 +46,11 @@ def search_page( "query": q, "tag": tag, "sort": sort, + "mode": mode, + "chroma_enabled": settings.CHROMA_ENABLED, "results": result["results"], "snippets": result["snippets"], + "distances": result.get("distances", {}), "total": result["total"], "page": result["page"], "total_pages": result["total_pages"], @@ -65,28 +69,31 @@ def search_api( q: str = Query(default=""), tag: str = Query(default=""), sort: str = Query(default="relevance"), + mode: str = Query(default="keyword"), page: int = Query(default=1, ge=1), db: Session = Depends(get_db), ): - """搜索 JSON API。""" - result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page) + """搜索 JSON API,支持 keyword 和 semantic 模式。""" + result = search_papers(db, query=q or None, tag=tag or None, sort=sort, page=page, mode=mode) + distances = result.get("distances", {}) items = [] for paper in result["results"]: snippet = result["snippets"].get(paper.id, {}) - items.append( - { - "arxiv_id": paper.arxiv_id, - "title_en": paper.title_en, - "title_zh": paper.title_zh, - "paper_date": paper.paper_date.isoformat() if paper.paper_date else None, - "upvotes": paper.upvotes, - "tags": [t.tag for t in paper.tags], - "authors": [a.name for a in paper.authors], - "snippet_title_zh": snippet.get("title_zh"), - "snippet_abstract": snippet.get("abstract"), - } - ) + item = { + "arxiv_id": paper.arxiv_id, + "title_en": paper.title_en, + "title_zh": paper.title_zh, + "paper_date": paper.paper_date.isoformat() if paper.paper_date else None, + "upvotes": paper.upvotes, + "tags": [t.tag for t in paper.tags], + "authors": [a.name for a in paper.authors], + "snippet_title_zh": snippet.get("title_zh"), + "snippet_abstract": snippet.get("abstract"), + } + if paper.arxiv_id in distances: + item["distance"] = distances[paper.arxiv_id] + items.append(item) return { "results": items, diff --git a/app/routes/trends.py b/app/routes/trends.py new file mode 100644 index 0000000..80e6ce5 --- /dev/null +++ b/app/routes/trends.py @@ -0,0 +1,120 @@ +"""趋势看板路由 — 论文统计图表页面和数据 API。""" + +from __future__ import annotations + +import logging +from datetime import date, timedelta + +from fastapi import APIRouter, Depends, Request +from fastapi.templating import Jinja2Templates +from sqlalchemy import func, text +from sqlalchemy.orm import Session + +from app.config import settings +from app.database import get_db + +logger = logging.getLogger(__name__) + +router = APIRouter() +templates = Jinja2Templates(directory="app/templates") + + +@router.get("/trends") +def trends_page(request: Request, db: Session = Depends(get_db)): + """趋势看板页面。""" + stats = _get_trends_data(db) + return templates.TemplateResponse( + request, + "trends.html", + { + "page_title": "趋势看板", + "stats": stats, + "today": _today_str(), + }, + ) + + +@router.get("/api/stats/trends") +def trends_api(db: Session = Depends(get_db)): + """趋势数据 JSON API。""" + return _get_trends_data(db) + + +def _get_trends_data(db: Session) -> dict: + """从 DB 聚合趋势数据。""" + thirty_days_ago = (date.today() - timedelta(days=30)).isoformat() + + # 1. 按日论文数量(近 30 天) + daily_rows = db.execute(text(""" + SELECT paper_date, COUNT(*) as cnt + FROM papers + WHERE paper_date >= :start_date + GROUP BY paper_date + ORDER BY paper_date ASC + """), {"start_date": thirty_days_ago}).fetchall() + daily_counts = [ + {"date": str(row[0]), "count": row[1]} + for row in daily_rows + ] + + # 2. 热门标签 Top 20 + tag_rows = db.execute(text(""" + SELECT tag, COUNT(*) as cnt + FROM paper_tags + GROUP BY tag + ORDER BY cnt DESC + LIMIT 20 + """)).fetchall() + top_tags = [ + {"tag": row[0], "count": row[1]} + for row in tag_rows + ] + + # 3. Upvotes 分布 + upvote_rows = db.execute(text(""" + SELECT + CASE + WHEN upvotes >= 100 THEN '100+' + WHEN upvotes >= 50 THEN '50-99' + WHEN upvotes >= 20 THEN '20-49' + WHEN upvotes >= 10 THEN '10-19' + WHEN upvotes >= 5 THEN '5-9' + ELSE '0-4' + END as bucket, + COUNT(*) as cnt + FROM papers + GROUP BY bucket + ORDER BY MIN(upvotes) DESC + """)).fetchall() + upvotes_dist = [ + {"range": row[0], "count": row[1]} + for row in upvote_rows + ] + + # 4. 总结完成率 + summary_rows = db.execute(text(""" + SELECT + COALESCE(ss.status, 'none') as status, + COUNT(*) as cnt + FROM papers p + LEFT JOIN summary_status ss ON ss.paper_id = p.id + GROUP BY status + """)).fetchall() + summary_completion = [ + {"status": row[0], "count": row[1]} + for row in summary_rows + ] + + return { + "daily_counts": daily_counts, + "top_tags": top_tags, + "upvotes_dist": upvotes_dist, + "summary_completion": summary_completion, + } + + +def _today_str() -> str: + from datetime import datetime + from zoneinfo import ZoneInfo + tz = ZoneInfo(settings.APP_TIMEZONE) + return datetime.now(tz).strftime("%Y-%m-%d") diff --git a/app/services/cleaner.py b/app/services/cleaner.py index 56af5d1..f5f4c9b 100644 --- a/app/services/cleaner.py +++ b/app/services/cleaner.py @@ -139,6 +139,13 @@ async def delete_papers_by_date_range( {"paper_id": paper_id}, ) + # 1.5 Phase 5: 从 ChromaDB 删除语义索引 + try: + from app.services.embedder import delete_paper + delete_paper(arxiv_id) + except Exception: + logger.warning("Failed to delete %s from ChromaDB", arxiv_id, exc_info=True) + # 2. 删除本地文件 data/papers/{arxiv_id}/ paper_dir = _PAPERS_DIR / arxiv_id if paper_dir.exists(): diff --git a/app/services/embedder.py b/app/services/embedder.py new file mode 100644 index 0000000..f153251 --- /dev/null +++ b/app/services/embedder.py @@ -0,0 +1,314 @@ +"""ChromaDB 语义搜索服务 — embedding 生成、索引管理、相似查询。""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import httpx +from sqlalchemy import select +from sqlalchemy.orm import Session, joinedload + +from app.config import settings +from app.models import Paper + +logger = logging.getLogger(__name__) + +# ── 单例客户端和 collection ───────────────────────────────────────────── +_client = None +_collection = None + + +def _chroma_dir() -> Path: + return Path(settings.CHROMA_DIR) + + +def init_chroma() -> None: + """CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。""" + global _client, _collection + if not settings.CHROMA_ENABLED: + logger.debug("ChromaDB disabled, skip init") + return + + if _client is not None: + return + + try: + import chromadb + + chroma_path = _chroma_dir() + chroma_path.mkdir(parents=True, exist_ok=True) + + _client = chromadb.PersistentClient(path=str(chroma_path)) + _collection = _get_or_create_collection() + logger.info("ChromaDB initialized at %s", chroma_path) + except Exception: + logger.exception("Failed to initialize ChromaDB") + _client = None + _collection = None + + +def _get_or_create_collection(): + """获取或创建 papers_embeddings collection,维度不匹配时记录日志并跳过。""" + import chromadb + + try: + col = _client.get_collection("papers_embeddings") + logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count()) + return col + except Exception: + pass + + col = _client.create_collection( + name="papers_embeddings", + metadata={"hnsw:space": "cosine"}, + ) + logger.info("ChromaDB collection 'papers_embeddings' created") + return col + + +def get_collection(): + """返回当前 collection,未初始化则返回 None。""" + if not settings.CHROMA_ENABLED: + return None + if _collection is None: + init_chroma() + return _collection + + +# ── Embedding API 调用 ────────────────────────────────────────────────── + + +def _get_embedding(text: str) -> list[float] | None: + """调用 EMBED_API_BASE 的 embedding API 生成向量。 + + POST /v1/embeddings, model=EMBED_MODEL + 校验返回向量长度 == EMBED_DIMENSIONS + 失败时返回 None 并记录日志。 + """ + if not settings.EMBED_API_BASE or not settings.EMBED_MODEL: + logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding") + return None + + url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings" + headers = {"Content-Type": "application/json"} + if settings.EMBED_API_KEY: + headers["Authorization"] = f"Bearer {settings.EMBED_API_KEY}" + + payload = { + "model": settings.EMBED_MODEL, + "input": text, + } + + try: + with httpx.Client(timeout=settings.HTTP_TIMEOUT_SECONDS) as client: + resp = client.post(url, json=payload, headers=headers) + resp.raise_for_status() + data = resp.json() + + vec = data["data"][0]["embedding"] + if settings.EMBED_DIMENSIONS > 0 and len(vec) != settings.EMBED_DIMENSIONS: + logger.warning( + "Embedding dimension mismatch: expected=%d got=%d, skip", + settings.EMBED_DIMENSIONS, + len(vec), + ) + return None + return vec + except Exception: + logger.exception("Embedding API call failed") + return None + + +# ── 索引文本拼接 ──────────────────────────────────────────────────────── + + +def _build_index_text(paper: Paper, summary_dict: dict | None) -> str: + """拼接高信号字段为一段文本用于 embedding。""" + parts: list[str] = [] + + if paper.title_zh: + parts.append(paper.title_zh) + if paper.title_en: + parts.append(paper.title_en) + if paper.tags: + tag_str = " ".join(t.tag for t in paper.tags) + if tag_str: + parts.append(tag_str) + + if summary_dict: + for key in ("one_line", "motivation_problem", "method_key_idea"): + val = summary_dict.get(key) + if val: + parts.append(val) + + return " ".join(parts) + + +# ── 单篇索引 ──────────────────────────────────────────────────────────── + + +def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool: + """将单篇论文写入 ChromaDB 语义索引。 + + Args: + paper_id: arxiv_id + texts_dict: 可选的预拼接字段 dict,包含 title_zh, title_en, tags, one_line 等 + 如果为 None,则从 DB 加载。 + + Returns: + True 表示成功,False 表示失败或跳过。 + """ + col = get_collection() + if col is None: + return False + + try: + # 如果没传 texts_dict,从 DB 加载 + if texts_dict is None: + from app.database import SessionLocal + + db = SessionLocal() + try: + paper = ( + db.query(Paper) + .filter(Paper.arxiv_id == paper_id) + .options(joinedload(Paper.tags), joinedload(Paper.summary)) + .first() + ) + if not paper: + logger.warning("Paper %s not found for indexing", paper_id) + return False + + summary_dict = None + if paper.summary: + summary_dict = { + "one_line": paper.summary.one_line or "", + "motivation_problem": paper.summary.motivation_problem or "", + "method_key_idea": paper.summary.method_key_idea or "", + } + index_text = _build_index_text(paper, summary_dict) + arxiv_id = paper.arxiv_id + title_zh = paper.title_zh or "" + paper_date = paper.paper_date.isoformat() if paper.paper_date else "" + finally: + db.close() + else: + index_text = " ".join( + v for v in texts_dict.values() if isinstance(v, str) and v + ) + arxiv_id = texts_dict.get("arxiv_id", paper_id) + title_zh = texts_dict.get("title_zh", "") + paper_date = texts_dict.get("paper_date", "") + + if not index_text.strip(): + logger.warning("Empty index text for %s, skip", paper_id) + return False + + vec = _get_embedding(index_text) + if vec is None: + return False + + col.upsert( + ids=[arxiv_id], + embeddings=[vec], + metadatas=[{"arxiv_id": arxiv_id, "title_zh": title_zh, "paper_date": paper_date}], + ) + logger.info("Indexed paper %s in ChromaDB", arxiv_id) + return True + + except Exception: + logger.exception("Failed to index paper %s in ChromaDB", paper_id) + return False + + +# ── 批量索引 ──────────────────────────────────────────────────────────── + + +def index_batch(paper_ids: list[str]) -> dict: + """批量索引论文,单篇失败不影响其他。 + + Returns: + {"total": int, "success": int, "failed": int} + """ + if not paper_ids: + return {"total": 0, "success": 0, "failed": 0} + + col = get_collection() + if col is None: + return {"total": len(paper_ids), "success": 0, "failed": len(paper_ids)} + + success = 0 + failed = 0 + for pid in paper_ids: + if index_paper(pid): + success += 1 + else: + failed += 1 + + logger.info("Batch index: total=%d success=%d failed=%d", len(paper_ids), success, failed) + return {"total": len(paper_ids), "success": success, "failed": failed} + + +# ── 删除 ──────────────────────────────────────────────────────────────── + + +def delete_paper(paper_id: str) -> bool: + """从 ChromaDB 删除论文索引。 + + Args: + paper_id: arxiv_id + """ + col = get_collection() + if col is None: + return False + + try: + col.delete(ids=[paper_id]) + logger.info("Deleted paper %s from ChromaDB", paper_id) + return True + except Exception: + logger.exception("Failed to delete paper %s from ChromaDB", paper_id) + return False + + +# ── 相似查询 ──────────────────────────────────────────────────────────── + + +def search_similar(query_text: str, top_k: int = 20) -> list[dict]: + """语义相似搜索。 + + Args: + query_text: 查询文本 + top_k: 返回前 K 个结果 + + Returns: + [{"arxiv_id": str, "distance": float}, ...] + """ + col = get_collection() + if col is None: + return [] + + try: + vec = _get_embedding(query_text) + if vec is None: + return [] + + results = col.query( + query_embeddings=[vec], + n_results=min(top_k, col.count()) if col.count() > 0 else top_k, + include=["metadatas", "distances"], + ) + + if not results["ids"] or not results["ids"][0]: + return [] + + items = [] + for i, arxiv_id in enumerate(results["ids"][0]): + dist = results["distances"][0][i] if results["distances"] else 0.0 + items.append({"arxiv_id": arxiv_id, "distance": dist}) + return items + + except Exception: + logger.exception("ChromaDB search_similar failed") + return [] diff --git a/app/services/searcher.py b/app/services/searcher.py index 35a0284..9dd88a1 100644 --- a/app/services/searcher.py +++ b/app/services/searcher.py @@ -1,15 +1,19 @@ -"""FTS5 全文搜索服务 — 关键词 + 标签筛选,命中片段高亮,分页。""" +"""搜索服务 — FTS5 关键词搜索 + ChromaDB 语义搜索,命中片段高亮,分页。""" from __future__ import annotations +import logging import math import re from sqlalchemy import text from sqlalchemy.orm import Session, joinedload +from app.config import settings from app.models import Paper +logger = logging.getLogger(__name__) + # ── 输入清洗 ────────────────────────────────────────────────────────── # FTS5 查询语法中的特殊字符,用户输入时需要移除 @@ -41,8 +45,9 @@ def search_papers( sort: str = "relevance", page: int = 1, page_size: int = 20, + mode: str = "keyword", ) -> dict: - """FTS5 搜索论文。 + """搜索论文,支持 keyword (FTS5) 和 semantic (ChromaDB) 两种模式。 返回:: { @@ -51,8 +56,14 @@ def search_papers( "total": int, "page": int, "total_pages": int, + "distances": dict[str, float], # arxiv_id → distance (仅 semantic) } """ + # ── semantic 模式 ── + if mode == "semantic" and settings.CHROMA_ENABLED and query: + return _search_semantic(db, query, tag, sort, page, page_size) + + # ── keyword 模式(默认)── match_expr = _sanitize_query(query) if query else None # ── 无关键词 + 无标签 → 空结果 ── @@ -63,6 +74,7 @@ def search_papers( "total": 0, "page": page, "total_pages": 0, + "distances": {}, } # ── 构建条件性 JOIN 和 WHERE 片段 ── @@ -146,6 +158,75 @@ def _search_with_fts( "total": total, "page": page, "total_pages": math.ceil(total / page_size) if total else 0, + "distances": {}, + } + + +def _search_semantic( + db: Session, + query: str, + tag: str | None, + sort: str, + page: int, + page_size: int, +) -> dict: + """ChromaDB 语义搜索,失败时回退到 FTS5。""" + try: + from app.services.embedder import search_similar + + top_k = page_size * 3 # 多取一些用于 tag 过滤 + candidates = search_similar(query, top_k=top_k) + except Exception: + logger.exception("Semantic search failed, falling back to keyword") + candidates = [] + + if not candidates: + # 回退到 FTS5 + return _search_with_fts( + db, + _sanitize_query(query) or query, + "JOIN paper_tags pt ON pt.paper_id = p.id" if tag else "", + "AND pt.tag = :tag" if tag else "", + {"tag": tag} if tag else {}, + sort, page, page_size, (page - 1) * page_size, + ) + + # 按 arxiv_id 从 DB 加载完整数据 + arxiv_ids = [c["arxiv_id"] for c in candidates] + distance_map = {c["arxiv_id"]: c["distance"] for c in candidates} + + papers_query = ( + db.query(Paper) + .filter(Paper.arxiv_id.in_(arxiv_ids)) + .options( + joinedload(Paper.authors), + joinedload(Paper.tags), + joinedload(Paper.summary_status), + joinedload(Paper.bookmark), + joinedload(Paper.reading_status), + ) + ) + if tag: + papers_query = papers_query.filter(Paper.tags.any(tag=tag)) + + papers = papers_query.all() + + # 按语义距离排序 + id_order = {aid: idx for idx, aid in enumerate(arxiv_ids)} + papers.sort(key=lambda p: id_order.get(p.arxiv_id, 999)) + + # 分页 + total = len(papers) + start = (page - 1) * page_size + page_papers = papers[start:start + page_size] + + return { + "results": page_papers, + "snippets": {}, + "total": total, + "page": page, + "total_pages": math.ceil(total / page_size) if total else 0, + "distances": distance_map, } @@ -187,6 +268,7 @@ def _search_tag_only( "total": total, "page": page, "total_pages": math.ceil(total / page_size) if total else 0, + "distances": {}, } diff --git a/app/services/summarizer.py b/app/services/summarizer.py index 7003f95..523eb18 100644 --- a/app/services/summarizer.py +++ b/app/services/summarizer.py @@ -359,6 +359,127 @@ def _cleanup_tmp(arxiv_id: str) -> None: logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True) +# ── LaTeX 图片提取(Phase 5)─────────────────────────────────────────── + +_INCLUDEGRAPHICS_RE = re.compile( + r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE +) +_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"} + + +async def _extract_images_from_source(arxiv_id: str, tmp_source: Path | None = None) -> int: + """从 LaTeX 源码中提取图片文件。 + + 流程: + 1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/ + 2. 扫描 .tex 文件中的 \\includegraphics + 3. 复制图片到 data/papers/{arxiv_id}/images/ + 4. 清理源码临时文件 + + Returns: + 提取的图片数量 + """ + tmp_source = _tmp_dir(arxiv_id) / "source" + images_dest = _paper_dir(arxiv_id) / "images" + + try: + # 下载源码 zip(如果还没下载) + if not tmp_source.exists(): + source_url = f"https://arxiv.org/e-print/{arxiv_id}" + await _download_source_zip(arxiv_id, source_url, tmp_source) + + if not tmp_source.exists(): + return 0 + + # 扫描 .tex 文件,收集图片路径 + image_paths: set[str] = set() + for tex_file in tmp_source.rglob("*.tex"): + try: + content = tex_file.read_text(encoding="utf-8", errors="replace") + for match in _INCLUDEGRAPHICS_RE.finditer(content): + img_path = match.group(1).strip() + image_paths.add(img_path) + except Exception: + continue + + if not image_paths: + return 0 + + # 查找并复制图片 + images_dest.mkdir(parents=True, exist_ok=True) + copied = 0 + for img_rel in image_paths: + # 尝试在源码目录中找到文件 + for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"): + candidate = tmp_source / (img_rel + ext) + if candidate.is_file(): + dest_name = candidate.name + # 避免文件名冲突 + dest = images_dest / dest_name + if dest.exists(): + stem = dest.stem + suffix = dest.suffix + dest = images_dest / f"{stem}_{copied}{suffix}" + shutil.copy2(candidate, dest) + copied += 1 + break + + if copied > 0: + logger.info("Extracted %d images from source for %s", copied, arxiv_id) + return copied + + except Exception: + logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True) + return 0 + + +async def _download_source_zip( + arxiv_id: str, source_url: str, dest_dir: Path +) -> None: + """下载 arXiv 源码并解压。""" + import zipfile + + dest_dir.mkdir(parents=True, exist_ok=True) + zip_path = _tmp_dir(arxiv_id) / "source.zip" + + transport = None + if settings.http_proxy: + transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy) + + try: + async with httpx.AsyncClient( + timeout=settings.HTTP_TIMEOUT_SECONDS, + headers={"User-Agent": settings.HTTP_USER_AGENT}, + transport=transport, + follow_redirects=True, + ) as client: + resp = await client.get(source_url) + resp.raise_for_status() + zip_path.write_bytes(resp.content) + except Exception as exc: + logger.debug("Failed to download source for %s: %s", arxiv_id, exc) + return + + try: + with zipfile.ZipFile(zip_path, "r") as zf: + zf.extractall(dest_dir) + logger.debug("Extracted source for %s", arxiv_id) + except zipfile.BadZipFile: + # 可能是 tar.gz + import tarfile + try: + with tarfile.open(zip_path, "r:*") as tf: + tf.extractall(dest_dir) + logger.debug("Extracted source (tar) for %s", arxiv_id) + except Exception: + logger.warning("Cannot extract source for %s", arxiv_id) + except Exception: + logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True) + finally: + if zip_path.exists(): + zip_path.unlink() + + # ── 单篇总结 ──────────────────────────────────────────────────────────── @@ -441,6 +562,30 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict: status.raw_output_saved = True db.commit() + # Phase 5: LaTeX 图片提取(可选增强,失败不影响总结) + try: + await _extract_images_from_source(arxiv_id) + except Exception: + logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True) + + # Phase 5: 同步写入语义索引(失败仅 log) + try: + from app.services.embedder import index_paper + + texts_dict = { + "arxiv_id": arxiv_id, + "title_zh": schema.title_zh or "", + "title_en": paper.title_en or "", + "tags": " ".join(t.tag for t in paper.tags) if paper.tags else "", + "one_line": schema.one_line or "", + "motivation_problem": schema.motivation_problem or "", + "method_key_idea": schema.method_key_idea or "", + "paper_date": paper.paper_date.isoformat() if paper.paper_date else "", + } + index_paper(arxiv_id, texts_dict) + except Exception: + logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True) + logger.info("Summarize done: %s quality=%s", arxiv_id, quality) return {"arxiv_id": arxiv_id, "status": "done", "quality": quality} diff --git a/app/static/css/style.css b/app/static/css/style.css index ff4ed02..a8f85d1 100644 --- a/app/static/css/style.css +++ b/app/static/css/style.css @@ -556,3 +556,185 @@ mark { .reading-list-filters { gap: 4px; } .filter-chip { padding: 4px 10px; font-size: 0.8rem; } } + +/* ── Search Mode Toggle (Phase 5) ─────────────────────────────── */ +.search-mode-toggle { + display: flex; + gap: 0; + border: 1px solid var(--border); + border-radius: var(--radius); + overflow: hidden; +} +.mode-option { + padding: 8px 14px; + font-size: 0.85rem; + color: var(--ink-light); + cursor: pointer; + transition: all 0.2s; + display: flex; + align-items: center; + gap: 4px; +} +.mode-option input[type="radio"] { display: none; } +.mode-option:hover { background: var(--bg); } +.mode-option.active { + background: var(--accent); + color: #fff; +} + +/* ── Similarity Score (Phase 5) ────────────────────────────────── */ +.similarity-score { + font-size: 0.8rem; + color: var(--accent); + white-space: nowrap; + padding: 2px 8px; + background: #eef3f8; + border-radius: 4px; +} + +/* ── Similar Papers (Phase 5) ──────────────────────────────────── */ +.similar-papers { + margin-top: 32px; + padding-top: 24px; + border-top: 1px solid var(--border); +} +.similar-papers h2 { + font-family: var(--font-body); + font-size: 1.1rem; + font-weight: 600; + margin-bottom: 12px; + color: var(--accent); +} +.similar-paper-item { + display: flex; + justify-content: space-between; + align-items: center; + padding: 10px 0; + border-bottom: 1px solid var(--border); +} +.similar-paper-item:last-child { border-bottom: none; } +.similar-paper-title a { + font-size: 0.92rem; + color: var(--ink); +} +.similar-paper-title a:hover { color: var(--accent); } +.similar-paper-dist { + font-size: 0.8rem; + color: var(--ink-light); +} + +/* ── Trends Dashboard (Phase 5) ────────────────────────────────── */ +.trends-page h1 { + font-family: var(--font-body); + font-size: 1.5rem; + font-weight: 700; + margin-bottom: 24px; +} +.charts-grid { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 24px; + margin-bottom: 32px; +} +.chart-card { + background: var(--surface); + border: 1px solid var(--border); + border-radius: var(--radius); + padding: 20px; +} +.chart-card h2 { + font-family: var(--font-body); + font-size: 1rem; + font-weight: 600; + margin-bottom: 12px; + color: var(--accent); +} +.chart-card canvas { + width: 100% !important; + max-height: 300px; +} +@media (max-width: 768px) { + .charts-grid { grid-template-columns: 1fr; } +} + +/* ── Compare Page (Phase 5) ────────────────────────────────────── */ +.compare-page h1 { + font-family: var(--font-body); + font-size: 1.5rem; + font-weight: 700; + margin-bottom: 24px; +} +.compare-table-wrapper { + overflow-x: auto; + margin-bottom: 24px; +} +.compare-table { + width: 100%; + border-collapse: collapse; + background: var(--surface); + border: 1px solid var(--border); + border-radius: var(--radius); + font-size: 0.88rem; +} +.compare-table th, +.compare-table td { + padding: 12px 16px; + border: 1px solid var(--border); + vertical-align: top; + text-align: left; +} +.compare-table th { + background: var(--bg); + font-weight: 600; + color: var(--ink-light); + white-space: nowrap; + min-width: 100px; +} +.compare-table td.field-label { + background: var(--bg); + font-weight: 500; + color: var(--ink); + white-space: nowrap; + min-width: 90px; +} +.compare-table td.paper-col { + min-width: 200px; +} +.compare-table td .no-summary { + color: var(--ink-light); + font-style: italic; +} + +/* ── Image Gallery (Phase 5) ───────────────────────────────────── */ +.image-gallery { + margin-top: 24px; +} +.image-gallery h2 { + font-family: var(--font-body); + font-size: 1.05rem; + font-weight: 600; + margin-bottom: 12px; + color: var(--accent); +} +.gallery-grid { + display: grid; + grid-template-columns: repeat(auto-fill, minmax(260px, 1fr)); + gap: 16px; +} +.gallery-item { + background: var(--surface); + border: 1px solid var(--border); + border-radius: var(--radius); + overflow: hidden; +} +.gallery-item img { + width: 100%; + height: auto; + display: block; +} +.gallery-item .gallery-caption { + padding: 8px 12px; + font-size: 0.8rem; + color: var(--ink-light); + text-align: center; +} diff --git a/app/templates/base.html b/app/templates/base.html index eb32d2d..6e272f9 100644 --- a/app/templates/base.html +++ b/app/templates/base.html @@ -16,6 +16,7 @@
diff --git a/app/templates/compare.html b/app/templates/compare.html new file mode 100644 index 0000000..c59a934 --- /dev/null +++ b/app/templates/compare.html @@ -0,0 +1,86 @@ +{% extends "base.html" %} + +{% block title %}{{ page_title }} — HF Daily Papers{% endblock %} + +{% block content %} +{{ error }}
+| 字段 | + {% for paper in papers %} +
+ {{ paper.arxiv_id }}
+ + + {{ paper.upvotes }} 👍 · {{ paper.paper_date }} + + |
+ {% endfor %}
+
|---|---|
| 作者 | + {% for paper in papers %} +{{ paper.authors|map(attribute='name')|join(', ') }} | + {% endfor %} +
| 标签 | + {% for paper in papers %} ++ {% for t in paper.tags[:5] %} + {{ t.tag }} + {% endfor %} + | + {% endfor %} +
| {{ row.label }} | + {% for cell in row.cells %} ++ {% if cell %} + {{ cell }} + {% else %} + 暂无总结 + {% endif %} + | + {% endfor %} +
未找到匹配的论文
+请检查 arXiv ID 是否正确
+{{ paper.abstract }}
{% endif %} + + {# Phase 5: 图片画廊 #} + {% if paper_images %} +