Files
daily-paper/app/routes/pages.py
T
Rain-Bus 90fe705e8f refactor: 迁移布局检测模型从 PicoDet 到 DocLayout-YOLO
- 核心变更:
  - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024)
  - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备
  - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式
  - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls]
  - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table)

- 新增文件:
  - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行)
  - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例)

- 配置更新:
  - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID)
  - app/config.py: Settings 类对应字段
  - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等)

- 删除旧文件:
  - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本

- 文档更新:
  - README.md: 更新环境变量说明
  - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py)

此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
2026-06-14 10:41:44 +08:00

281 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""页面路由 — 首页、日期页、论文详情。"""
from __future__ import annotations
import logging
from datetime import date, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import RedirectResponse
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.database import get_db
from app.models import PAPER_FULL_LOAD, Paper
from app.services.pdf_image_extractor import link_figures_with_images
from app.utils import (
PAPERS_DIR,
safe_json_loads,
templates,
today_str,
latest_paper_date,
)
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/")
def index(request: Request, db: Session = Depends(get_db)):
"""重定向到最新有论文的日期页。"""
return RedirectResponse(url=f"/day/{latest_paper_date(db)}")
@router.get("/day/{date_str}")
def day_page(date_str: str, request: Request, db: Session = Depends(get_db)):
"""指定日期论文列表。"""
try:
target = date.fromisoformat(date_str)
except ValueError:
raise HTTPException(status_code=404, detail="Invalid date format")
prev_day = (target - timedelta(days=1)).isoformat()
next_day = (target + timedelta(days=1)).isoformat()
today = today_str()
papers = (
db.execute(
select(Paper)
.where(Paper.paper_date == date_str)
.options(*PAPER_FULL_LOAD)
.order_by(Paper.upvotes.desc())
)
.scalars()
.unique()
.all()
)
dates_raw = db.execute(
select(Paper.paper_date).distinct().order_by(Paper.paper_date.desc()).limit(30)
).all()
available_dates = [
d[0].isoformat() if isinstance(d[0], date) else str(d[0]) for d in dates_raw
]
return templates.TemplateResponse(
request,
"index.html",
{
"papers": papers,
"current_date": date_str,
"prev_day": prev_day,
"next_day": next_day,
"today": today,
"available_dates": available_dates,
"page_title": f"{date_str} 论文列表",
},
)
@router.get("/paper/{arxiv_id}")
def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db)):
"""论文详情页。"""
paper = (
db.execute(
select(Paper)
.where(Paper.arxiv_id == arxiv_id)
.options(
joinedload(Paper.summary),
joinedload(Paper.note),
*PAPER_FULL_LOAD,
)
)
.unique()
.scalar_one_or_none()
)
if not paper:
raise HTTPException(status_code=404, detail="Paper not found")
summary_state = "none"
if paper.summary_status:
summary_state = paper.summary_status.status
# 相似论文推荐
similar_papers = _get_similar_papers(db, arxiv_id, top_k=6)
# 图片画廊
images = _get_paper_images(arxiv_id)
# 预处理 JSON 字段供模板直接使用
prereqs = safe_json_loads(
paper.summary.prerequisites_json if paper.summary else None, default={}
)
benchmarks = safe_json_loads(
paper.summary.results_benchmarks_json if paper.summary else None, default=[]
)
figures_raw = safe_json_loads(
paper.summary.figures_json if paper.summary else None, default=[]
)
linked_figures = link_figures_with_images(figures_raw, images, arxiv_id)
# 拆分图片到对应展示区域:
# table_figures → 实验结果区域(Table 截图,不变)
# method_figures → 核心方法区域(section=="method"
# results_figures → 实验结果区域(section=="results" 的 Figure
# gallery_figures → 底部画廊(其余:motivation/limitations/无 section/无图)
table_figures: list[dict] = []
method_figures: list[dict] = []
results_figures: list[dict] = []
gallery_figures: list[dict] = []
for fig in linked_figures:
fig_id = fig.get("id", "")
section = fig.get("section", "")
is_table = fig_id.lower().startswith("table")
if is_table and fig.get("image_url"):
table_figures.append(fig)
elif not is_table and section == "method" and fig.get("image_url"):
method_figures.append(fig)
elif not is_table and section == "results" and fig.get("image_url"):
results_figures.append(fig)
else:
gallery_figures.append(fig)
return templates.TemplateResponse(
request,
"detail.html",
{
"paper": paper,
"summary_state": summary_state,
"similar_papers": similar_papers,
"paper_images": images,
"prereqs": prereqs,
"benchmarks": benchmarks,
"figures": gallery_figures,
"table_figures": table_figures,
"method_figures": method_figures,
"results_figures": results_figures,
"chroma_enabled": settings.CHROMA_ENABLED,
"page_title": paper.title_zh or paper.title_en,
},
)
# ── 相似论文 API ──────────────────────────────────────────────────────
@router.get("/api/similar/{arxiv_id}")
def similar_api(
arxiv_id: str,
top_k: int = Query(default=5, ge=1, le=20),
db: Session = Depends(get_db),
):
"""返回与指定论文相似的论文列表(JSON)。"""
similar = _get_similar_papers(db, arxiv_id, top_k=top_k + 1)
# 排除自身
items = [s for s in similar if s["arxiv_id"] != arxiv_id][:top_k]
return {"results": items}
def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict]:
"""从 ChromaDB 获取相似论文,返回 [{arxiv_id, title_zh, distance, paper_date}, ...]。"""
if not settings.CHROMA_ENABLED:
return []
try:
from app.services.embedder import get_collection
col = get_collection()
if col is None:
return []
# 获取当前论文的 embedding
result = col.get(ids=[arxiv_id], include=["embeddings"])
if not result["embeddings"] or not result["embeddings"][0]:
return []
vec = result["embeddings"][0]
count = col.count()
if count == 0:
return []
results = col.query(
query_embeddings=[vec],
n_results=min(top_k, count),
include=["metadatas", "distances"],
)
if not results["ids"] or not results["ids"][0]:
return []
# 从 DB 加载论文信息
similar_ids = results["ids"][0]
distances = (
results["distances"][0]
if results["distances"]
else [0.0] * len(similar_ids)
)
# 排除自身
papers_info = {}
for i, sid in enumerate(similar_ids):
if sid != arxiv_id:
papers_info[sid] = distances[i]
if not papers_info:
return []
papers = (
db.execute(
select(Paper)
.where(Paper.arxiv_id.in_(list(papers_info.keys())))
.options(joinedload(Paper.tags))
)
.scalars()
.all()
)
items = []
for p in papers:
items.append(
{
"arxiv_id": p.arxiv_id,
"title_zh": p.title_zh or p.title_en,
"distance": papers_info.get(p.arxiv_id, 0.0),
"paper_date": p.paper_date.isoformat() if p.paper_date else "",
"tags": [t.tag for t in p.tags[:3]],
}
)
# 按距离排序
items.sort(key=lambda x: x["distance"])
return items
except Exception:
logger.exception("Failed to get similar papers for %s", arxiv_id)
return []
# ── 图片画廊 ──────────────────────────────────────────────────────────
def _get_paper_images(arxiv_id: str) -> list[dict]:
"""获取论文提取的图片列表。"""
images_dir = PAPERS_DIR / arxiv_id / "images"
if not images_dir.exists():
return []
images = []
for img_file in sorted(images_dir.iterdir()):
if img_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg"):
images.append(
{
"url": f"/papers/{arxiv_id}/images/{img_file.name}",
"name": img_file.name,
}
)
return images