18f44ac244
- Enhance pdf_image_extractor with caption text extraction near images/tables - Add figure/table type correction based on caption content - Implement sequential numbering fallback for unmatched items - Improve figure linking in pages with manifest ID matching and fallback strategies - Remove docling dependency, add dev dependency group
359 lines
11 KiB
Python
359 lines
11 KiB
Python
"""页面路由 — 首页、日期页、论文详情。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
from datetime import date, timedelta
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||
from fastapi.responses import RedirectResponse
|
||
from sqlalchemy import select
|
||
from sqlalchemy.orm import Session, joinedload
|
||
|
||
from app.config import settings
|
||
from app.database import get_db
|
||
from app.models import PAPER_FULL_LOAD, Paper
|
||
from app.utils import PAPERS_DIR, safe_json_loads, templates, today_str, latest_paper_date
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
@router.get("/")
|
||
def index(request: Request, db: Session = Depends(get_db)):
|
||
"""重定向到最新有论文的日期页。"""
|
||
return RedirectResponse(url=f"/day/{latest_paper_date(db)}")
|
||
|
||
|
||
@router.get("/day/{date_str}")
|
||
def day_page(date_str: str, request: Request, db: Session = Depends(get_db)):
|
||
"""指定日期论文列表。"""
|
||
try:
|
||
target = date.fromisoformat(date_str)
|
||
except ValueError:
|
||
raise HTTPException(status_code=404, detail="Invalid date format")
|
||
|
||
prev_day = (target - timedelta(days=1)).isoformat()
|
||
next_day = (target + timedelta(days=1)).isoformat()
|
||
today = today_str()
|
||
|
||
papers = (
|
||
db.execute(
|
||
select(Paper)
|
||
.where(Paper.paper_date == date_str)
|
||
.options(*PAPER_FULL_LOAD)
|
||
.order_by(Paper.upvotes.desc())
|
||
)
|
||
.scalars()
|
||
.unique()
|
||
.all()
|
||
)
|
||
|
||
dates_raw = (
|
||
db.execute(
|
||
select(Paper.paper_date)
|
||
.distinct()
|
||
.order_by(Paper.paper_date.desc())
|
||
.limit(30)
|
||
)
|
||
.all()
|
||
)
|
||
available_dates = [
|
||
d[0].isoformat() if isinstance(d[0], date) else str(d[0]) for d in dates_raw
|
||
]
|
||
|
||
return templates.TemplateResponse(
|
||
request,
|
||
"index.html",
|
||
{
|
||
"papers": papers,
|
||
"current_date": date_str,
|
||
"prev_day": prev_day,
|
||
"next_day": next_day,
|
||
"today": today,
|
||
"available_dates": available_dates,
|
||
"page_title": f"{date_str} 论文列表",
|
||
},
|
||
)
|
||
|
||
|
||
@router.get("/paper/{arxiv_id}")
|
||
def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db)):
|
||
"""论文详情页。"""
|
||
paper = (
|
||
db.execute(
|
||
select(Paper)
|
||
.where(Paper.arxiv_id == arxiv_id)
|
||
.options(
|
||
joinedload(Paper.summary),
|
||
joinedload(Paper.note),
|
||
*PAPER_FULL_LOAD,
|
||
)
|
||
)
|
||
.unique()
|
||
.scalar_one_or_none()
|
||
)
|
||
if not paper:
|
||
raise HTTPException(status_code=404, detail="Paper not found")
|
||
|
||
summary_state = "none"
|
||
if paper.summary_status:
|
||
summary_state = paper.summary_status.status
|
||
|
||
# 相似论文推荐
|
||
similar_papers = _get_similar_papers(db, arxiv_id, top_k=6)
|
||
|
||
# 图片画廊
|
||
images = _get_paper_images(arxiv_id)
|
||
|
||
# 预处理 JSON 字段供模板直接使用
|
||
prereqs = safe_json_loads(
|
||
paper.summary.prerequisites_json if paper.summary else None, default={}
|
||
)
|
||
benchmarks = safe_json_loads(
|
||
paper.summary.results_benchmarks_json if paper.summary else None, default=[]
|
||
)
|
||
figures_raw = safe_json_loads(
|
||
paper.summary.figures_json if paper.summary else None, default=[]
|
||
)
|
||
|
||
linked_figures = _link_figures_with_images(figures_raw, images, arxiv_id)
|
||
|
||
# 拆分:table_figures(有截图的 Table 类型)→ 实验结果区域展示截图
|
||
# figures(其余)→ 论文图表画廊
|
||
table_figures = []
|
||
figures = []
|
||
for fig in linked_figures:
|
||
fig_id = fig.get("id", "")
|
||
is_table = fig_id.lower().startswith("table")
|
||
if is_table and fig.get("image_url"):
|
||
table_figures.append(fig)
|
||
else:
|
||
figures.append(fig)
|
||
|
||
return templates.TemplateResponse(
|
||
request,
|
||
"detail.html",
|
||
{
|
||
"paper": paper,
|
||
"summary_state": summary_state,
|
||
"similar_papers": similar_papers,
|
||
"paper_images": images,
|
||
"prereqs": prereqs,
|
||
"benchmarks": benchmarks,
|
||
"figures": figures,
|
||
"table_figures": table_figures,
|
||
"chroma_enabled": settings.CHROMA_ENABLED,
|
||
"page_title": paper.title_zh or paper.title_en,
|
||
},
|
||
)
|
||
|
||
|
||
# ── 相似论文 API ──────────────────────────────────────────────────────
|
||
|
||
|
||
@router.get("/api/similar/{arxiv_id}")
|
||
def similar_api(
|
||
arxiv_id: str,
|
||
top_k: int = Query(default=5, ge=1, le=20),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""返回与指定论文相似的论文列表(JSON)。"""
|
||
similar = _get_similar_papers(db, arxiv_id, top_k=top_k + 1)
|
||
# 排除自身
|
||
items = [s for s in similar if s["arxiv_id"] != arxiv_id][:top_k]
|
||
return {"results": items}
|
||
|
||
|
||
def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict]:
|
||
"""从 ChromaDB 获取相似论文,返回 [{arxiv_id, title_zh, distance, paper_date}, ...]。"""
|
||
if not settings.CHROMA_ENABLED:
|
||
return []
|
||
|
||
try:
|
||
from app.services.embedder import get_collection
|
||
|
||
col = get_collection()
|
||
if col is None:
|
||
return []
|
||
|
||
# 获取当前论文的 embedding
|
||
result = col.get(ids=[arxiv_id], include=["embeddings"])
|
||
if not result["embeddings"] or not result["embeddings"][0]:
|
||
return []
|
||
|
||
vec = result["embeddings"][0]
|
||
count = col.count()
|
||
if count == 0:
|
||
return []
|
||
|
||
results = col.query(
|
||
query_embeddings=[vec],
|
||
n_results=min(top_k, count),
|
||
include=["metadatas", "distances"],
|
||
)
|
||
|
||
if not results["ids"] or not results["ids"][0]:
|
||
return []
|
||
|
||
# 从 DB 加载论文信息
|
||
similar_ids = results["ids"][0]
|
||
distances = (
|
||
results["distances"][0]
|
||
if results["distances"]
|
||
else [0.0] * len(similar_ids)
|
||
)
|
||
|
||
# 排除自身
|
||
papers_info = {}
|
||
for i, sid in enumerate(similar_ids):
|
||
if sid != arxiv_id:
|
||
papers_info[sid] = distances[i]
|
||
|
||
if not papers_info:
|
||
return []
|
||
|
||
papers = (
|
||
db.execute(
|
||
select(Paper)
|
||
.where(Paper.arxiv_id.in_(list(papers_info.keys())))
|
||
.options(joinedload(Paper.tags))
|
||
)
|
||
.scalars()
|
||
.all()
|
||
)
|
||
|
||
items = []
|
||
for p in papers:
|
||
items.append(
|
||
{
|
||
"arxiv_id": p.arxiv_id,
|
||
"title_zh": p.title_zh or p.title_en,
|
||
"distance": papers_info.get(p.arxiv_id, 0.0),
|
||
"paper_date": p.paper_date.isoformat() if p.paper_date else "",
|
||
"tags": [t.tag for t in p.tags[:3]],
|
||
}
|
||
)
|
||
|
||
# 按距离排序
|
||
items.sort(key=lambda x: x["distance"])
|
||
return items
|
||
|
||
except Exception:
|
||
logger.exception("Failed to get similar papers for %s", arxiv_id)
|
||
return []
|
||
|
||
|
||
# ── 图片画廊 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
def _get_paper_images(arxiv_id: str) -> list[dict]:
|
||
"""获取论文提取的图片列表。"""
|
||
images_dir = PAPERS_DIR / arxiv_id / "images"
|
||
if not images_dir.exists():
|
||
return []
|
||
|
||
images = []
|
||
for img_file in sorted(images_dir.iterdir()):
|
||
if img_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg"):
|
||
images.append(
|
||
{
|
||
"url": f"/papers/{arxiv_id}/images/{img_file.name}",
|
||
"name": img_file.name,
|
||
}
|
||
)
|
||
return images
|
||
|
||
|
||
def _link_figures_with_images(
|
||
figures: list[dict], images: list[dict], arxiv_id: str
|
||
) -> list[dict]:
|
||
"""将 summary figures 元数据与提取的图片文件关联。
|
||
|
||
策略:
|
||
1. 优先用 manifest.json 的 label 做 ID 精确匹配
|
||
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
|
||
"""
|
||
if not figures or not images:
|
||
return figures
|
||
|
||
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
|
||
|
||
# ── 策略 1:manifest ID 精确匹配 ──
|
||
id_to_url: dict[str, str] = {}
|
||
if manifest_path.exists():
|
||
try:
|
||
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||
except (ValueError, TypeError):
|
||
manifest = {}
|
||
for filename, info in manifest.items():
|
||
url = f"/papers/{arxiv_id}/images/{filename}"
|
||
# 优先用 label 字段(新格式)
|
||
label = info.get("label", "")
|
||
if label:
|
||
id_to_url[label] = url
|
||
# 也兼容 figures/tables 列表(旧格式)
|
||
for fig_id in info.get("figures", []) + info.get("tables", []):
|
||
if fig_id not in id_to_url:
|
||
id_to_url[fig_id] = url
|
||
|
||
for fig in figures:
|
||
raw_id = fig.get("id", "")
|
||
normalized = _normalize_figure_id(raw_id)
|
||
if normalized in id_to_url:
|
||
fig["image_url"] = id_to_url[normalized]
|
||
|
||
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
|
||
unmatched = [f for f in figures if not f.get("image_url")]
|
||
if not unmatched:
|
||
return figures
|
||
|
||
# 按类型分流:Figure vs Table
|
||
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
|
||
table_type_unmatched = [f for f in unmatched if not _is_figure_type(f.get("id", ""))]
|
||
|
||
# 提取的图片也按类型分流,按文件名排序
|
||
def _sort_key(name: str) -> tuple[int, int]:
|
||
m = re.search(r'page(\d+)_(?:img|table)(\d+)', name)
|
||
if m:
|
||
return (int(m.group(1)), int(m.group(2)))
|
||
return (0, 0)
|
||
|
||
fig_images = sorted(
|
||
[img for img in images if "table" not in img["name"].lower()],
|
||
key=lambda img: _sort_key(img["name"]),
|
||
)
|
||
table_images = sorted(
|
||
[img for img in images if "table" in img["name"].lower()],
|
||
key=lambda img: _sort_key(img["name"]),
|
||
)
|
||
|
||
for i, fig in enumerate(fig_type_unmatched):
|
||
if i < len(fig_images):
|
||
fig["image_url"] = fig_images[i]["url"]
|
||
|
||
for i, fig in enumerate(table_type_unmatched):
|
||
if i < len(table_images):
|
||
fig["image_url"] = table_images[i]["url"]
|
||
|
||
return figures
|
||
|
||
|
||
def _normalize_figure_id(raw_id: str) -> str:
|
||
"""归一化 Figure/Table ID:'Figure 1'/'Fig.1' → 'Figure 1'。"""
|
||
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
|
||
if m:
|
||
return f"Figure {m.group(1)}"
|
||
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
|
||
if m2:
|
||
return f"Table {m2.group(1)}"
|
||
return raw_id
|
||
|
||
|
||
def _is_figure_type(fig_id: str) -> bool:
|
||
"""判断是否为 Figure 类型(非 Table)。"""
|
||
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
|