Files
daily-paper/app/routes/pages.py
T
Rain-Bus 18f44ac244 feat: improve PDF image extraction with caption-based labeling and fallback matching
- Enhance pdf_image_extractor with caption text extraction near images/tables
- Add figure/table type correction based on caption content
- Implement sequential numbering fallback for unmatched items
- Improve figure linking in pages with manifest ID matching and fallback strategies
- Remove docling dependency, add dev dependency group
2026-06-09 14:07:21 +08:00

359 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""页面路由 — 首页、日期页、论文详情。"""
from __future__ import annotations
import json
import logging
import re
from datetime import date, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import RedirectResponse
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.database import get_db
from app.models import PAPER_FULL_LOAD, Paper
from app.utils import PAPERS_DIR, safe_json_loads, templates, today_str, latest_paper_date
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/")
def index(request: Request, db: Session = Depends(get_db)):
"""重定向到最新有论文的日期页。"""
return RedirectResponse(url=f"/day/{latest_paper_date(db)}")
@router.get("/day/{date_str}")
def day_page(date_str: str, request: Request, db: Session = Depends(get_db)):
"""指定日期论文列表。"""
try:
target = date.fromisoformat(date_str)
except ValueError:
raise HTTPException(status_code=404, detail="Invalid date format")
prev_day = (target - timedelta(days=1)).isoformat()
next_day = (target + timedelta(days=1)).isoformat()
today = today_str()
papers = (
db.execute(
select(Paper)
.where(Paper.paper_date == date_str)
.options(*PAPER_FULL_LOAD)
.order_by(Paper.upvotes.desc())
)
.scalars()
.unique()
.all()
)
dates_raw = (
db.execute(
select(Paper.paper_date)
.distinct()
.order_by(Paper.paper_date.desc())
.limit(30)
)
.all()
)
available_dates = [
d[0].isoformat() if isinstance(d[0], date) else str(d[0]) for d in dates_raw
]
return templates.TemplateResponse(
request,
"index.html",
{
"papers": papers,
"current_date": date_str,
"prev_day": prev_day,
"next_day": next_day,
"today": today,
"available_dates": available_dates,
"page_title": f"{date_str} 论文列表",
},
)
@router.get("/paper/{arxiv_id}")
def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db)):
"""论文详情页。"""
paper = (
db.execute(
select(Paper)
.where(Paper.arxiv_id == arxiv_id)
.options(
joinedload(Paper.summary),
joinedload(Paper.note),
*PAPER_FULL_LOAD,
)
)
.unique()
.scalar_one_or_none()
)
if not paper:
raise HTTPException(status_code=404, detail="Paper not found")
summary_state = "none"
if paper.summary_status:
summary_state = paper.summary_status.status
# 相似论文推荐
similar_papers = _get_similar_papers(db, arxiv_id, top_k=6)
# 图片画廊
images = _get_paper_images(arxiv_id)
# 预处理 JSON 字段供模板直接使用
prereqs = safe_json_loads(
paper.summary.prerequisites_json if paper.summary else None, default={}
)
benchmarks = safe_json_loads(
paper.summary.results_benchmarks_json if paper.summary else None, default=[]
)
figures_raw = safe_json_loads(
paper.summary.figures_json if paper.summary else None, default=[]
)
linked_figures = _link_figures_with_images(figures_raw, images, arxiv_id)
# 拆分:table_figures(有截图的 Table 类型)→ 实验结果区域展示截图
# figures(其余)→ 论文图表画廊
table_figures = []
figures = []
for fig in linked_figures:
fig_id = fig.get("id", "")
is_table = fig_id.lower().startswith("table")
if is_table and fig.get("image_url"):
table_figures.append(fig)
else:
figures.append(fig)
return templates.TemplateResponse(
request,
"detail.html",
{
"paper": paper,
"summary_state": summary_state,
"similar_papers": similar_papers,
"paper_images": images,
"prereqs": prereqs,
"benchmarks": benchmarks,
"figures": figures,
"table_figures": table_figures,
"chroma_enabled": settings.CHROMA_ENABLED,
"page_title": paper.title_zh or paper.title_en,
},
)
# ── 相似论文 API ──────────────────────────────────────────────────────
@router.get("/api/similar/{arxiv_id}")
def similar_api(
arxiv_id: str,
top_k: int = Query(default=5, ge=1, le=20),
db: Session = Depends(get_db),
):
"""返回与指定论文相似的论文列表(JSON)。"""
similar = _get_similar_papers(db, arxiv_id, top_k=top_k + 1)
# 排除自身
items = [s for s in similar if s["arxiv_id"] != arxiv_id][:top_k]
return {"results": items}
def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict]:
"""从 ChromaDB 获取相似论文,返回 [{arxiv_id, title_zh, distance, paper_date}, ...]。"""
if not settings.CHROMA_ENABLED:
return []
try:
from app.services.embedder import get_collection
col = get_collection()
if col is None:
return []
# 获取当前论文的 embedding
result = col.get(ids=[arxiv_id], include=["embeddings"])
if not result["embeddings"] or not result["embeddings"][0]:
return []
vec = result["embeddings"][0]
count = col.count()
if count == 0:
return []
results = col.query(
query_embeddings=[vec],
n_results=min(top_k, count),
include=["metadatas", "distances"],
)
if not results["ids"] or not results["ids"][0]:
return []
# 从 DB 加载论文信息
similar_ids = results["ids"][0]
distances = (
results["distances"][0]
if results["distances"]
else [0.0] * len(similar_ids)
)
# 排除自身
papers_info = {}
for i, sid in enumerate(similar_ids):
if sid != arxiv_id:
papers_info[sid] = distances[i]
if not papers_info:
return []
papers = (
db.execute(
select(Paper)
.where(Paper.arxiv_id.in_(list(papers_info.keys())))
.options(joinedload(Paper.tags))
)
.scalars()
.all()
)
items = []
for p in papers:
items.append(
{
"arxiv_id": p.arxiv_id,
"title_zh": p.title_zh or p.title_en,
"distance": papers_info.get(p.arxiv_id, 0.0),
"paper_date": p.paper_date.isoformat() if p.paper_date else "",
"tags": [t.tag for t in p.tags[:3]],
}
)
# 按距离排序
items.sort(key=lambda x: x["distance"])
return items
except Exception:
logger.exception("Failed to get similar papers for %s", arxiv_id)
return []
# ── 图片画廊 ──────────────────────────────────────────────────────────
def _get_paper_images(arxiv_id: str) -> list[dict]:
"""获取论文提取的图片列表。"""
images_dir = PAPERS_DIR / arxiv_id / "images"
if not images_dir.exists():
return []
images = []
for img_file in sorted(images_dir.iterdir()):
if img_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg"):
images.append(
{
"url": f"/papers/{arxiv_id}/images/{img_file.name}",
"name": img_file.name,
}
)
return images
def _link_figures_with_images(
figures: list[dict], images: list[dict], arxiv_id: str
) -> list[dict]:
"""将 summary figures 元数据与提取的图片文件关联。
策略:
1. 优先用 manifest.json 的 label 做 ID 精确匹配
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
"""
if not figures or not images:
return figures
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
# ── 策略 1manifest ID 精确匹配 ──
id_to_url: dict[str, str] = {}
if manifest_path.exists():
try:
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
except (ValueError, TypeError):
manifest = {}
for filename, info in manifest.items():
url = f"/papers/{arxiv_id}/images/{filename}"
# 优先用 label 字段(新格式)
label = info.get("label", "")
if label:
id_to_url[label] = url
# 也兼容 figures/tables 列表(旧格式)
for fig_id in info.get("figures", []) + info.get("tables", []):
if fig_id not in id_to_url:
id_to_url[fig_id] = url
for fig in figures:
raw_id = fig.get("id", "")
normalized = _normalize_figure_id(raw_id)
if normalized in id_to_url:
fig["image_url"] = id_to_url[normalized]
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
unmatched = [f for f in figures if not f.get("image_url")]
if not unmatched:
return figures
# 按类型分流:Figure vs Table
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
table_type_unmatched = [f for f in unmatched if not _is_figure_type(f.get("id", ""))]
# 提取的图片也按类型分流,按文件名排序
def _sort_key(name: str) -> tuple[int, int]:
m = re.search(r'page(\d+)_(?:img|table)(\d+)', name)
if m:
return (int(m.group(1)), int(m.group(2)))
return (0, 0)
fig_images = sorted(
[img for img in images if "table" not in img["name"].lower()],
key=lambda img: _sort_key(img["name"]),
)
table_images = sorted(
[img for img in images if "table" in img["name"].lower()],
key=lambda img: _sort_key(img["name"]),
)
for i, fig in enumerate(fig_type_unmatched):
if i < len(fig_images):
fig["image_url"] = fig_images[i]["url"]
for i, fig in enumerate(table_type_unmatched):
if i < len(table_images):
fig["image_url"] = table_images[i]["url"]
return figures
def _normalize_figure_id(raw_id: str) -> str:
"""归一化 Figure/Table ID'Figure 1'/'Fig.1''Figure 1'"""
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
if m:
return f"Figure {m.group(1)}"
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
if m2:
return f"Table {m2.group(1)}"
return raw_id
def _is_figure_type(fig_id: str) -> bool:
"""判断是否为 Figure 类型(非 Table)。"""
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)