feat: refactor summarizer and PDF extraction pipeline

- Split summarizer into summary_generator and summary_persister modules
- Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection
- Add layout_detector service for PicoDet-S_layout_3cls integration
- Add exceptions module with ConflictError and NotFoundError
- Improve admin dashboard with better statistics and task management
- Add design review document with system optimization suggestions
- Add new tests for crawler, pdf_downloader, pipeline, and summary_utils
- Update dependencies and configuration
- Clean up dead code and improve error handling
This commit is contained in:
2026-06-13 13:16:47 +08:00
parent e2f0e1a8be
commit 21f16e6756
43 changed files with 3304 additions and 1494 deletions
+27 -14
View File
@@ -42,9 +42,16 @@ def crawl(
try:
# 检查是否已抓取过(非 force 模式)
if not force and not date_str:
existing = db.scalar(select(func.count(Paper.id)).where(Paper.paper_date == target)) or 0
existing = (
db.scalar(
select(func.count(Paper.id)).where(Paper.paper_date == target)
)
or 0
)
if existing > 0:
typer.echo(f"⏭️ {target} 已有 {existing} 篇论文,跳过(用 --force 强制重抓)")
typer.echo(
f"⏭️ {target} 已有 {existing} 篇论文,跳过(用 --force 强制重抓)"
)
return
typer.echo(f"📡 开始抓取 {target} ...")
@@ -56,7 +63,12 @@ def crawl(
)
if need_fallback:
fallback = yesterday_str()
existing = db.scalar(select(func.count(Paper.id)).where(Paper.paper_date == fallback)) or 0
existing = (
db.scalar(
select(func.count(Paper.id)).where(Paper.paper_date == fallback)
)
or 0
)
if existing > 0:
typer.echo(
f"⏭️ {fallback} 已有 {existing} 篇论文,跳过(用 --force 强制重抓)"
@@ -103,7 +115,9 @@ def summarize(
import os
if pdf_mode not in ("auto", "inject", "search"):
typer.echo(f"❌ 无效的 pdf_mode: {pdf_mode},只支持 auto / inject / search", err=True)
typer.echo(
f"❌ 无效的 pdf_mode: {pdf_mode},只支持 auto / inject / search", err=True
)
raise typer.Exit(code=1)
if backend:
@@ -122,6 +136,8 @@ def summarize(
datefmt="%H:%M:%S",
)
from app.exceptions import ConflictError, NotFoundError
db = SessionLocal()
try:
if arxiv_id:
@@ -131,16 +147,13 @@ def summarize(
typer.echo(f"🤖 开始批量总结 pending 论文 (mode={pdf_mode}) ...")
result = asyncio.run(summarize_batch(db, pdf_mode=pdf_mode))
if result.get("status") in ("success", "done"):
typer.echo(f"✅ 总结完成:{result}")
elif result.get("status") == "conflict":
typer.echo("⚠️ 已有批量总结任务在运行中", err=True)
raise typer.Exit(code=1)
elif result.get("status") == "not_found":
typer.echo(f"❌ 论文未找到:{arxiv_id}", err=True)
raise typer.Exit(code=1)
else:
typer.echo(f"⚠️ 总结结果:{result}", err=True)
typer.echo(f"✅ 总结完成:{result}")
except NotFoundError as exc:
typer.echo(f"{exc.message}", err=True)
raise typer.Exit(code=1) from exc
except ConflictError as exc:
typer.echo(f"⚠️ {exc.message}", err=True)
raise typer.Exit(code=1) from exc
finally:
db.close()
+8 -1
View File
@@ -27,6 +27,7 @@ class Settings(BaseSettings):
HTTP_TIMEOUT_SECONDS: int = 30
HTTP_MAX_RETRIES: int = 3
HTTP_USER_AGENT: str = "hf-daily-papers-local/0.1"
PDF_DOWNLOAD_TIMEOUT: int = 120
# AI 总结
SUMMARY_BACKEND: str = "pi" # "pi" | "claude"
@@ -36,7 +37,9 @@ class Settings(BaseSettings):
SUMMARY_CONCURRENCY: int = 3
SUMMARY_TIMEOUT_SECONDS: int = 1200
SUMMARY_MAX_RETRIES: int = 2
SUMMARY_PDF_MODE: str = "auto" # "auto" = ≤80k 用 inject>80k 用 search;也可强制 "inject" / "search"
SUMMARY_PDF_MODE: str = (
"auto" # "auto" = ≤80k 用 inject>80k 用 search;也可强制 "inject" / "search"
)
# 调度
SCHEDULER_ENABLED: bool = False
@@ -56,6 +59,10 @@ class Settings(BaseSettings):
EMBED_MODEL: str = ""
EMBED_DIMENSIONS: int = 0
# 布局检测
LAYOUT_MODEL_PATH: str = "data/models/picodet_layout_3cls.onnx"
LAYOUT_THRESHOLD: float = 0.5
model_config = {
"env_file": str(BASE_DIR / ".env"),
"env_file_encoding": "utf-8",
+2 -5
View File
@@ -82,15 +82,12 @@ def _migrate(engine) -> None:
for table, columns in _MIGRATIONS.items():
# 获取已有列名
existing = {
row[1]
for row in conn.execute(text(f"PRAGMA table_info({table})"))
row[1] for row in conn.execute(text(f"PRAGMA table_info({table})"))
}
for col_name, col_type in columns:
if col_name not in existing:
conn.execute(
text(
f"ALTER TABLE {table} ADD COLUMN {col_name} {col_type}"
)
text(f"ALTER TABLE {table} ADD COLUMN {col_name} {col_type}")
)
logger.info("Migrated: %s.%s added", table, col_name)
conn.commit()
+35
View File
@@ -0,0 +1,35 @@
"""业务异常体系 — 统一错误类型,供路由层和 service 层使用。
路由层通过 main.py 的 @app.exception_handler(AppError) 统一捕获,
转为对应 HTTP 状态码 + JSON 响应。
"""
from __future__ import annotations
class AppError(Exception):
"""所有业务异常的基类。"""
def __init__(self, message: str = "", *, detail: str = ""):
self.message = message or detail or self.__class__.__name__
super().__init__(self.message)
class NotFoundError(AppError):
"""资源不存在(404)。"""
class ValidationError(AppError):
"""请求参数校验失败(400)。"""
class ExternalAPIError(AppError):
"""外部 API 调用失败(502)。"""
class PdfProcessError(AppError):
"""PDF 处理错误(500)。"""
class ConflictError(AppError):
"""资源冲突(409)— 如锁冲突、并发任务冲突。"""
+30 -3
View File
@@ -5,10 +5,12 @@ import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from starlette.middleware.sessions import SessionMiddleware
from app.config import settings
from app.exceptions import AppError, ConflictError, ExternalAPIError, NotFoundError, PdfProcessError, ValidationError
from app.database import engine, init_db
from app.routes.admin import router as admin_router
from app.routes.compare import router as compare_router
@@ -38,8 +40,10 @@ async def lifespan(app: FastAPI):
# ── shutdown ──
from app.services.scheduler import stop_scheduler
from app.services.pdf_downloader import close_http_session
stop_scheduler()
close_http_session()
def create_app() -> FastAPI:
@@ -60,15 +64,38 @@ def create_app() -> FastAPI:
# Session 中间件
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY)
# ── 统一业务异常处理 ──
@app.exception_handler(NotFoundError)
async def _not_found_handler(request, exc):
return JSONResponse(status_code=404, content={"error": exc.message})
@app.exception_handler(ValidationError)
async def _validation_handler(request, exc):
return JSONResponse(status_code=400, content={"error": exc.message})
@app.exception_handler(ExternalAPIError)
async def _external_api_handler(request, exc):
return JSONResponse(status_code=502, content={"error": exc.message})
@app.exception_handler(PdfProcessError)
async def _pdf_process_handler(request, exc):
return JSONResponse(status_code=500, content={"error": exc.message})
@app.exception_handler(ConflictError)
async def _conflict_handler(request, exc):
return JSONResponse(status_code=409, content={"error": exc.message})
@app.exception_handler(AppError)
async def _app_error_handler(request, exc):
return JSONResponse(status_code=500, content={"error": exc.message})
# 安全警告
if settings.SECRET_KEY == "change-me":
logger.warning(
"⚠️ SECRET_KEY is the default value 'change-me'. Please change it in .env!"
)
if not settings.ADMIN_PASSWORD:
logger.warning(
"⚠️ ADMIN_PASSWORD is empty. Please set it in .env!"
)
logger.warning("⚠️ ADMIN_PASSWORD is empty. Please set it in .env!")
# 静态文件
app.mount("/static", StaticFiles(directory="app/static"), name="static")
+17 -4
View File
@@ -12,6 +12,7 @@ from sqlalchemy import (
String,
Text,
UniqueConstraint,
select,
)
from sqlalchemy.orm import joinedload, relationship
@@ -93,7 +94,7 @@ class PaperAuthor(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
paper_id = Column(
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False, index=True
)
name = Column(String, nullable=False)
position = Column(Integer, default=0)
@@ -108,7 +109,7 @@ class PaperTag(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
paper_id = Column(
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False, index=True
)
tag = Column(String, nullable=False)
source = Column(String, default="hf")
@@ -155,7 +156,7 @@ class SummaryStatus(Base):
paper_id = Column(
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
)
status = Column(String, nullable=False, default="pending")
status = Column(String, nullable=False, default="pending", index=True)
quality = Column(String)
error_type = Column(String)
error = Column(Text)
@@ -219,7 +220,7 @@ class UserReadingStatus(Base):
paper_id = Column(
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
)
status = Column(String, nullable=False, default="unread")
status = Column(String, nullable=False, default="unread", index=True)
updated_at = Column(DateTime, nullable=False)
paper = relationship("Paper", back_populates="reading_status")
@@ -271,3 +272,15 @@ PAPER_FULL_LOAD = (
joinedload(Paper.bookmark),
joinedload(Paper.reading_status),
)
def get_paper_by_arxiv_id(db, arxiv_id: str, *, load=PAPER_DEFAULT_LOAD):
"""按 arxiv_id 查询论文(带关联加载),未找到返回 None。"""
stmt = select(Paper).where(Paper.arxiv_id == arxiv_id).options(*load)
return db.execute(stmt).unique().scalar_one_or_none()
def get_paper_by_id(db, paper_id: int, *, load=PAPER_DEFAULT_LOAD):
"""按主键查询论文(带关联加载),未找到返回 None。"""
stmt = select(Paper).where(Paper.id == paper_id).options(*load)
return db.execute(stmt).unique().scalar_one_or_none()
+94 -144
View File
@@ -3,6 +3,7 @@
from __future__ import annotations
import hashlib
import hmac
import json
import logging
from datetime import date
@@ -10,7 +11,7 @@ from datetime import date
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request
from fastapi.responses import RedirectResponse
from pydantic import BaseModel, field_validator
from sqlalchemy import func, select, text
from sqlalchemy import bindparam, func, select, text
from sqlalchemy.orm import Session
from app.config import settings
@@ -22,15 +23,15 @@ from app.models import (
PaperTag,
SummaryState,
SummaryStatus,
TaskLock,
)
from app.services import admin as admin_svc
from app.services.admin import get_admin_stats
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
from app.services.crawler import crawl_daily, refresh_upvotes
from app.services.pipeline import run_pipeline
from app.services.crawler import refresh_upvotes
from app.services.pipeline import run_crawl, run_pipeline
from app.services.scheduler import get_scheduler
from app.services.summarizer import summarize_batch, summarize_single
from app.utils import release_lock, templates, today_str, utc_now
from app.utils import templates, today_str, utc_now
logger = logging.getLogger(__name__)
@@ -41,14 +42,15 @@ router = APIRouter(prefix="/admin", tags=["admin"])
def _check_password(password: str) -> bool:
"""校验密码,支持明文或 sha256 哈希。"""
"""校验密码,支持明文或 sha256 哈希(常量时间比较)"""
stored = settings.ADMIN_PASSWORD
if not stored:
return False
if password == stored:
if hmac.compare_digest(password, stored):
return True
# 也支持存 sha256 哈希
return hashlib.sha256(password.encode()).hexdigest() == stored
hashed = hashlib.sha256(password.encode()).hexdigest()
return hmac.compare_digest(hashed, stored)
async def verify_admin(request: Request) -> None:
@@ -204,32 +206,12 @@ async def admin_crawl(
):
"""手动抓取指定日期,默认今天。"""
target_date = date or today_str()
# TaskLock 防重入
now = utc_now()
lock = TaskLock(
task="crawl",
lock_key=target_date,
status="running",
owner="admin_crawl",
acquired_at=now,
)
try:
db.add(lock)
db.commit()
except Exception:
db.rollback()
raise HTTPException(
status_code=409, detail=f"Crawl already running for {target_date}"
)
try:
result = await crawl_daily(db, target_date)
return result
return await run_crawl(db, target_date, owner="admin_crawl")
except RuntimeError as exc:
raise HTTPException(status_code=409, detail=str(exc))
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
finally:
release_lock(db, lock)
# ── 总结 ──────────────────────────────────────────────────────────────
@@ -241,12 +223,7 @@ async def admin_summarize_batch(
db: Session = Depends(get_db),
):
"""批量总结所有 pending 论文。"""
result = await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE)
if result.get("status") == "conflict":
raise HTTPException(
status_code=409, detail=result.get("error", "batch already running")
)
return result
return await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE)
@router.post("/summarize/{arxiv_id}")
@@ -256,10 +233,9 @@ async def admin_summarize_single(
db: Session = Depends(get_db),
):
"""总结或重跑单篇论文。"""
result = await summarize_single(db, arxiv_id, force=True, pdf_mode=settings.SUMMARY_PDF_MODE)
if result.get("status") == "not_found":
raise HTTPException(status_code=404, detail=f"Paper not found: {arxiv_id}")
return result
return await summarize_single(
db, arxiv_id, force=True, pdf_mode=settings.SUMMARY_PDF_MODE
)
# ── 清理 ──────────────────────────────────────────────────────────────
@@ -284,10 +260,13 @@ async def admin_cleanup(
result = cleanup_tmp()
log_entry.status = "success"
log_entry.completed_at = utc_now()
log_entry.details_json = json.dumps({
"scanned": result.get("scanned", 0),
"removed": result.get("removed", 0),
}, ensure_ascii=False)
log_entry.details_json = json.dumps(
{
"scanned": result.get("scanned", 0),
"removed": result.get("removed", 0),
},
ensure_ascii=False,
)
if result.get("errors"):
log_entry.error = "; ".join(result["errors"])[:2000]
db.commit()
@@ -358,19 +337,34 @@ async def admin_logs(
# 总结状态统计概要
summary_total = db.scalar(select(func.count(Paper.id))) or 0
summary_done = db.scalar(
select(func.count(SummaryStatus.id)).where(SummaryStatus.status == SummaryState.DONE)
) or 0
summary_pending = db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_([SummaryState.PENDING, SummaryState.PROCESSING])
summary_done = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status == SummaryState.DONE
)
)
) or 0
summary_failed = db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE])
or 0
)
summary_pending = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_(
[SummaryState.PENDING, SummaryState.PROCESSING]
)
)
)
) or 0
or 0
)
summary_failed = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
)
or 0
)
return templates.TemplateResponse(
request,
@@ -414,13 +408,8 @@ async def admin_summary_status(
else:
query = query.where(SummaryStatus.status == status)
total = db.scalar(
select(func.count()).select_from(query.subquery())
)
results = (
db.execute(query.offset((page - 1) * per_page).limit(per_page))
.all()
)
total = db.scalar(select(func.count()).select_from(query.subquery()))
results = db.execute(query.offset((page - 1) * per_page).limit(per_page)).all()
# 判断是否 HTMX 请求
is_htmx = request.headers.get("HX-Request") == "true"
@@ -465,7 +454,11 @@ async def admin_summary_retry_failed(
db.execute(
select(Paper.arxiv_id)
.join(SummaryStatus, SummaryStatus.paper_id == Paper.id)
.where(SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]))
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
)
.scalars()
.all()
@@ -477,7 +470,11 @@ async def admin_summary_retry_failed(
# 重置失败任务的状态为 pending
db.execute(
SummaryStatus.__table__.update()
.where(SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]))
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
.values(status=SummaryState.PENDING, error=None, error_type=None)
)
db.commit()
@@ -492,15 +489,6 @@ async def admin_summary_retry_failed(
# ── 论文管理 ────────────────────────────────────────────────────────
# 排序映射
_SORT_MAP = {
"date_desc": Paper.paper_date.desc(),
"date_asc": Paper.paper_date.asc(),
"upvotes_desc": Paper.upvotes.desc(),
"title_asc": Paper.title_en.asc(),
}
@router.get("/papers")
async def admin_papers(
request: Request,
@@ -516,66 +504,18 @@ async def admin_papers(
per_page: int = Query(20, ge=1, le=100),
):
"""论文管理列表页面。"""
query = select(Paper)
# 搜索
if q.strip():
query = query.where(
Paper.title_en.ilike(f"%{q}%")
| Paper.title_zh.ilike(f"%{q}%")
| Paper.abstract.ilike(f"%{q}%")
)
# 日期范围
if date_from:
query = query.where(Paper.paper_date >= date_from)
if date_to:
query = query.where(Paper.paper_date <= date_to)
# 标签筛选
if tag:
query = query.join(PaperTag, PaperTag.paper_id == Paper.id).where(
PaperTag.tag == tag
)
# 总结状态筛选
if summary_status != "all":
if summary_status == "none":
query = query.outerjoin(
SummaryStatus, SummaryStatus.paper_id == Paper.id
).where(SummaryStatus.paper_id == None) # noqa: E711
else:
query = query.join(
SummaryStatus, SummaryStatus.paper_id == Paper.id
).where(SummaryStatus.status == summary_status)
# 排序
order = _SORT_MAP.get(sort, Paper.paper_date.desc())
query = query.order_by(order)
# 计数
total = db.scalar(select(func.count()).select_from(query.subquery()))
# 分页
papers = (
db.execute(query.offset((page - 1) * per_page).limit(per_page))
.scalars()
.all()
papers, total, statuses = admin_svc.query_papers(
db,
q=q,
date_from=date_from,
date_to=date_to,
tag=tag,
summary_status=summary_status,
sort=sort,
page=page,
per_page=per_page,
)
# 获取每篇论文的总结状态
paper_ids = [p.id for p in papers]
statuses = {}
if paper_ids:
rows = db.execute(
select(SummaryStatus.paper_id, SummaryStatus.status).where(
SummaryStatus.paper_id.in_(paper_ids)
)
).all()
paper_id_to_arxiv = {p.id: p.arxiv_id for p in papers}
for pid, st in rows:
statuses[paper_id_to_arxiv.get(pid, "")] = st
# 构建分页 URL 辅助函数
def pagination_url(p: int) -> str:
params = dict(request.query_params)
@@ -588,7 +528,7 @@ async def admin_papers(
{
"papers": papers,
"paper_summary_statuses": statuses,
"total": total or 0,
"total": total,
"page": page,
"per_page": per_page,
"current_status": summary_status,
@@ -615,7 +555,9 @@ async def admin_paper_delete(
# 清理 FTS 索引
try:
db.execute(text("DELETE FROM papers_fts WHERE arxiv_id = :aid"), {"aid": arxiv_id})
db.execute(
text("DELETE FROM papers_fts WHERE arxiv_id = :aid"), {"aid": arxiv_id}
)
db.commit()
except Exception:
logger.warning("Failed to clean FTS index for %s", arxiv_id, exc_info=True)
@@ -646,9 +588,11 @@ async def admin_papers_batch_action(
raise HTTPException(status_code=400, detail="arxiv_ids 不能为空")
if body.action == "delete":
papers = db.execute(
select(Paper).where(Paper.arxiv_id.in_(body.arxiv_ids))
).scalars().all()
papers = (
db.execute(select(Paper).where(Paper.arxiv_id.in_(body.arxiv_ids)))
.scalars()
.all()
)
count = 0
for paper in papers:
@@ -658,21 +602,27 @@ async def admin_papers_batch_action(
# 清理 FTS 索引
try:
db.execute(
text("DELETE FROM papers_fts WHERE arxiv_id IN :ids"),
{"ids": tuple(body.arxiv_ids)},
stmt = text("DELETE FROM papers_fts WHERE arxiv_id IN :ids").bindparams(
bindparam("ids", expanding=True)
)
db.execute(stmt, {"ids": body.arxiv_ids})
db.commit()
except Exception:
logger.warning("Failed to clean FTS index for batch delete", exc_info=True)
return {"status": "success", "message": f"已删除 {count} 篇论文", "count": count}
return {
"status": "success",
"message": f"已删除 {count} 篇论文",
"count": count,
}
elif body.action == "summarize":
# 将选中论文的总结状态重置为 pending
paper_ids = db.execute(
select(Paper.id).where(Paper.arxiv_id.in_(body.arxiv_ids))
).scalars().all()
paper_ids = (
db.execute(select(Paper.id).where(Paper.arxiv_id.in_(body.arxiv_ids)))
.scalars()
.all()
)
if paper_ids:
# 删除旧的 status 记录让其重新进入 pipeline
+5 -3
View File
@@ -12,6 +12,8 @@ from app.utils import templates
router = APIRouter()
MAX_COMPARE_PAPERS = 5
@router.get("/compare")
def compare_page(
@@ -33,9 +35,9 @@ def compare_page(
arxiv_ids = [i.strip() for i in ids.split(",") if i.strip()]
# 最多 5
if len(arxiv_ids) > 5:
arxiv_ids = arxiv_ids[:5]
# 最多 MAX_COMPARE_PAPERS
if len(arxiv_ids) > MAX_COMPARE_PAPERS:
arxiv_ids = arxiv_ids[:MAX_COMPARE_PAPERS]
if not arxiv_ids:
return templates.TemplateResponse(
+2 -99
View File
@@ -4,7 +4,6 @@ from __future__ import annotations
import json
import logging
import re
from datetime import date, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query, Request
@@ -15,6 +14,7 @@ from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.database import get_db
from app.models import PAPER_FULL_LOAD, Paper
from app.services.pdf_image_extractor import link_figures_with_images
from app.utils import (
PAPERS_DIR,
safe_json_loads,
@@ -120,7 +120,7 @@ def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db))
paper.summary.figures_json if paper.summary else None, default=[]
)
linked_figures = _link_figures_with_images(figures_raw, images, arxiv_id)
linked_figures = link_figures_with_images(figures_raw, images, arxiv_id)
# 拆分图片到对应展示区域:
# table_figures → 实验结果区域(Table 截图,不变)
@@ -279,100 +279,3 @@ def _get_paper_images(arxiv_id: str) -> list[dict]:
}
)
return images
def _link_figures_with_images(
figures: list[dict], images: list[dict], arxiv_id: str
) -> list[dict]:
"""将 summary figures 元数据与提取的图片文件关联。
策略:
1. 优先用 manifest.json 的 label 做 ID 精确匹配
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
"""
if not figures or not images:
return figures
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
# ── 策略 1manifest ID 精确匹配 ──
id_to_url: dict[str, str] = {}
if manifest_path.exists():
try:
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
except (ValueError, TypeError):
manifest = {}
for filename, info in manifest.items():
url = f"/papers/{arxiv_id}/images/{filename}"
# 优先用 label 字段(新格式)
label = info.get("label", "")
if label:
id_to_url[label] = url
# 也兼容 figures/tables 列表(旧格式)
for fig_id in info.get("figures", []) + info.get("tables", []):
if fig_id not in id_to_url:
id_to_url[fig_id] = url
for fig in figures:
raw_id = fig.get("id", "")
normalized = _normalize_figure_id(raw_id)
if normalized in id_to_url:
fig["image_url"] = id_to_url[normalized]
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
unmatched = [f for f in figures if not f.get("image_url")]
if not unmatched:
return figures
# 按类型分流:Figure vs Table
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
table_type_unmatched = [
f for f in unmatched if not _is_figure_type(f.get("id", ""))
]
# 提取的图片按类型分流,按文件名中的编号排序
def _sort_key(name: str) -> tuple[int, int]:
# 新格式:figure_1.jpg, table_1.jpg
m = re.search(r"(?:figure|table)_(\d+)", name)
if m:
return (0, int(m.group(1)))
# 旧格式:page2_img1.png, page5_table1.png, figure_1.png
m2 = re.search(r"page(\d+)_(?:img|table)(\d+)", name)
if m2:
return (int(m2.group(1)), int(m2.group(2)))
return (0, 0)
fig_images = sorted(
[img for img in images if "table" not in img["name"].lower()],
key=lambda img: _sort_key(img["name"]),
)
table_images = sorted(
[img for img in images if "table" in img["name"].lower()],
key=lambda img: _sort_key(img["name"]),
)
for i, fig in enumerate(fig_type_unmatched):
if i < len(fig_images):
fig["image_url"] = fig_images[i]["url"]
for i, fig in enumerate(table_type_unmatched):
if i < len(table_images):
fig["image_url"] = table_images[i]["url"]
return figures
def _normalize_figure_id(raw_id: str) -> str:
"""归一化 Figure/Table ID'Figure 1'/'Fig.1''Figure 1'"""
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
if m:
return f"Figure {m.group(1)}"
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
if m2:
return f"Table {m2.group(1)}"
return raw_id
def _is_figure_type(fig_id: str) -> bool:
"""判断是否为 Figure 类型(非 Table)。"""
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
+5 -23
View File
@@ -2,12 +2,13 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
from app.exceptions import NotFoundError
from app.services.user_data import (
get_note,
save_note,
@@ -37,9 +38,6 @@ def bookmark_toggle(arxiv_id: str, request: Request, db: Session = Depends(get_d
"""切换收藏状态。支持 HTMX 局部刷新和 JSON 响应。"""
result = toggle_bookmark(db, arxiv_id)
if "error" in result:
raise HTTPException(status_code=404, detail=result["error"])
# HTMX 请求 → 返回 HTML 片段
if request.headers.get("HX-Request"):
star = "" if result["bookmarked"] else ""
@@ -66,18 +64,7 @@ def reading_status_update(
db: Session = Depends(get_db),
):
"""更新阅读状态。"""
result = set_reading_status(db, arxiv_id, body.status)
if "error" in result:
if result["error"] == "not_found":
raise HTTPException(status_code=404, detail="Paper not found")
elif result["error"] == "invalid_status":
raise HTTPException(
status_code=422,
detail=f"Invalid status. Valid: {result['valid']}",
)
return result
return set_reading_status(db, arxiv_id, body.status)
# ── 笔记 ──────────────────────────────────────────────────────────────
@@ -88,16 +75,11 @@ def note_get(arxiv_id: str, db: Session = Depends(get_db)):
"""获取笔记。"""
result = get_note(db, arxiv_id)
if result is None:
raise HTTPException(status_code=404, detail="Paper not found")
raise NotFoundError(f"Paper not found: {arxiv_id}")
return result
@router.post("/note/{arxiv_id}")
def note_save(arxiv_id: str, body: NoteRequest, db: Session = Depends(get_db)):
"""保存笔记。"""
result = save_note(db, arxiv_id, body.content)
if "error" in result:
raise HTTPException(status_code=404, detail=result["error"])
return result
return save_note(db, arxiv_id, body.content)
+94 -12
View File
@@ -9,10 +9,18 @@ from sqlalchemy import func, select, text
from sqlalchemy.orm import Session
from app.config import settings
from app.models import CrawlLog, Paper, SummaryState, TaskLock
from app.models import CrawlLog, Paper, PaperTag, SummaryState, SummaryStatus, TaskLock
from app.services.scheduler import get_scheduler
from app.utils import PAPERS_DIR, TMP_DIR
# admin_papers 排序映射
SORT_MAP = {
"date_desc": Paper.paper_date.desc(),
"date_asc": Paper.paper_date.asc(),
"upvotes_desc": Paper.upvotes.desc(),
"title_asc": Paper.title_en.asc(),
}
def _dir_size(path: Path) -> int:
"""递归计算目录总字节数。"""
@@ -52,7 +60,11 @@ def get_admin_stats(db: Session) -> dict:
status_counts = {row[0]: row[1] for row in summary_rows}
# ── 存储概况 ──────────────────────────────────────────────────────
db_size = _fmt_size(settings.db_path.stat().st_size) if settings.db_path.exists() else "0 B"
db_size = (
_fmt_size(settings.db_path.stat().st_size)
if settings.db_path.exists()
else "0 B"
)
papers_size = _fmt_size(_dir_size(PAPERS_DIR))
tmp_size = _fmt_size(_dir_size(TMP_DIR))
@@ -68,22 +80,14 @@ def get_admin_stats(db: Session) -> dict:
# ── 最近日志(5 条) ──────────────────────────────────────────────
recent_logs = (
db.execute(
select(CrawlLog)
.order_by(CrawlLog.started_at.desc())
.limit(5)
)
db.execute(select(CrawlLog).order_by(CrawlLog.started_at.desc()).limit(5))
.scalars()
.all()
)
# ── 活跃锁 ────────────────────────────────────────────────────────
active_locks = (
db.execute(
select(TaskLock).where(TaskLock.status == "running")
)
.scalars()
.all()
db.execute(select(TaskLock).where(TaskLock.status == "running")).scalars().all()
)
return {
@@ -108,3 +112,81 @@ def get_admin_stats(db: Session) -> dict:
"active_locks": active_locks,
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
}
def query_papers(
db: Session,
*,
q: str = "",
date_from: str | None = None,
date_to: str | None = None,
tag: str = "",
summary_status: str = "all",
sort: str = "date_desc",
page: int = 1,
per_page: int = 20,
) -> tuple[list[Paper], int, dict[str, str]]:
"""论文管理查询 — 构建过滤、排序、分页。
Returns:
(papers, total, statuses) — 论文列表、总数、{arxiv_id: summary_status}
"""
query = select(Paper)
# 搜索
if q.strip():
query = query.where(
Paper.title_en.ilike(f"%{q}%")
| Paper.title_zh.ilike(f"%{q}%")
| Paper.abstract.ilike(f"%{q}%")
)
# 日期范围
if date_from:
query = query.where(Paper.paper_date >= date_from)
if date_to:
query = query.where(Paper.paper_date <= date_to)
# 标签筛选
if tag:
query = query.join(PaperTag, PaperTag.paper_id == Paper.id).where(
PaperTag.tag == tag
)
# 总结状态筛选
if summary_status != "all":
if summary_status == "none":
query = query.outerjoin(
SummaryStatus, SummaryStatus.paper_id == Paper.id
).where(SummaryStatus.paper_id == None) # noqa: E711
else:
query = query.join(SummaryStatus, SummaryStatus.paper_id == Paper.id).where(
SummaryStatus.status == summary_status
)
# 排序
order = SORT_MAP.get(sort, Paper.paper_date.desc())
query = query.order_by(order)
# 计数
total = db.scalar(select(func.count()).select_from(query.subquery()))
# 分页
papers = (
db.execute(query.offset((page - 1) * per_page).limit(per_page)).scalars().all()
)
# 每篇论文的总结状态
paper_ids = [p.id for p in papers]
statuses: dict[str, str] = {}
if paper_ids:
rows = db.execute(
select(SummaryStatus.paper_id, SummaryStatus.status).where(
SummaryStatus.paper_id.in_(paper_ids)
)
).all()
paper_id_to_arxiv = {p.id: p.arxiv_id for p in papers}
for pid, st in rows:
statuses[paper_id_to_arxiv.get(pid, "")] = st
return papers, total or 0, statuses
+8 -5
View File
@@ -207,11 +207,14 @@ async def delete_papers_by_date_range(
completed_at=utc_now(),
papers_found=total,
papers_new=deleted,
details_json=json.dumps({
"total_before": total,
"deleted": deleted,
"failed": len(failed_items),
}, ensure_ascii=False),
details_json=json.dumps(
{
"total_before": total,
"deleted": deleted,
"failed": len(failed_items),
},
ensure_ascii=False,
),
error=job_error,
)
db.add(log_entry)
+9 -5
View File
@@ -189,11 +189,15 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
db = SessionLocal()
try:
paper = db.execute(
select(Paper)
.where(Paper.arxiv_id == paper_id)
.options(joinedload(Paper.tags), joinedload(Paper.summary))
).unique().scalar_one_or_none()
paper = (
db.execute(
select(Paper)
.where(Paper.arxiv_id == paper_id)
.options(joinedload(Paper.tags), joinedload(Paper.summary))
)
.unique()
.scalar_one_or_none()
)
if not paper:
logger.warning("Paper %s not found for indexing", paper_id)
return False
+174
View File
@@ -0,0 +1,174 @@
"""PicoDet-S_layout_3cls 布局检测 — 纯 ONNX Runtime 推理.
用 onnxruntime 加载导出好的 ONNX 模型,检测 PDF 页面中的 figure / table 区域。
模型自带 NMS + GFL decode,输出即为后处理完毕的检测框。
输入:
image: (1, 3, 480, 480) float32 — ImageNet 标准化后的图片
scale_factor: (1, 2) float32 — [y_scale, x_scale],用于坐标还原
输出:
fetch_name_0: (N, 6) float32 — [xmin, ymin, xmax, ymax, score, class_id]
fetch_name_1: (1,) int32 — 有效框数量 N
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import onnxruntime as ort
import pymupdf
from app.config import settings
logger = logging.getLogger(__name__)
# 模型输入尺寸
_MODEL_SIZE = 480
# ImageNet normalize
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
# PicoDet label → 内部 boxclass
_LABEL_MAP: dict[int, str] = {
0: "picture", # PicoDet "image" → "picture"
1: "table",
# 2: seal — 忽略
}
# 最小 bbox 尺寸(PDF 点)
_MIN_BOX_SIZE = 20
@dataclass
class LayoutBox:
"""检测到的布局区域,兼容现有 _process_page 代码。"""
x0: float
y0: float
x1: float
y1: float
boxclass: str # "picture" | "table"
class _LayoutDetector:
"""单例:管理 ONNX InferenceSession 生命周期。"""
def __init__(self) -> None:
self._session: ort.InferenceSession | None = None
def _init_session(self) -> ort.InferenceSession:
if self._session is not None:
return self._session
model_path = Path(settings.LAYOUT_MODEL_PATH)
if not model_path.exists():
raise FileNotFoundError(
f"Layout model not found: {model_path}. "
"Run scripts/export_picodet_onnx.py first."
)
logger.info("Loading ONNX layout model: %s", model_path)
self._session = ort.InferenceSession(
str(model_path), providers=["CPUExecutionProvider"]
)
logger.info("ONNX layout model loaded")
return self._session
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
"""检测单页 PDF 的 figure / table 区域。
流程:
1. pymupdf 以 480×480 渲染页面
2. ImageNet normalize → NCHW
3. ONNX 推理 → 得到已解码+NMS 的检测框
4. 像素坐标 → PDF 点坐标
5. 过滤 seal 类和低置信度框
Args:
page: pymupdf Page 对象
Returns:
LayoutBox 列表,坐标为 PDF 点
"""
session = self._init_session()
page_w = page.rect.width
page_h = page.rect.height
# 1. 渲染页面到 _MODEL_SIZE × _MODEL_SIZE
zoom_x = _MODEL_SIZE / page_w
zoom_y = _MODEL_SIZE / page_h
mat = pymupdf.Matrix(zoom_x, zoom_y)
pix = page.get_pixmap(matrix=mat)
# 2. 预处理
img = (
np.frombuffer(pix.samples, dtype=np.uint8)
.reshape(pix.height, pix.width, pix.n)
.astype(np.float32)
/ 255.0
)
# 去掉 alpha 通道(如有)
if img.shape[2] == 4:
img = img[:, :, :3]
img = (img - _MEAN) / _STD
img = img.transpose(2, 0, 1)[np.newaxis] # (1, 3, H, W)
# scale_factor 用于坐标还原(模型内部可能用)
scale_factor = np.array([[1.0, 1.0]], dtype=np.float32)
# 3. 推理
input_names = [i.name for i in session.get_inputs()]
feed = {input_names[0]: img}
if len(input_names) > 1:
feed[input_names[1]] = scale_factor
outputs = session.run(None, feed)
boxes_raw = outputs[0] # (N, 6): [class_id, score, xmin, ymin, xmax, ymax]
num_boxes = int(outputs[1][0]) # 有效框数
if num_boxes == 0:
return []
# 4. 像素 → PDF 点坐标
sx = page_w / _MODEL_SIZE
sy = page_h / _MODEL_SIZE
result: list[LayoutBox] = []
for i in range(min(num_boxes, len(boxes_raw))):
cls_id, score, xmin, ymin, xmax, ymax = boxes_raw[i]
cls_id = int(cls_id)
# 跳过 seal 类和低置信度
if cls_id not in _LABEL_MAP:
continue
if score < settings.LAYOUT_THRESHOLD:
continue
x0, y0 = xmin * sx, ymin * sy
x1, y1 = xmax * sx, ymax * sy
# 跳过极小区域
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
continue
result.append(
LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=_LABEL_MAP[cls_id])
)
return result
# 模块级单例
_detector = _LayoutDetector()
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
"""检测 PDF 页面中的 figure / table 区域。
Returns:
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table。
"""
return _detector.detect_page(page)
+16 -1
View File
@@ -9,6 +9,7 @@ from pathlib import Path
import requests
from app.config import settings
from app.utils import PAPERS_DIR, TMP_DIR
logger = logging.getLogger(__name__)
@@ -51,6 +52,14 @@ def _get_session() -> requests.Session:
return _http_session
def close_http_session() -> None:
"""关闭全局 HTTP Session,供应用 shutdown 时调用。"""
global _http_session
if _http_session is not None:
_http_session.close()
_http_session = None
async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
if not pdf_url:
@@ -62,10 +71,16 @@ async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
try:
session = _get_session()
resp = session.get(pdf_url, timeout=120, allow_redirects=True)
resp = session.get(pdf_url, timeout=settings.PDF_DOWNLOAD_TIMEOUT, allow_redirects=True)
resp.raise_for_status()
dest.write_bytes(resp.content)
except Exception as exc:
# 清理残留的部分文件
if dest.exists():
try:
dest.unlink()
except OSError:
pass
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
+383 -462
View File
@@ -1,12 +1,12 @@
"""PDF 图片与表格提取 — 基于 pymupdf4llm layout analysis
"""PDF 图片与表格提取 — 两阶段流水线
用 pymupdf4llm 的 layout analysis 检测 table / picture 区域,
再通过 caption 文字匹配确定 Figure/Table 编号,渲染为 JPEG。
Phase 1: PicoDet-S_layout_3cls 检测 figure/table 区域 → 渲染为 JPEG(通用标签)
Phase 2: 用 LLM summary 的 figures[].id 在 PDF 中搜索定位 → 匹配到 box → 重命名
相比旧方案(caption 正则 + pdfplumber/find_tables/文本块扫描三套策略):
- layout analysis 直接给出区域 bbox,不存在相邻表格互相侵入的问题
- 无需手动调参(最大高度、间隙阈值等)
- 页面级 caption 匹配:每个 caption 只分配给最近的 box,避免上下相邻表格抢夺同一个 caption
相比旧方案(正则匹配 caption):
- 不再依赖正则,用 LLM 输出的 ID 直接搜索 PDF 文本
- page.search_for() 精确搜索 + 空间距离过滤,避免正文引用误匹配
- 通用标签兜底,LLM 没提到的图表不会被丢弃
"""
from __future__ import annotations
@@ -17,44 +17,30 @@ import re
from pathlib import Path
import pymupdf
import pymupdf4llm.helpers.document_layout as dl
from app.services.layout_detector import LayoutBox, detect_page_layout
from app.services.pdf_downloader import paper_dir
from app.utils import TMP_DIR
from app.utils import PAPERS_DIR, TMP_DIR
logger = logging.getLogger(__name__)
# ── Caption 正则 ───────────────────────────────────────────────────────
# 用于从 caption 文字中提取 Figure/Table 编号
_FIGURE_CAPTION_RE = re.compile(
r"^(?:Fig\.?|Figure)\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))",
re.IGNORECASE,
)
_TABLE_CAPTION_RE = re.compile(
r"^Table\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))",
re.IGNORECASE,
)
# caption 与 table/picture 的最大匹配距离(点)
_CAPTION_MATCH_DISTANCE = 100
# 截图区域的外边距
# 截图区域的外边距(单位: pt
_REGION_PADDING = 5
# 3x 渲染,保证清晰度
# 渲染倍率(3x 保证清晰度
_RENDER_ZOOM = 3
# 相邻 box 聚类间距()— 同一 figure/table 的碎片间距通常 < 15pt
# 相邻 box 聚类间距(单位: pt)— 同一 figure/table 的碎片间距通常 < 15pt
_CLUSTER_GAP = 15
# 最小 bbox 面积(单位: pt²)— 过滤 icon/logo 等微小误检
_MIN_BOX_AREA = 2000
# Phase 2: 搜索文本到 box 的最大匹配距离(单位: pt)
_LABEL_MATCH_DISTANCE = 100
# ── Box 聚类 ─────────────────────────────────────────────────────────
class _BoxCluster:
"""合并后的布局区域(由一个或多个相邻 LayoutBox 组成)。
pymupdf4llm 有时将一个大图拆成多个小 picture box(如视频帧网格),
聚类后用整体 bbox 作为渲染区域。
"""
"""合并后的布局区域(由一个或多个相邻 LayoutBox 组成)。"""
__slots__ = ("x0", "y0", "x1", "y1", "boxclass")
@@ -63,17 +49,12 @@ class _BoxCluster:
self.y0 = min(b.y0 for b in boxes)
self.x1 = max(b.x1 for b in boxes)
self.y1 = max(b.y1 for b in boxes)
# table-fallback 归一化为 tablelayout model 检测到表格但无法提取结构)
raw = boxes[0].boxclass
self.boxclass = "table" if raw == "table-fallback" else raw
def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
"""将相邻的同类型 box 合并为聚类。
用 union-find 将间距 ≤ gap 的同类型 box 归为一组,
每组生成一个 _BoxCluster(整体 bbox)。
"""
"""将相邻的同类型 box 合并为聚类。"""
if not boxes:
return []
@@ -111,242 +92,58 @@ def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
return [_BoxCluster(members) for members in groups.values()]
# ── 页面级 Caption 查找与匹配 ──────────────────────────────────────────
# ── Phase 1: 检测 + 渲染 ──────────────────────────────────────────────
def _find_page_captions(page) -> list[dict]:
"""查找页面上所有 Figure/Table caption 文字块。"""
blocks = page.get_text("blocks")
captions = []
for b in blocks:
if len(b) < 5:
continue
bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3]
text = str(b[4]).strip()
first_line = text.split("\n")[0].strip()
cap_type = None
m = _TABLE_CAPTION_RE.match(first_line)
if m:
cap_type = "table"
else:
m = _FIGURE_CAPTION_RE.match(first_line)
if m:
cap_type = "figure"
if m is None:
continue
captions.append(
{
"label": f"{'Table' if cap_type == 'table' else 'Figure'} {m.group(1)}",
"type": cap_type,
"caption_text": text,
"caption_y0": by0,
"caption_y1": by1,
"caption_x0": bx0,
"caption_x1": bx1,
}
)
return captions
def _vertical_distance(cap_y0, cap_y1, box_y0, box_y1) -> float | None:
"""计算 caption 到 box 的垂直距离。不邻接时返回 None。
三种情况:caption 完全在 box 上方、完全在下方、与 box 有垂直重叠。
重叠(含部分溢出)视为 distance=0,确保 caption 延伸到 box 边界外时不会丢失。
"""
# Caption 完全在 box 上方
if cap_y1 <= box_y0:
dist = box_y0 - cap_y1
return dist if dist <= _CAPTION_MATCH_DISTANCE else None
# Caption 完全在 box 下方
if cap_y0 >= box_y1:
dist = cap_y0 - box_y1
return dist if dist <= _CAPTION_MATCH_DISTANCE else None
# Caption 与 box 有垂直重叠(内部、部分溢出都算)→ 距离 0
return 0
def _same_column(cap: dict, box, page_width: float) -> bool:
"""判断 caption 和 box 是否在同一列。
双栏论文中左右栏间距有限,简单的水平重叠检查会跨列匹配。
策略:用中心 X 坐标判断各自在哪半边,只有同半边才算同列。
跨栏图表(caption 或 box 宽度 >65% 页宽)不受此限制。
"""
cap_w = cap["caption_x1"] - cap["caption_x0"]
box_w = box.x1 - box.x0
# 跨栏元素:宽度超过页面的 65%
if cap_w > page_width * 0.65 or box_w > page_width * 0.65:
return True
cap_cx = (cap["caption_x0"] + cap["caption_x1"]) / 2
box_cx = (box.x0 + box.x1) / 2
mid = page_width / 2
# 同在左半边或同在右半边
return (cap_cx < mid) == (box_cx < mid)
def _match_captions_to_boxes(
page_boxes: list, captions: list[dict], page_width: float
) -> list[tuple[list[int], list[dict]]]:
"""将 caption 分配给 box,允许一个 caption 匹配多个同类型 box。
典型场景:
- Figure 由左右两个 picture box 组成,caption 同时靠近两者
- Table 的视觉内容被 layout analysis 误分类为 picture,需要跨类型匹配
Returns:
[(box_indices, captions), ...] 每组是一个独立的渲染任务
"""
# 每个 caption 找到所有距离在阈值内的 box
# 优先匹配同类型;如果找不到,再匹配任意 table/picture box
cap_to_boxes: dict[int, list[tuple[int, float]]] = {}
for ci, cap in enumerate(captions):
same_type: list[tuple[int, float]] = []
any_type: list[tuple[int, float]] = []
expected = "table" if cap["type"] == "table" else "picture"
for bi, box in enumerate(page_boxes):
# 列感知:双栏论文中只匹配同栏的 box
if not _same_column(cap, box, page_width):
continue
# 水平重叠检查(同列内仍需有重叠)
if not (
cap["caption_x1"] > box.x0 - 5 and cap["caption_x0"] < box.x1 + 5
):
continue
dist = _vertical_distance(
cap["caption_y0"], cap["caption_y1"], box.y0, box.y1
)
if dist is None:
continue
entry = (bi, dist)
any_type.append(entry)
if box.boxclass == expected:
same_type.append(entry)
# 优先用同类型匹配;没有时回退到任意类型;都没有则跳过
if same_type:
cap_to_boxes[ci] = same_type
elif any_type:
cap_to_boxes[ci] = any_type
# else: 该 caption 无匹配 box,不加入 cap_to_boxes
# 每个 caption → 最近的 box(用于分组),但记录所有匹配的 box
cap_primary: dict[int, int] = {} # caption → primary box index
cap_all_boxes: dict[int, list[int]] = {} # caption → all matched box indices
for ci, matches in cap_to_boxes.items():
matches.sort(key=lambda x: x[1])
cap_primary[ci] = matches[0][0]
# 所有距离最近的同组 box(距离差 < 20pt 视为同一组)
best_dist = matches[0][1]
cap_all_boxes[ci] = [bi for bi, d in matches if d <= best_dist + 20]
# 按 primary box 分组
box_to_caps: dict[int, list[int]] = {}
for ci, bi in cap_primary.items():
box_to_caps.setdefault(bi, []).append(ci)
# 构建渲染组:每个 caption 独立成组(共享 box 但各自渲染)
# 同类型同 label 的 caption 会合并;不同类型则分开
used_captions: set[int] = set()
groups: list[tuple[list[int], list[dict]]] = []
for bi in sorted(box_to_caps.keys()):
cis = box_to_caps[bi]
for ci in cis:
if ci in used_captions:
continue
used_captions.add(ci)
all_box_indices = set(cap_all_boxes.get(ci, [bi]))
# 只合并同 label 的 caption(同 figure/table 的重复 caption
merged_captions = [captions[ci]]
for other_bi in all_box_indices:
if other_bi in box_to_caps:
for other_ci in box_to_caps[other_bi]:
if other_ci not in used_captions:
other_cap = captions[other_ci]
if other_cap["label"] == captions[ci]["label"]:
used_captions.add(other_ci)
merged_captions.append(other_cap)
groups.append((sorted(all_box_indices), merged_captions))
return groups
# ── 单页处理 ─────────────────────────────────────────────────────────
def _render_and_save(
def _render_box(
page,
clip: pymupdf.Rect,
box: _BoxCluster,
images_dest: Path,
manifest: dict,
label: str,
filename: str,
cap_type: str,
caption_text: str,
page_num_1based: int,
arxiv_id: str,
page_num: int,
) -> bool:
"""渲染页面区域并保存 JPEG写入 manifest。成功返回 True。"""
"""渲染单个 box 区域并保存 JPEG,成功返回 True。"""
page_width = page.rect.width
clip = pymupdf.Rect(
max(0, box.x0 - _REGION_PADDING),
max(0, box.y0 - _REGION_PADDING),
min(page_width, box.x1 + _REGION_PADDING),
box.y1 + _REGION_PADDING,
)
mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM)
try:
pix = page.get_pixmap(matrix=mat, clip=clip)
except Exception:
logger.debug("Failed to render %s for %s", label, arxiv_id)
return False
filename = f"{label.replace(' ', '_').lower()}.jpg"
(images_dest / filename).write_bytes(pix.tobytes("jpeg"))
manifest[filename] = {
"page": page_num_1based,
"type": cap_type,
"label": label,
"caption_text": caption_text[:200] if caption_text else "",
"figures" if cap_type == "figure" else "tables": [label],
}
logger.debug(
"Rendered %s: page %d, region (%.0f,%.0f)-(%.0f,%.0f) → %s",
label,
page_num_1based,
clip.x0,
clip.y0,
clip.x1,
clip.y1,
filename,
)
return True
def _process_page(
doc,
page_idx: int,
page_layout,
page_boxes: list[LayoutBox],
images_dest: Path,
manifest: dict,
seen_labels: set,
arxiv_id: str,
) -> int:
"""处理单页:caption 匹配 + orphan 兜底,返回本页提取数量"""
"""处理单页:检测 → 聚类 → 渲染,全部用通用标签"""
page = doc[page_idx]
page_width = page.rect.width
page_num = page_idx + 1
orphan_fig_counter = 0
orphan_tbl_counter = 0
fig_counter = 0
tbl_counter = 0
# 收集本页的 table/picture box(跳过极小区域)
raw_boxes = []
for box in page_layout.boxes:
for box in page_boxes:
if box.boxclass not in ("table", "table-fallback", "picture"):
continue
if (box.x1 - box.x0) < 20 or (box.y1 - box.y0) < 20:
w = box.x1 - box.x0
h = box.y1 - box.y0
if w < 20 or h < 20 or w * h < _MIN_BOX_AREA:
continue
raw_boxes.append(box)
@@ -354,153 +151,48 @@ def _process_page(
return 0
# 聚类:将同一 figure/table 的碎片 box 合并
page_boxes = _cluster_boxes(raw_boxes)
clusters = _cluster_boxes(raw_boxes)
# 页面级匹配:查找所有 caption,分配给 box
captions = _find_page_captions(page)
groups = _match_captions_to_boxes(page_boxes, captions, page_width)
# 只合并同 label 的 group(同一个 figure/table 的重复 caption
# 不同 label 的 group 即使共享 box 也不合并(如 Figure 7 和 Figure 8),
# 渲染时用 caption 位置切割区域
_merged_groups: set[int] = set()
merged_groups: list[tuple[list[int], list[dict]]] = []
for gi, (box_indices, caps) in enumerate(groups):
if gi in _merged_groups:
continue
this_labels = {c["label"] for c in caps}
all_box_set = set(box_indices)
merge_targets = {gi}
for other_gi, (other_bi, other_caps) in enumerate(groups):
if other_gi <= gi or other_gi in _merged_groups:
continue
other_labels = {c["label"] for c in other_caps}
# 只在 label 有交集时合并(同一个 figure/table
if this_labels & other_labels and all_box_set & set(other_bi):
merge_targets.add(other_gi)
all_box_set |= set(other_bi)
all_caps = []
for mgi in sorted(merge_targets):
_merged_groups.add(mgi)
all_caps.extend(groups[mgi][1])
merged_groups.append((sorted(all_box_set), all_caps))
groups = merged_groups
# ── 阶段 1:渲染有 caption 匹配的图/表 ──
matched_box_indices: set[int] = set()
extracted = 0
for box_indices, caps in groups:
matched_box_indices.update(box_indices)
# 去重同一 label,跳过已处理的
unique_caps = []
for cap in caps:
if cap["label"] not in seen_labels:
seen_labels.add(cap["label"])
unique_caps.append(cap)
if not unique_caps:
continue
# 合并所有关联 box 的 bbox
bx0 = min(page_boxes[i].x0 for i in box_indices)
by0 = min(page_boxes[i].y0 for i in box_indices)
bx1 = max(page_boxes[i].x1 for i in box_indices)
by1 = max(page_boxes[i].y1 for i in box_indices)
# 渲染区域:box + caption
all_cap_y0 = min(c["caption_y0"] for c in unique_caps)
all_cap_y1 = max(c["caption_y1"] for c in unique_caps)
all_cap_x0 = min(c["caption_x0"] for c in unique_caps)
all_cap_x1 = max(c["caption_x1"] for c in unique_caps)
top = max(0, min(by0, all_cap_y0) - _REGION_PADDING)
bottom = max(by1, all_cap_y1) + _REGION_PADDING
rx0 = max(0, min(bx0, all_cap_x0) - _REGION_PADDING)
rx1 = min(page_width, max(bx1, all_cap_x1) + _REGION_PADDING)
clip = pymupdf.Rect(rx0, top, rx1, bottom)
# 多个 caption 可能共享同一区域(如 subfigure),只需渲染一次
jpeg_bytes = None
for cap in unique_caps:
if jpeg_bytes is None:
if not _render_and_save(
page,
clip,
images_dest,
manifest,
cap["label"],
cap["type"],
cap["caption_text"],
page_num,
arxiv_id,
):
break
# 读取刚写入的 bytes 供后续同名 caption 复用
filename = f"{cap['label'].replace(' ', '_').lower()}.jpg"
jpeg_bytes = (images_dest / filename).read_bytes()
extracted += 1
else:
# 同区域的不同 caption(如 subfigure),复用图片
filename = f"{cap['label'].replace(' ', '_').lower()}.jpg"
(images_dest / filename).write_bytes(jpeg_bytes)
cap_preview = cap["caption_text"][:200]
manifest[filename] = {
"page": page_num,
"type": cap["type"],
"label": cap["label"],
"caption_text": cap_preview,
"figures" if cap["type"] == "figure" else "tables": [cap["label"]],
}
extracted += 1
# ── 阶段 2:渲染无 caption 匹配的图/表(orphan boxes ──
orphan_indices = set(range(len(page_boxes))) - matched_box_indices
for bi in sorted(orphan_indices):
box = page_boxes[bi]
cap_type = "figure" if box.boxclass == "picture" else "table"
for cluster in clusters:
cap_type = "figure" if cluster.boxclass == "picture" else "table"
if cap_type == "figure":
orphan_fig_counter += 1
label = f"Figure (p{page_num}-{orphan_fig_counter})"
fig_counter += 1
label = f"Figure (p{page_num}-{fig_counter})"
else:
orphan_tbl_counter += 1
label = f"Table (p{page_num}-{orphan_tbl_counter})"
tbl_counter += 1
label = f"Table (p{page_num}-{tbl_counter})"
if label in seen_labels:
continue
seen_labels.add(label)
clip = pymupdf.Rect(
max(0, box.x0 - _REGION_PADDING),
max(0, box.y0 - _REGION_PADDING),
min(page_width, box.x1 + _REGION_PADDING),
box.y1 + _REGION_PADDING,
)
if _render_and_save(
page,
clip,
images_dest,
manifest,
label,
cap_type,
"",
page_num,
arxiv_id,
):
extracted += 1
filename = f"{label.replace(' ', '_').lower()}.jpg"
if not _render_box(page, cluster, images_dest, filename, cap_type, page_num):
continue
manifest[filename] = {
"page": page_num,
"type": cap_type,
"label": label,
"box": [
round(float(cluster.x0), 1),
round(float(cluster.y0), 1),
round(float(cluster.x1), 1),
round(float(cluster.y1), 1),
],
}
extracted += 1
return extracted
# ── 核心提取 ───────────────────────────────────────────────────────────
# ── Phase 1 核心入口 ───────────────────────────────────────────────────
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
"""从 PDF 提取 Figure/Table 截图,生成 manifest。
用 pymupdf4llm layout analysis 检测 table/picture 区域,
再通过 caption 文字确定编号,渲染为 JPEG。
"""Phase 1: 从 PDF 提取 Figure/Table 截图,生成通用标签的 manifest。
Args:
arxiv_id: 论文 ID
@@ -526,45 +218,31 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
if (images_dest / "manifest.json").exists():
(images_dest / "manifest.json").unlink()
doc = pymupdf.open(str(pdf_path))
with pymupdf.open(str(pdf_path)) as doc:
extracted = 0
manifest: dict[str, dict] = {}
seen_labels: set[str] = set()
# layout analysis
try:
parsed = dl.parse_document(
doc, filename=str(pdf_path), use_ocr=dl.OCRMode.NEVER
)
except Exception:
logger.warning(
"pymupdf4llm layout analysis failed for %s", arxiv_id, exc_info=True
)
doc.close()
return 0
extracted = 0
manifest: dict[str, dict] = {}
seen_labels: set[str] = set()
for page_idx, page_layout in enumerate(parsed.pages):
try:
extracted += _process_page(
doc,
page_idx,
page_layout,
images_dest=images_dest,
manifest=manifest,
seen_labels=seen_labels,
arxiv_id=arxiv_id,
)
except Exception:
logger.warning(
"Failed to process page %d for %s",
page_idx + 1,
arxiv_id,
exc_info=True,
)
continue
doc.close()
for page_idx in range(doc.page_count):
try:
page_boxes = detect_page_layout(doc[page_idx])
extracted += _process_page(
doc,
page_idx,
page_boxes,
images_dest=images_dest,
manifest=manifest,
seen_labels=seen_labels,
arxiv_id=arxiv_id,
)
except Exception:
logger.warning(
"Failed to process page %d for %s",
page_idx + 1,
arxiv_id,
exc_info=True,
)
continue
# 保存 manifest
manifest_path = images_dest / "manifest.json"
@@ -580,78 +258,321 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
return extracted
# ── 按 summary 过滤 ────────────────────────────────────────────────────
# ── Phase 2: 用 summary 的 figures ID 定位并重命名 ─────────────────────
def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
"""根据 summary 中的 figures 字段过滤提取的图片/表格
def _distance_text_to_box(rect: pymupdf.Rect, box: list[float]) -> float | None:
"""计算搜索到的文本 rect 到 box 的距离。超出阈值返回 None
用 manifest.json 中的 label 匹配,保留被 AI 总结引用的图片
判断逻辑:rect 中心与 box 的垂直距离 + 水平重叠检查
"""
rect_cx = (rect.x0 + rect.x1) / 2
rect_cy = (rect.y0 + rect.y1) / 2
bx0, by0, bx1, by1 = box
# 水平重叠:rect 中心在 box 水平范围内(或接近)
if not (bx0 - 20 <= rect_cx <= bx1 + 20):
return None
# 垂直距离
if rect_cy < by0:
dist = by0 - rect_cy
elif rect_cy > by1:
dist = rect_cy - by1
else:
dist = 0
return dist if dist <= _LABEL_MATCH_DISTANCE else None
def _search_variants(fig_id: str) -> list[str]:
"""为 figure/table ID 生成搜索变体。
"Figure 1" → ["Figure 1", "Fig. 1", "Fig 1"]
"Fig. 1" → ["Fig. 1", "Figure 1", "Fig 1"]
"Table A1" → ["Table A1"]
"""
variants = [fig_id]
m = re.match(r"(Fig\.?|Figure)\s+(\d+.*)", fig_id, re.IGNORECASE)
if m:
num_part = m.group(2)
variants.extend(
[
f"Figure {num_part}",
f"Fig. {num_part}",
f"Fig {num_part}",
]
)
# 去重保序
seen = set()
result = []
for v in variants:
if v not in seen:
seen.add(v)
result.append(v)
return result
def label_images_by_summary(
arxiv_id: str,
figures: list[dict],
pdf_path: Path | None = None,
) -> int:
"""Phase 2: 用 summary 的 figures ID 在 PDF 中搜索定位,重命名图片。
对 summary 中的每个 figure/table ID
1. page.search_for(id) 在所有页面搜索文本位置
2. 计算搜索位置与 manifest 中 box 坐标的距离
3. 最近匹配 → 重命名文件、更新 manifest
Args:
arxiv_id: 论文 ID
figures: summary 的 figures 列表,每项含 id/caption/description 等
pdf_path: PDF 路径
Returns:
成功重命名的图片数量
"""
if not figures:
return 0
images_dir = paper_dir(arxiv_id) / "images"
manifest_path = images_dir / "manifest.json"
if not images_dir.exists() or not manifest_path.exists():
if pdf_path is None:
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
if not pdf_path.exists():
return 0
all_files = [
f for f in images_dir.iterdir() if f.suffix.lower() in (".png", ".jpg", ".jpeg")
]
if not all_files:
images_dest = paper_dir(arxiv_id) / "images"
manifest_path = images_dest / "manifest.json"
if not manifest_path.exists():
return 0
manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8"))
manifest: dict[str, dict] = json.loads(manifest_path.read_text(encoding="utf-8"))
if not manifest:
return 0
# 收集 summary 中引用的所有 Figure/Table ID(归一化)
referenced_ids: set[str] = set()
for fig in figures:
fig_id = fig.get("id", "")
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", fig_id, re.IGNORECASE)
if m:
referenced_ids.add(f"Figure {m.group(1)}")
m2 = re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
if m2:
referenced_ids.add(f"Table {m2.group(1)}")
# 构建候选列表:只对通用标签的条目做匹配
candidates: dict[str, dict] = {} # filename → {page, box, ...}
for fname, info in manifest.items():
if "(p" in info.get("label", ""):
candidates[fname] = info
if not referenced_ids:
logger.warning("No valid figure/table IDs in summary for %s", arxiv_id)
return len(all_files)
if not candidates:
return 0
# 根据 manifest 的 label 字段匹配
keep_filenames: set[str] = set()
for filename, info in manifest.items():
label = info.get("label", "")
if label in referenced_ids:
keep_filenames.add(filename)
with pymupdf.open(str(pdf_path)) as doc:
# 收集所有匹配候选:(fig_id, fig_index, filename, distance)
matches: list[tuple[str, int, str, float]] = []
for fig_idx, fig in enumerate(figures):
fig_id = fig.get("id", "")
if not fig_id:
continue
# 生成搜索变体:Figure 1 / Fig. 1 / Fig 1 等
search_terms = _search_variants(fig_id)
# 在所有页面搜索该文本(含变体)
search_hits: list[tuple[int, pymupdf.Rect]] = [] # (page_num_1based, Rect)
for page_idx in range(doc.page_count):
page = doc[page_idx]
seen_rects: set[tuple[float, float]] = set()
for term in search_terms:
for r in page.search_for(term):
key = (round(r.x0, 1), round(r.y0, 1))
if key not in seen_rects:
seen_rects.add(key)
search_hits.append((page_idx + 1, r))
if not search_hits:
continue
# 对每个候选 manifest 条目,找最近的搜索命中
for fname, info in candidates.items():
box = info.get("box")
if not box:
continue
manifest_page = info.get("page", 0)
best_dist: float | None = None
for hit_page, rect in search_hits:
# 只匹配同页面
if hit_page != manifest_page:
continue
dist = _distance_text_to_box(rect, box)
if dist is not None and (best_dist is None or dist < best_dist):
best_dist = dist
if best_dist is not None:
matches.append((fig_id, fig_idx, fname, best_dist))
if not matches:
logger.info("No label matches for %s", arxiv_id)
return 0
# 去冲突:按距离排序,每个 fig_id 和每个 filename 只匹配一次
matches.sort(key=lambda x: x[3])
used_fig_ids: set[int] = set()
used_filenames: set[str] = set()
renames: list[tuple[str, str, str]] = [] # (old_fname, new_fname, fig_id)
for fig_id, fig_idx, fname, dist in matches:
if fig_idx in used_fig_ids or fname in used_filenames:
continue
for ref in info.get("figures", []) + info.get("tables", []):
if ref in referenced_ids:
keep_filenames.add(filename)
used_fig_ids.add(fig_idx)
used_filenames.add(fname)
new_fname = f"{fig_id.replace(' ', '_').lower()}.jpg"
renames.append((fname, new_fname, fig_id))
# 执行重命名
labeled = 0
new_manifest: dict[str, dict] = {}
for fname, info in manifest.items():
if fname in used_filenames:
continue
# 未匹配的保持原样
new_manifest[fname] = info
for old_fname, new_fname, fig_id in renames:
old_path = images_dest / old_fname
new_path = images_dest / new_fname
if not old_path.exists():
continue
# 搬运 manifest 信息
info = manifest[old_fname].copy()
cap_type = info.get("type", "figure")
# 读取 caption 文本(从 figures 列表)
caption_text = ""
for fig in figures:
if fig.get("id") == fig_id:
caption_text = fig.get("caption", "")
break
if not keep_filenames:
logger.warning(
"No manifest matches for %s (refs=%s), keeping all",
arxiv_id,
referenced_ids,
info["label"] = fig_id
info["caption_text"] = caption_text[:200] if caption_text else ""
info.setdefault("figures" if cap_type == "figure" else "tables", []).append(
fig_id
)
return len(all_files)
removed = 0
for f in all_files:
if f.name not in keep_filenames:
f.unlink()
removed += 1
# 重命名文件
if new_fname != old_fname:
old_path.rename(new_path)
new_manifest[new_fname] = info
labeled += 1
# 写回 manifest
manifest_path.write_text(json.dumps(new_manifest, ensure_ascii=False, indent=2))
kept = len(all_files) - removed
logger.info(
"Filtered images for %s: kept %d, removed %d (refs=%s)",
"Labeled %d/%d images for %s using summary figures",
labeled,
len(manifest),
arxiv_id,
kept,
removed,
referenced_ids,
)
return kept
return labeled
# ── Figure ↔ Image 关联 ────────────────────────────────────────────────
def _normalize_figure_id(raw_id: str) -> str:
"""归一化 Figure/Table ID'Figure 1'/'Fig.1''Figure 1'"""
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
if m:
return f"Figure {m.group(1)}"
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
if m2:
return f"Table {m2.group(1)}"
return raw_id
def _is_figure_type(fig_id: str) -> bool:
"""判断是否为 Figure 类型(非 Table)。"""
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
def _image_sort_key(name: str) -> tuple[int, int]:
"""按文件名中的编号排序提取的图片。"""
# 新格式:figure_1.jpg, table_1.jpg
m = re.search(r"(?:figure|table)_(\d+)", name)
if m:
return (0, int(m.group(1)))
# 旧格式:page2_img1.png, page5_table1.png, figure_1.png
m2 = re.search(r"page(\d+)_(?:img|table)(\d+)", name)
if m2:
return (int(m2.group(1)), int(m2.group(2)))
return (0, 0)
def link_figures_with_images(
figures: list[dict], images: list[dict], arxiv_id: str
) -> list[dict]:
"""将 summary figures 元数据与提取的图片文件关联。
策略:
1. 优先用 manifest.json 的 label 做 ID 精确匹配
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
"""
if not figures or not images:
return figures
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
# ── 策略 1manifest ID 精确匹配 ──
id_to_url: dict[str, str] = {}
if manifest_path.exists():
try:
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
except (ValueError, TypeError):
manifest = {}
for filename, info in manifest.items():
url = f"/papers/{arxiv_id}/images/{filename}"
# 优先用 label 字段(新格式)
label = info.get("label", "")
if label:
id_to_url[label] = url
# 也兼容 figures/tables 列表(旧格式)
for fig_id in info.get("figures", []) + info.get("tables", []):
if fig_id not in id_to_url:
id_to_url[fig_id] = url
for fig in figures:
raw_id = fig.get("id", "")
normalized = _normalize_figure_id(raw_id)
if normalized in id_to_url:
fig["image_url"] = id_to_url[normalized]
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
unmatched = [f for f in figures if not f.get("image_url")]
if not unmatched:
return figures
# 按类型分流:Figure vs Table
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
table_type_unmatched = [
f for f in unmatched if not _is_figure_type(f.get("id", ""))
]
# 提取的图片按类型分流,按文件名中的编号排序
fig_images = sorted(
[img for img in images if "table" not in img["name"].lower()],
key=lambda img: _image_sort_key(img["name"]),
)
table_images = sorted(
[img for img in images if "table" in img["name"].lower()],
key=lambda img: _image_sort_key(img["name"]),
)
for i, fig in enumerate(fig_type_unmatched):
if i < len(fig_images):
fig["image_url"] = fig_images[i]["url"]
for i, fig in enumerate(table_type_unmatched):
if i < len(table_images):
fig["image_url"] = table_images[i]["url"]
return figures
+25 -10
View File
@@ -11,6 +11,7 @@ import uuid
from pathlib import Path
from app.config import settings
from app.utils import truncate_error
from app.services.summary_utils import (
JsonNotFoundError,
build_prompt,
@@ -21,6 +22,9 @@ from app.services.summary_utils import (
logger = logging.getLogger(__name__)
# PDF 全文注入模式的字符上限 — 超过此阈值自动切换到 search 模式
_PDF_MAX_CHARS = 80_000
# 重新导出,保持向后兼容
__all__ = [
"PiTimeoutError",
@@ -45,7 +49,7 @@ class PiProcessError(Exception):
def __init__(self, returncode: int, stderr: str):
self.returncode = returncode
self.stderr = stderr
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
super().__init__(f"pi exited with code {returncode}: {truncate_error(stderr)}")
# ── pi CLI 调用 ────────────────────────────────────────────────────────
@@ -72,23 +76,27 @@ async def call_pi(
actual_mode = pdf_mode
if pdf_mode == "auto":
if txt_size > 80_000:
if txt_size > _PDF_MAX_CHARS:
actual_mode = "search"
logger.info(
"Auto mode: %s text=%d chars > 80k → search", arxiv_id, txt_size
"Auto mode: %s text=%d chars > %dk → search",
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
)
else:
actual_mode = "inject"
logger.info(
"Auto mode: %s text=%d chars ≤ 80k → inject", arxiv_id, txt_size
"Auto mode: %s text=%d chars ≤ %dk → inject",
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
)
# inject 模式需要截断过长的文本(避免撑爆 context)
if actual_mode == "inject" and txt_size > 80_000:
if actual_mode == "inject" and txt_size > _PDF_MAX_CHARS:
body = txt_path.read_text(encoding="utf-8")
trimmed = body[:80_000].rstrip()
trimmed = body[:_PDF_MAX_CHARS].rstrip()
txt_path.write_text(trimmed, encoding="utf-8")
logger.info("Truncated %s for inject: %d%d chars", arxiv_id, txt_size, len(trimmed))
logger.info(
"Truncated %s for inject: %d%d chars", arxiv_id, txt_size, len(trimmed)
)
prompt_text = build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
@@ -101,7 +109,8 @@ async def call_pi(
cmd = [
settings.PI_BIN,
"-p",
"--tools", tools,
"--tools",
tools,
]
if fix_errors:
cmd += ["--session", session_id, "--continue"]
@@ -118,10 +127,14 @@ async def call_pi(
logger.info(
"Calling pi for %s (fix=%s, session=%s, mode=%s)",
arxiv_id, bool(fix_errors), session_id, actual_mode,
arxiv_id,
bool(fix_errors),
session_id,
actual_mode,
)
import time as _time
_t_sub_start = _time.monotonic()
proc = await asyncio.create_subprocess_exec(
@@ -151,7 +164,9 @@ async def call_pi(
logger.info(
"pi subprocess for %s: %.2fs%s",
arxiv_id, _t_sub_end - _t_sub_start, _file_info,
arxiv_id,
_t_sub_end - _t_sub_start,
_file_info,
)
if proc.returncode != 0:
+56 -11
View File
@@ -8,6 +8,7 @@ from __future__ import annotations
import logging
from datetime import date as date_type
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.config import settings
@@ -15,11 +16,50 @@ from app.models import CrawlLog, TaskLock
from app.services.cleaner import cleanup_tmp
from app.services.crawler import crawl_daily
from app.services.summarizer import summarize_batch
from app.utils import utc_now, yesterday_str
from app.utils import release_lock, truncate_error, utc_now, yesterday_str
logger = logging.getLogger(__name__)
def acquire_lock(db: Session, task: str, lock_key: str, owner: str) -> TaskLock:
"""获取 TaskLock,锁冲突时抛出 RuntimeError。
供需要防重入的操作(crawl、pipeline 等)统一调用。
"""
lock = TaskLock(
task=task,
lock_key=lock_key,
status="running",
owner=owner,
acquired_at=utc_now(),
)
try:
db.add(lock)
db.commit()
except IntegrityError:
db.rollback()
raise RuntimeError(f"{task} already running for {lock_key}")
return lock
async def run_crawl(db: Session, target_date: str, owner: str = "admin_crawl") -> dict:
"""执行单次抓取(带防重入锁)。
Args:
db: 数据库 session
target_date: 目标日期 YYYY-MM-DD
owner: 调用者标识
Returns:
crawl_daily() 的原始返回值
"""
lock = acquire_lock(db, "crawl", target_date, owner)
try:
return await crawl_daily(db, target_date)
finally:
release_lock(db, lock)
async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
"""执行完整流水线:crawl → summarize → cleanup。
@@ -47,7 +87,7 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
try:
db.add(lock)
db.commit()
except Exception:
except IntegrityError:
db.rollback()
raise RuntimeError(f"Pipeline already running for {target_date}")
@@ -66,9 +106,13 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
try:
# Step 1: 抓取(先试今天,无数据则回退昨天)
crawl_result = await crawl_daily(db, target_date)
logger.info("Pipeline [%s]: crawl %s, found=%d new=%d",
owner, target_date,
crawl_result.get("found", 0), crawl_result.get("new", 0))
logger.info(
"Pipeline [%s]: crawl %s, found=%d new=%d",
owner,
target_date,
crawl_result.get("found", 0),
crawl_result.get("new", 0),
)
if crawl_result.get("status") == "success" and crawl_result.get("found") == 0:
yesterday = yesterday_str()
@@ -81,8 +125,11 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
# Step 3: 清理
cleanup_result = cleanup_tmp()
logger.info("Pipeline [%s]: cleanup done, removed=%d",
owner, cleanup_result.get("removed", 0))
logger.info(
"Pipeline [%s]: cleanup done, removed=%d",
owner,
cleanup_result.get("removed", 0),
)
log_entry.status = "success"
log_entry.papers_found = crawl_result.get("found", 0)
@@ -91,7 +138,7 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
except Exception as exc:
logger.exception("Pipeline [%s] failed", owner)
log_entry.status = "failed"
error_msg = str(exc)[:2000]
error_msg = truncate_error(exc, limit=2000)
finally:
log_entry.completed_at = utc_now()
@@ -99,9 +146,7 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
log_entry.error = error_msg
db.commit()
lock.status = "finished"
lock.released_at = utc_now()
db.commit()
release_lock(db, lock)
if error_msg:
return {"status": "failed", "error": error_msg}
+1
View File
@@ -90,6 +90,7 @@ class SummarySchema(BaseModel):
# ── 质量评估 ────────────────────────────────────────────────────────────
def assess_quality(schema: SummarySchema) -> str:
"""评估总结质量:normal / degraded / low。"""
# low:内容空洞的启发式判断
+2 -8
View File
@@ -213,11 +213,7 @@ def _search_semantic(
arxiv_ids = [c["arxiv_id"] for c in candidates]
distance_map = {c["arxiv_id"]: c["distance"] for c in candidates}
stmt = (
select(Paper)
.where(Paper.arxiv_id.in_(arxiv_ids))
.options(*PAPER_FULL_LOAD)
)
stmt = select(Paper).where(Paper.arxiv_id.in_(arxiv_ids)).options(*PAPER_FULL_LOAD)
if tag:
stmt = stmt.where(Paper.tags.any(tag=tag))
@@ -298,9 +294,7 @@ def _load_papers_by_ids(
papers = (
db.execute(
select(Paper)
.where(Paper.id.in_(paper_ids))
.options(*PAPER_FULL_LOAD)
select(Paper).where(Paper.id.in_(paper_ids)).options(*PAPER_FULL_LOAD)
)
.unique()
.scalars()
+39 -502
View File
@@ -1,233 +1,42 @@
"""AI 总结编排服务 — 协调 PDF 下载、pi CLI 调用、JSON 校验、DB 写入、语义索引"""
"""AI 总结编排服务 — 协调生成器、持久化、批量处理的顶层入口"""
from __future__ import annotations
import asyncio
import json
import logging
from pathlib import Path
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.config import settings
from app.database import SessionLocal
from app.exceptions import ConflictError, NotFoundError
from app.models import (
PAPER_DEFAULT_LOAD,
CrawlLog,
Paper,
PaperSummary,
PaperTag,
SummaryState,
SummaryStatus,
TaskLock,
get_paper_by_arxiv_id,
get_paper_by_id,
)
from app.services.pdf_downloader import (
PdfDownloadError,
cleanup_tmp,
download_pdf,
paper_dir,
from app.services.pdf_downloader import download_pdf
from app.services.summary_utils import write_meta_json
from app.services.summary_generator import (
_generate_with_retry,
)
from app.services.summary_utils import (
JsonNotFoundError,
build_prompt,
extract_json,
write_meta_json,
extract_pdf_text,
from app.services.summary_persister import (
_cleanup_old_images,
_handle_summary_failure,
_persist_summary,
)
from app.services.pi_client import (
PiProcessError,
PiTimeoutError,
call_pi,
)
from app.services import claude_backend
from app.services.schemas import (
SummarySchema,
assess_quality,
classify_validation_error,
flatten_for_db,
)
from app.utils import TMP_DIR, release_lock, utc_now
from app.utils import TMP_DIR, release_lock, truncate_error, utc_now
logger = logging.getLogger(__name__)
# ── 错误分类 ────────────────────────────────────────────────────────────
def _classify_error(exc: Exception) -> str:
"""将异常映射到 error_type 枚举值。"""
if isinstance(exc, PdfDownloadError):
return "pdf_download_failed"
if isinstance(exc, PiTimeoutError):
return "timeout"
if isinstance(exc, PiProcessError):
return "process_error"
if isinstance(exc, JsonNotFoundError):
return "json_not_found"
if isinstance(exc, json.JSONDecodeError):
return "json_invalid"
if isinstance(exc, ValidationError):
return classify_validation_error(exc)
return "unknown"
# ── FTS5 文本构建 ───────────────────────────────────────────────────────
def _build_fts_summary_text(schema: SummarySchema) -> str:
"""拼接用于 FTS5 索引的总结文本。"""
parts = [
schema.one_line or "",
schema.motivation.problem or "",
schema.motivation.goal or "",
schema.method.overview or "",
schema.method.key_idea or "",
schema.results.main_findings or "",
]
return " ".join(p for p in parts if p)
# ── DB 更新 ─────────────────────────────────────────────────────────────
def _update_summary_in_db(
db: Session,
paper: Paper,
schema: SummarySchema,
quality: str,
raw_output: str,
) -> None:
"""将校验后的总结写入 DBpaper_summaries + papers + paper_tags + FTS5。"""
from sqlalchemy import text
# 1. paper_summariesupsert
existing = db.get(PaperSummary, paper.id)
flat = flatten_for_db(schema)
if existing:
for k, v in flat.items():
setattr(existing, k, v)
else:
db.add(PaperSummary(paper_id=paper.id, **flat))
# 2. papers 表
paper.title_zh = schema.title_zh
paper.summary_quality = quality
p_dir = paper_dir(paper.arxiv_id)
paper.summary_path = str(p_dir / "summary.json")
paper.raw_output_path = str(p_dir / "raw_output.txt")
# 3. AI 标签
existing_tag_names = {t.tag for t in paper.tags}
for tag_name in schema.tags:
if tag_name not in existing_tag_names:
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
existing_tag_names.add(tag_name)
# 4. FTS5 更新
summary_text = _build_fts_summary_text(schema)
db.execute(
text(
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
"WHERE rowid=:paper_id"
),
{
"title_zh": schema.title_zh,
"summary_text": summary_text,
"paper_id": paper.id,
},
)
db.commit()
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
# ── JSON 验证 ──────────────────────────────────────────────────────────
def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
"""验证 JSON 数据是否符合要求,返回错误列表(空=通过)。"""
errors: list[str] = []
if not isinstance(json_data, dict):
return ["顶层必须是 JSON 对象"]
# 必填字段
for f in ["arxiv_id", "title_zh", "one_line", "tags"]:
if f not in json_data or not json_data[f]:
errors.append(f"缺少必填字段: {f}")
# tags 必须是非空数组
tags = json_data.get("tags")
if not isinstance(tags, list) or len(tags) == 0:
errors.append("tags 必须是非空数组")
# 字符串段落字段(必须是 str 且 ≥50 字)
string_fields = [
("motivation", "problem"), ("motivation", "goal"), ("motivation", "gap"),
("method", "overview"), ("method", "key_idea"), ("method", "steps"),
("method", "novelty"),
("results", "main_findings"), ("results", "limitations"),
("improvements", "weaknesses"), ("improvements", "future_work"),
("improvements", "reproducibility"),
]
for section, field in string_fields:
val = json_data.get(section, {}).get(field)
if isinstance(val, list):
errors.append(f"{section}.{field} 应该是字符串段落,不能是数组")
elif not isinstance(val, str) or len(val.strip()) < 50:
errors.append(
f"{section}.{field} 必须是详细段落(≥50字),"
f"当前: {type(val).__name__} ({len(str(val))}字)"
)
# benchmarks 必须是数组
benchmarks = json_data.get("results", {}).get("benchmarks")
if benchmarks is not None and not isinstance(benchmarks, list):
errors.append("results.benchmarks 必须是数组")
# prerequisites.concepts 必须是对象数组,每个有 term
concepts = json_data.get("prerequisites", {}).get("concepts")
if concepts is not None:
if not isinstance(concepts, list):
errors.append("prerequisites.concepts 必须是数组")
elif len(concepts) == 0:
errors.append("prerequisites.concepts 不能为空")
else:
for i, c in enumerate(concepts):
if isinstance(c, str):
errors.append(f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串")
elif isinstance(c, dict) and not c.get("term"):
errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段")
# figures 必须是数组,每个元素应有 id
figures = json_data.get("figures")
if figures is not None:
if not isinstance(figures, list):
errors.append("figures 必须是数组")
else:
for i, fig in enumerate(figures):
if isinstance(fig, dict) and not fig.get("id"):
errors.append(f"figures[{i}] 缺少 id 字段")
return errors
# ── 文件操作 ────────────────────────────────────────────────────────────
def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) -> None:
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
if schema:
(d / "summary.json").write_text(
schema.model_dump_json(ensure_ascii=False, indent=2),
encoding="utf-8",
)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
# ── 单篇总结 ────────────────────────────────────────────────────────────
@@ -264,277 +73,7 @@ async def summarize_one(
return await _do_summarize_one(db, paper, pdf_mode=pdf_mode)
async def _generate_with_retry(
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
) -> tuple[dict, str]:
"""调用 AI 后端生成总结,最多 4 轮验证循环。
根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。
Returns:
(json_data, raw_output)
Raises:
ValueError: 4 轮验证仍未通过
"""
import time as _time
backend = settings.SUMMARY_BACKEND
validation_errors: list[str] = []
json_data: dict | None = None
raw_output = ""
session_id = None
summary_file = paper_dir(arxiv_id) / "summary.json"
# claude 后端需要预构建 promptpi 后端在 call_pi 内部构建)
claude_prompt: str | None = None
if backend == "claude":
_t0 = _time.monotonic()
txt_path = extract_pdf_text(pdf_path, max_chars=None)
body = txt_path.read_text(encoding="utf-8")
if len(body) > 80_000:
trimmed = body[:80_000].rstrip()
txt_path.write_text(trimmed, encoding="utf-8")
claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None)
logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0)
for attempt in range(1, 5):
# 清理上一轮写入的不完整文件
if summary_file.exists():
summary_file.unlink()
# 记录 AI 调用开始时间
_t_call_start = _time.monotonic()
if backend == "claude":
if attempt == 1:
raw_output, session_id = await claude_backend.call_claude(
claude_prompt, session_id=None,
)
else:
retry_prompt = build_prompt(
arxiv_id, meta_path,
extract_pdf_text(pdf_path, max_chars=80000),
"inject", fix_errors=validation_errors,
)
raw_output, session_id = await claude_backend.call_claude(
retry_prompt, session_id=session_id, fix_errors=validation_errors,
)
else:
if attempt == 1:
raw_output, session_id = await call_pi(meta_path, pdf_path, pdf_mode=pdf_mode)
else:
raw_output, session_id = await call_pi(
meta_path, pdf_path,
fix_errors=validation_errors,
session_id=session_id,
pdf_mode=pdf_mode,
)
_t_call_end = _time.monotonic()
# 检查 summary.json 是否由 AI 子进程写入
file_written_by_ai = summary_file.exists()
file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None
file_size = summary_file.stat().st_size if file_written_by_ai else 0
logger.info(
" [%s] attempt %d AI调用: %.2fs summary.json=%s%s",
arxiv_id, attempt,
_t_call_end - _t_call_start,
f"已写入({file_size}B)" if file_written_by_ai else "未写入",
f" mtime={file_mtime:.2f}" if file_mtime else "",
)
# 提取 JSON
_t_json_start = _time.monotonic()
try:
if file_written_by_ai:
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id)
else:
json_data = extract_json(raw_output)
except (json.JSONDecodeError, JsonNotFoundError) as exc:
_t_json_end = _time.monotonic()
logger.warning(
" [%s] JSON提取失败: %.2fs %s",
arxiv_id, _t_json_end - _t_json_start, str(exc)[:200],
)
validation_errors = [f"无法提取有效 JSON: {str(exc)[:100]}"]
continue
_t_json_end = _time.monotonic()
# 验证
_t_val_start = _time.monotonic()
validation_errors = _validate_summary(json_data, arxiv_id)
_t_val_end = _time.monotonic()
if not validation_errors:
logger.info(
" [%s] JSON提取: %.2fs 验证: %.2fs ✅",
arxiv_id,
_t_json_end - _t_json_start,
_t_val_end - _t_val_start,
)
break
logger.warning(
" [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s",
arxiv_id,
_t_json_end - _t_json_start,
_t_val_end - _t_val_start,
"; ".join(validation_errors)[:200],
)
if validation_errors:
exc = ValueError(
f"Summary validation failed after 4 attempts: {'; '.join(validation_errors)}"
)
exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用
raise exc
return json_data, raw_output
def _persist_summary(
db: Session, paper: Paper, json_data: dict, raw_output: str
) -> str:
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
import time as _time
arxiv_id = paper.arxiv_id
_t0 = _time.monotonic()
schema = SummarySchema.model_validate(json_data)
quality = assess_quality(schema)
_t1 = _time.monotonic()
_save_files(arxiv_id, schema, raw_output)
_t2 = _time.monotonic()
_update_summary_in_db(db, paper, schema, quality, raw_output)
_t3 = _time.monotonic()
# 状态 → done
paper.summary_status.status = SummaryState.DONE
paper.summary_status.quality = quality
paper.summary_status.completed_at = utc_now()
paper.summary_status.raw_output_saved = True
db.commit()
_t4 = _time.monotonic()
logger.info(
" [%s] persist: pydantic=%.2fs 文件=%.2fs DB写入=%.2fs 状态commit=%.2fs",
arxiv_id,
_t1 - _t0,
_t2 - _t1,
_t3 - _t2,
_t4 - _t3,
)
# 触发性增强(失败不影响总结)
_t5 = _time.monotonic()
_maybe_extract_images(arxiv_id, schema)
_t6 = _time.monotonic()
_maybe_index_chroma(arxiv_id, paper, schema)
_t7 = _time.monotonic()
logger.info(
" [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs",
arxiv_id,
_t6 - _t5,
_t7 - _t6,
)
return quality
def _handle_summary_failure(
db: Session, paper: Paper, exc: Exception, raw_output: str,
) -> dict:
"""记录失败:保存 raw_output、重试计数、错误分类。"""
error_type = _classify_error(exc)
logger.error(
"Summarize failed: %s error_type=%s %s",
paper.arxiv_id, error_type, str(exc)[:200],
)
status = paper.summary_status
if raw_output:
_save_files(paper.arxiv_id, None, raw_output)
status.raw_output_saved = True
status.retry_count = (status.retry_count or 0) + 1
status.error_type = error_type
status.error = str(exc)[:2000]
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
status.status = SummaryState.PERMANENT_FAILURE
else:
status.status = SummaryState.PENDING
status.completed_at = utc_now()
db.commit()
return {
"arxiv_id": paper.arxiv_id,
"status": "failed",
"error_type": error_type,
"error": str(exc)[:200],
"retry_count": status.retry_count,
}
def _cleanup_old_images(db: Session, paper: Paper) -> None:
"""清理旧的图片文件和 figures_json,避免重新总结时残留。"""
arxiv_id = paper.arxiv_id
images_dir = paper_dir(arxiv_id) / "images"
if images_dir.exists():
for old_file in images_dir.iterdir():
if old_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg") or old_file.name == "manifest.json":
old_file.unlink(missing_ok=True)
# 清除数据库中的 figures_json
if paper.summary and paper.summary.figures_json:
paper.summary.figures_json = None
db.commit()
def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
"""从 PDF 提取图片和表格(失败不影响总结)。"""
try:
from app.services.pdf_image_extractor import (
extract_images_from_pdf,
filter_images_by_summary,
)
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
extract_images_from_pdf(arxiv_id, pdf_path)
if schema.figures:
filter_images_by_summary(arxiv_id, schema.figures)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None:
"""写入 ChromaDB 语义索引(失败不影响总结)。"""
try:
from app.services.embedder import index_paper
texts_dict = {
"arxiv_id": arxiv_id,
"title_zh": schema.title_zh or "",
"title_en": paper.title_en or "",
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
"one_line": schema.one_line or "",
"motivation_problem": schema.motivation.problem or "",
"method_key_idea": schema.method.key_idea or "",
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
}
index_paper(arxiv_id, texts_dict)
except Exception:
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
async def _do_summarize_one(
db: Session, paper: Paper, pdf_mode: str = "auto"
) -> dict:
async def _do_summarize_one(db: Session, paper: Paper, pdf_mode: str = "auto") -> dict:
"""实际的单篇总结执行(在 semaphore 保护下)。"""
arxiv_id = paper.arxiv_id
title_short = (paper.title_en or "")[:50]
@@ -548,6 +87,7 @@ async def _do_summarize_one(
# 清理旧的图片文件和 figures_json,避免重新总结时残留
import time as _time
_t_cleanup_start = _time.monotonic()
_cleanup_old_images(db, paper)
_t_cleanup_end = _time.monotonic()
@@ -567,7 +107,9 @@ async def _do_summarize_one(
logger.info(" [%s] 调用 pi 生成总结...", arxiv_id)
json_data, raw_output = await _generate_with_retry(
arxiv_id, meta_path, TMP_DIR / arxiv_id / "paper.pdf",
arxiv_id,
meta_path,
TMP_DIR / arxiv_id / "paper.pdf",
pdf_mode=pdf_mode,
)
_t3 = _time.monotonic()
@@ -577,7 +119,9 @@ async def _do_summarize_one(
_t4 = _time.monotonic()
logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3)
logger.info("✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0)
logger.info(
"✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0
)
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
except Exception as exc:
@@ -586,7 +130,7 @@ async def _do_summarize_one(
return _handle_summary_failure(db, paper, exc, fail_output)
finally:
cleanup_tmp(arxiv_id)
pass # cleanup_tmp(arxiv_id) # 暂时禁用,保留 PDF 用于调试图片提取
# ── 单篇入口 ────────────────────────────────────────────────────────────
@@ -604,25 +148,19 @@ async def summarize_single(
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
paper = db.execute(
select(Paper)
.where(Paper.arxiv_id == arxiv_id)
.options(*PAPER_DEFAULT_LOAD)
).unique().scalar_one_or_none()
paper = get_paper_by_arxiv_id(db, arxiv_id)
if not paper:
return {"status": "not_found", "arxiv_id": arxiv_id}
raise NotFoundError(f"Paper not found: {arxiv_id}")
make_session = _session_factory or SessionLocal
# 每篇用独立 session 避免并发问题
paper_db = make_session()
try:
paper_in_new_session = paper_db.execute(
select(Paper)
.where(Paper.arxiv_id == arxiv_id)
.options(*PAPER_DEFAULT_LOAD)
).unique().scalar_one_or_none()
result = await summarize_one(paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode)
paper_in_new_session = get_paper_by_arxiv_id(paper_db, arxiv_id)
result = await summarize_one(
paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode
)
finally:
paper_db.close()
@@ -656,10 +194,10 @@ async def summarize_batch(
try:
db.add(lock)
db.commit()
except Exception:
except IntegrityError:
db.rollback()
logger.warning("Summarize batch already running (lock conflict)")
return {"status": "conflict", "error": "summarize batch already running"}
raise ConflictError("summarize batch already running")
# CrawlLog
log_entry = CrawlLog(
@@ -717,19 +255,18 @@ async def summarize_batch(
break
paper_db = make_session()
try:
p = paper_db.execute(
select(Paper)
.where(Paper.id == paper.id)
.options(*PAPER_DEFAULT_LOAD)
).unique().scalar_one_or_none()
p = get_paper_by_id(paper_db, paper.id)
result = await summarize_one(paper_db, p, pdf_mode=pdf_mode)
status = result.get("status", "failed")
progress[status] = progress.get(status, 0) + 1
finished = sum(progress.values())
logger.info(
"📊 进度: %d/%d (✅%d%d ⏭️%d) — %s",
finished, total,
progress["done"], progress["failed"], progress["skipped"],
finished,
total,
progress["done"],
progress["failed"],
progress["skipped"],
paper.arxiv_id,
)
results.append(result)
@@ -785,10 +322,10 @@ async def summarize_batch(
except Exception as exc:
logger.exception("Summarize batch failed")
log_entry.status = "failed"
log_entry.error = str(exc)[:2000]
log_entry.error = truncate_error(exc, limit=2000)
log_entry.completed_at = utc_now()
db.commit()
return {"status": "failed", "error": str(exc)}
return {"status": "failed", "error": truncate_error(exc)}
finally:
release_lock(db, lock)
+275
View File
@@ -0,0 +1,275 @@
"""AI 总结生成器 — AI 后端调用、重试循环、JSON 验证、错误分类。"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from pydantic import ValidationError
from app.config import settings
from app.services.pdf_downloader import (
PdfDownloadError,
paper_dir,
)
from app.services.summary_utils import (
JsonNotFoundError,
build_prompt,
extract_json,
extract_pdf_text,
)
from app.services.pi_client import (
PiProcessError,
PiTimeoutError,
call_pi,
)
from app.services import claude_backend
from app.services.schemas import classify_validation_error
from app.utils import truncate_error
logger = logging.getLogger(__name__)
# ── 错误分类 ────────────────────────────────────────────────────────────
def _classify_error(exc: Exception) -> str:
"""将异常映射到 error_type 枚举值。"""
if isinstance(exc, PdfDownloadError):
return "pdf_download_failed"
if isinstance(exc, PiTimeoutError):
return "timeout"
if isinstance(exc, PiProcessError):
return "process_error"
if isinstance(exc, JsonNotFoundError):
return "json_not_found"
if isinstance(exc, json.JSONDecodeError):
return "json_invalid"
if isinstance(exc, ValidationError):
return classify_validation_error(exc)
return "unknown"
# ── JSON 验证 ──────────────────────────────────────────────────────────
def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
"""验证 JSON 数据是否符合要求,返回错误列表(空=通过)。"""
errors: list[str] = []
if not isinstance(json_data, dict):
return ["顶层必须是 JSON 对象"]
# 必填字段
for f in ["arxiv_id", "title_zh", "one_line", "tags"]:
if f not in json_data or not json_data[f]:
errors.append(f"缺少必填字段: {f}")
# tags 必须是非空数组
tags = json_data.get("tags")
if not isinstance(tags, list) or len(tags) == 0:
errors.append("tags 必须是非空数组")
# 字符串段落字段(必须是 str 且 ≥50 字)
string_fields = [
("motivation", "problem"),
("motivation", "goal"),
("motivation", "gap"),
("method", "overview"),
("method", "key_idea"),
("method", "steps"),
("method", "novelty"),
("results", "main_findings"),
("results", "limitations"),
("improvements", "weaknesses"),
("improvements", "future_work"),
("improvements", "reproducibility"),
]
for section, field in string_fields:
val = json_data.get(section, {}).get(field)
if isinstance(val, list):
errors.append(f"{section}.{field} 应该是字符串段落,不能是数组")
elif not isinstance(val, str) or len(val.strip()) < 50:
errors.append(
f"{section}.{field} 必须是详细段落(≥50字),"
f"当前: {type(val).__name__} ({len(str(val))}字)"
)
# benchmarks 必须是数组
benchmarks = json_data.get("results", {}).get("benchmarks")
if benchmarks is not None and not isinstance(benchmarks, list):
errors.append("results.benchmarks 必须是数组")
# prerequisites.concepts 必须是对象数组,每个有 term
concepts = json_data.get("prerequisites", {}).get("concepts")
if concepts is not None:
if not isinstance(concepts, list):
errors.append("prerequisites.concepts 必须是数组")
elif len(concepts) == 0:
errors.append("prerequisites.concepts 不能为空")
else:
for i, c in enumerate(concepts):
if isinstance(c, str):
errors.append(
f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串"
)
elif isinstance(c, dict) and not c.get("term"):
errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段")
# figures 必须是数组,每个元素应有 id
figures = json_data.get("figures")
if figures is not None:
if not isinstance(figures, list):
errors.append("figures 必须是数组")
else:
for i, fig in enumerate(figures):
if isinstance(fig, dict) and not fig.get("id"):
errors.append(f"figures[{i}] 缺少 id 字段")
return errors
# ── AI 调用 + 重试 ──────────────────────────────────────────────────────
async def _generate_with_retry(
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
) -> tuple[dict, str]:
"""调用 AI 后端生成总结,最多 4 轮验证循环。
根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。
Returns:
(json_data, raw_output)
Raises:
ValueError: 4 轮验证仍未通过
"""
import time as _time
backend = settings.SUMMARY_BACKEND
validation_errors: list[str] = []
json_data: dict | None = None
raw_output = ""
session_id = None
summary_file = paper_dir(arxiv_id) / "summary.json"
# claude 后端需要预构建 promptpi 后端在 call_pi 内部构建)
claude_prompt: str | None = None
if backend == "claude":
_t0 = _time.monotonic()
txt_path = extract_pdf_text(pdf_path, max_chars=None)
body = txt_path.read_text(encoding="utf-8")
if len(body) > 80_000:
trimmed = body[:80_000].rstrip()
txt_path.write_text(trimmed, encoding="utf-8")
claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None)
logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0)
for attempt in range(1, settings.SUMMARY_MAX_RETRIES + 1):
# 清理上一轮写入的不完整文件
if summary_file.exists():
summary_file.unlink()
# 记录 AI 调用开始时间
_t_call_start = _time.monotonic()
if backend == "claude":
if attempt == 1:
raw_output, session_id = await claude_backend.call_claude(
claude_prompt,
session_id=None,
)
else:
retry_prompt = build_prompt(
arxiv_id,
meta_path,
extract_pdf_text(pdf_path, max_chars=80000),
"inject",
fix_errors=validation_errors,
)
raw_output, session_id = await claude_backend.call_claude(
retry_prompt,
session_id=session_id,
fix_errors=validation_errors,
)
else:
if attempt == 1:
raw_output, session_id = await call_pi(
meta_path, pdf_path, pdf_mode=pdf_mode
)
else:
raw_output, session_id = await call_pi(
meta_path,
pdf_path,
fix_errors=validation_errors,
session_id=session_id,
pdf_mode=pdf_mode,
)
_t_call_end = _time.monotonic()
# 检查 summary.json 是否由 AI 子进程写入
file_written_by_ai = summary_file.exists()
file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None
file_size = summary_file.stat().st_size if file_written_by_ai else 0
logger.info(
" [%s] attempt %d AI调用: %.2fs summary.json=%s%s",
arxiv_id,
attempt,
_t_call_end - _t_call_start,
f"已写入({file_size}B)" if file_written_by_ai else "未写入",
f" mtime={file_mtime:.2f}" if file_mtime else "",
)
# 提取 JSON
_t_json_start = _time.monotonic()
try:
if file_written_by_ai:
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id)
else:
json_data = extract_json(raw_output)
except (json.JSONDecodeError, JsonNotFoundError) as exc:
_t_json_end = _time.monotonic()
logger.warning(
" [%s] JSON提取失败: %.2fs %s",
arxiv_id,
_t_json_end - _t_json_start,
str(exc)[:200],
)
validation_errors = [f"无法提取有效 JSON: {truncate_error(exc)}"]
continue
_t_json_end = _time.monotonic()
# 验证
_t_val_start = _time.monotonic()
validation_errors = _validate_summary(json_data, arxiv_id)
_t_val_end = _time.monotonic()
if not validation_errors:
logger.info(
" [%s] JSON提取: %.2fs 验证: %.2fs ✅",
arxiv_id,
_t_json_end - _t_json_start,
_t_val_end - _t_val_start,
)
break
logger.warning(
" [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s",
arxiv_id,
_t_json_end - _t_json_start,
_t_val_end - _t_val_start,
"; ".join(validation_errors)[:200],
)
if validation_errors:
exc = ValueError(
f"Summary validation failed after {settings.SUMMARY_MAX_RETRIES} attempts: {'; '.join(validation_errors)}"
)
exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用
raise exc
return json_data, raw_output
+273
View File
@@ -0,0 +1,273 @@
"""AI 总结持久化 — DB 写入、文件保存、FTS 索引、图片提取、ChromaDB 索引。"""
from __future__ import annotations
import logging
from sqlalchemy import text
from sqlalchemy.orm import Session
from app.models import (
Paper,
PaperSummary,
PaperTag,
SummaryState,
)
from app.services.pdf_downloader import paper_dir
from app.services.schemas import (
SummarySchema,
assess_quality,
flatten_for_db,
)
from app.services.summary_generator import _classify_error
from app.utils import TMP_DIR, truncate_error, utc_now
logger = logging.getLogger(__name__)
# ── FTS5 文本构建 ───────────────────────────────────────────────────────
def _build_fts_summary_text(schema: SummarySchema) -> str:
"""拼接用于 FTS5 索引的总结文本。"""
parts = [
schema.one_line or "",
schema.motivation.problem or "",
schema.motivation.goal or "",
schema.method.overview or "",
schema.method.key_idea or "",
schema.results.main_findings or "",
]
return " ".join(p for p in parts if p)
# ── DB 更新 ─────────────────────────────────────────────────────────────
def _update_summary_in_db(
db: Session,
paper: Paper,
schema: SummarySchema,
quality: str,
raw_output: str,
) -> None:
"""将校验后的总结写入 DBpaper_summaries + papers + paper_tags + FTS5。"""
# 1. paper_summariesupsert
existing = db.get(PaperSummary, paper.id)
flat = flatten_for_db(schema)
if existing:
for k, v in flat.items():
setattr(existing, k, v)
else:
db.add(PaperSummary(paper_id=paper.id, **flat))
# 2. papers 表
paper.title_zh = schema.title_zh
paper.summary_quality = quality
p_dir = paper_dir(paper.arxiv_id)
paper.summary_path = str(p_dir / "summary.json")
paper.raw_output_path = str(p_dir / "raw_output.txt")
# 3. AI 标签
existing_tag_names = {t.tag for t in paper.tags}
for tag_name in schema.tags:
if tag_name not in existing_tag_names:
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
existing_tag_names.add(tag_name)
# 4. FTS5 更新
summary_text = _build_fts_summary_text(schema)
db.execute(
text(
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
"WHERE rowid=:paper_id"
),
{
"title_zh": schema.title_zh,
"summary_text": summary_text,
"paper_id": paper.id,
},
)
db.commit()
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
# ── 文件操作 ────────────────────────────────────────────────────────────
def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) -> None:
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
if schema:
(d / "summary.json").write_text(
schema.model_dump_json(ensure_ascii=False, indent=2),
encoding="utf-8",
)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
# ── 失败处理 ────────────────────────────────────────────────────────────
def _handle_summary_failure(
db: Session,
paper: Paper,
exc: Exception,
raw_output: str,
) -> dict:
"""记录失败:保存 raw_output、重试计数、错误分类。"""
from app.config import settings
error_type = _classify_error(exc)
logger.error(
"Summarize failed: %s error_type=%s %s",
paper.arxiv_id,
error_type,
truncate_error(exc),
)
status = paper.summary_status
if raw_output:
_save_files(paper.arxiv_id, None, raw_output)
status.raw_output_saved = True
status.retry_count = (status.retry_count or 0) + 1
status.error_type = error_type
status.error = truncate_error(exc, limit=2000)
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
status.status = SummaryState.PERMANENT_FAILURE
else:
status.status = SummaryState.PENDING
status.completed_at = utc_now()
db.commit()
return {
"arxiv_id": paper.arxiv_id,
"status": "failed",
"error_type": error_type,
"error": truncate_error(exc),
"retry_count": status.retry_count,
}
# ── 持久化 ──────────────────────────────────────────────────────────────
def _persist_summary(
db: Session, paper: Paper, json_data: dict, raw_output: str
) -> str:
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
import time as _time
arxiv_id = paper.arxiv_id
_t0 = _time.monotonic()
schema = SummarySchema.model_validate(json_data)
quality = assess_quality(schema)
_t1 = _time.monotonic()
_save_files(arxiv_id, schema, raw_output)
_t2 = _time.monotonic()
_update_summary_in_db(db, paper, schema, quality, raw_output)
_t3 = _time.monotonic()
# 状态 → done
paper.summary_status.status = SummaryState.DONE
paper.summary_status.quality = quality
paper.summary_status.completed_at = utc_now()
paper.summary_status.raw_output_saved = True
db.commit()
_t4 = _time.monotonic()
logger.info(
" [%s] persist: pydantic=%.2fs 文件=%.2fs DB写入=%.2fs 状态commit=%.2fs",
arxiv_id,
_t1 - _t0,
_t2 - _t1,
_t3 - _t2,
_t4 - _t3,
)
# 触发性增强(失败不影响总结)
_t5 = _time.monotonic()
_maybe_extract_images(arxiv_id, schema)
_t6 = _time.monotonic()
_maybe_index_chroma(arxiv_id, paper, schema)
_t7 = _time.monotonic()
logger.info(
" [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs",
arxiv_id,
_t6 - _t5,
_t7 - _t6,
)
return quality
# ── 清理 ────────────────────────────────────────────────────────────────
def _cleanup_old_images(db: Session, paper: Paper) -> None:
"""清理旧的图片文件和 figures_json,避免重新总结时残留。"""
arxiv_id = paper.arxiv_id
images_dir = paper_dir(arxiv_id) / "images"
if images_dir.exists():
for old_file in images_dir.iterdir():
if (
old_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg")
or old_file.name == "manifest.json"
):
old_file.unlink(missing_ok=True)
# 清除数据库中的 figures_json
if paper.summary and paper.summary.figures_json:
paper.summary.figures_json = None
db.commit()
# ── 触发性增强 ──────────────────────────────────────────────────────────
def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
"""从 PDF 提取图片和表格(失败不影响总结)。
两阶段流水线:
1. PicoDet 检测 + 渲染截图(通用标签)
2. 用 summary 的 figures ID 在 PDF 中搜索定位 → 重命名
"""
try:
from app.services.pdf_image_extractor import (
extract_images_from_pdf,
label_images_by_summary,
)
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
extract_images_from_pdf(arxiv_id, pdf_path)
if schema.figures:
label_images_by_summary(arxiv_id, schema.figures, pdf_path)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None:
"""写入 ChromaDB 语义索引(失败不影响总结)。"""
try:
from app.services.embedder import index_paper
texts_dict = {
"arxiv_id": arxiv_id,
"title_zh": schema.title_zh or "",
"title_en": paper.title_en or "",
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
"one_line": schema.one_line or "",
"motivation_problem": schema.motivation.problem or "",
"method_key_idea": schema.method.key_idea or "",
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
}
index_paper(arxiv_id, texts_dict)
except Exception:
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
+13 -8
View File
@@ -80,11 +80,16 @@ def _trim_body(text: str, max_chars: int | None = None) -> str:
ack_match = re.search(r"(?m)^(?:Acknowledgments?\s*|致谢\s*)$", text)
if ack_match:
# 只删 Acknowledgments 本身,不删后面的内容
next_section = re.search(r"(?m)^(?:A\s|Appendix|Supplementary|附录)\s*$", text[ack_match.start():])
next_section = re.search(
r"(?m)^(?:A\s|Appendix|Supplementary|附录)\s*$", text[ack_match.start() :]
)
if next_section:
text = text[:ack_match.start()] + text[ack_match.start() + next_section.start():]
text = (
text[: ack_match.start()]
+ text[ack_match.start() + next_section.start() :]
)
else:
text = text[:ack_match.start()].rstrip()
text = text[: ack_match.start()].rstrip()
# 最后:如果指定了上限且超长,从末尾截断(附录在后面,正文在前面,优先保留正文)
if max_chars is not None and len(text) > max_chars:
@@ -105,10 +110,9 @@ def extract_pdf_text(pdf_path: Path, max_chars: int | None = None) -> Path:
# 缓存优先;如果需重新提取(不同 max_chars),先删旧文件
return txt_path
doc = pymupdf.open(str(pdf_path))
# sort=True 启用阅读顺序检测,避免双栏论文中跨栏错位
raw_text = "\n\n".join(page.get_text(sort=True) for page in doc)
doc.close()
with pymupdf.open(str(pdf_path)) as doc:
# sort=True 启用阅读顺序检测,避免双栏论文中跨栏错位
raw_text = "\n\n".join(page.get_text(sort=True) for page in doc)
body = _trim_body(raw_text, max_chars=max_chars)
txt_path.write_text(body, encoding="utf-8")
@@ -160,7 +164,8 @@ def build_prompt(
'"reproducibility": "详细段落:复现评估(开源情况、数据、算力、难度")}, '
'"figures": [{"id":"Figure 1","caption":"原图标题","description":"文字描述图展示了什么","reason":"为什么这张图对理解论文重要","section":"method"},'
'{"id":"Table 1","caption":"表格标题","description":"文字描述表格包含的数据和结论","reason":"为什么这个表格对理解论文重要","section":"results"}]'
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Tableid 严格使用 \"Figure N\"\"Table N\" 格式"
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Table。"
'id 必须严格复用论文原文的写法(原文用 "Fig. 1" 就写 "Fig. 1",用 "Figure A1" 就写 "Figure A1",用 "Table 1" 就写 "Table 1")。'
"section 必须是 motivation/method/results/limitations 之一,表示该图最适合展示在哪个章节。"
"}"
)
+28 -11
View File
@@ -5,7 +5,15 @@ from __future__ import annotations
from sqlalchemy import or_, select
from sqlalchemy.orm import Session, joinedload
from app.models import PAPER_FULL_LOAD, Paper, PaperTag, UserBookmark, UserNote, UserReadingStatus
from app.exceptions import NotFoundError, ValidationError
from app.models import (
PAPER_FULL_LOAD,
Paper,
PaperTag,
UserBookmark,
UserNote,
UserReadingStatus,
)
from app.utils import utc_now
# ── 收藏 ──────────────────────────────────────────────────────────────
@@ -13,9 +21,11 @@ from app.utils import utc_now
def toggle_bookmark(db: Session, arxiv_id: str) -> dict:
"""切换收藏状态。返回 {"bookmarked": bool, "arxiv_id": str}。"""
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
paper = db.execute(
select(Paper).where(Paper.arxiv_id == arxiv_id)
).scalar_one_or_none()
if not paper:
return {"error": "not_found"}
raise NotFoundError(f"Paper not found: {arxiv_id}")
existing = db.execute(
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
@@ -42,11 +52,15 @@ VALID_STATUSES = {"unread", "skimmed", "read_summary", "read_full"}
def set_reading_status(db: Session, arxiv_id: str, status: str) -> dict:
"""设置阅读状态。status 必须是 unread/skimmed/read_summary/read_full。"""
if status not in VALID_STATUSES:
return {"error": "invalid_status", "valid": sorted(VALID_STATUSES)}
raise ValidationError(
f"Invalid reading status: {status}. Valid: {', '.join(sorted(VALID_STATUSES))}"
)
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
paper = db.execute(
select(Paper).where(Paper.arxiv_id == arxiv_id)
).scalar_one_or_none()
if not paper:
return {"error": "not_found"}
raise NotFoundError(f"Paper not found: {arxiv_id}")
now = utc_now()
existing = db.execute(
@@ -72,7 +86,9 @@ def set_reading_status(db: Session, arxiv_id: str, status: str) -> dict:
def get_note(db: Session, arxiv_id: str) -> dict | None:
"""获取笔记。返回 {"arxiv_id", "content", "updated_at"} 或 None(论文不存在时)。"""
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
paper = db.execute(
select(Paper).where(Paper.arxiv_id == arxiv_id)
).scalar_one_or_none()
if not paper:
return None
@@ -91,9 +107,11 @@ def get_note(db: Session, arxiv_id: str) -> dict | None:
def save_note(db: Session, arxiv_id: str, content: str) -> dict:
"""创建或更新笔记。返回 {"arxiv_id", "content", "updated_at"}。"""
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
paper = db.execute(
select(Paper).where(Paper.arxiv_id == arxiv_id)
).scalar_one_or_none()
if not paper:
return {"error": "not_found"}
raise NotFoundError(f"Paper not found: {arxiv_id}")
now = utc_now()
existing = db.execute(
@@ -154,8 +172,7 @@ def query_reading_list(
stmt.options(
joinedload(Paper.note),
*PAPER_FULL_LOAD,
)
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
).order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
)
.unique()
.scalars()
+42 -6
View File
@@ -137,12 +137,35 @@ def safe_json_loads(text: str | None, default: Any = None) -> Any:
# AI 生成内容中允许的 HTML 标签和属性
_ALLOWED_TAGS = {
"p", "br", "strong", "b", "em", "i", "u", "s", "del",
"h3", "h4", "h5", "h6",
"ul", "ol", "li",
"a", "code", "pre", "blockquote",
"table", "thead", "tbody", "tr", "th", "td",
"sup", "sub", "span",
"p",
"br",
"strong",
"b",
"em",
"i",
"u",
"s",
"del",
"h3",
"h4",
"h5",
"h6",
"ul",
"ol",
"li",
"a",
"code",
"pre",
"blockquote",
"table",
"thead",
"tbody",
"tr",
"th",
"td",
"sup",
"sub",
"span",
}
_ALLOWED_ATTRS = {
"a": {"href", "title"},
@@ -167,3 +190,16 @@ def sanitize_html(text: str | None) -> str:
strip=True,
)
return cleaned
# ── 错误消息截断 ────────────────────────────────────────────────────────
_ERROR_TRUNCATE_LIMIT = 500
def truncate_error(exc: Exception | str, limit: int = _ERROR_TRUNCATE_LIMIT) -> str:
"""将异常或字符串截断到指定长度,保持统一的错误消息格式。"""
text = str(exc)
if len(text) <= limit:
return text
return text[:limit] + f"... ({len(text)} chars total)"