feat: refactor summarizer and PDF extraction pipeline
- Split summarizer into summary_generator and summary_persister modules - Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection - Add layout_detector service for PicoDet-S_layout_3cls integration - Add exceptions module with ConflictError and NotFoundError - Improve admin dashboard with better statistics and task management - Add design review document with system optimization suggestions - Add new tests for crawler, pdf_downloader, pipeline, and summary_utils - Update dependencies and configuration - Clean up dead code and improve error handling
This commit is contained in:
+94
-144
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
from datetime import date
|
||||
@@ -10,7 +11,7 @@ from datetime import date
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel, field_validator
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy import bindparam, func, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
@@ -22,15 +23,15 @@ from app.models import (
|
||||
PaperTag,
|
||||
SummaryState,
|
||||
SummaryStatus,
|
||||
TaskLock,
|
||||
)
|
||||
from app.services import admin as admin_svc
|
||||
from app.services.admin import get_admin_stats
|
||||
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
|
||||
from app.services.crawler import crawl_daily, refresh_upvotes
|
||||
from app.services.pipeline import run_pipeline
|
||||
from app.services.crawler import refresh_upvotes
|
||||
from app.services.pipeline import run_crawl, run_pipeline
|
||||
from app.services.scheduler import get_scheduler
|
||||
from app.services.summarizer import summarize_batch, summarize_single
|
||||
from app.utils import release_lock, templates, today_str, utc_now
|
||||
from app.utils import templates, today_str, utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,14 +42,15 @@ router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
|
||||
def _check_password(password: str) -> bool:
|
||||
"""校验密码,支持明文或 sha256 哈希。"""
|
||||
"""校验密码,支持明文或 sha256 哈希(常量时间比较)。"""
|
||||
stored = settings.ADMIN_PASSWORD
|
||||
if not stored:
|
||||
return False
|
||||
if password == stored:
|
||||
if hmac.compare_digest(password, stored):
|
||||
return True
|
||||
# 也支持存 sha256 哈希
|
||||
return hashlib.sha256(password.encode()).hexdigest() == stored
|
||||
hashed = hashlib.sha256(password.encode()).hexdigest()
|
||||
return hmac.compare_digest(hashed, stored)
|
||||
|
||||
|
||||
async def verify_admin(request: Request) -> None:
|
||||
@@ -204,32 +206,12 @@ async def admin_crawl(
|
||||
):
|
||||
"""手动抓取指定日期,默认今天。"""
|
||||
target_date = date or today_str()
|
||||
|
||||
# TaskLock 防重入
|
||||
now = utc_now()
|
||||
lock = TaskLock(
|
||||
task="crawl",
|
||||
lock_key=target_date,
|
||||
status="running",
|
||||
owner="admin_crawl",
|
||||
acquired_at=now,
|
||||
)
|
||||
try:
|
||||
db.add(lock)
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409, detail=f"Crawl already running for {target_date}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await crawl_daily(db, target_date)
|
||||
return result
|
||||
return await run_crawl(db, target_date, owner="admin_crawl")
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc))
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
finally:
|
||||
release_lock(db, lock)
|
||||
|
||||
|
||||
# ── 总结 ──────────────────────────────────────────────────────────────
|
||||
@@ -241,12 +223,7 @@ async def admin_summarize_batch(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""批量总结所有 pending 论文。"""
|
||||
result = await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE)
|
||||
if result.get("status") == "conflict":
|
||||
raise HTTPException(
|
||||
status_code=409, detail=result.get("error", "batch already running")
|
||||
)
|
||||
return result
|
||||
return await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE)
|
||||
|
||||
|
||||
@router.post("/summarize/{arxiv_id}")
|
||||
@@ -256,10 +233,9 @@ async def admin_summarize_single(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""总结或重跑单篇论文。"""
|
||||
result = await summarize_single(db, arxiv_id, force=True, pdf_mode=settings.SUMMARY_PDF_MODE)
|
||||
if result.get("status") == "not_found":
|
||||
raise HTTPException(status_code=404, detail=f"Paper not found: {arxiv_id}")
|
||||
return result
|
||||
return await summarize_single(
|
||||
db, arxiv_id, force=True, pdf_mode=settings.SUMMARY_PDF_MODE
|
||||
)
|
||||
|
||||
|
||||
# ── 清理 ──────────────────────────────────────────────────────────────
|
||||
@@ -284,10 +260,13 @@ async def admin_cleanup(
|
||||
result = cleanup_tmp()
|
||||
log_entry.status = "success"
|
||||
log_entry.completed_at = utc_now()
|
||||
log_entry.details_json = json.dumps({
|
||||
"scanned": result.get("scanned", 0),
|
||||
"removed": result.get("removed", 0),
|
||||
}, ensure_ascii=False)
|
||||
log_entry.details_json = json.dumps(
|
||||
{
|
||||
"scanned": result.get("scanned", 0),
|
||||
"removed": result.get("removed", 0),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
if result.get("errors"):
|
||||
log_entry.error = "; ".join(result["errors"])[:2000]
|
||||
db.commit()
|
||||
@@ -358,19 +337,34 @@ async def admin_logs(
|
||||
|
||||
# 总结状态统计概要
|
||||
summary_total = db.scalar(select(func.count(Paper.id))) or 0
|
||||
summary_done = db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(SummaryStatus.status == SummaryState.DONE)
|
||||
) or 0
|
||||
summary_pending = db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status.in_([SummaryState.PENDING, SummaryState.PROCESSING])
|
||||
summary_done = (
|
||||
db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status == SummaryState.DONE
|
||||
)
|
||||
)
|
||||
) or 0
|
||||
summary_failed = db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE])
|
||||
or 0
|
||||
)
|
||||
summary_pending = (
|
||||
db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.PENDING, SummaryState.PROCESSING]
|
||||
)
|
||||
)
|
||||
)
|
||||
) or 0
|
||||
or 0
|
||||
)
|
||||
summary_failed = (
|
||||
db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||
)
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
@@ -414,13 +408,8 @@ async def admin_summary_status(
|
||||
else:
|
||||
query = query.where(SummaryStatus.status == status)
|
||||
|
||||
total = db.scalar(
|
||||
select(func.count()).select_from(query.subquery())
|
||||
)
|
||||
results = (
|
||||
db.execute(query.offset((page - 1) * per_page).limit(per_page))
|
||||
.all()
|
||||
)
|
||||
total = db.scalar(select(func.count()).select_from(query.subquery()))
|
||||
results = db.execute(query.offset((page - 1) * per_page).limit(per_page)).all()
|
||||
|
||||
# 判断是否 HTMX 请求
|
||||
is_htmx = request.headers.get("HX-Request") == "true"
|
||||
@@ -465,7 +454,11 @@ async def admin_summary_retry_failed(
|
||||
db.execute(
|
||||
select(Paper.arxiv_id)
|
||||
.join(SummaryStatus, SummaryStatus.paper_id == Paper.id)
|
||||
.where(SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]))
|
||||
.where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
@@ -477,7 +470,11 @@ async def admin_summary_retry_failed(
|
||||
# 重置失败任务的状态为 pending
|
||||
db.execute(
|
||||
SummaryStatus.__table__.update()
|
||||
.where(SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]))
|
||||
.where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||
)
|
||||
)
|
||||
.values(status=SummaryState.PENDING, error=None, error_type=None)
|
||||
)
|
||||
db.commit()
|
||||
@@ -492,15 +489,6 @@ async def admin_summary_retry_failed(
|
||||
# ── 论文管理 ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# 排序映射
|
||||
_SORT_MAP = {
|
||||
"date_desc": Paper.paper_date.desc(),
|
||||
"date_asc": Paper.paper_date.asc(),
|
||||
"upvotes_desc": Paper.upvotes.desc(),
|
||||
"title_asc": Paper.title_en.asc(),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/papers")
|
||||
async def admin_papers(
|
||||
request: Request,
|
||||
@@ -516,66 +504,18 @@ async def admin_papers(
|
||||
per_page: int = Query(20, ge=1, le=100),
|
||||
):
|
||||
"""论文管理列表页面。"""
|
||||
query = select(Paper)
|
||||
|
||||
# 搜索
|
||||
if q.strip():
|
||||
query = query.where(
|
||||
Paper.title_en.ilike(f"%{q}%")
|
||||
| Paper.title_zh.ilike(f"%{q}%")
|
||||
| Paper.abstract.ilike(f"%{q}%")
|
||||
)
|
||||
|
||||
# 日期范围
|
||||
if date_from:
|
||||
query = query.where(Paper.paper_date >= date_from)
|
||||
if date_to:
|
||||
query = query.where(Paper.paper_date <= date_to)
|
||||
|
||||
# 标签筛选
|
||||
if tag:
|
||||
query = query.join(PaperTag, PaperTag.paper_id == Paper.id).where(
|
||||
PaperTag.tag == tag
|
||||
)
|
||||
|
||||
# 总结状态筛选
|
||||
if summary_status != "all":
|
||||
if summary_status == "none":
|
||||
query = query.outerjoin(
|
||||
SummaryStatus, SummaryStatus.paper_id == Paper.id
|
||||
).where(SummaryStatus.paper_id == None) # noqa: E711
|
||||
else:
|
||||
query = query.join(
|
||||
SummaryStatus, SummaryStatus.paper_id == Paper.id
|
||||
).where(SummaryStatus.status == summary_status)
|
||||
|
||||
# 排序
|
||||
order = _SORT_MAP.get(sort, Paper.paper_date.desc())
|
||||
query = query.order_by(order)
|
||||
|
||||
# 计数
|
||||
total = db.scalar(select(func.count()).select_from(query.subquery()))
|
||||
|
||||
# 分页
|
||||
papers = (
|
||||
db.execute(query.offset((page - 1) * per_page).limit(per_page))
|
||||
.scalars()
|
||||
.all()
|
||||
papers, total, statuses = admin_svc.query_papers(
|
||||
db,
|
||||
q=q,
|
||||
date_from=date_from,
|
||||
date_to=date_to,
|
||||
tag=tag,
|
||||
summary_status=summary_status,
|
||||
sort=sort,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
)
|
||||
|
||||
# 获取每篇论文的总结状态
|
||||
paper_ids = [p.id for p in papers]
|
||||
statuses = {}
|
||||
if paper_ids:
|
||||
rows = db.execute(
|
||||
select(SummaryStatus.paper_id, SummaryStatus.status).where(
|
||||
SummaryStatus.paper_id.in_(paper_ids)
|
||||
)
|
||||
).all()
|
||||
paper_id_to_arxiv = {p.id: p.arxiv_id for p in papers}
|
||||
for pid, st in rows:
|
||||
statuses[paper_id_to_arxiv.get(pid, "")] = st
|
||||
|
||||
# 构建分页 URL 辅助函数
|
||||
def pagination_url(p: int) -> str:
|
||||
params = dict(request.query_params)
|
||||
@@ -588,7 +528,7 @@ async def admin_papers(
|
||||
{
|
||||
"papers": papers,
|
||||
"paper_summary_statuses": statuses,
|
||||
"total": total or 0,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"current_status": summary_status,
|
||||
@@ -615,7 +555,9 @@ async def admin_paper_delete(
|
||||
|
||||
# 清理 FTS 索引
|
||||
try:
|
||||
db.execute(text("DELETE FROM papers_fts WHERE arxiv_id = :aid"), {"aid": arxiv_id})
|
||||
db.execute(
|
||||
text("DELETE FROM papers_fts WHERE arxiv_id = :aid"), {"aid": arxiv_id}
|
||||
)
|
||||
db.commit()
|
||||
except Exception:
|
||||
logger.warning("Failed to clean FTS index for %s", arxiv_id, exc_info=True)
|
||||
@@ -646,9 +588,11 @@ async def admin_papers_batch_action(
|
||||
raise HTTPException(status_code=400, detail="arxiv_ids 不能为空")
|
||||
|
||||
if body.action == "delete":
|
||||
papers = db.execute(
|
||||
select(Paper).where(Paper.arxiv_id.in_(body.arxiv_ids))
|
||||
).scalars().all()
|
||||
papers = (
|
||||
db.execute(select(Paper).where(Paper.arxiv_id.in_(body.arxiv_ids)))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
count = 0
|
||||
for paper in papers:
|
||||
@@ -658,21 +602,27 @@ async def admin_papers_batch_action(
|
||||
|
||||
# 清理 FTS 索引
|
||||
try:
|
||||
db.execute(
|
||||
text("DELETE FROM papers_fts WHERE arxiv_id IN :ids"),
|
||||
{"ids": tuple(body.arxiv_ids)},
|
||||
stmt = text("DELETE FROM papers_fts WHERE arxiv_id IN :ids").bindparams(
|
||||
bindparam("ids", expanding=True)
|
||||
)
|
||||
db.execute(stmt, {"ids": body.arxiv_ids})
|
||||
db.commit()
|
||||
except Exception:
|
||||
logger.warning("Failed to clean FTS index for batch delete", exc_info=True)
|
||||
|
||||
return {"status": "success", "message": f"已删除 {count} 篇论文", "count": count}
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"已删除 {count} 篇论文",
|
||||
"count": count,
|
||||
}
|
||||
|
||||
elif body.action == "summarize":
|
||||
# 将选中论文的总结状态重置为 pending
|
||||
paper_ids = db.execute(
|
||||
select(Paper.id).where(Paper.arxiv_id.in_(body.arxiv_ids))
|
||||
).scalars().all()
|
||||
paper_ids = (
|
||||
db.execute(select(Paper.id).where(Paper.arxiv_id.in_(body.arxiv_ids)))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
if paper_ids:
|
||||
# 删除旧的 status 记录让其重新进入 pipeline
|
||||
|
||||
@@ -12,6 +12,8 @@ from app.utils import templates
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
MAX_COMPARE_PAPERS = 5
|
||||
|
||||
|
||||
@router.get("/compare")
|
||||
def compare_page(
|
||||
@@ -33,9 +35,9 @@ def compare_page(
|
||||
|
||||
arxiv_ids = [i.strip() for i in ids.split(",") if i.strip()]
|
||||
|
||||
# 最多 5 篇
|
||||
if len(arxiv_ids) > 5:
|
||||
arxiv_ids = arxiv_ids[:5]
|
||||
# 最多 MAX_COMPARE_PAPERS 篇
|
||||
if len(arxiv_ids) > MAX_COMPARE_PAPERS:
|
||||
arxiv_ids = arxiv_ids[:MAX_COMPARE_PAPERS]
|
||||
|
||||
if not arxiv_ids:
|
||||
return templates.TemplateResponse(
|
||||
|
||||
+2
-99
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import date, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
@@ -15,6 +14,7 @@ from sqlalchemy.orm import Session, joinedload
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.models import PAPER_FULL_LOAD, Paper
|
||||
from app.services.pdf_image_extractor import link_figures_with_images
|
||||
from app.utils import (
|
||||
PAPERS_DIR,
|
||||
safe_json_loads,
|
||||
@@ -120,7 +120,7 @@ def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db))
|
||||
paper.summary.figures_json if paper.summary else None, default=[]
|
||||
)
|
||||
|
||||
linked_figures = _link_figures_with_images(figures_raw, images, arxiv_id)
|
||||
linked_figures = link_figures_with_images(figures_raw, images, arxiv_id)
|
||||
|
||||
# 拆分图片到对应展示区域:
|
||||
# table_figures → 实验结果区域(Table 截图,不变)
|
||||
@@ -279,100 +279,3 @@ def _get_paper_images(arxiv_id: str) -> list[dict]:
|
||||
}
|
||||
)
|
||||
return images
|
||||
|
||||
|
||||
def _link_figures_with_images(
|
||||
figures: list[dict], images: list[dict], arxiv_id: str
|
||||
) -> list[dict]:
|
||||
"""将 summary figures 元数据与提取的图片文件关联。
|
||||
|
||||
策略:
|
||||
1. 优先用 manifest.json 的 label 做 ID 精确匹配
|
||||
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
|
||||
"""
|
||||
if not figures or not images:
|
||||
return figures
|
||||
|
||||
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
|
||||
|
||||
# ── 策略 1:manifest ID 精确匹配 ──
|
||||
id_to_url: dict[str, str] = {}
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||||
except (ValueError, TypeError):
|
||||
manifest = {}
|
||||
for filename, info in manifest.items():
|
||||
url = f"/papers/{arxiv_id}/images/{filename}"
|
||||
# 优先用 label 字段(新格式)
|
||||
label = info.get("label", "")
|
||||
if label:
|
||||
id_to_url[label] = url
|
||||
# 也兼容 figures/tables 列表(旧格式)
|
||||
for fig_id in info.get("figures", []) + info.get("tables", []):
|
||||
if fig_id not in id_to_url:
|
||||
id_to_url[fig_id] = url
|
||||
|
||||
for fig in figures:
|
||||
raw_id = fig.get("id", "")
|
||||
normalized = _normalize_figure_id(raw_id)
|
||||
if normalized in id_to_url:
|
||||
fig["image_url"] = id_to_url[normalized]
|
||||
|
||||
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
|
||||
unmatched = [f for f in figures if not f.get("image_url")]
|
||||
if not unmatched:
|
||||
return figures
|
||||
|
||||
# 按类型分流:Figure vs Table
|
||||
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
|
||||
table_type_unmatched = [
|
||||
f for f in unmatched if not _is_figure_type(f.get("id", ""))
|
||||
]
|
||||
|
||||
# 提取的图片按类型分流,按文件名中的编号排序
|
||||
def _sort_key(name: str) -> tuple[int, int]:
|
||||
# 新格式:figure_1.jpg, table_1.jpg
|
||||
m = re.search(r"(?:figure|table)_(\d+)", name)
|
||||
if m:
|
||||
return (0, int(m.group(1)))
|
||||
# 旧格式:page2_img1.png, page5_table1.png, figure_1.png
|
||||
m2 = re.search(r"page(\d+)_(?:img|table)(\d+)", name)
|
||||
if m2:
|
||||
return (int(m2.group(1)), int(m2.group(2)))
|
||||
return (0, 0)
|
||||
|
||||
fig_images = sorted(
|
||||
[img for img in images if "table" not in img["name"].lower()],
|
||||
key=lambda img: _sort_key(img["name"]),
|
||||
)
|
||||
table_images = sorted(
|
||||
[img for img in images if "table" in img["name"].lower()],
|
||||
key=lambda img: _sort_key(img["name"]),
|
||||
)
|
||||
|
||||
for i, fig in enumerate(fig_type_unmatched):
|
||||
if i < len(fig_images):
|
||||
fig["image_url"] = fig_images[i]["url"]
|
||||
|
||||
for i, fig in enumerate(table_type_unmatched):
|
||||
if i < len(table_images):
|
||||
fig["image_url"] = table_images[i]["url"]
|
||||
|
||||
return figures
|
||||
|
||||
|
||||
def _normalize_figure_id(raw_id: str) -> str:
|
||||
"""归一化 Figure/Table ID:'Figure 1'/'Fig.1' → 'Figure 1'。"""
|
||||
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
|
||||
if m:
|
||||
return f"Figure {m.group(1)}"
|
||||
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
|
||||
if m2:
|
||||
return f"Table {m2.group(1)}"
|
||||
return raw_id
|
||||
|
||||
|
||||
def _is_figure_type(fig_id: str) -> bool:
|
||||
"""判断是否为 Figure 类型(非 Table)。"""
|
||||
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
|
||||
|
||||
+5
-23
@@ -2,12 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.exceptions import NotFoundError
|
||||
from app.services.user_data import (
|
||||
get_note,
|
||||
save_note,
|
||||
@@ -37,9 +38,6 @@ def bookmark_toggle(arxiv_id: str, request: Request, db: Session = Depends(get_d
|
||||
"""切换收藏状态。支持 HTMX 局部刷新和 JSON 响应。"""
|
||||
result = toggle_bookmark(db, arxiv_id)
|
||||
|
||||
if "error" in result:
|
||||
raise HTTPException(status_code=404, detail=result["error"])
|
||||
|
||||
# HTMX 请求 → 返回 HTML 片段
|
||||
if request.headers.get("HX-Request"):
|
||||
star = "★" if result["bookmarked"] else "☆"
|
||||
@@ -66,18 +64,7 @@ def reading_status_update(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""更新阅读状态。"""
|
||||
result = set_reading_status(db, arxiv_id, body.status)
|
||||
|
||||
if "error" in result:
|
||||
if result["error"] == "not_found":
|
||||
raise HTTPException(status_code=404, detail="Paper not found")
|
||||
elif result["error"] == "invalid_status":
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Invalid status. Valid: {result['valid']}",
|
||||
)
|
||||
|
||||
return result
|
||||
return set_reading_status(db, arxiv_id, body.status)
|
||||
|
||||
|
||||
# ── 笔记 ──────────────────────────────────────────────────────────────
|
||||
@@ -88,16 +75,11 @@ def note_get(arxiv_id: str, db: Session = Depends(get_db)):
|
||||
"""获取笔记。"""
|
||||
result = get_note(db, arxiv_id)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail="Paper not found")
|
||||
raise NotFoundError(f"Paper not found: {arxiv_id}")
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/note/{arxiv_id}")
|
||||
def note_save(arxiv_id: str, body: NoteRequest, db: Session = Depends(get_db)):
|
||||
"""保存笔记。"""
|
||||
result = save_note(db, arxiv_id, body.content)
|
||||
|
||||
if "error" in result:
|
||||
raise HTTPException(status_code=404, detail=result["error"])
|
||||
|
||||
return result
|
||||
return save_note(db, arxiv_id, body.content)
|
||||
|
||||
Reference in New Issue
Block a user