feat: refactor summarizer and PDF extraction pipeline

- Split summarizer into summary_generator and summary_persister modules
- Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection
- Add layout_detector service for PicoDet-S_layout_3cls integration
- Add exceptions module with ConflictError and NotFoundError
- Improve admin dashboard with better statistics and task management
- Add design review document with system optimization suggestions
- Add new tests for crawler, pdf_downloader, pipeline, and summary_utils
- Update dependencies and configuration
- Clean up dead code and improve error handling
This commit is contained in:
2026-06-13 13:16:47 +08:00
parent e2f0e1a8be
commit 21f16e6756
43 changed files with 3304 additions and 1494 deletions
+94 -144
View File
@@ -3,6 +3,7 @@
from __future__ import annotations
import hashlib
import hmac
import json
import logging
from datetime import date
@@ -10,7 +11,7 @@ from datetime import date
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request
from fastapi.responses import RedirectResponse
from pydantic import BaseModel, field_validator
from sqlalchemy import func, select, text
from sqlalchemy import bindparam, func, select, text
from sqlalchemy.orm import Session
from app.config import settings
@@ -22,15 +23,15 @@ from app.models import (
PaperTag,
SummaryState,
SummaryStatus,
TaskLock,
)
from app.services import admin as admin_svc
from app.services.admin import get_admin_stats
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
from app.services.crawler import crawl_daily, refresh_upvotes
from app.services.pipeline import run_pipeline
from app.services.crawler import refresh_upvotes
from app.services.pipeline import run_crawl, run_pipeline
from app.services.scheduler import get_scheduler
from app.services.summarizer import summarize_batch, summarize_single
from app.utils import release_lock, templates, today_str, utc_now
from app.utils import templates, today_str, utc_now
logger = logging.getLogger(__name__)
@@ -41,14 +42,15 @@ router = APIRouter(prefix="/admin", tags=["admin"])
def _check_password(password: str) -> bool:
"""校验密码,支持明文或 sha256 哈希。"""
"""校验密码,支持明文或 sha256 哈希(常量时间比较)"""
stored = settings.ADMIN_PASSWORD
if not stored:
return False
if password == stored:
if hmac.compare_digest(password, stored):
return True
# 也支持存 sha256 哈希
return hashlib.sha256(password.encode()).hexdigest() == stored
hashed = hashlib.sha256(password.encode()).hexdigest()
return hmac.compare_digest(hashed, stored)
async def verify_admin(request: Request) -> None:
@@ -204,32 +206,12 @@ async def admin_crawl(
):
"""手动抓取指定日期,默认今天。"""
target_date = date or today_str()
# TaskLock 防重入
now = utc_now()
lock = TaskLock(
task="crawl",
lock_key=target_date,
status="running",
owner="admin_crawl",
acquired_at=now,
)
try:
db.add(lock)
db.commit()
except Exception:
db.rollback()
raise HTTPException(
status_code=409, detail=f"Crawl already running for {target_date}"
)
try:
result = await crawl_daily(db, target_date)
return result
return await run_crawl(db, target_date, owner="admin_crawl")
except RuntimeError as exc:
raise HTTPException(status_code=409, detail=str(exc))
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
finally:
release_lock(db, lock)
# ── 总结 ──────────────────────────────────────────────────────────────
@@ -241,12 +223,7 @@ async def admin_summarize_batch(
db: Session = Depends(get_db),
):
"""批量总结所有 pending 论文。"""
result = await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE)
if result.get("status") == "conflict":
raise HTTPException(
status_code=409, detail=result.get("error", "batch already running")
)
return result
return await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE)
@router.post("/summarize/{arxiv_id}")
@@ -256,10 +233,9 @@ async def admin_summarize_single(
db: Session = Depends(get_db),
):
"""总结或重跑单篇论文。"""
result = await summarize_single(db, arxiv_id, force=True, pdf_mode=settings.SUMMARY_PDF_MODE)
if result.get("status") == "not_found":
raise HTTPException(status_code=404, detail=f"Paper not found: {arxiv_id}")
return result
return await summarize_single(
db, arxiv_id, force=True, pdf_mode=settings.SUMMARY_PDF_MODE
)
# ── 清理 ──────────────────────────────────────────────────────────────
@@ -284,10 +260,13 @@ async def admin_cleanup(
result = cleanup_tmp()
log_entry.status = "success"
log_entry.completed_at = utc_now()
log_entry.details_json = json.dumps({
"scanned": result.get("scanned", 0),
"removed": result.get("removed", 0),
}, ensure_ascii=False)
log_entry.details_json = json.dumps(
{
"scanned": result.get("scanned", 0),
"removed": result.get("removed", 0),
},
ensure_ascii=False,
)
if result.get("errors"):
log_entry.error = "; ".join(result["errors"])[:2000]
db.commit()
@@ -358,19 +337,34 @@ async def admin_logs(
# 总结状态统计概要
summary_total = db.scalar(select(func.count(Paper.id))) or 0
summary_done = db.scalar(
select(func.count(SummaryStatus.id)).where(SummaryStatus.status == SummaryState.DONE)
) or 0
summary_pending = db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_([SummaryState.PENDING, SummaryState.PROCESSING])
summary_done = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status == SummaryState.DONE
)
)
) or 0
summary_failed = db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE])
or 0
)
summary_pending = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_(
[SummaryState.PENDING, SummaryState.PROCESSING]
)
)
)
) or 0
or 0
)
summary_failed = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
)
or 0
)
return templates.TemplateResponse(
request,
@@ -414,13 +408,8 @@ async def admin_summary_status(
else:
query = query.where(SummaryStatus.status == status)
total = db.scalar(
select(func.count()).select_from(query.subquery())
)
results = (
db.execute(query.offset((page - 1) * per_page).limit(per_page))
.all()
)
total = db.scalar(select(func.count()).select_from(query.subquery()))
results = db.execute(query.offset((page - 1) * per_page).limit(per_page)).all()
# 判断是否 HTMX 请求
is_htmx = request.headers.get("HX-Request") == "true"
@@ -465,7 +454,11 @@ async def admin_summary_retry_failed(
db.execute(
select(Paper.arxiv_id)
.join(SummaryStatus, SummaryStatus.paper_id == Paper.id)
.where(SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]))
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
)
.scalars()
.all()
@@ -477,7 +470,11 @@ async def admin_summary_retry_failed(
# 重置失败任务的状态为 pending
db.execute(
SummaryStatus.__table__.update()
.where(SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]))
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
.values(status=SummaryState.PENDING, error=None, error_type=None)
)
db.commit()
@@ -492,15 +489,6 @@ async def admin_summary_retry_failed(
# ── 论文管理 ────────────────────────────────────────────────────────
# 排序映射
_SORT_MAP = {
"date_desc": Paper.paper_date.desc(),
"date_asc": Paper.paper_date.asc(),
"upvotes_desc": Paper.upvotes.desc(),
"title_asc": Paper.title_en.asc(),
}
@router.get("/papers")
async def admin_papers(
request: Request,
@@ -516,66 +504,18 @@ async def admin_papers(
per_page: int = Query(20, ge=1, le=100),
):
"""论文管理列表页面。"""
query = select(Paper)
# 搜索
if q.strip():
query = query.where(
Paper.title_en.ilike(f"%{q}%")
| Paper.title_zh.ilike(f"%{q}%")
| Paper.abstract.ilike(f"%{q}%")
)
# 日期范围
if date_from:
query = query.where(Paper.paper_date >= date_from)
if date_to:
query = query.where(Paper.paper_date <= date_to)
# 标签筛选
if tag:
query = query.join(PaperTag, PaperTag.paper_id == Paper.id).where(
PaperTag.tag == tag
)
# 总结状态筛选
if summary_status != "all":
if summary_status == "none":
query = query.outerjoin(
SummaryStatus, SummaryStatus.paper_id == Paper.id
).where(SummaryStatus.paper_id == None) # noqa: E711
else:
query = query.join(
SummaryStatus, SummaryStatus.paper_id == Paper.id
).where(SummaryStatus.status == summary_status)
# 排序
order = _SORT_MAP.get(sort, Paper.paper_date.desc())
query = query.order_by(order)
# 计数
total = db.scalar(select(func.count()).select_from(query.subquery()))
# 分页
papers = (
db.execute(query.offset((page - 1) * per_page).limit(per_page))
.scalars()
.all()
papers, total, statuses = admin_svc.query_papers(
db,
q=q,
date_from=date_from,
date_to=date_to,
tag=tag,
summary_status=summary_status,
sort=sort,
page=page,
per_page=per_page,
)
# 获取每篇论文的总结状态
paper_ids = [p.id for p in papers]
statuses = {}
if paper_ids:
rows = db.execute(
select(SummaryStatus.paper_id, SummaryStatus.status).where(
SummaryStatus.paper_id.in_(paper_ids)
)
).all()
paper_id_to_arxiv = {p.id: p.arxiv_id for p in papers}
for pid, st in rows:
statuses[paper_id_to_arxiv.get(pid, "")] = st
# 构建分页 URL 辅助函数
def pagination_url(p: int) -> str:
params = dict(request.query_params)
@@ -588,7 +528,7 @@ async def admin_papers(
{
"papers": papers,
"paper_summary_statuses": statuses,
"total": total or 0,
"total": total,
"page": page,
"per_page": per_page,
"current_status": summary_status,
@@ -615,7 +555,9 @@ async def admin_paper_delete(
# 清理 FTS 索引
try:
db.execute(text("DELETE FROM papers_fts WHERE arxiv_id = :aid"), {"aid": arxiv_id})
db.execute(
text("DELETE FROM papers_fts WHERE arxiv_id = :aid"), {"aid": arxiv_id}
)
db.commit()
except Exception:
logger.warning("Failed to clean FTS index for %s", arxiv_id, exc_info=True)
@@ -646,9 +588,11 @@ async def admin_papers_batch_action(
raise HTTPException(status_code=400, detail="arxiv_ids 不能为空")
if body.action == "delete":
papers = db.execute(
select(Paper).where(Paper.arxiv_id.in_(body.arxiv_ids))
).scalars().all()
papers = (
db.execute(select(Paper).where(Paper.arxiv_id.in_(body.arxiv_ids)))
.scalars()
.all()
)
count = 0
for paper in papers:
@@ -658,21 +602,27 @@ async def admin_papers_batch_action(
# 清理 FTS 索引
try:
db.execute(
text("DELETE FROM papers_fts WHERE arxiv_id IN :ids"),
{"ids": tuple(body.arxiv_ids)},
stmt = text("DELETE FROM papers_fts WHERE arxiv_id IN :ids").bindparams(
bindparam("ids", expanding=True)
)
db.execute(stmt, {"ids": body.arxiv_ids})
db.commit()
except Exception:
logger.warning("Failed to clean FTS index for batch delete", exc_info=True)
return {"status": "success", "message": f"已删除 {count} 篇论文", "count": count}
return {
"status": "success",
"message": f"已删除 {count} 篇论文",
"count": count,
}
elif body.action == "summarize":
# 将选中论文的总结状态重置为 pending
paper_ids = db.execute(
select(Paper.id).where(Paper.arxiv_id.in_(body.arxiv_ids))
).scalars().all()
paper_ids = (
db.execute(select(Paper.id).where(Paper.arxiv_id.in_(body.arxiv_ids)))
.scalars()
.all()
)
if paper_ids:
# 删除旧的 status 记录让其重新进入 pipeline
+5 -3
View File
@@ -12,6 +12,8 @@ from app.utils import templates
router = APIRouter()
MAX_COMPARE_PAPERS = 5
@router.get("/compare")
def compare_page(
@@ -33,9 +35,9 @@ def compare_page(
arxiv_ids = [i.strip() for i in ids.split(",") if i.strip()]
# 最多 5
if len(arxiv_ids) > 5:
arxiv_ids = arxiv_ids[:5]
# 最多 MAX_COMPARE_PAPERS
if len(arxiv_ids) > MAX_COMPARE_PAPERS:
arxiv_ids = arxiv_ids[:MAX_COMPARE_PAPERS]
if not arxiv_ids:
return templates.TemplateResponse(
+2 -99
View File
@@ -4,7 +4,6 @@ from __future__ import annotations
import json
import logging
import re
from datetime import date, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query, Request
@@ -15,6 +14,7 @@ from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.database import get_db
from app.models import PAPER_FULL_LOAD, Paper
from app.services.pdf_image_extractor import link_figures_with_images
from app.utils import (
PAPERS_DIR,
safe_json_loads,
@@ -120,7 +120,7 @@ def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db))
paper.summary.figures_json if paper.summary else None, default=[]
)
linked_figures = _link_figures_with_images(figures_raw, images, arxiv_id)
linked_figures = link_figures_with_images(figures_raw, images, arxiv_id)
# 拆分图片到对应展示区域:
# table_figures → 实验结果区域(Table 截图,不变)
@@ -279,100 +279,3 @@ def _get_paper_images(arxiv_id: str) -> list[dict]:
}
)
return images
def _link_figures_with_images(
figures: list[dict], images: list[dict], arxiv_id: str
) -> list[dict]:
"""将 summary figures 元数据与提取的图片文件关联。
策略:
1. 优先用 manifest.json 的 label 做 ID 精确匹配
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
"""
if not figures or not images:
return figures
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
# ── 策略 1manifest ID 精确匹配 ──
id_to_url: dict[str, str] = {}
if manifest_path.exists():
try:
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
except (ValueError, TypeError):
manifest = {}
for filename, info in manifest.items():
url = f"/papers/{arxiv_id}/images/{filename}"
# 优先用 label 字段(新格式)
label = info.get("label", "")
if label:
id_to_url[label] = url
# 也兼容 figures/tables 列表(旧格式)
for fig_id in info.get("figures", []) + info.get("tables", []):
if fig_id not in id_to_url:
id_to_url[fig_id] = url
for fig in figures:
raw_id = fig.get("id", "")
normalized = _normalize_figure_id(raw_id)
if normalized in id_to_url:
fig["image_url"] = id_to_url[normalized]
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
unmatched = [f for f in figures if not f.get("image_url")]
if not unmatched:
return figures
# 按类型分流:Figure vs Table
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
table_type_unmatched = [
f for f in unmatched if not _is_figure_type(f.get("id", ""))
]
# 提取的图片按类型分流,按文件名中的编号排序
def _sort_key(name: str) -> tuple[int, int]:
# 新格式:figure_1.jpg, table_1.jpg
m = re.search(r"(?:figure|table)_(\d+)", name)
if m:
return (0, int(m.group(1)))
# 旧格式:page2_img1.png, page5_table1.png, figure_1.png
m2 = re.search(r"page(\d+)_(?:img|table)(\d+)", name)
if m2:
return (int(m2.group(1)), int(m2.group(2)))
return (0, 0)
fig_images = sorted(
[img for img in images if "table" not in img["name"].lower()],
key=lambda img: _sort_key(img["name"]),
)
table_images = sorted(
[img for img in images if "table" in img["name"].lower()],
key=lambda img: _sort_key(img["name"]),
)
for i, fig in enumerate(fig_type_unmatched):
if i < len(fig_images):
fig["image_url"] = fig_images[i]["url"]
for i, fig in enumerate(table_type_unmatched):
if i < len(table_images):
fig["image_url"] = table_images[i]["url"]
return figures
def _normalize_figure_id(raw_id: str) -> str:
"""归一化 Figure/Table ID'Figure 1'/'Fig.1''Figure 1'"""
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
if m:
return f"Figure {m.group(1)}"
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
if m2:
return f"Table {m2.group(1)}"
return raw_id
def _is_figure_type(fig_id: str) -> bool:
"""判断是否为 Figure 类型(非 Table)。"""
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
+5 -23
View File
@@ -2,12 +2,13 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
from app.exceptions import NotFoundError
from app.services.user_data import (
get_note,
save_note,
@@ -37,9 +38,6 @@ def bookmark_toggle(arxiv_id: str, request: Request, db: Session = Depends(get_d
"""切换收藏状态。支持 HTMX 局部刷新和 JSON 响应。"""
result = toggle_bookmark(db, arxiv_id)
if "error" in result:
raise HTTPException(status_code=404, detail=result["error"])
# HTMX 请求 → 返回 HTML 片段
if request.headers.get("HX-Request"):
star = "" if result["bookmarked"] else ""
@@ -66,18 +64,7 @@ def reading_status_update(
db: Session = Depends(get_db),
):
"""更新阅读状态。"""
result = set_reading_status(db, arxiv_id, body.status)
if "error" in result:
if result["error"] == "not_found":
raise HTTPException(status_code=404, detail="Paper not found")
elif result["error"] == "invalid_status":
raise HTTPException(
status_code=422,
detail=f"Invalid status. Valid: {result['valid']}",
)
return result
return set_reading_status(db, arxiv_id, body.status)
# ── 笔记 ──────────────────────────────────────────────────────────────
@@ -88,16 +75,11 @@ def note_get(arxiv_id: str, db: Session = Depends(get_db)):
"""获取笔记。"""
result = get_note(db, arxiv_id)
if result is None:
raise HTTPException(status_code=404, detail="Paper not found")
raise NotFoundError(f"Paper not found: {arxiv_id}")
return result
@router.post("/note/{arxiv_id}")
def note_save(arxiv_id: str, body: NoteRequest, db: Session = Depends(get_db)):
"""保存笔记。"""
result = save_note(db, arxiv_id, body.content)
if "error" in result:
raise HTTPException(status_code=404, detail=result["error"])
return result
return save_note(db, arxiv_id, body.content)