feat: refactor summarizer and PDF extraction pipeline
- Split summarizer into summary_generator and summary_persister modules - Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection - Add layout_detector service for PicoDet-S_layout_3cls integration - Add exceptions module with ConflictError and NotFoundError - Improve admin dashboard with better statistics and task management - Add design review document with system optimization suggestions - Add new tests for crawler, pdf_downloader, pipeline, and summary_utils - Update dependencies and configuration - Clean up dead code and improve error handling
This commit is contained in:
+27
-14
@@ -42,9 +42,16 @@ def crawl(
|
||||
try:
|
||||
# 检查是否已抓取过(非 force 模式)
|
||||
if not force and not date_str:
|
||||
existing = db.scalar(select(func.count(Paper.id)).where(Paper.paper_date == target)) or 0
|
||||
existing = (
|
||||
db.scalar(
|
||||
select(func.count(Paper.id)).where(Paper.paper_date == target)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if existing > 0:
|
||||
typer.echo(f"⏭️ {target} 已有 {existing} 篇论文,跳过(用 --force 强制重抓)")
|
||||
typer.echo(
|
||||
f"⏭️ {target} 已有 {existing} 篇论文,跳过(用 --force 强制重抓)"
|
||||
)
|
||||
return
|
||||
|
||||
typer.echo(f"📡 开始抓取 {target} ...")
|
||||
@@ -56,7 +63,12 @@ def crawl(
|
||||
)
|
||||
if need_fallback:
|
||||
fallback = yesterday_str()
|
||||
existing = db.scalar(select(func.count(Paper.id)).where(Paper.paper_date == fallback)) or 0
|
||||
existing = (
|
||||
db.scalar(
|
||||
select(func.count(Paper.id)).where(Paper.paper_date == fallback)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if existing > 0:
|
||||
typer.echo(
|
||||
f"⏭️ {fallback} 已有 {existing} 篇论文,跳过(用 --force 强制重抓)"
|
||||
@@ -103,7 +115,9 @@ def summarize(
|
||||
import os
|
||||
|
||||
if pdf_mode not in ("auto", "inject", "search"):
|
||||
typer.echo(f"❌ 无效的 pdf_mode: {pdf_mode},只支持 auto / inject / search", err=True)
|
||||
typer.echo(
|
||||
f"❌ 无效的 pdf_mode: {pdf_mode},只支持 auto / inject / search", err=True
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
if backend:
|
||||
@@ -122,6 +136,8 @@ def summarize(
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
from app.exceptions import ConflictError, NotFoundError
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if arxiv_id:
|
||||
@@ -131,16 +147,13 @@ def summarize(
|
||||
typer.echo(f"🤖 开始批量总结 pending 论文 (mode={pdf_mode}) ...")
|
||||
result = asyncio.run(summarize_batch(db, pdf_mode=pdf_mode))
|
||||
|
||||
if result.get("status") in ("success", "done"):
|
||||
typer.echo(f"✅ 总结完成:{result}")
|
||||
elif result.get("status") == "conflict":
|
||||
typer.echo("⚠️ 已有批量总结任务在运行中", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
elif result.get("status") == "not_found":
|
||||
typer.echo(f"❌ 论文未找到:{arxiv_id}", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
else:
|
||||
typer.echo(f"⚠️ 总结结果:{result}", err=True)
|
||||
typer.echo(f"✅ 总结完成:{result}")
|
||||
except NotFoundError as exc:
|
||||
typer.echo(f"❌ {exc.message}", err=True)
|
||||
raise typer.Exit(code=1) from exc
|
||||
except ConflictError as exc:
|
||||
typer.echo(f"⚠️ {exc.message}", err=True)
|
||||
raise typer.Exit(code=1) from exc
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
+8
-1
@@ -27,6 +27,7 @@ class Settings(BaseSettings):
|
||||
HTTP_TIMEOUT_SECONDS: int = 30
|
||||
HTTP_MAX_RETRIES: int = 3
|
||||
HTTP_USER_AGENT: str = "hf-daily-papers-local/0.1"
|
||||
PDF_DOWNLOAD_TIMEOUT: int = 120
|
||||
|
||||
# AI 总结
|
||||
SUMMARY_BACKEND: str = "pi" # "pi" | "claude"
|
||||
@@ -36,7 +37,9 @@ class Settings(BaseSettings):
|
||||
SUMMARY_CONCURRENCY: int = 3
|
||||
SUMMARY_TIMEOUT_SECONDS: int = 1200
|
||||
SUMMARY_MAX_RETRIES: int = 2
|
||||
SUMMARY_PDF_MODE: str = "auto" # "auto" = ≤80k 用 inject,>80k 用 search;也可强制 "inject" / "search"
|
||||
SUMMARY_PDF_MODE: str = (
|
||||
"auto" # "auto" = ≤80k 用 inject,>80k 用 search;也可强制 "inject" / "search"
|
||||
)
|
||||
|
||||
# 调度
|
||||
SCHEDULER_ENABLED: bool = False
|
||||
@@ -56,6 +59,10 @@ class Settings(BaseSettings):
|
||||
EMBED_MODEL: str = ""
|
||||
EMBED_DIMENSIONS: int = 0
|
||||
|
||||
# 布局检测
|
||||
LAYOUT_MODEL_PATH: str = "data/models/picodet_layout_3cls.onnx"
|
||||
LAYOUT_THRESHOLD: float = 0.5
|
||||
|
||||
model_config = {
|
||||
"env_file": str(BASE_DIR / ".env"),
|
||||
"env_file_encoding": "utf-8",
|
||||
|
||||
+2
-5
@@ -82,15 +82,12 @@ def _migrate(engine) -> None:
|
||||
for table, columns in _MIGRATIONS.items():
|
||||
# 获取已有列名
|
||||
existing = {
|
||||
row[1]
|
||||
for row in conn.execute(text(f"PRAGMA table_info({table})"))
|
||||
row[1] for row in conn.execute(text(f"PRAGMA table_info({table})"))
|
||||
}
|
||||
for col_name, col_type in columns:
|
||||
if col_name not in existing:
|
||||
conn.execute(
|
||||
text(
|
||||
f"ALTER TABLE {table} ADD COLUMN {col_name} {col_type}"
|
||||
)
|
||||
text(f"ALTER TABLE {table} ADD COLUMN {col_name} {col_type}")
|
||||
)
|
||||
logger.info("Migrated: %s.%s added", table, col_name)
|
||||
conn.commit()
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""业务异常体系 — 统一错误类型,供路由层和 service 层使用。
|
||||
|
||||
路由层通过 main.py 的 @app.exception_handler(AppError) 统一捕获,
|
||||
转为对应 HTTP 状态码 + JSON 响应。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class AppError(Exception):
|
||||
"""所有业务异常的基类。"""
|
||||
|
||||
def __init__(self, message: str = "", *, detail: str = ""):
|
||||
self.message = message or detail or self.__class__.__name__
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class NotFoundError(AppError):
|
||||
"""资源不存在(404)。"""
|
||||
|
||||
|
||||
class ValidationError(AppError):
|
||||
"""请求参数校验失败(400)。"""
|
||||
|
||||
|
||||
class ExternalAPIError(AppError):
|
||||
"""外部 API 调用失败(502)。"""
|
||||
|
||||
|
||||
class PdfProcessError(AppError):
|
||||
"""PDF 处理错误(500)。"""
|
||||
|
||||
|
||||
class ConflictError(AppError):
|
||||
"""资源冲突(409)— 如锁冲突、并发任务冲突。"""
|
||||
+30
-3
@@ -5,10 +5,12 @@ import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
|
||||
from app.config import settings
|
||||
from app.exceptions import AppError, ConflictError, ExternalAPIError, NotFoundError, PdfProcessError, ValidationError
|
||||
from app.database import engine, init_db
|
||||
from app.routes.admin import router as admin_router
|
||||
from app.routes.compare import router as compare_router
|
||||
@@ -38,8 +40,10 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# ── shutdown ──
|
||||
from app.services.scheduler import stop_scheduler
|
||||
from app.services.pdf_downloader import close_http_session
|
||||
|
||||
stop_scheduler()
|
||||
close_http_session()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
@@ -60,15 +64,38 @@ def create_app() -> FastAPI:
|
||||
# Session 中间件
|
||||
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY)
|
||||
|
||||
# ── 统一业务异常处理 ──
|
||||
@app.exception_handler(NotFoundError)
|
||||
async def _not_found_handler(request, exc):
|
||||
return JSONResponse(status_code=404, content={"error": exc.message})
|
||||
|
||||
@app.exception_handler(ValidationError)
|
||||
async def _validation_handler(request, exc):
|
||||
return JSONResponse(status_code=400, content={"error": exc.message})
|
||||
|
||||
@app.exception_handler(ExternalAPIError)
|
||||
async def _external_api_handler(request, exc):
|
||||
return JSONResponse(status_code=502, content={"error": exc.message})
|
||||
|
||||
@app.exception_handler(PdfProcessError)
|
||||
async def _pdf_process_handler(request, exc):
|
||||
return JSONResponse(status_code=500, content={"error": exc.message})
|
||||
|
||||
@app.exception_handler(ConflictError)
|
||||
async def _conflict_handler(request, exc):
|
||||
return JSONResponse(status_code=409, content={"error": exc.message})
|
||||
|
||||
@app.exception_handler(AppError)
|
||||
async def _app_error_handler(request, exc):
|
||||
return JSONResponse(status_code=500, content={"error": exc.message})
|
||||
|
||||
# 安全警告
|
||||
if settings.SECRET_KEY == "change-me":
|
||||
logger.warning(
|
||||
"⚠️ SECRET_KEY is the default value 'change-me'. Please change it in .env!"
|
||||
)
|
||||
if not settings.ADMIN_PASSWORD:
|
||||
logger.warning(
|
||||
"⚠️ ADMIN_PASSWORD is empty. Please set it in .env!"
|
||||
)
|
||||
logger.warning("⚠️ ADMIN_PASSWORD is empty. Please set it in .env!")
|
||||
|
||||
# 静态文件
|
||||
app.mount("/static", StaticFiles(directory="app/static"), name="static")
|
||||
|
||||
+17
-4
@@ -12,6 +12,7 @@ from sqlalchemy import (
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.orm import joinedload, relationship
|
||||
|
||||
@@ -93,7 +94,7 @@ class PaperAuthor(Base):
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
paper_id = Column(
|
||||
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
|
||||
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
name = Column(String, nullable=False)
|
||||
position = Column(Integer, default=0)
|
||||
@@ -108,7 +109,7 @@ class PaperTag(Base):
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
paper_id = Column(
|
||||
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
|
||||
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
tag = Column(String, nullable=False)
|
||||
source = Column(String, default="hf")
|
||||
@@ -155,7 +156,7 @@ class SummaryStatus(Base):
|
||||
paper_id = Column(
|
||||
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
status = Column(String, nullable=False, default="pending")
|
||||
status = Column(String, nullable=False, default="pending", index=True)
|
||||
quality = Column(String)
|
||||
error_type = Column(String)
|
||||
error = Column(Text)
|
||||
@@ -219,7 +220,7 @@ class UserReadingStatus(Base):
|
||||
paper_id = Column(
|
||||
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
status = Column(String, nullable=False, default="unread")
|
||||
status = Column(String, nullable=False, default="unread", index=True)
|
||||
updated_at = Column(DateTime, nullable=False)
|
||||
|
||||
paper = relationship("Paper", back_populates="reading_status")
|
||||
@@ -271,3 +272,15 @@ PAPER_FULL_LOAD = (
|
||||
joinedload(Paper.bookmark),
|
||||
joinedload(Paper.reading_status),
|
||||
)
|
||||
|
||||
|
||||
def get_paper_by_arxiv_id(db, arxiv_id: str, *, load=PAPER_DEFAULT_LOAD):
|
||||
"""按 arxiv_id 查询论文(带关联加载),未找到返回 None。"""
|
||||
stmt = select(Paper).where(Paper.arxiv_id == arxiv_id).options(*load)
|
||||
return db.execute(stmt).unique().scalar_one_or_none()
|
||||
|
||||
|
||||
def get_paper_by_id(db, paper_id: int, *, load=PAPER_DEFAULT_LOAD):
|
||||
"""按主键查询论文(带关联加载),未找到返回 None。"""
|
||||
stmt = select(Paper).where(Paper.id == paper_id).options(*load)
|
||||
return db.execute(stmt).unique().scalar_one_or_none()
|
||||
|
||||
+94
-144
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
from datetime import date
|
||||
@@ -10,7 +11,7 @@ from datetime import date
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel, field_validator
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy import bindparam, func, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
@@ -22,15 +23,15 @@ from app.models import (
|
||||
PaperTag,
|
||||
SummaryState,
|
||||
SummaryStatus,
|
||||
TaskLock,
|
||||
)
|
||||
from app.services import admin as admin_svc
|
||||
from app.services.admin import get_admin_stats
|
||||
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
|
||||
from app.services.crawler import crawl_daily, refresh_upvotes
|
||||
from app.services.pipeline import run_pipeline
|
||||
from app.services.crawler import refresh_upvotes
|
||||
from app.services.pipeline import run_crawl, run_pipeline
|
||||
from app.services.scheduler import get_scheduler
|
||||
from app.services.summarizer import summarize_batch, summarize_single
|
||||
from app.utils import release_lock, templates, today_str, utc_now
|
||||
from app.utils import templates, today_str, utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,14 +42,15 @@ router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
|
||||
def _check_password(password: str) -> bool:
|
||||
"""校验密码,支持明文或 sha256 哈希。"""
|
||||
"""校验密码,支持明文或 sha256 哈希(常量时间比较)。"""
|
||||
stored = settings.ADMIN_PASSWORD
|
||||
if not stored:
|
||||
return False
|
||||
if password == stored:
|
||||
if hmac.compare_digest(password, stored):
|
||||
return True
|
||||
# 也支持存 sha256 哈希
|
||||
return hashlib.sha256(password.encode()).hexdigest() == stored
|
||||
hashed = hashlib.sha256(password.encode()).hexdigest()
|
||||
return hmac.compare_digest(hashed, stored)
|
||||
|
||||
|
||||
async def verify_admin(request: Request) -> None:
|
||||
@@ -204,32 +206,12 @@ async def admin_crawl(
|
||||
):
|
||||
"""手动抓取指定日期,默认今天。"""
|
||||
target_date = date or today_str()
|
||||
|
||||
# TaskLock 防重入
|
||||
now = utc_now()
|
||||
lock = TaskLock(
|
||||
task="crawl",
|
||||
lock_key=target_date,
|
||||
status="running",
|
||||
owner="admin_crawl",
|
||||
acquired_at=now,
|
||||
)
|
||||
try:
|
||||
db.add(lock)
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409, detail=f"Crawl already running for {target_date}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await crawl_daily(db, target_date)
|
||||
return result
|
||||
return await run_crawl(db, target_date, owner="admin_crawl")
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc))
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
finally:
|
||||
release_lock(db, lock)
|
||||
|
||||
|
||||
# ── 总结 ──────────────────────────────────────────────────────────────
|
||||
@@ -241,12 +223,7 @@ async def admin_summarize_batch(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""批量总结所有 pending 论文。"""
|
||||
result = await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE)
|
||||
if result.get("status") == "conflict":
|
||||
raise HTTPException(
|
||||
status_code=409, detail=result.get("error", "batch already running")
|
||||
)
|
||||
return result
|
||||
return await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE)
|
||||
|
||||
|
||||
@router.post("/summarize/{arxiv_id}")
|
||||
@@ -256,10 +233,9 @@ async def admin_summarize_single(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""总结或重跑单篇论文。"""
|
||||
result = await summarize_single(db, arxiv_id, force=True, pdf_mode=settings.SUMMARY_PDF_MODE)
|
||||
if result.get("status") == "not_found":
|
||||
raise HTTPException(status_code=404, detail=f"Paper not found: {arxiv_id}")
|
||||
return result
|
||||
return await summarize_single(
|
||||
db, arxiv_id, force=True, pdf_mode=settings.SUMMARY_PDF_MODE
|
||||
)
|
||||
|
||||
|
||||
# ── 清理 ──────────────────────────────────────────────────────────────
|
||||
@@ -284,10 +260,13 @@ async def admin_cleanup(
|
||||
result = cleanup_tmp()
|
||||
log_entry.status = "success"
|
||||
log_entry.completed_at = utc_now()
|
||||
log_entry.details_json = json.dumps({
|
||||
"scanned": result.get("scanned", 0),
|
||||
"removed": result.get("removed", 0),
|
||||
}, ensure_ascii=False)
|
||||
log_entry.details_json = json.dumps(
|
||||
{
|
||||
"scanned": result.get("scanned", 0),
|
||||
"removed": result.get("removed", 0),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
if result.get("errors"):
|
||||
log_entry.error = "; ".join(result["errors"])[:2000]
|
||||
db.commit()
|
||||
@@ -358,19 +337,34 @@ async def admin_logs(
|
||||
|
||||
# 总结状态统计概要
|
||||
summary_total = db.scalar(select(func.count(Paper.id))) or 0
|
||||
summary_done = db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(SummaryStatus.status == SummaryState.DONE)
|
||||
) or 0
|
||||
summary_pending = db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status.in_([SummaryState.PENDING, SummaryState.PROCESSING])
|
||||
summary_done = (
|
||||
db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status == SummaryState.DONE
|
||||
)
|
||||
)
|
||||
) or 0
|
||||
summary_failed = db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE])
|
||||
or 0
|
||||
)
|
||||
summary_pending = (
|
||||
db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.PENDING, SummaryState.PROCESSING]
|
||||
)
|
||||
)
|
||||
)
|
||||
) or 0
|
||||
or 0
|
||||
)
|
||||
summary_failed = (
|
||||
db.scalar(
|
||||
select(func.count(SummaryStatus.id)).where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||
)
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
@@ -414,13 +408,8 @@ async def admin_summary_status(
|
||||
else:
|
||||
query = query.where(SummaryStatus.status == status)
|
||||
|
||||
total = db.scalar(
|
||||
select(func.count()).select_from(query.subquery())
|
||||
)
|
||||
results = (
|
||||
db.execute(query.offset((page - 1) * per_page).limit(per_page))
|
||||
.all()
|
||||
)
|
||||
total = db.scalar(select(func.count()).select_from(query.subquery()))
|
||||
results = db.execute(query.offset((page - 1) * per_page).limit(per_page)).all()
|
||||
|
||||
# 判断是否 HTMX 请求
|
||||
is_htmx = request.headers.get("HX-Request") == "true"
|
||||
@@ -465,7 +454,11 @@ async def admin_summary_retry_failed(
|
||||
db.execute(
|
||||
select(Paper.arxiv_id)
|
||||
.join(SummaryStatus, SummaryStatus.paper_id == Paper.id)
|
||||
.where(SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]))
|
||||
.where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
@@ -477,7 +470,11 @@ async def admin_summary_retry_failed(
|
||||
# 重置失败任务的状态为 pending
|
||||
db.execute(
|
||||
SummaryStatus.__table__.update()
|
||||
.where(SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]))
|
||||
.where(
|
||||
SummaryStatus.status.in_(
|
||||
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||
)
|
||||
)
|
||||
.values(status=SummaryState.PENDING, error=None, error_type=None)
|
||||
)
|
||||
db.commit()
|
||||
@@ -492,15 +489,6 @@ async def admin_summary_retry_failed(
|
||||
# ── 论文管理 ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# 排序映射
|
||||
_SORT_MAP = {
|
||||
"date_desc": Paper.paper_date.desc(),
|
||||
"date_asc": Paper.paper_date.asc(),
|
||||
"upvotes_desc": Paper.upvotes.desc(),
|
||||
"title_asc": Paper.title_en.asc(),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/papers")
|
||||
async def admin_papers(
|
||||
request: Request,
|
||||
@@ -516,66 +504,18 @@ async def admin_papers(
|
||||
per_page: int = Query(20, ge=1, le=100),
|
||||
):
|
||||
"""论文管理列表页面。"""
|
||||
query = select(Paper)
|
||||
|
||||
# 搜索
|
||||
if q.strip():
|
||||
query = query.where(
|
||||
Paper.title_en.ilike(f"%{q}%")
|
||||
| Paper.title_zh.ilike(f"%{q}%")
|
||||
| Paper.abstract.ilike(f"%{q}%")
|
||||
)
|
||||
|
||||
# 日期范围
|
||||
if date_from:
|
||||
query = query.where(Paper.paper_date >= date_from)
|
||||
if date_to:
|
||||
query = query.where(Paper.paper_date <= date_to)
|
||||
|
||||
# 标签筛选
|
||||
if tag:
|
||||
query = query.join(PaperTag, PaperTag.paper_id == Paper.id).where(
|
||||
PaperTag.tag == tag
|
||||
)
|
||||
|
||||
# 总结状态筛选
|
||||
if summary_status != "all":
|
||||
if summary_status == "none":
|
||||
query = query.outerjoin(
|
||||
SummaryStatus, SummaryStatus.paper_id == Paper.id
|
||||
).where(SummaryStatus.paper_id == None) # noqa: E711
|
||||
else:
|
||||
query = query.join(
|
||||
SummaryStatus, SummaryStatus.paper_id == Paper.id
|
||||
).where(SummaryStatus.status == summary_status)
|
||||
|
||||
# 排序
|
||||
order = _SORT_MAP.get(sort, Paper.paper_date.desc())
|
||||
query = query.order_by(order)
|
||||
|
||||
# 计数
|
||||
total = db.scalar(select(func.count()).select_from(query.subquery()))
|
||||
|
||||
# 分页
|
||||
papers = (
|
||||
db.execute(query.offset((page - 1) * per_page).limit(per_page))
|
||||
.scalars()
|
||||
.all()
|
||||
papers, total, statuses = admin_svc.query_papers(
|
||||
db,
|
||||
q=q,
|
||||
date_from=date_from,
|
||||
date_to=date_to,
|
||||
tag=tag,
|
||||
summary_status=summary_status,
|
||||
sort=sort,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
)
|
||||
|
||||
# 获取每篇论文的总结状态
|
||||
paper_ids = [p.id for p in papers]
|
||||
statuses = {}
|
||||
if paper_ids:
|
||||
rows = db.execute(
|
||||
select(SummaryStatus.paper_id, SummaryStatus.status).where(
|
||||
SummaryStatus.paper_id.in_(paper_ids)
|
||||
)
|
||||
).all()
|
||||
paper_id_to_arxiv = {p.id: p.arxiv_id for p in papers}
|
||||
for pid, st in rows:
|
||||
statuses[paper_id_to_arxiv.get(pid, "")] = st
|
||||
|
||||
# 构建分页 URL 辅助函数
|
||||
def pagination_url(p: int) -> str:
|
||||
params = dict(request.query_params)
|
||||
@@ -588,7 +528,7 @@ async def admin_papers(
|
||||
{
|
||||
"papers": papers,
|
||||
"paper_summary_statuses": statuses,
|
||||
"total": total or 0,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"current_status": summary_status,
|
||||
@@ -615,7 +555,9 @@ async def admin_paper_delete(
|
||||
|
||||
# 清理 FTS 索引
|
||||
try:
|
||||
db.execute(text("DELETE FROM papers_fts WHERE arxiv_id = :aid"), {"aid": arxiv_id})
|
||||
db.execute(
|
||||
text("DELETE FROM papers_fts WHERE arxiv_id = :aid"), {"aid": arxiv_id}
|
||||
)
|
||||
db.commit()
|
||||
except Exception:
|
||||
logger.warning("Failed to clean FTS index for %s", arxiv_id, exc_info=True)
|
||||
@@ -646,9 +588,11 @@ async def admin_papers_batch_action(
|
||||
raise HTTPException(status_code=400, detail="arxiv_ids 不能为空")
|
||||
|
||||
if body.action == "delete":
|
||||
papers = db.execute(
|
||||
select(Paper).where(Paper.arxiv_id.in_(body.arxiv_ids))
|
||||
).scalars().all()
|
||||
papers = (
|
||||
db.execute(select(Paper).where(Paper.arxiv_id.in_(body.arxiv_ids)))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
count = 0
|
||||
for paper in papers:
|
||||
@@ -658,21 +602,27 @@ async def admin_papers_batch_action(
|
||||
|
||||
# 清理 FTS 索引
|
||||
try:
|
||||
db.execute(
|
||||
text("DELETE FROM papers_fts WHERE arxiv_id IN :ids"),
|
||||
{"ids": tuple(body.arxiv_ids)},
|
||||
stmt = text("DELETE FROM papers_fts WHERE arxiv_id IN :ids").bindparams(
|
||||
bindparam("ids", expanding=True)
|
||||
)
|
||||
db.execute(stmt, {"ids": body.arxiv_ids})
|
||||
db.commit()
|
||||
except Exception:
|
||||
logger.warning("Failed to clean FTS index for batch delete", exc_info=True)
|
||||
|
||||
return {"status": "success", "message": f"已删除 {count} 篇论文", "count": count}
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"已删除 {count} 篇论文",
|
||||
"count": count,
|
||||
}
|
||||
|
||||
elif body.action == "summarize":
|
||||
# 将选中论文的总结状态重置为 pending
|
||||
paper_ids = db.execute(
|
||||
select(Paper.id).where(Paper.arxiv_id.in_(body.arxiv_ids))
|
||||
).scalars().all()
|
||||
paper_ids = (
|
||||
db.execute(select(Paper.id).where(Paper.arxiv_id.in_(body.arxiv_ids)))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
if paper_ids:
|
||||
# 删除旧的 status 记录让其重新进入 pipeline
|
||||
|
||||
@@ -12,6 +12,8 @@ from app.utils import templates
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
MAX_COMPARE_PAPERS = 5
|
||||
|
||||
|
||||
@router.get("/compare")
|
||||
def compare_page(
|
||||
@@ -33,9 +35,9 @@ def compare_page(
|
||||
|
||||
arxiv_ids = [i.strip() for i in ids.split(",") if i.strip()]
|
||||
|
||||
# 最多 5 篇
|
||||
if len(arxiv_ids) > 5:
|
||||
arxiv_ids = arxiv_ids[:5]
|
||||
# 最多 MAX_COMPARE_PAPERS 篇
|
||||
if len(arxiv_ids) > MAX_COMPARE_PAPERS:
|
||||
arxiv_ids = arxiv_ids[:MAX_COMPARE_PAPERS]
|
||||
|
||||
if not arxiv_ids:
|
||||
return templates.TemplateResponse(
|
||||
|
||||
+2
-99
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import date, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
@@ -15,6 +14,7 @@ from sqlalchemy.orm import Session, joinedload
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.models import PAPER_FULL_LOAD, Paper
|
||||
from app.services.pdf_image_extractor import link_figures_with_images
|
||||
from app.utils import (
|
||||
PAPERS_DIR,
|
||||
safe_json_loads,
|
||||
@@ -120,7 +120,7 @@ def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db))
|
||||
paper.summary.figures_json if paper.summary else None, default=[]
|
||||
)
|
||||
|
||||
linked_figures = _link_figures_with_images(figures_raw, images, arxiv_id)
|
||||
linked_figures = link_figures_with_images(figures_raw, images, arxiv_id)
|
||||
|
||||
# 拆分图片到对应展示区域:
|
||||
# table_figures → 实验结果区域(Table 截图,不变)
|
||||
@@ -279,100 +279,3 @@ def _get_paper_images(arxiv_id: str) -> list[dict]:
|
||||
}
|
||||
)
|
||||
return images
|
||||
|
||||
|
||||
def _link_figures_with_images(
|
||||
figures: list[dict], images: list[dict], arxiv_id: str
|
||||
) -> list[dict]:
|
||||
"""将 summary figures 元数据与提取的图片文件关联。
|
||||
|
||||
策略:
|
||||
1. 优先用 manifest.json 的 label 做 ID 精确匹配
|
||||
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
|
||||
"""
|
||||
if not figures or not images:
|
||||
return figures
|
||||
|
||||
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
|
||||
|
||||
# ── 策略 1:manifest ID 精确匹配 ──
|
||||
id_to_url: dict[str, str] = {}
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||||
except (ValueError, TypeError):
|
||||
manifest = {}
|
||||
for filename, info in manifest.items():
|
||||
url = f"/papers/{arxiv_id}/images/{filename}"
|
||||
# 优先用 label 字段(新格式)
|
||||
label = info.get("label", "")
|
||||
if label:
|
||||
id_to_url[label] = url
|
||||
# 也兼容 figures/tables 列表(旧格式)
|
||||
for fig_id in info.get("figures", []) + info.get("tables", []):
|
||||
if fig_id not in id_to_url:
|
||||
id_to_url[fig_id] = url
|
||||
|
||||
for fig in figures:
|
||||
raw_id = fig.get("id", "")
|
||||
normalized = _normalize_figure_id(raw_id)
|
||||
if normalized in id_to_url:
|
||||
fig["image_url"] = id_to_url[normalized]
|
||||
|
||||
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
|
||||
unmatched = [f for f in figures if not f.get("image_url")]
|
||||
if not unmatched:
|
||||
return figures
|
||||
|
||||
# 按类型分流:Figure vs Table
|
||||
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
|
||||
table_type_unmatched = [
|
||||
f for f in unmatched if not _is_figure_type(f.get("id", ""))
|
||||
]
|
||||
|
||||
# 提取的图片按类型分流,按文件名中的编号排序
|
||||
def _sort_key(name: str) -> tuple[int, int]:
|
||||
# 新格式:figure_1.jpg, table_1.jpg
|
||||
m = re.search(r"(?:figure|table)_(\d+)", name)
|
||||
if m:
|
||||
return (0, int(m.group(1)))
|
||||
# 旧格式:page2_img1.png, page5_table1.png, figure_1.png
|
||||
m2 = re.search(r"page(\d+)_(?:img|table)(\d+)", name)
|
||||
if m2:
|
||||
return (int(m2.group(1)), int(m2.group(2)))
|
||||
return (0, 0)
|
||||
|
||||
fig_images = sorted(
|
||||
[img for img in images if "table" not in img["name"].lower()],
|
||||
key=lambda img: _sort_key(img["name"]),
|
||||
)
|
||||
table_images = sorted(
|
||||
[img for img in images if "table" in img["name"].lower()],
|
||||
key=lambda img: _sort_key(img["name"]),
|
||||
)
|
||||
|
||||
for i, fig in enumerate(fig_type_unmatched):
|
||||
if i < len(fig_images):
|
||||
fig["image_url"] = fig_images[i]["url"]
|
||||
|
||||
for i, fig in enumerate(table_type_unmatched):
|
||||
if i < len(table_images):
|
||||
fig["image_url"] = table_images[i]["url"]
|
||||
|
||||
return figures
|
||||
|
||||
|
||||
def _normalize_figure_id(raw_id: str) -> str:
|
||||
"""归一化 Figure/Table ID:'Figure 1'/'Fig.1' → 'Figure 1'。"""
|
||||
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
|
||||
if m:
|
||||
return f"Figure {m.group(1)}"
|
||||
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
|
||||
if m2:
|
||||
return f"Table {m2.group(1)}"
|
||||
return raw_id
|
||||
|
||||
|
||||
def _is_figure_type(fig_id: str) -> bool:
|
||||
"""判断是否为 Figure 类型(非 Table)。"""
|
||||
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
|
||||
|
||||
+5
-23
@@ -2,12 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.exceptions import NotFoundError
|
||||
from app.services.user_data import (
|
||||
get_note,
|
||||
save_note,
|
||||
@@ -37,9 +38,6 @@ def bookmark_toggle(arxiv_id: str, request: Request, db: Session = Depends(get_d
|
||||
"""切换收藏状态。支持 HTMX 局部刷新和 JSON 响应。"""
|
||||
result = toggle_bookmark(db, arxiv_id)
|
||||
|
||||
if "error" in result:
|
||||
raise HTTPException(status_code=404, detail=result["error"])
|
||||
|
||||
# HTMX 请求 → 返回 HTML 片段
|
||||
if request.headers.get("HX-Request"):
|
||||
star = "★" if result["bookmarked"] else "☆"
|
||||
@@ -66,18 +64,7 @@ def reading_status_update(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""更新阅读状态。"""
|
||||
result = set_reading_status(db, arxiv_id, body.status)
|
||||
|
||||
if "error" in result:
|
||||
if result["error"] == "not_found":
|
||||
raise HTTPException(status_code=404, detail="Paper not found")
|
||||
elif result["error"] == "invalid_status":
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Invalid status. Valid: {result['valid']}",
|
||||
)
|
||||
|
||||
return result
|
||||
return set_reading_status(db, arxiv_id, body.status)
|
||||
|
||||
|
||||
# ── 笔记 ──────────────────────────────────────────────────────────────
|
||||
@@ -88,16 +75,11 @@ def note_get(arxiv_id: str, db: Session = Depends(get_db)):
|
||||
"""获取笔记。"""
|
||||
result = get_note(db, arxiv_id)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail="Paper not found")
|
||||
raise NotFoundError(f"Paper not found: {arxiv_id}")
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/note/{arxiv_id}")
|
||||
def note_save(arxiv_id: str, body: NoteRequest, db: Session = Depends(get_db)):
|
||||
"""保存笔记。"""
|
||||
result = save_note(db, arxiv_id, body.content)
|
||||
|
||||
if "error" in result:
|
||||
raise HTTPException(status_code=404, detail=result["error"])
|
||||
|
||||
return result
|
||||
return save_note(db, arxiv_id, body.content)
|
||||
|
||||
+94
-12
@@ -9,10 +9,18 @@ from sqlalchemy import func, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import CrawlLog, Paper, SummaryState, TaskLock
|
||||
from app.models import CrawlLog, Paper, PaperTag, SummaryState, SummaryStatus, TaskLock
|
||||
from app.services.scheduler import get_scheduler
|
||||
from app.utils import PAPERS_DIR, TMP_DIR
|
||||
|
||||
# admin_papers 排序映射
|
||||
SORT_MAP = {
|
||||
"date_desc": Paper.paper_date.desc(),
|
||||
"date_asc": Paper.paper_date.asc(),
|
||||
"upvotes_desc": Paper.upvotes.desc(),
|
||||
"title_asc": Paper.title_en.asc(),
|
||||
}
|
||||
|
||||
|
||||
def _dir_size(path: Path) -> int:
|
||||
"""递归计算目录总字节数。"""
|
||||
@@ -52,7 +60,11 @@ def get_admin_stats(db: Session) -> dict:
|
||||
status_counts = {row[0]: row[1] for row in summary_rows}
|
||||
|
||||
# ── 存储概况 ──────────────────────────────────────────────────────
|
||||
db_size = _fmt_size(settings.db_path.stat().st_size) if settings.db_path.exists() else "0 B"
|
||||
db_size = (
|
||||
_fmt_size(settings.db_path.stat().st_size)
|
||||
if settings.db_path.exists()
|
||||
else "0 B"
|
||||
)
|
||||
papers_size = _fmt_size(_dir_size(PAPERS_DIR))
|
||||
tmp_size = _fmt_size(_dir_size(TMP_DIR))
|
||||
|
||||
@@ -68,22 +80,14 @@ def get_admin_stats(db: Session) -> dict:
|
||||
|
||||
# ── 最近日志(5 条) ──────────────────────────────────────────────
|
||||
recent_logs = (
|
||||
db.execute(
|
||||
select(CrawlLog)
|
||||
.order_by(CrawlLog.started_at.desc())
|
||||
.limit(5)
|
||||
)
|
||||
db.execute(select(CrawlLog).order_by(CrawlLog.started_at.desc()).limit(5))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
# ── 活跃锁 ────────────────────────────────────────────────────────
|
||||
active_locks = (
|
||||
db.execute(
|
||||
select(TaskLock).where(TaskLock.status == "running")
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
db.execute(select(TaskLock).where(TaskLock.status == "running")).scalars().all()
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -108,3 +112,81 @@ def get_admin_stats(db: Session) -> dict:
|
||||
"active_locks": active_locks,
|
||||
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
|
||||
}
|
||||
|
||||
|
||||
def query_papers(
|
||||
db: Session,
|
||||
*,
|
||||
q: str = "",
|
||||
date_from: str | None = None,
|
||||
date_to: str | None = None,
|
||||
tag: str = "",
|
||||
summary_status: str = "all",
|
||||
sort: str = "date_desc",
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
) -> tuple[list[Paper], int, dict[str, str]]:
|
||||
"""论文管理查询 — 构建过滤、排序、分页。
|
||||
|
||||
Returns:
|
||||
(papers, total, statuses) — 论文列表、总数、{arxiv_id: summary_status}
|
||||
"""
|
||||
query = select(Paper)
|
||||
|
||||
# 搜索
|
||||
if q.strip():
|
||||
query = query.where(
|
||||
Paper.title_en.ilike(f"%{q}%")
|
||||
| Paper.title_zh.ilike(f"%{q}%")
|
||||
| Paper.abstract.ilike(f"%{q}%")
|
||||
)
|
||||
|
||||
# 日期范围
|
||||
if date_from:
|
||||
query = query.where(Paper.paper_date >= date_from)
|
||||
if date_to:
|
||||
query = query.where(Paper.paper_date <= date_to)
|
||||
|
||||
# 标签筛选
|
||||
if tag:
|
||||
query = query.join(PaperTag, PaperTag.paper_id == Paper.id).where(
|
||||
PaperTag.tag == tag
|
||||
)
|
||||
|
||||
# 总结状态筛选
|
||||
if summary_status != "all":
|
||||
if summary_status == "none":
|
||||
query = query.outerjoin(
|
||||
SummaryStatus, SummaryStatus.paper_id == Paper.id
|
||||
).where(SummaryStatus.paper_id == None) # noqa: E711
|
||||
else:
|
||||
query = query.join(SummaryStatus, SummaryStatus.paper_id == Paper.id).where(
|
||||
SummaryStatus.status == summary_status
|
||||
)
|
||||
|
||||
# 排序
|
||||
order = SORT_MAP.get(sort, Paper.paper_date.desc())
|
||||
query = query.order_by(order)
|
||||
|
||||
# 计数
|
||||
total = db.scalar(select(func.count()).select_from(query.subquery()))
|
||||
|
||||
# 分页
|
||||
papers = (
|
||||
db.execute(query.offset((page - 1) * per_page).limit(per_page)).scalars().all()
|
||||
)
|
||||
|
||||
# 每篇论文的总结状态
|
||||
paper_ids = [p.id for p in papers]
|
||||
statuses: dict[str, str] = {}
|
||||
if paper_ids:
|
||||
rows = db.execute(
|
||||
select(SummaryStatus.paper_id, SummaryStatus.status).where(
|
||||
SummaryStatus.paper_id.in_(paper_ids)
|
||||
)
|
||||
).all()
|
||||
paper_id_to_arxiv = {p.id: p.arxiv_id for p in papers}
|
||||
for pid, st in rows:
|
||||
statuses[paper_id_to_arxiv.get(pid, "")] = st
|
||||
|
||||
return papers, total or 0, statuses
|
||||
|
||||
@@ -207,11 +207,14 @@ async def delete_papers_by_date_range(
|
||||
completed_at=utc_now(),
|
||||
papers_found=total,
|
||||
papers_new=deleted,
|
||||
details_json=json.dumps({
|
||||
"total_before": total,
|
||||
"deleted": deleted,
|
||||
"failed": len(failed_items),
|
||||
}, ensure_ascii=False),
|
||||
details_json=json.dumps(
|
||||
{
|
||||
"total_before": total,
|
||||
"deleted": deleted,
|
||||
"failed": len(failed_items),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
error=job_error,
|
||||
)
|
||||
db.add(log_entry)
|
||||
|
||||
@@ -189,11 +189,15 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
paper = db.execute(
|
||||
select(Paper)
|
||||
.where(Paper.arxiv_id == paper_id)
|
||||
.options(joinedload(Paper.tags), joinedload(Paper.summary))
|
||||
).unique().scalar_one_or_none()
|
||||
paper = (
|
||||
db.execute(
|
||||
select(Paper)
|
||||
.where(Paper.arxiv_id == paper_id)
|
||||
.options(joinedload(Paper.tags), joinedload(Paper.summary))
|
||||
)
|
||||
.unique()
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
if not paper:
|
||||
logger.warning("Paper %s not found for indexing", paper_id)
|
||||
return False
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
"""PicoDet-S_layout_3cls 布局检测 — 纯 ONNX Runtime 推理.
|
||||
|
||||
用 onnxruntime 加载导出好的 ONNX 模型,检测 PDF 页面中的 figure / table 区域。
|
||||
模型自带 NMS + GFL decode,输出即为后处理完毕的检测框。
|
||||
|
||||
输入:
|
||||
image: (1, 3, 480, 480) float32 — ImageNet 标准化后的图片
|
||||
scale_factor: (1, 2) float32 — [y_scale, x_scale],用于坐标还原
|
||||
|
||||
输出:
|
||||
fetch_name_0: (N, 6) float32 — [xmin, ymin, xmax, ymax, score, class_id]
|
||||
fetch_name_1: (1,) int32 — 有效框数量 N
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import pymupdf
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 模型输入尺寸
|
||||
_MODEL_SIZE = 480
|
||||
# ImageNet normalize
|
||||
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||||
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||
# PicoDet label → 内部 boxclass
|
||||
_LABEL_MAP: dict[int, str] = {
|
||||
0: "picture", # PicoDet "image" → "picture"
|
||||
1: "table",
|
||||
# 2: seal — 忽略
|
||||
}
|
||||
# 最小 bbox 尺寸(PDF 点)
|
||||
_MIN_BOX_SIZE = 20
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayoutBox:
|
||||
"""检测到的布局区域,兼容现有 _process_page 代码。"""
|
||||
|
||||
x0: float
|
||||
y0: float
|
||||
x1: float
|
||||
y1: float
|
||||
boxclass: str # "picture" | "table"
|
||||
|
||||
|
||||
class _LayoutDetector:
|
||||
"""单例:管理 ONNX InferenceSession 生命周期。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._session: ort.InferenceSession | None = None
|
||||
|
||||
def _init_session(self) -> ort.InferenceSession:
|
||||
if self._session is not None:
|
||||
return self._session
|
||||
|
||||
model_path = Path(settings.LAYOUT_MODEL_PATH)
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Layout model not found: {model_path}. "
|
||||
"Run scripts/export_picodet_onnx.py first."
|
||||
)
|
||||
|
||||
logger.info("Loading ONNX layout model: %s", model_path)
|
||||
self._session = ort.InferenceSession(
|
||||
str(model_path), providers=["CPUExecutionProvider"]
|
||||
)
|
||||
logger.info("ONNX layout model loaded")
|
||||
return self._session
|
||||
|
||||
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||||
"""检测单页 PDF 的 figure / table 区域。
|
||||
|
||||
流程:
|
||||
1. pymupdf 以 480×480 渲染页面
|
||||
2. ImageNet normalize → NCHW
|
||||
3. ONNX 推理 → 得到已解码+NMS 的检测框
|
||||
4. 像素坐标 → PDF 点坐标
|
||||
5. 过滤 seal 类和低置信度框
|
||||
|
||||
Args:
|
||||
page: pymupdf Page 对象
|
||||
|
||||
Returns:
|
||||
LayoutBox 列表,坐标为 PDF 点
|
||||
"""
|
||||
session = self._init_session()
|
||||
|
||||
page_w = page.rect.width
|
||||
page_h = page.rect.height
|
||||
|
||||
# 1. 渲染页面到 _MODEL_SIZE × _MODEL_SIZE
|
||||
zoom_x = _MODEL_SIZE / page_w
|
||||
zoom_y = _MODEL_SIZE / page_h
|
||||
mat = pymupdf.Matrix(zoom_x, zoom_y)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
# 2. 预处理
|
||||
img = (
|
||||
np.frombuffer(pix.samples, dtype=np.uint8)
|
||||
.reshape(pix.height, pix.width, pix.n)
|
||||
.astype(np.float32)
|
||||
/ 255.0
|
||||
)
|
||||
# 去掉 alpha 通道(如有)
|
||||
if img.shape[2] == 4:
|
||||
img = img[:, :, :3]
|
||||
img = (img - _MEAN) / _STD
|
||||
img = img.transpose(2, 0, 1)[np.newaxis] # (1, 3, H, W)
|
||||
|
||||
# scale_factor 用于坐标还原(模型内部可能用)
|
||||
scale_factor = np.array([[1.0, 1.0]], dtype=np.float32)
|
||||
|
||||
# 3. 推理
|
||||
input_names = [i.name for i in session.get_inputs()]
|
||||
feed = {input_names[0]: img}
|
||||
if len(input_names) > 1:
|
||||
feed[input_names[1]] = scale_factor
|
||||
|
||||
outputs = session.run(None, feed)
|
||||
boxes_raw = outputs[0] # (N, 6): [class_id, score, xmin, ymin, xmax, ymax]
|
||||
num_boxes = int(outputs[1][0]) # 有效框数
|
||||
|
||||
if num_boxes == 0:
|
||||
return []
|
||||
|
||||
# 4. 像素 → PDF 点坐标
|
||||
sx = page_w / _MODEL_SIZE
|
||||
sy = page_h / _MODEL_SIZE
|
||||
|
||||
result: list[LayoutBox] = []
|
||||
for i in range(min(num_boxes, len(boxes_raw))):
|
||||
cls_id, score, xmin, ymin, xmax, ymax = boxes_raw[i]
|
||||
cls_id = int(cls_id)
|
||||
|
||||
# 跳过 seal 类和低置信度
|
||||
if cls_id not in _LABEL_MAP:
|
||||
continue
|
||||
if score < settings.LAYOUT_THRESHOLD:
|
||||
continue
|
||||
|
||||
x0, y0 = xmin * sx, ymin * sy
|
||||
x1, y1 = xmax * sx, ymax * sy
|
||||
|
||||
# 跳过极小区域
|
||||
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
|
||||
continue
|
||||
|
||||
result.append(
|
||||
LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=_LABEL_MAP[cls_id])
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 模块级单例
|
||||
_detector = _LayoutDetector()
|
||||
|
||||
|
||||
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
|
||||
"""检测 PDF 页面中的 figure / table 区域。
|
||||
|
||||
Returns:
|
||||
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table。
|
||||
"""
|
||||
return _detector.detect_page(page)
|
||||
@@ -9,6 +9,7 @@ from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
from app.config import settings
|
||||
from app.utils import PAPERS_DIR, TMP_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -51,6 +52,14 @@ def _get_session() -> requests.Session:
|
||||
return _http_session
|
||||
|
||||
|
||||
def close_http_session() -> None:
|
||||
"""关闭全局 HTTP Session,供应用 shutdown 时调用。"""
|
||||
global _http_session
|
||||
if _http_session is not None:
|
||||
_http_session.close()
|
||||
_http_session = None
|
||||
|
||||
|
||||
async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
|
||||
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
|
||||
if not pdf_url:
|
||||
@@ -62,10 +71,16 @@ async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
|
||||
|
||||
try:
|
||||
session = _get_session()
|
||||
resp = session.get(pdf_url, timeout=120, allow_redirects=True)
|
||||
resp = session.get(pdf_url, timeout=settings.PDF_DOWNLOAD_TIMEOUT, allow_redirects=True)
|
||||
resp.raise_for_status()
|
||||
dest.write_bytes(resp.content)
|
||||
except Exception as exc:
|
||||
# 清理残留的部分文件
|
||||
if dest.exists():
|
||||
try:
|
||||
dest.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
|
||||
|
||||
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
|
||||
|
||||
+383
-462
@@ -1,12 +1,12 @@
|
||||
"""PDF 图片与表格提取 — 基于 pymupdf4llm layout analysis。
|
||||
"""PDF 图片与表格提取 — 两阶段流水线。
|
||||
|
||||
用 pymupdf4llm 的 layout analysis 检测 table / picture 区域,
|
||||
再通过 caption 文字匹配确定 Figure/Table 编号,渲染为 JPEG。
|
||||
Phase 1: PicoDet-S_layout_3cls 检测 figure/table 区域 → 渲染为 JPEG(通用标签)
|
||||
Phase 2: 用 LLM summary 的 figures[].id 在 PDF 中搜索定位 → 匹配到 box → 重命名
|
||||
|
||||
相比旧方案(caption 正则 + pdfplumber/find_tables/文本块扫描三套策略):
|
||||
- layout analysis 直接给出区域 bbox,不存在相邻表格互相侵入的问题
|
||||
- 无需手动调参(最大高度、间隙阈值等)
|
||||
- 页面级 caption 匹配:每个 caption 只分配给最近的 box,避免上下相邻表格抢夺同一个 caption
|
||||
相比旧方案(正则匹配 caption):
|
||||
- 不再依赖正则,用 LLM 输出的 ID 直接搜索 PDF 文本
|
||||
- page.search_for() 精确搜索 + 空间距离过滤,避免正文引用误匹配
|
||||
- 通用标签兜底,LLM 没提到的图表不会被丢弃
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -17,44 +17,30 @@ import re
|
||||
from pathlib import Path
|
||||
|
||||
import pymupdf
|
||||
import pymupdf4llm.helpers.document_layout as dl
|
||||
|
||||
from app.services.layout_detector import LayoutBox, detect_page_layout
|
||||
from app.services.pdf_downloader import paper_dir
|
||||
from app.utils import TMP_DIR
|
||||
from app.utils import PAPERS_DIR, TMP_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Caption 正则 ───────────────────────────────────────────────────────
|
||||
|
||||
# 用于从 caption 文字中提取 Figure/Table 编号
|
||||
_FIGURE_CAPTION_RE = re.compile(
|
||||
r"^(?:Fig\.?|Figure)\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_TABLE_CAPTION_RE = re.compile(
|
||||
r"^Table\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# caption 与 table/picture 的最大匹配距离(点)
|
||||
_CAPTION_MATCH_DISTANCE = 100
|
||||
# 截图区域的外边距
|
||||
# 截图区域的外边距(单位: pt)
|
||||
_REGION_PADDING = 5
|
||||
# 3x 渲染,保证清晰度
|
||||
# 渲染倍率(3x 保证清晰度)
|
||||
_RENDER_ZOOM = 3
|
||||
# 相邻 box 聚类间距(点)— 同一 figure/table 的碎片间距通常 < 15pt
|
||||
# 相邻 box 聚类间距(单位: pt)— 同一 figure/table 的碎片间距通常 < 15pt
|
||||
_CLUSTER_GAP = 15
|
||||
# 最小 bbox 面积(单位: pt²)— 过滤 icon/logo 等微小误检
|
||||
_MIN_BOX_AREA = 2000
|
||||
# Phase 2: 搜索文本到 box 的最大匹配距离(单位: pt)
|
||||
_LABEL_MATCH_DISTANCE = 100
|
||||
|
||||
|
||||
# ── Box 聚类 ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _BoxCluster:
|
||||
"""合并后的布局区域(由一个或多个相邻 LayoutBox 组成)。
|
||||
|
||||
pymupdf4llm 有时将一个大图拆成多个小 picture box(如视频帧网格),
|
||||
聚类后用整体 bbox 作为渲染区域。
|
||||
"""
|
||||
"""合并后的布局区域(由一个或多个相邻 LayoutBox 组成)。"""
|
||||
|
||||
__slots__ = ("x0", "y0", "x1", "y1", "boxclass")
|
||||
|
||||
@@ -63,17 +49,12 @@ class _BoxCluster:
|
||||
self.y0 = min(b.y0 for b in boxes)
|
||||
self.x1 = max(b.x1 for b in boxes)
|
||||
self.y1 = max(b.y1 for b in boxes)
|
||||
# table-fallback 归一化为 table(layout model 检测到表格但无法提取结构)
|
||||
raw = boxes[0].boxclass
|
||||
self.boxclass = "table" if raw == "table-fallback" else raw
|
||||
|
||||
|
||||
def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
|
||||
"""将相邻的同类型 box 合并为聚类。
|
||||
|
||||
用 union-find 将间距 ≤ gap 的同类型 box 归为一组,
|
||||
每组生成一个 _BoxCluster(整体 bbox)。
|
||||
"""
|
||||
"""将相邻的同类型 box 合并为聚类。"""
|
||||
if not boxes:
|
||||
return []
|
||||
|
||||
@@ -111,242 +92,58 @@ def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
|
||||
return [_BoxCluster(members) for members in groups.values()]
|
||||
|
||||
|
||||
# ── 页面级 Caption 查找与匹配 ──────────────────────────────────────────
|
||||
# ── Phase 1: 检测 + 渲染 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _find_page_captions(page) -> list[dict]:
|
||||
"""查找页面上所有 Figure/Table caption 文字块。"""
|
||||
blocks = page.get_text("blocks")
|
||||
captions = []
|
||||
for b in blocks:
|
||||
if len(b) < 5:
|
||||
continue
|
||||
bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3]
|
||||
text = str(b[4]).strip()
|
||||
first_line = text.split("\n")[0].strip()
|
||||
|
||||
cap_type = None
|
||||
m = _TABLE_CAPTION_RE.match(first_line)
|
||||
if m:
|
||||
cap_type = "table"
|
||||
else:
|
||||
m = _FIGURE_CAPTION_RE.match(first_line)
|
||||
if m:
|
||||
cap_type = "figure"
|
||||
if m is None:
|
||||
continue
|
||||
|
||||
captions.append(
|
||||
{
|
||||
"label": f"{'Table' if cap_type == 'table' else 'Figure'} {m.group(1)}",
|
||||
"type": cap_type,
|
||||
"caption_text": text,
|
||||
"caption_y0": by0,
|
||||
"caption_y1": by1,
|
||||
"caption_x0": bx0,
|
||||
"caption_x1": bx1,
|
||||
}
|
||||
)
|
||||
return captions
|
||||
|
||||
|
||||
def _vertical_distance(cap_y0, cap_y1, box_y0, box_y1) -> float | None:
|
||||
"""计算 caption 到 box 的垂直距离。不邻接时返回 None。
|
||||
|
||||
三种情况:caption 完全在 box 上方、完全在下方、与 box 有垂直重叠。
|
||||
重叠(含部分溢出)视为 distance=0,确保 caption 延伸到 box 边界外时不会丢失。
|
||||
"""
|
||||
# Caption 完全在 box 上方
|
||||
if cap_y1 <= box_y0:
|
||||
dist = box_y0 - cap_y1
|
||||
return dist if dist <= _CAPTION_MATCH_DISTANCE else None
|
||||
# Caption 完全在 box 下方
|
||||
if cap_y0 >= box_y1:
|
||||
dist = cap_y0 - box_y1
|
||||
return dist if dist <= _CAPTION_MATCH_DISTANCE else None
|
||||
# Caption 与 box 有垂直重叠(内部、部分溢出都算)→ 距离 0
|
||||
return 0
|
||||
|
||||
|
||||
def _same_column(cap: dict, box, page_width: float) -> bool:
|
||||
"""判断 caption 和 box 是否在同一列。
|
||||
|
||||
双栏论文中左右栏间距有限,简单的水平重叠检查会跨列匹配。
|
||||
策略:用中心 X 坐标判断各自在哪半边,只有同半边才算同列。
|
||||
跨栏图表(caption 或 box 宽度 >65% 页宽)不受此限制。
|
||||
"""
|
||||
cap_w = cap["caption_x1"] - cap["caption_x0"]
|
||||
box_w = box.x1 - box.x0
|
||||
|
||||
# 跨栏元素:宽度超过页面的 65%
|
||||
if cap_w > page_width * 0.65 or box_w > page_width * 0.65:
|
||||
return True
|
||||
|
||||
cap_cx = (cap["caption_x0"] + cap["caption_x1"]) / 2
|
||||
box_cx = (box.x0 + box.x1) / 2
|
||||
mid = page_width / 2
|
||||
|
||||
# 同在左半边或同在右半边
|
||||
return (cap_cx < mid) == (box_cx < mid)
|
||||
|
||||
|
||||
def _match_captions_to_boxes(
|
||||
page_boxes: list, captions: list[dict], page_width: float
|
||||
) -> list[tuple[list[int], list[dict]]]:
|
||||
"""将 caption 分配给 box,允许一个 caption 匹配多个同类型 box。
|
||||
|
||||
典型场景:
|
||||
- Figure 由左右两个 picture box 组成,caption 同时靠近两者
|
||||
- Table 的视觉内容被 layout analysis 误分类为 picture,需要跨类型匹配
|
||||
|
||||
Returns:
|
||||
[(box_indices, captions), ...] 每组是一个独立的渲染任务
|
||||
"""
|
||||
# 每个 caption 找到所有距离在阈值内的 box
|
||||
# 优先匹配同类型;如果找不到,再匹配任意 table/picture box
|
||||
cap_to_boxes: dict[int, list[tuple[int, float]]] = {}
|
||||
|
||||
for ci, cap in enumerate(captions):
|
||||
same_type: list[tuple[int, float]] = []
|
||||
any_type: list[tuple[int, float]] = []
|
||||
expected = "table" if cap["type"] == "table" else "picture"
|
||||
|
||||
for bi, box in enumerate(page_boxes):
|
||||
# 列感知:双栏论文中只匹配同栏的 box
|
||||
if not _same_column(cap, box, page_width):
|
||||
continue
|
||||
# 水平重叠检查(同列内仍需有重叠)
|
||||
if not (
|
||||
cap["caption_x1"] > box.x0 - 5 and cap["caption_x0"] < box.x1 + 5
|
||||
):
|
||||
continue
|
||||
dist = _vertical_distance(
|
||||
cap["caption_y0"], cap["caption_y1"], box.y0, box.y1
|
||||
)
|
||||
if dist is None:
|
||||
continue
|
||||
entry = (bi, dist)
|
||||
any_type.append(entry)
|
||||
if box.boxclass == expected:
|
||||
same_type.append(entry)
|
||||
|
||||
# 优先用同类型匹配;没有时回退到任意类型;都没有则跳过
|
||||
if same_type:
|
||||
cap_to_boxes[ci] = same_type
|
||||
elif any_type:
|
||||
cap_to_boxes[ci] = any_type
|
||||
# else: 该 caption 无匹配 box,不加入 cap_to_boxes
|
||||
|
||||
# 每个 caption → 最近的 box(用于分组),但记录所有匹配的 box
|
||||
cap_primary: dict[int, int] = {} # caption → primary box index
|
||||
cap_all_boxes: dict[int, list[int]] = {} # caption → all matched box indices
|
||||
for ci, matches in cap_to_boxes.items():
|
||||
matches.sort(key=lambda x: x[1])
|
||||
cap_primary[ci] = matches[0][0]
|
||||
# 所有距离最近的同组 box(距离差 < 20pt 视为同一组)
|
||||
best_dist = matches[0][1]
|
||||
cap_all_boxes[ci] = [bi for bi, d in matches if d <= best_dist + 20]
|
||||
|
||||
# 按 primary box 分组
|
||||
box_to_caps: dict[int, list[int]] = {}
|
||||
for ci, bi in cap_primary.items():
|
||||
box_to_caps.setdefault(bi, []).append(ci)
|
||||
|
||||
# 构建渲染组:每个 caption 独立成组(共享 box 但各自渲染)
|
||||
# 同类型同 label 的 caption 会合并;不同类型则分开
|
||||
used_captions: set[int] = set()
|
||||
groups: list[tuple[list[int], list[dict]]] = []
|
||||
|
||||
for bi in sorted(box_to_caps.keys()):
|
||||
cis = box_to_caps[bi]
|
||||
for ci in cis:
|
||||
if ci in used_captions:
|
||||
continue
|
||||
used_captions.add(ci)
|
||||
|
||||
all_box_indices = set(cap_all_boxes.get(ci, [bi]))
|
||||
# 只合并同 label 的 caption(同 figure/table 的重复 caption)
|
||||
merged_captions = [captions[ci]]
|
||||
for other_bi in all_box_indices:
|
||||
if other_bi in box_to_caps:
|
||||
for other_ci in box_to_caps[other_bi]:
|
||||
if other_ci not in used_captions:
|
||||
other_cap = captions[other_ci]
|
||||
if other_cap["label"] == captions[ci]["label"]:
|
||||
used_captions.add(other_ci)
|
||||
merged_captions.append(other_cap)
|
||||
groups.append((sorted(all_box_indices), merged_captions))
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
# ── 单页处理 ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _render_and_save(
|
||||
def _render_box(
|
||||
page,
|
||||
clip: pymupdf.Rect,
|
||||
box: _BoxCluster,
|
||||
images_dest: Path,
|
||||
manifest: dict,
|
||||
label: str,
|
||||
filename: str,
|
||||
cap_type: str,
|
||||
caption_text: str,
|
||||
page_num_1based: int,
|
||||
arxiv_id: str,
|
||||
page_num: int,
|
||||
) -> bool:
|
||||
"""渲染页面区域并保存 JPEG,写入 manifest。成功返回 True。"""
|
||||
"""渲染单个 box 区域并保存 JPEG,成功返回 True。"""
|
||||
page_width = page.rect.width
|
||||
clip = pymupdf.Rect(
|
||||
max(0, box.x0 - _REGION_PADDING),
|
||||
max(0, box.y0 - _REGION_PADDING),
|
||||
min(page_width, box.x1 + _REGION_PADDING),
|
||||
box.y1 + _REGION_PADDING,
|
||||
)
|
||||
mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM)
|
||||
try:
|
||||
pix = page.get_pixmap(matrix=mat, clip=clip)
|
||||
except Exception:
|
||||
logger.debug("Failed to render %s for %s", label, arxiv_id)
|
||||
return False
|
||||
|
||||
filename = f"{label.replace(' ', '_').lower()}.jpg"
|
||||
(images_dest / filename).write_bytes(pix.tobytes("jpeg"))
|
||||
|
||||
manifest[filename] = {
|
||||
"page": page_num_1based,
|
||||
"type": cap_type,
|
||||
"label": label,
|
||||
"caption_text": caption_text[:200] if caption_text else "",
|
||||
"figures" if cap_type == "figure" else "tables": [label],
|
||||
}
|
||||
logger.debug(
|
||||
"Rendered %s: page %d, region (%.0f,%.0f)-(%.0f,%.0f) → %s",
|
||||
label,
|
||||
page_num_1based,
|
||||
clip.x0,
|
||||
clip.y0,
|
||||
clip.x1,
|
||||
clip.y1,
|
||||
filename,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _process_page(
|
||||
doc,
|
||||
page_idx: int,
|
||||
page_layout,
|
||||
page_boxes: list[LayoutBox],
|
||||
images_dest: Path,
|
||||
manifest: dict,
|
||||
seen_labels: set,
|
||||
arxiv_id: str,
|
||||
) -> int:
|
||||
"""处理单页:caption 匹配 + orphan 兜底,返回本页提取数量。"""
|
||||
"""处理单页:检测 → 聚类 → 渲染,全部用通用标签。"""
|
||||
page = doc[page_idx]
|
||||
page_width = page.rect.width
|
||||
page_num = page_idx + 1
|
||||
orphan_fig_counter = 0
|
||||
orphan_tbl_counter = 0
|
||||
fig_counter = 0
|
||||
tbl_counter = 0
|
||||
|
||||
# 收集本页的 table/picture box(跳过极小区域)
|
||||
raw_boxes = []
|
||||
for box in page_layout.boxes:
|
||||
for box in page_boxes:
|
||||
if box.boxclass not in ("table", "table-fallback", "picture"):
|
||||
continue
|
||||
if (box.x1 - box.x0) < 20 or (box.y1 - box.y0) < 20:
|
||||
w = box.x1 - box.x0
|
||||
h = box.y1 - box.y0
|
||||
if w < 20 or h < 20 or w * h < _MIN_BOX_AREA:
|
||||
continue
|
||||
raw_boxes.append(box)
|
||||
|
||||
@@ -354,153 +151,48 @@ def _process_page(
|
||||
return 0
|
||||
|
||||
# 聚类:将同一 figure/table 的碎片 box 合并
|
||||
page_boxes = _cluster_boxes(raw_boxes)
|
||||
clusters = _cluster_boxes(raw_boxes)
|
||||
|
||||
# 页面级匹配:查找所有 caption,分配给 box
|
||||
captions = _find_page_captions(page)
|
||||
groups = _match_captions_to_boxes(page_boxes, captions, page_width)
|
||||
|
||||
# 只合并同 label 的 group(同一个 figure/table 的重复 caption)
|
||||
# 不同 label 的 group 即使共享 box 也不合并(如 Figure 7 和 Figure 8),
|
||||
# 渲染时用 caption 位置切割区域
|
||||
_merged_groups: set[int] = set()
|
||||
merged_groups: list[tuple[list[int], list[dict]]] = []
|
||||
for gi, (box_indices, caps) in enumerate(groups):
|
||||
if gi in _merged_groups:
|
||||
continue
|
||||
this_labels = {c["label"] for c in caps}
|
||||
all_box_set = set(box_indices)
|
||||
merge_targets = {gi}
|
||||
for other_gi, (other_bi, other_caps) in enumerate(groups):
|
||||
if other_gi <= gi or other_gi in _merged_groups:
|
||||
continue
|
||||
other_labels = {c["label"] for c in other_caps}
|
||||
# 只在 label 有交集时合并(同一个 figure/table)
|
||||
if this_labels & other_labels and all_box_set & set(other_bi):
|
||||
merge_targets.add(other_gi)
|
||||
all_box_set |= set(other_bi)
|
||||
all_caps = []
|
||||
for mgi in sorted(merge_targets):
|
||||
_merged_groups.add(mgi)
|
||||
all_caps.extend(groups[mgi][1])
|
||||
merged_groups.append((sorted(all_box_set), all_caps))
|
||||
groups = merged_groups
|
||||
|
||||
# ── 阶段 1:渲染有 caption 匹配的图/表 ──
|
||||
matched_box_indices: set[int] = set()
|
||||
extracted = 0
|
||||
|
||||
for box_indices, caps in groups:
|
||||
matched_box_indices.update(box_indices)
|
||||
|
||||
# 去重同一 label,跳过已处理的
|
||||
unique_caps = []
|
||||
for cap in caps:
|
||||
if cap["label"] not in seen_labels:
|
||||
seen_labels.add(cap["label"])
|
||||
unique_caps.append(cap)
|
||||
if not unique_caps:
|
||||
continue
|
||||
|
||||
# 合并所有关联 box 的 bbox
|
||||
bx0 = min(page_boxes[i].x0 for i in box_indices)
|
||||
by0 = min(page_boxes[i].y0 for i in box_indices)
|
||||
bx1 = max(page_boxes[i].x1 for i in box_indices)
|
||||
by1 = max(page_boxes[i].y1 for i in box_indices)
|
||||
|
||||
# 渲染区域:box + caption
|
||||
all_cap_y0 = min(c["caption_y0"] for c in unique_caps)
|
||||
all_cap_y1 = max(c["caption_y1"] for c in unique_caps)
|
||||
all_cap_x0 = min(c["caption_x0"] for c in unique_caps)
|
||||
all_cap_x1 = max(c["caption_x1"] for c in unique_caps)
|
||||
|
||||
top = max(0, min(by0, all_cap_y0) - _REGION_PADDING)
|
||||
bottom = max(by1, all_cap_y1) + _REGION_PADDING
|
||||
rx0 = max(0, min(bx0, all_cap_x0) - _REGION_PADDING)
|
||||
rx1 = min(page_width, max(bx1, all_cap_x1) + _REGION_PADDING)
|
||||
|
||||
clip = pymupdf.Rect(rx0, top, rx1, bottom)
|
||||
# 多个 caption 可能共享同一区域(如 subfigure),只需渲染一次
|
||||
jpeg_bytes = None
|
||||
for cap in unique_caps:
|
||||
if jpeg_bytes is None:
|
||||
if not _render_and_save(
|
||||
page,
|
||||
clip,
|
||||
images_dest,
|
||||
manifest,
|
||||
cap["label"],
|
||||
cap["type"],
|
||||
cap["caption_text"],
|
||||
page_num,
|
||||
arxiv_id,
|
||||
):
|
||||
break
|
||||
# 读取刚写入的 bytes 供后续同名 caption 复用
|
||||
filename = f"{cap['label'].replace(' ', '_').lower()}.jpg"
|
||||
jpeg_bytes = (images_dest / filename).read_bytes()
|
||||
extracted += 1
|
||||
else:
|
||||
# 同区域的不同 caption(如 subfigure),复用图片
|
||||
filename = f"{cap['label'].replace(' ', '_').lower()}.jpg"
|
||||
(images_dest / filename).write_bytes(jpeg_bytes)
|
||||
cap_preview = cap["caption_text"][:200]
|
||||
manifest[filename] = {
|
||||
"page": page_num,
|
||||
"type": cap["type"],
|
||||
"label": cap["label"],
|
||||
"caption_text": cap_preview,
|
||||
"figures" if cap["type"] == "figure" else "tables": [cap["label"]],
|
||||
}
|
||||
extracted += 1
|
||||
|
||||
# ── 阶段 2:渲染无 caption 匹配的图/表(orphan boxes) ──
|
||||
orphan_indices = set(range(len(page_boxes))) - matched_box_indices
|
||||
for bi in sorted(orphan_indices):
|
||||
box = page_boxes[bi]
|
||||
cap_type = "figure" if box.boxclass == "picture" else "table"
|
||||
for cluster in clusters:
|
||||
cap_type = "figure" if cluster.boxclass == "picture" else "table"
|
||||
|
||||
if cap_type == "figure":
|
||||
orphan_fig_counter += 1
|
||||
label = f"Figure (p{page_num}-{orphan_fig_counter})"
|
||||
fig_counter += 1
|
||||
label = f"Figure (p{page_num}-{fig_counter})"
|
||||
else:
|
||||
orphan_tbl_counter += 1
|
||||
label = f"Table (p{page_num}-{orphan_tbl_counter})"
|
||||
tbl_counter += 1
|
||||
label = f"Table (p{page_num}-{tbl_counter})"
|
||||
|
||||
if label in seen_labels:
|
||||
continue
|
||||
seen_labels.add(label)
|
||||
|
||||
clip = pymupdf.Rect(
|
||||
max(0, box.x0 - _REGION_PADDING),
|
||||
max(0, box.y0 - _REGION_PADDING),
|
||||
min(page_width, box.x1 + _REGION_PADDING),
|
||||
box.y1 + _REGION_PADDING,
|
||||
)
|
||||
if _render_and_save(
|
||||
page,
|
||||
clip,
|
||||
images_dest,
|
||||
manifest,
|
||||
label,
|
||||
cap_type,
|
||||
"",
|
||||
page_num,
|
||||
arxiv_id,
|
||||
):
|
||||
extracted += 1
|
||||
filename = f"{label.replace(' ', '_').lower()}.jpg"
|
||||
if not _render_box(page, cluster, images_dest, filename, cap_type, page_num):
|
||||
continue
|
||||
|
||||
manifest[filename] = {
|
||||
"page": page_num,
|
||||
"type": cap_type,
|
||||
"label": label,
|
||||
"box": [
|
||||
round(float(cluster.x0), 1),
|
||||
round(float(cluster.y0), 1),
|
||||
round(float(cluster.x1), 1),
|
||||
round(float(cluster.y1), 1),
|
||||
],
|
||||
}
|
||||
extracted += 1
|
||||
|
||||
return extracted
|
||||
|
||||
|
||||
# ── 核心提取 ───────────────────────────────────────────────────────────
|
||||
# ── Phase 1 核心入口 ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
|
||||
"""从 PDF 提取 Figure/Table 截图,生成 manifest。
|
||||
|
||||
用 pymupdf4llm layout analysis 检测 table/picture 区域,
|
||||
再通过 caption 文字确定编号,渲染为 JPEG。
|
||||
"""Phase 1: 从 PDF 提取 Figure/Table 截图,生成通用标签的 manifest。
|
||||
|
||||
Args:
|
||||
arxiv_id: 论文 ID
|
||||
@@ -526,45 +218,31 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
|
||||
if (images_dest / "manifest.json").exists():
|
||||
(images_dest / "manifest.json").unlink()
|
||||
|
||||
doc = pymupdf.open(str(pdf_path))
|
||||
with pymupdf.open(str(pdf_path)) as doc:
|
||||
extracted = 0
|
||||
manifest: dict[str, dict] = {}
|
||||
seen_labels: set[str] = set()
|
||||
|
||||
# layout analysis
|
||||
try:
|
||||
parsed = dl.parse_document(
|
||||
doc, filename=str(pdf_path), use_ocr=dl.OCRMode.NEVER
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"pymupdf4llm layout analysis failed for %s", arxiv_id, exc_info=True
|
||||
)
|
||||
doc.close()
|
||||
return 0
|
||||
|
||||
extracted = 0
|
||||
manifest: dict[str, dict] = {}
|
||||
seen_labels: set[str] = set()
|
||||
|
||||
for page_idx, page_layout in enumerate(parsed.pages):
|
||||
try:
|
||||
extracted += _process_page(
|
||||
doc,
|
||||
page_idx,
|
||||
page_layout,
|
||||
images_dest=images_dest,
|
||||
manifest=manifest,
|
||||
seen_labels=seen_labels,
|
||||
arxiv_id=arxiv_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to process page %d for %s",
|
||||
page_idx + 1,
|
||||
arxiv_id,
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
doc.close()
|
||||
for page_idx in range(doc.page_count):
|
||||
try:
|
||||
page_boxes = detect_page_layout(doc[page_idx])
|
||||
extracted += _process_page(
|
||||
doc,
|
||||
page_idx,
|
||||
page_boxes,
|
||||
images_dest=images_dest,
|
||||
manifest=manifest,
|
||||
seen_labels=seen_labels,
|
||||
arxiv_id=arxiv_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to process page %d for %s",
|
||||
page_idx + 1,
|
||||
arxiv_id,
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
# 保存 manifest
|
||||
manifest_path = images_dest / "manifest.json"
|
||||
@@ -580,78 +258,321 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
|
||||
return extracted
|
||||
|
||||
|
||||
# ── 按 summary 过滤 ────────────────────────────────────────────────────
|
||||
# ── Phase 2: 用 summary 的 figures ID 定位并重命名 ─────────────────────
|
||||
|
||||
|
||||
def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
|
||||
"""根据 summary 中的 figures 字段过滤提取的图片/表格。
|
||||
def _distance_text_to_box(rect: pymupdf.Rect, box: list[float]) -> float | None:
|
||||
"""计算搜索到的文本 rect 到 box 的距离。超出阈值返回 None。
|
||||
|
||||
用 manifest.json 中的 label 匹配,保留被 AI 总结引用的图片。
|
||||
判断逻辑:rect 中心与 box 的垂直距离 + 水平重叠检查。
|
||||
"""
|
||||
rect_cx = (rect.x0 + rect.x1) / 2
|
||||
rect_cy = (rect.y0 + rect.y1) / 2
|
||||
bx0, by0, bx1, by1 = box
|
||||
|
||||
# 水平重叠:rect 中心在 box 水平范围内(或接近)
|
||||
if not (bx0 - 20 <= rect_cx <= bx1 + 20):
|
||||
return None
|
||||
|
||||
# 垂直距离
|
||||
if rect_cy < by0:
|
||||
dist = by0 - rect_cy
|
||||
elif rect_cy > by1:
|
||||
dist = rect_cy - by1
|
||||
else:
|
||||
dist = 0
|
||||
|
||||
return dist if dist <= _LABEL_MATCH_DISTANCE else None
|
||||
|
||||
|
||||
def _search_variants(fig_id: str) -> list[str]:
|
||||
"""为 figure/table ID 生成搜索变体。
|
||||
|
||||
"Figure 1" → ["Figure 1", "Fig. 1", "Fig 1"]
|
||||
"Fig. 1" → ["Fig. 1", "Figure 1", "Fig 1"]
|
||||
"Table A1" → ["Table A1"]
|
||||
"""
|
||||
variants = [fig_id]
|
||||
|
||||
m = re.match(r"(Fig\.?|Figure)\s+(\d+.*)", fig_id, re.IGNORECASE)
|
||||
if m:
|
||||
num_part = m.group(2)
|
||||
variants.extend(
|
||||
[
|
||||
f"Figure {num_part}",
|
||||
f"Fig. {num_part}",
|
||||
f"Fig {num_part}",
|
||||
]
|
||||
)
|
||||
|
||||
# 去重保序
|
||||
seen = set()
|
||||
result = []
|
||||
for v in variants:
|
||||
if v not in seen:
|
||||
seen.add(v)
|
||||
result.append(v)
|
||||
return result
|
||||
|
||||
|
||||
def label_images_by_summary(
|
||||
arxiv_id: str,
|
||||
figures: list[dict],
|
||||
pdf_path: Path | None = None,
|
||||
) -> int:
|
||||
"""Phase 2: 用 summary 的 figures ID 在 PDF 中搜索定位,重命名图片。
|
||||
|
||||
对 summary 中的每个 figure/table ID:
|
||||
1. page.search_for(id) 在所有页面搜索文本位置
|
||||
2. 计算搜索位置与 manifest 中 box 坐标的距离
|
||||
3. 最近匹配 → 重命名文件、更新 manifest
|
||||
|
||||
Args:
|
||||
arxiv_id: 论文 ID
|
||||
figures: summary 的 figures 列表,每项含 id/caption/description 等
|
||||
pdf_path: PDF 路径
|
||||
|
||||
Returns:
|
||||
成功重命名的图片数量
|
||||
"""
|
||||
if not figures:
|
||||
return 0
|
||||
|
||||
images_dir = paper_dir(arxiv_id) / "images"
|
||||
manifest_path = images_dir / "manifest.json"
|
||||
|
||||
if not images_dir.exists() or not manifest_path.exists():
|
||||
if pdf_path is None:
|
||||
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
|
||||
if not pdf_path.exists():
|
||||
return 0
|
||||
|
||||
all_files = [
|
||||
f for f in images_dir.iterdir() if f.suffix.lower() in (".png", ".jpg", ".jpeg")
|
||||
]
|
||||
if not all_files:
|
||||
images_dest = paper_dir(arxiv_id) / "images"
|
||||
manifest_path = images_dest / "manifest.json"
|
||||
if not manifest_path.exists():
|
||||
return 0
|
||||
|
||||
manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||||
manifest: dict[str, dict] = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||||
if not manifest:
|
||||
return 0
|
||||
|
||||
# 收集 summary 中引用的所有 Figure/Table ID(归一化)
|
||||
referenced_ids: set[str] = set()
|
||||
for fig in figures:
|
||||
fig_id = fig.get("id", "")
|
||||
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", fig_id, re.IGNORECASE)
|
||||
if m:
|
||||
referenced_ids.add(f"Figure {m.group(1)}")
|
||||
m2 = re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
|
||||
if m2:
|
||||
referenced_ids.add(f"Table {m2.group(1)}")
|
||||
# 构建候选列表:只对通用标签的条目做匹配
|
||||
candidates: dict[str, dict] = {} # filename → {page, box, ...}
|
||||
for fname, info in manifest.items():
|
||||
if "(p" in info.get("label", ""):
|
||||
candidates[fname] = info
|
||||
|
||||
if not referenced_ids:
|
||||
logger.warning("No valid figure/table IDs in summary for %s", arxiv_id)
|
||||
return len(all_files)
|
||||
if not candidates:
|
||||
return 0
|
||||
|
||||
# 根据 manifest 的 label 字段匹配
|
||||
keep_filenames: set[str] = set()
|
||||
for filename, info in manifest.items():
|
||||
label = info.get("label", "")
|
||||
if label in referenced_ids:
|
||||
keep_filenames.add(filename)
|
||||
with pymupdf.open(str(pdf_path)) as doc:
|
||||
# 收集所有匹配候选:(fig_id, fig_index, filename, distance)
|
||||
matches: list[tuple[str, int, str, float]] = []
|
||||
|
||||
for fig_idx, fig in enumerate(figures):
|
||||
fig_id = fig.get("id", "")
|
||||
if not fig_id:
|
||||
continue
|
||||
|
||||
# 生成搜索变体:Figure 1 / Fig. 1 / Fig 1 等
|
||||
search_terms = _search_variants(fig_id)
|
||||
|
||||
# 在所有页面搜索该文本(含变体)
|
||||
search_hits: list[tuple[int, pymupdf.Rect]] = [] # (page_num_1based, Rect)
|
||||
for page_idx in range(doc.page_count):
|
||||
page = doc[page_idx]
|
||||
seen_rects: set[tuple[float, float]] = set()
|
||||
for term in search_terms:
|
||||
for r in page.search_for(term):
|
||||
key = (round(r.x0, 1), round(r.y0, 1))
|
||||
if key not in seen_rects:
|
||||
seen_rects.add(key)
|
||||
search_hits.append((page_idx + 1, r))
|
||||
|
||||
if not search_hits:
|
||||
continue
|
||||
|
||||
# 对每个候选 manifest 条目,找最近的搜索命中
|
||||
for fname, info in candidates.items():
|
||||
box = info.get("box")
|
||||
if not box:
|
||||
continue
|
||||
manifest_page = info.get("page", 0)
|
||||
|
||||
best_dist: float | None = None
|
||||
for hit_page, rect in search_hits:
|
||||
# 只匹配同页面
|
||||
if hit_page != manifest_page:
|
||||
continue
|
||||
dist = _distance_text_to_box(rect, box)
|
||||
if dist is not None and (best_dist is None or dist < best_dist):
|
||||
best_dist = dist
|
||||
|
||||
if best_dist is not None:
|
||||
matches.append((fig_id, fig_idx, fname, best_dist))
|
||||
|
||||
if not matches:
|
||||
logger.info("No label matches for %s", arxiv_id)
|
||||
return 0
|
||||
|
||||
# 去冲突:按距离排序,每个 fig_id 和每个 filename 只匹配一次
|
||||
matches.sort(key=lambda x: x[3])
|
||||
used_fig_ids: set[int] = set()
|
||||
used_filenames: set[str] = set()
|
||||
renames: list[tuple[str, str, str]] = [] # (old_fname, new_fname, fig_id)
|
||||
|
||||
for fig_id, fig_idx, fname, dist in matches:
|
||||
if fig_idx in used_fig_ids or fname in used_filenames:
|
||||
continue
|
||||
for ref in info.get("figures", []) + info.get("tables", []):
|
||||
if ref in referenced_ids:
|
||||
keep_filenames.add(filename)
|
||||
used_fig_ids.add(fig_idx)
|
||||
used_filenames.add(fname)
|
||||
new_fname = f"{fig_id.replace(' ', '_').lower()}.jpg"
|
||||
renames.append((fname, new_fname, fig_id))
|
||||
|
||||
# 执行重命名
|
||||
labeled = 0
|
||||
new_manifest: dict[str, dict] = {}
|
||||
|
||||
for fname, info in manifest.items():
|
||||
if fname in used_filenames:
|
||||
continue
|
||||
# 未匹配的保持原样
|
||||
new_manifest[fname] = info
|
||||
|
||||
for old_fname, new_fname, fig_id in renames:
|
||||
old_path = images_dest / old_fname
|
||||
new_path = images_dest / new_fname
|
||||
if not old_path.exists():
|
||||
continue
|
||||
|
||||
# 搬运 manifest 信息
|
||||
info = manifest[old_fname].copy()
|
||||
cap_type = info.get("type", "figure")
|
||||
|
||||
# 读取 caption 文本(从 figures 列表)
|
||||
caption_text = ""
|
||||
for fig in figures:
|
||||
if fig.get("id") == fig_id:
|
||||
caption_text = fig.get("caption", "")
|
||||
break
|
||||
|
||||
if not keep_filenames:
|
||||
logger.warning(
|
||||
"No manifest matches for %s (refs=%s), keeping all",
|
||||
arxiv_id,
|
||||
referenced_ids,
|
||||
info["label"] = fig_id
|
||||
info["caption_text"] = caption_text[:200] if caption_text else ""
|
||||
info.setdefault("figures" if cap_type == "figure" else "tables", []).append(
|
||||
fig_id
|
||||
)
|
||||
return len(all_files)
|
||||
|
||||
removed = 0
|
||||
for f in all_files:
|
||||
if f.name not in keep_filenames:
|
||||
f.unlink()
|
||||
removed += 1
|
||||
# 重命名文件
|
||||
if new_fname != old_fname:
|
||||
old_path.rename(new_path)
|
||||
new_manifest[new_fname] = info
|
||||
labeled += 1
|
||||
|
||||
# 写回 manifest
|
||||
manifest_path.write_text(json.dumps(new_manifest, ensure_ascii=False, indent=2))
|
||||
|
||||
kept = len(all_files) - removed
|
||||
logger.info(
|
||||
"Filtered images for %s: kept %d, removed %d (refs=%s)",
|
||||
"Labeled %d/%d images for %s using summary figures",
|
||||
labeled,
|
||||
len(manifest),
|
||||
arxiv_id,
|
||||
kept,
|
||||
removed,
|
||||
referenced_ids,
|
||||
)
|
||||
return kept
|
||||
return labeled
|
||||
|
||||
|
||||
# ── Figure ↔ Image 关联 ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _normalize_figure_id(raw_id: str) -> str:
|
||||
"""归一化 Figure/Table ID:'Figure 1'/'Fig.1' → 'Figure 1'。"""
|
||||
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
|
||||
if m:
|
||||
return f"Figure {m.group(1)}"
|
||||
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
|
||||
if m2:
|
||||
return f"Table {m2.group(1)}"
|
||||
return raw_id
|
||||
|
||||
|
||||
def _is_figure_type(fig_id: str) -> bool:
|
||||
"""判断是否为 Figure 类型(非 Table)。"""
|
||||
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
|
||||
|
||||
|
||||
def _image_sort_key(name: str) -> tuple[int, int]:
|
||||
"""按文件名中的编号排序提取的图片。"""
|
||||
# 新格式:figure_1.jpg, table_1.jpg
|
||||
m = re.search(r"(?:figure|table)_(\d+)", name)
|
||||
if m:
|
||||
return (0, int(m.group(1)))
|
||||
# 旧格式:page2_img1.png, page5_table1.png, figure_1.png
|
||||
m2 = re.search(r"page(\d+)_(?:img|table)(\d+)", name)
|
||||
if m2:
|
||||
return (int(m2.group(1)), int(m2.group(2)))
|
||||
return (0, 0)
|
||||
|
||||
|
||||
def link_figures_with_images(
|
||||
figures: list[dict], images: list[dict], arxiv_id: str
|
||||
) -> list[dict]:
|
||||
"""将 summary figures 元数据与提取的图片文件关联。
|
||||
|
||||
策略:
|
||||
1. 优先用 manifest.json 的 label 做 ID 精确匹配
|
||||
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
|
||||
"""
|
||||
if not figures or not images:
|
||||
return figures
|
||||
|
||||
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
|
||||
|
||||
# ── 策略 1:manifest ID 精确匹配 ──
|
||||
id_to_url: dict[str, str] = {}
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||||
except (ValueError, TypeError):
|
||||
manifest = {}
|
||||
for filename, info in manifest.items():
|
||||
url = f"/papers/{arxiv_id}/images/{filename}"
|
||||
# 优先用 label 字段(新格式)
|
||||
label = info.get("label", "")
|
||||
if label:
|
||||
id_to_url[label] = url
|
||||
# 也兼容 figures/tables 列表(旧格式)
|
||||
for fig_id in info.get("figures", []) + info.get("tables", []):
|
||||
if fig_id not in id_to_url:
|
||||
id_to_url[fig_id] = url
|
||||
|
||||
for fig in figures:
|
||||
raw_id = fig.get("id", "")
|
||||
normalized = _normalize_figure_id(raw_id)
|
||||
if normalized in id_to_url:
|
||||
fig["image_url"] = id_to_url[normalized]
|
||||
|
||||
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
|
||||
unmatched = [f for f in figures if not f.get("image_url")]
|
||||
if not unmatched:
|
||||
return figures
|
||||
|
||||
# 按类型分流:Figure vs Table
|
||||
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
|
||||
table_type_unmatched = [
|
||||
f for f in unmatched if not _is_figure_type(f.get("id", ""))
|
||||
]
|
||||
|
||||
# 提取的图片按类型分流,按文件名中的编号排序
|
||||
fig_images = sorted(
|
||||
[img for img in images if "table" not in img["name"].lower()],
|
||||
key=lambda img: _image_sort_key(img["name"]),
|
||||
)
|
||||
table_images = sorted(
|
||||
[img for img in images if "table" in img["name"].lower()],
|
||||
key=lambda img: _image_sort_key(img["name"]),
|
||||
)
|
||||
|
||||
for i, fig in enumerate(fig_type_unmatched):
|
||||
if i < len(fig_images):
|
||||
fig["image_url"] = fig_images[i]["url"]
|
||||
|
||||
for i, fig in enumerate(table_type_unmatched):
|
||||
if i < len(table_images):
|
||||
fig["image_url"] = table_images[i]["url"]
|
||||
|
||||
return figures
|
||||
|
||||
+25
-10
@@ -11,6 +11,7 @@ import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from app.config import settings
|
||||
from app.utils import truncate_error
|
||||
from app.services.summary_utils import (
|
||||
JsonNotFoundError,
|
||||
build_prompt,
|
||||
@@ -21,6 +22,9 @@ from app.services.summary_utils import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# PDF 全文注入模式的字符上限 — 超过此阈值自动切换到 search 模式
|
||||
_PDF_MAX_CHARS = 80_000
|
||||
|
||||
# 重新导出,保持向后兼容
|
||||
__all__ = [
|
||||
"PiTimeoutError",
|
||||
@@ -45,7 +49,7 @@ class PiProcessError(Exception):
|
||||
def __init__(self, returncode: int, stderr: str):
|
||||
self.returncode = returncode
|
||||
self.stderr = stderr
|
||||
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
|
||||
super().__init__(f"pi exited with code {returncode}: {truncate_error(stderr)}")
|
||||
|
||||
|
||||
# ── pi CLI 调用 ────────────────────────────────────────────────────────
|
||||
@@ -72,23 +76,27 @@ async def call_pi(
|
||||
|
||||
actual_mode = pdf_mode
|
||||
if pdf_mode == "auto":
|
||||
if txt_size > 80_000:
|
||||
if txt_size > _PDF_MAX_CHARS:
|
||||
actual_mode = "search"
|
||||
logger.info(
|
||||
"Auto mode: %s text=%d chars > 80k → search", arxiv_id, txt_size
|
||||
"Auto mode: %s text=%d chars > %dk → search",
|
||||
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
|
||||
)
|
||||
else:
|
||||
actual_mode = "inject"
|
||||
logger.info(
|
||||
"Auto mode: %s text=%d chars ≤ 80k → inject", arxiv_id, txt_size
|
||||
"Auto mode: %s text=%d chars ≤ %dk → inject",
|
||||
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
|
||||
)
|
||||
|
||||
# inject 模式需要截断过长的文本(避免撑爆 context)
|
||||
if actual_mode == "inject" and txt_size > 80_000:
|
||||
if actual_mode == "inject" and txt_size > _PDF_MAX_CHARS:
|
||||
body = txt_path.read_text(encoding="utf-8")
|
||||
trimmed = body[:80_000].rstrip()
|
||||
trimmed = body[:_PDF_MAX_CHARS].rstrip()
|
||||
txt_path.write_text(trimmed, encoding="utf-8")
|
||||
logger.info("Truncated %s for inject: %d → %d chars", arxiv_id, txt_size, len(trimmed))
|
||||
logger.info(
|
||||
"Truncated %s for inject: %d → %d chars", arxiv_id, txt_size, len(trimmed)
|
||||
)
|
||||
|
||||
prompt_text = build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
|
||||
|
||||
@@ -101,7 +109,8 @@ async def call_pi(
|
||||
cmd = [
|
||||
settings.PI_BIN,
|
||||
"-p",
|
||||
"--tools", tools,
|
||||
"--tools",
|
||||
tools,
|
||||
]
|
||||
if fix_errors:
|
||||
cmd += ["--session", session_id, "--continue"]
|
||||
@@ -118,10 +127,14 @@ async def call_pi(
|
||||
|
||||
logger.info(
|
||||
"Calling pi for %s (fix=%s, session=%s, mode=%s)",
|
||||
arxiv_id, bool(fix_errors), session_id, actual_mode,
|
||||
arxiv_id,
|
||||
bool(fix_errors),
|
||||
session_id,
|
||||
actual_mode,
|
||||
)
|
||||
|
||||
import time as _time
|
||||
|
||||
_t_sub_start = _time.monotonic()
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
@@ -151,7 +164,9 @@ async def call_pi(
|
||||
|
||||
logger.info(
|
||||
"pi subprocess for %s: %.2fs%s",
|
||||
arxiv_id, _t_sub_end - _t_sub_start, _file_info,
|
||||
arxiv_id,
|
||||
_t_sub_end - _t_sub_start,
|
||||
_file_info,
|
||||
)
|
||||
|
||||
if proc.returncode != 0:
|
||||
|
||||
+56
-11
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from datetime import date as date_type
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
@@ -15,11 +16,50 @@ from app.models import CrawlLog, TaskLock
|
||||
from app.services.cleaner import cleanup_tmp
|
||||
from app.services.crawler import crawl_daily
|
||||
from app.services.summarizer import summarize_batch
|
||||
from app.utils import utc_now, yesterday_str
|
||||
from app.utils import release_lock, truncate_error, utc_now, yesterday_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def acquire_lock(db: Session, task: str, lock_key: str, owner: str) -> TaskLock:
|
||||
"""获取 TaskLock,锁冲突时抛出 RuntimeError。
|
||||
|
||||
供需要防重入的操作(crawl、pipeline 等)统一调用。
|
||||
"""
|
||||
lock = TaskLock(
|
||||
task=task,
|
||||
lock_key=lock_key,
|
||||
status="running",
|
||||
owner=owner,
|
||||
acquired_at=utc_now(),
|
||||
)
|
||||
try:
|
||||
db.add(lock)
|
||||
db.commit()
|
||||
except IntegrityError:
|
||||
db.rollback()
|
||||
raise RuntimeError(f"{task} already running for {lock_key}")
|
||||
return lock
|
||||
|
||||
|
||||
async def run_crawl(db: Session, target_date: str, owner: str = "admin_crawl") -> dict:
|
||||
"""执行单次抓取(带防重入锁)。
|
||||
|
||||
Args:
|
||||
db: 数据库 session
|
||||
target_date: 目标日期 YYYY-MM-DD
|
||||
owner: 调用者标识
|
||||
|
||||
Returns:
|
||||
crawl_daily() 的原始返回值
|
||||
"""
|
||||
lock = acquire_lock(db, "crawl", target_date, owner)
|
||||
try:
|
||||
return await crawl_daily(db, target_date)
|
||||
finally:
|
||||
release_lock(db, lock)
|
||||
|
||||
|
||||
async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
|
||||
"""执行完整流水线:crawl → summarize → cleanup。
|
||||
|
||||
@@ -47,7 +87,7 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
|
||||
try:
|
||||
db.add(lock)
|
||||
db.commit()
|
||||
except Exception:
|
||||
except IntegrityError:
|
||||
db.rollback()
|
||||
raise RuntimeError(f"Pipeline already running for {target_date}")
|
||||
|
||||
@@ -66,9 +106,13 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
|
||||
try:
|
||||
# Step 1: 抓取(先试今天,无数据则回退昨天)
|
||||
crawl_result = await crawl_daily(db, target_date)
|
||||
logger.info("Pipeline [%s]: crawl %s, found=%d new=%d",
|
||||
owner, target_date,
|
||||
crawl_result.get("found", 0), crawl_result.get("new", 0))
|
||||
logger.info(
|
||||
"Pipeline [%s]: crawl %s, found=%d new=%d",
|
||||
owner,
|
||||
target_date,
|
||||
crawl_result.get("found", 0),
|
||||
crawl_result.get("new", 0),
|
||||
)
|
||||
|
||||
if crawl_result.get("status") == "success" and crawl_result.get("found") == 0:
|
||||
yesterday = yesterday_str()
|
||||
@@ -81,8 +125,11 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
|
||||
|
||||
# Step 3: 清理
|
||||
cleanup_result = cleanup_tmp()
|
||||
logger.info("Pipeline [%s]: cleanup done, removed=%d",
|
||||
owner, cleanup_result.get("removed", 0))
|
||||
logger.info(
|
||||
"Pipeline [%s]: cleanup done, removed=%d",
|
||||
owner,
|
||||
cleanup_result.get("removed", 0),
|
||||
)
|
||||
|
||||
log_entry.status = "success"
|
||||
log_entry.papers_found = crawl_result.get("found", 0)
|
||||
@@ -91,7 +138,7 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
|
||||
except Exception as exc:
|
||||
logger.exception("Pipeline [%s] failed", owner)
|
||||
log_entry.status = "failed"
|
||||
error_msg = str(exc)[:2000]
|
||||
error_msg = truncate_error(exc, limit=2000)
|
||||
|
||||
finally:
|
||||
log_entry.completed_at = utc_now()
|
||||
@@ -99,9 +146,7 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
|
||||
log_entry.error = error_msg
|
||||
db.commit()
|
||||
|
||||
lock.status = "finished"
|
||||
lock.released_at = utc_now()
|
||||
db.commit()
|
||||
release_lock(db, lock)
|
||||
|
||||
if error_msg:
|
||||
return {"status": "failed", "error": error_msg}
|
||||
|
||||
@@ -90,6 +90,7 @@ class SummarySchema(BaseModel):
|
||||
|
||||
# ── 质量评估 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def assess_quality(schema: SummarySchema) -> str:
|
||||
"""评估总结质量:normal / degraded / low。"""
|
||||
# low:内容空洞的启发式判断
|
||||
|
||||
@@ -213,11 +213,7 @@ def _search_semantic(
|
||||
arxiv_ids = [c["arxiv_id"] for c in candidates]
|
||||
distance_map = {c["arxiv_id"]: c["distance"] for c in candidates}
|
||||
|
||||
stmt = (
|
||||
select(Paper)
|
||||
.where(Paper.arxiv_id.in_(arxiv_ids))
|
||||
.options(*PAPER_FULL_LOAD)
|
||||
)
|
||||
stmt = select(Paper).where(Paper.arxiv_id.in_(arxiv_ids)).options(*PAPER_FULL_LOAD)
|
||||
if tag:
|
||||
stmt = stmt.where(Paper.tags.any(tag=tag))
|
||||
|
||||
@@ -298,9 +294,7 @@ def _load_papers_by_ids(
|
||||
|
||||
papers = (
|
||||
db.execute(
|
||||
select(Paper)
|
||||
.where(Paper.id.in_(paper_ids))
|
||||
.options(*PAPER_FULL_LOAD)
|
||||
select(Paper).where(Paper.id.in_(paper_ids)).options(*PAPER_FULL_LOAD)
|
||||
)
|
||||
.unique()
|
||||
.scalars()
|
||||
|
||||
+39
-502
@@ -1,233 +1,42 @@
|
||||
"""AI 总结编排服务 — 协调 PDF 下载、pi CLI 调用、JSON 校验、DB 写入、语义索引。"""
|
||||
"""AI 总结编排服务 — 协调生成器、持久化、批量处理的顶层入口。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.database import SessionLocal
|
||||
from app.exceptions import ConflictError, NotFoundError
|
||||
from app.models import (
|
||||
PAPER_DEFAULT_LOAD,
|
||||
CrawlLog,
|
||||
Paper,
|
||||
PaperSummary,
|
||||
PaperTag,
|
||||
SummaryState,
|
||||
SummaryStatus,
|
||||
TaskLock,
|
||||
get_paper_by_arxiv_id,
|
||||
get_paper_by_id,
|
||||
)
|
||||
from app.services.pdf_downloader import (
|
||||
PdfDownloadError,
|
||||
cleanup_tmp,
|
||||
download_pdf,
|
||||
paper_dir,
|
||||
from app.services.pdf_downloader import download_pdf
|
||||
from app.services.summary_utils import write_meta_json
|
||||
from app.services.summary_generator import (
|
||||
_generate_with_retry,
|
||||
)
|
||||
from app.services.summary_utils import (
|
||||
JsonNotFoundError,
|
||||
build_prompt,
|
||||
extract_json,
|
||||
write_meta_json,
|
||||
extract_pdf_text,
|
||||
from app.services.summary_persister import (
|
||||
_cleanup_old_images,
|
||||
_handle_summary_failure,
|
||||
_persist_summary,
|
||||
)
|
||||
from app.services.pi_client import (
|
||||
PiProcessError,
|
||||
PiTimeoutError,
|
||||
call_pi,
|
||||
)
|
||||
from app.services import claude_backend
|
||||
from app.services.schemas import (
|
||||
SummarySchema,
|
||||
assess_quality,
|
||||
classify_validation_error,
|
||||
flatten_for_db,
|
||||
)
|
||||
from app.utils import TMP_DIR, release_lock, utc_now
|
||||
from app.utils import TMP_DIR, release_lock, truncate_error, utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── 错误分类 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _classify_error(exc: Exception) -> str:
|
||||
"""将异常映射到 error_type 枚举值。"""
|
||||
if isinstance(exc, PdfDownloadError):
|
||||
return "pdf_download_failed"
|
||||
if isinstance(exc, PiTimeoutError):
|
||||
return "timeout"
|
||||
if isinstance(exc, PiProcessError):
|
||||
return "process_error"
|
||||
if isinstance(exc, JsonNotFoundError):
|
||||
return "json_not_found"
|
||||
if isinstance(exc, json.JSONDecodeError):
|
||||
return "json_invalid"
|
||||
if isinstance(exc, ValidationError):
|
||||
return classify_validation_error(exc)
|
||||
return "unknown"
|
||||
|
||||
|
||||
# ── FTS5 文本构建 ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _build_fts_summary_text(schema: SummarySchema) -> str:
|
||||
"""拼接用于 FTS5 索引的总结文本。"""
|
||||
parts = [
|
||||
schema.one_line or "",
|
||||
schema.motivation.problem or "",
|
||||
schema.motivation.goal or "",
|
||||
schema.method.overview or "",
|
||||
schema.method.key_idea or "",
|
||||
schema.results.main_findings or "",
|
||||
]
|
||||
return " ".join(p for p in parts if p)
|
||||
|
||||
|
||||
# ── DB 更新 ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _update_summary_in_db(
|
||||
db: Session,
|
||||
paper: Paper,
|
||||
schema: SummarySchema,
|
||||
quality: str,
|
||||
raw_output: str,
|
||||
) -> None:
|
||||
"""将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。"""
|
||||
from sqlalchemy import text
|
||||
|
||||
# 1. paper_summaries:upsert
|
||||
existing = db.get(PaperSummary, paper.id)
|
||||
flat = flatten_for_db(schema)
|
||||
if existing:
|
||||
for k, v in flat.items():
|
||||
setattr(existing, k, v)
|
||||
else:
|
||||
db.add(PaperSummary(paper_id=paper.id, **flat))
|
||||
|
||||
# 2. papers 表
|
||||
paper.title_zh = schema.title_zh
|
||||
paper.summary_quality = quality
|
||||
p_dir = paper_dir(paper.arxiv_id)
|
||||
paper.summary_path = str(p_dir / "summary.json")
|
||||
paper.raw_output_path = str(p_dir / "raw_output.txt")
|
||||
|
||||
# 3. AI 标签
|
||||
existing_tag_names = {t.tag for t in paper.tags}
|
||||
for tag_name in schema.tags:
|
||||
if tag_name not in existing_tag_names:
|
||||
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
|
||||
existing_tag_names.add(tag_name)
|
||||
|
||||
# 4. FTS5 更新
|
||||
summary_text = _build_fts_summary_text(schema)
|
||||
db.execute(
|
||||
text(
|
||||
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
|
||||
"WHERE rowid=:paper_id"
|
||||
),
|
||||
{
|
||||
"title_zh": schema.title_zh,
|
||||
"summary_text": summary_text,
|
||||
"paper_id": paper.id,
|
||||
},
|
||||
)
|
||||
|
||||
db.commit()
|
||||
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
|
||||
|
||||
|
||||
# ── JSON 验证 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
|
||||
"""验证 JSON 数据是否符合要求,返回错误列表(空=通过)。"""
|
||||
errors: list[str] = []
|
||||
|
||||
if not isinstance(json_data, dict):
|
||||
return ["顶层必须是 JSON 对象"]
|
||||
|
||||
# 必填字段
|
||||
for f in ["arxiv_id", "title_zh", "one_line", "tags"]:
|
||||
if f not in json_data or not json_data[f]:
|
||||
errors.append(f"缺少必填字段: {f}")
|
||||
|
||||
# tags 必须是非空数组
|
||||
tags = json_data.get("tags")
|
||||
if not isinstance(tags, list) or len(tags) == 0:
|
||||
errors.append("tags 必须是非空数组")
|
||||
|
||||
# 字符串段落字段(必须是 str 且 ≥50 字)
|
||||
string_fields = [
|
||||
("motivation", "problem"), ("motivation", "goal"), ("motivation", "gap"),
|
||||
("method", "overview"), ("method", "key_idea"), ("method", "steps"),
|
||||
("method", "novelty"),
|
||||
("results", "main_findings"), ("results", "limitations"),
|
||||
("improvements", "weaknesses"), ("improvements", "future_work"),
|
||||
("improvements", "reproducibility"),
|
||||
]
|
||||
for section, field in string_fields:
|
||||
val = json_data.get(section, {}).get(field)
|
||||
if isinstance(val, list):
|
||||
errors.append(f"{section}.{field} 应该是字符串段落,不能是数组")
|
||||
elif not isinstance(val, str) or len(val.strip()) < 50:
|
||||
errors.append(
|
||||
f"{section}.{field} 必须是详细段落(≥50字),"
|
||||
f"当前: {type(val).__name__} ({len(str(val))}字)"
|
||||
)
|
||||
|
||||
# benchmarks 必须是数组
|
||||
benchmarks = json_data.get("results", {}).get("benchmarks")
|
||||
if benchmarks is not None and not isinstance(benchmarks, list):
|
||||
errors.append("results.benchmarks 必须是数组")
|
||||
|
||||
# prerequisites.concepts 必须是对象数组,每个有 term
|
||||
concepts = json_data.get("prerequisites", {}).get("concepts")
|
||||
if concepts is not None:
|
||||
if not isinstance(concepts, list):
|
||||
errors.append("prerequisites.concepts 必须是数组")
|
||||
elif len(concepts) == 0:
|
||||
errors.append("prerequisites.concepts 不能为空")
|
||||
else:
|
||||
for i, c in enumerate(concepts):
|
||||
if isinstance(c, str):
|
||||
errors.append(f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串")
|
||||
elif isinstance(c, dict) and not c.get("term"):
|
||||
errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段")
|
||||
|
||||
# figures 必须是数组,每个元素应有 id
|
||||
figures = json_data.get("figures")
|
||||
if figures is not None:
|
||||
if not isinstance(figures, list):
|
||||
errors.append("figures 必须是数组")
|
||||
else:
|
||||
for i, fig in enumerate(figures):
|
||||
if isinstance(fig, dict) and not fig.get("id"):
|
||||
errors.append(f"figures[{i}] 缺少 id 字段")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# ── 文件操作 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) -> None:
|
||||
d = paper_dir(arxiv_id)
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
if schema:
|
||||
(d / "summary.json").write_text(
|
||||
schema.model_dump_json(ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
|
||||
|
||||
|
||||
# ── 单篇总结 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -264,277 +73,7 @@ async def summarize_one(
|
||||
return await _do_summarize_one(db, paper, pdf_mode=pdf_mode)
|
||||
|
||||
|
||||
async def _generate_with_retry(
|
||||
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
|
||||
) -> tuple[dict, str]:
|
||||
"""调用 AI 后端生成总结,最多 4 轮验证循环。
|
||||
|
||||
根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。
|
||||
|
||||
Returns:
|
||||
(json_data, raw_output)
|
||||
Raises:
|
||||
ValueError: 4 轮验证仍未通过
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
backend = settings.SUMMARY_BACKEND
|
||||
validation_errors: list[str] = []
|
||||
json_data: dict | None = None
|
||||
raw_output = ""
|
||||
session_id = None
|
||||
|
||||
summary_file = paper_dir(arxiv_id) / "summary.json"
|
||||
|
||||
# claude 后端需要预构建 prompt(pi 后端在 call_pi 内部构建)
|
||||
claude_prompt: str | None = None
|
||||
if backend == "claude":
|
||||
_t0 = _time.monotonic()
|
||||
txt_path = extract_pdf_text(pdf_path, max_chars=None)
|
||||
body = txt_path.read_text(encoding="utf-8")
|
||||
if len(body) > 80_000:
|
||||
trimmed = body[:80_000].rstrip()
|
||||
txt_path.write_text(trimmed, encoding="utf-8")
|
||||
claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None)
|
||||
logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0)
|
||||
|
||||
for attempt in range(1, 5):
|
||||
# 清理上一轮写入的不完整文件
|
||||
if summary_file.exists():
|
||||
summary_file.unlink()
|
||||
|
||||
# 记录 AI 调用开始时间
|
||||
_t_call_start = _time.monotonic()
|
||||
|
||||
if backend == "claude":
|
||||
if attempt == 1:
|
||||
raw_output, session_id = await claude_backend.call_claude(
|
||||
claude_prompt, session_id=None,
|
||||
)
|
||||
else:
|
||||
retry_prompt = build_prompt(
|
||||
arxiv_id, meta_path,
|
||||
extract_pdf_text(pdf_path, max_chars=80000),
|
||||
"inject", fix_errors=validation_errors,
|
||||
)
|
||||
raw_output, session_id = await claude_backend.call_claude(
|
||||
retry_prompt, session_id=session_id, fix_errors=validation_errors,
|
||||
)
|
||||
else:
|
||||
if attempt == 1:
|
||||
raw_output, session_id = await call_pi(meta_path, pdf_path, pdf_mode=pdf_mode)
|
||||
else:
|
||||
raw_output, session_id = await call_pi(
|
||||
meta_path, pdf_path,
|
||||
fix_errors=validation_errors,
|
||||
session_id=session_id,
|
||||
pdf_mode=pdf_mode,
|
||||
)
|
||||
|
||||
_t_call_end = _time.monotonic()
|
||||
|
||||
# 检查 summary.json 是否由 AI 子进程写入
|
||||
file_written_by_ai = summary_file.exists()
|
||||
file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None
|
||||
file_size = summary_file.stat().st_size if file_written_by_ai else 0
|
||||
|
||||
logger.info(
|
||||
" [%s] attempt %d AI调用: %.2fs summary.json=%s%s",
|
||||
arxiv_id, attempt,
|
||||
_t_call_end - _t_call_start,
|
||||
f"已写入({file_size}B)" if file_written_by_ai else "未写入",
|
||||
f" mtime={file_mtime:.2f}" if file_mtime else "",
|
||||
)
|
||||
|
||||
# 提取 JSON
|
||||
_t_json_start = _time.monotonic()
|
||||
try:
|
||||
if file_written_by_ai:
|
||||
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
|
||||
logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id)
|
||||
else:
|
||||
json_data = extract_json(raw_output)
|
||||
except (json.JSONDecodeError, JsonNotFoundError) as exc:
|
||||
_t_json_end = _time.monotonic()
|
||||
logger.warning(
|
||||
" [%s] JSON提取失败: %.2fs %s",
|
||||
arxiv_id, _t_json_end - _t_json_start, str(exc)[:200],
|
||||
)
|
||||
validation_errors = [f"无法提取有效 JSON: {str(exc)[:100]}"]
|
||||
continue
|
||||
_t_json_end = _time.monotonic()
|
||||
|
||||
# 验证
|
||||
_t_val_start = _time.monotonic()
|
||||
validation_errors = _validate_summary(json_data, arxiv_id)
|
||||
_t_val_end = _time.monotonic()
|
||||
|
||||
if not validation_errors:
|
||||
logger.info(
|
||||
" [%s] JSON提取: %.2fs 验证: %.2fs ✅",
|
||||
arxiv_id,
|
||||
_t_json_end - _t_json_start,
|
||||
_t_val_end - _t_val_start,
|
||||
)
|
||||
break
|
||||
logger.warning(
|
||||
" [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s",
|
||||
arxiv_id,
|
||||
_t_json_end - _t_json_start,
|
||||
_t_val_end - _t_val_start,
|
||||
"; ".join(validation_errors)[:200],
|
||||
)
|
||||
|
||||
if validation_errors:
|
||||
exc = ValueError(
|
||||
f"Summary validation failed after 4 attempts: {'; '.join(validation_errors)}"
|
||||
)
|
||||
exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用
|
||||
raise exc
|
||||
|
||||
return json_data, raw_output
|
||||
|
||||
|
||||
def _persist_summary(
|
||||
db: Session, paper: Paper, json_data: dict, raw_output: str
|
||||
) -> str:
|
||||
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
|
||||
import time as _time
|
||||
arxiv_id = paper.arxiv_id
|
||||
|
||||
_t0 = _time.monotonic()
|
||||
schema = SummarySchema.model_validate(json_data)
|
||||
quality = assess_quality(schema)
|
||||
_t1 = _time.monotonic()
|
||||
|
||||
_save_files(arxiv_id, schema, raw_output)
|
||||
_t2 = _time.monotonic()
|
||||
|
||||
_update_summary_in_db(db, paper, schema, quality, raw_output)
|
||||
_t3 = _time.monotonic()
|
||||
|
||||
# 状态 → done
|
||||
paper.summary_status.status = SummaryState.DONE
|
||||
paper.summary_status.quality = quality
|
||||
paper.summary_status.completed_at = utc_now()
|
||||
paper.summary_status.raw_output_saved = True
|
||||
db.commit()
|
||||
_t4 = _time.monotonic()
|
||||
|
||||
logger.info(
|
||||
" [%s] persist: pydantic=%.2fs 文件=%.2fs DB写入=%.2fs 状态commit=%.2fs",
|
||||
arxiv_id,
|
||||
_t1 - _t0,
|
||||
_t2 - _t1,
|
||||
_t3 - _t2,
|
||||
_t4 - _t3,
|
||||
)
|
||||
|
||||
# 触发性增强(失败不影响总结)
|
||||
_t5 = _time.monotonic()
|
||||
_maybe_extract_images(arxiv_id, schema)
|
||||
_t6 = _time.monotonic()
|
||||
_maybe_index_chroma(arxiv_id, paper, schema)
|
||||
_t7 = _time.monotonic()
|
||||
|
||||
logger.info(
|
||||
" [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs",
|
||||
arxiv_id,
|
||||
_t6 - _t5,
|
||||
_t7 - _t6,
|
||||
)
|
||||
|
||||
return quality
|
||||
|
||||
|
||||
def _handle_summary_failure(
|
||||
db: Session, paper: Paper, exc: Exception, raw_output: str,
|
||||
) -> dict:
|
||||
"""记录失败:保存 raw_output、重试计数、错误分类。"""
|
||||
error_type = _classify_error(exc)
|
||||
logger.error(
|
||||
"Summarize failed: %s error_type=%s %s",
|
||||
paper.arxiv_id, error_type, str(exc)[:200],
|
||||
)
|
||||
|
||||
status = paper.summary_status
|
||||
if raw_output:
|
||||
_save_files(paper.arxiv_id, None, raw_output)
|
||||
status.raw_output_saved = True
|
||||
|
||||
status.retry_count = (status.retry_count or 0) + 1
|
||||
status.error_type = error_type
|
||||
status.error = str(exc)[:2000]
|
||||
|
||||
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
|
||||
status.status = SummaryState.PERMANENT_FAILURE
|
||||
else:
|
||||
status.status = SummaryState.PENDING
|
||||
|
||||
status.completed_at = utc_now()
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"arxiv_id": paper.arxiv_id,
|
||||
"status": "failed",
|
||||
"error_type": error_type,
|
||||
"error": str(exc)[:200],
|
||||
"retry_count": status.retry_count,
|
||||
}
|
||||
|
||||
|
||||
def _cleanup_old_images(db: Session, paper: Paper) -> None:
|
||||
"""清理旧的图片文件和 figures_json,避免重新总结时残留。"""
|
||||
arxiv_id = paper.arxiv_id
|
||||
images_dir = paper_dir(arxiv_id) / "images"
|
||||
if images_dir.exists():
|
||||
for old_file in images_dir.iterdir():
|
||||
if old_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg") or old_file.name == "manifest.json":
|
||||
old_file.unlink(missing_ok=True)
|
||||
# 清除数据库中的 figures_json
|
||||
if paper.summary and paper.summary.figures_json:
|
||||
paper.summary.figures_json = None
|
||||
db.commit()
|
||||
|
||||
|
||||
def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
|
||||
"""从 PDF 提取图片和表格(失败不影响总结)。"""
|
||||
try:
|
||||
from app.services.pdf_image_extractor import (
|
||||
extract_images_from_pdf,
|
||||
filter_images_by_summary,
|
||||
)
|
||||
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
|
||||
extract_images_from_pdf(arxiv_id, pdf_path)
|
||||
if schema.figures:
|
||||
filter_images_by_summary(arxiv_id, schema.figures)
|
||||
except Exception:
|
||||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||||
|
||||
|
||||
def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None:
|
||||
"""写入 ChromaDB 语义索引(失败不影响总结)。"""
|
||||
try:
|
||||
from app.services.embedder import index_paper
|
||||
|
||||
texts_dict = {
|
||||
"arxiv_id": arxiv_id,
|
||||
"title_zh": schema.title_zh or "",
|
||||
"title_en": paper.title_en or "",
|
||||
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
|
||||
"one_line": schema.one_line or "",
|
||||
"motivation_problem": schema.motivation.problem or "",
|
||||
"method_key_idea": schema.method.key_idea or "",
|
||||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
|
||||
}
|
||||
index_paper(arxiv_id, texts_dict)
|
||||
except Exception:
|
||||
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
|
||||
|
||||
|
||||
async def _do_summarize_one(
|
||||
db: Session, paper: Paper, pdf_mode: str = "auto"
|
||||
) -> dict:
|
||||
async def _do_summarize_one(db: Session, paper: Paper, pdf_mode: str = "auto") -> dict:
|
||||
"""实际的单篇总结执行(在 semaphore 保护下)。"""
|
||||
arxiv_id = paper.arxiv_id
|
||||
title_short = (paper.title_en or "")[:50]
|
||||
@@ -548,6 +87,7 @@ async def _do_summarize_one(
|
||||
|
||||
# 清理旧的图片文件和 figures_json,避免重新总结时残留
|
||||
import time as _time
|
||||
|
||||
_t_cleanup_start = _time.monotonic()
|
||||
_cleanup_old_images(db, paper)
|
||||
_t_cleanup_end = _time.monotonic()
|
||||
@@ -567,7 +107,9 @@ async def _do_summarize_one(
|
||||
|
||||
logger.info(" [%s] 调用 pi 生成总结...", arxiv_id)
|
||||
json_data, raw_output = await _generate_with_retry(
|
||||
arxiv_id, meta_path, TMP_DIR / arxiv_id / "paper.pdf",
|
||||
arxiv_id,
|
||||
meta_path,
|
||||
TMP_DIR / arxiv_id / "paper.pdf",
|
||||
pdf_mode=pdf_mode,
|
||||
)
|
||||
_t3 = _time.monotonic()
|
||||
@@ -577,7 +119,9 @@ async def _do_summarize_one(
|
||||
_t4 = _time.monotonic()
|
||||
logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3)
|
||||
|
||||
logger.info("✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0)
|
||||
logger.info(
|
||||
"✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0
|
||||
)
|
||||
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
|
||||
|
||||
except Exception as exc:
|
||||
@@ -586,7 +130,7 @@ async def _do_summarize_one(
|
||||
return _handle_summary_failure(db, paper, exc, fail_output)
|
||||
|
||||
finally:
|
||||
cleanup_tmp(arxiv_id)
|
||||
pass # cleanup_tmp(arxiv_id) # 暂时禁用,保留 PDF 用于调试图片提取
|
||||
|
||||
|
||||
# ── 单篇入口 ────────────────────────────────────────────────────────────
|
||||
@@ -604,25 +148,19 @@ async def summarize_single(
|
||||
|
||||
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
|
||||
"""
|
||||
paper = db.execute(
|
||||
select(Paper)
|
||||
.where(Paper.arxiv_id == arxiv_id)
|
||||
.options(*PAPER_DEFAULT_LOAD)
|
||||
).unique().scalar_one_or_none()
|
||||
paper = get_paper_by_arxiv_id(db, arxiv_id)
|
||||
if not paper:
|
||||
return {"status": "not_found", "arxiv_id": arxiv_id}
|
||||
raise NotFoundError(f"Paper not found: {arxiv_id}")
|
||||
|
||||
make_session = _session_factory or SessionLocal
|
||||
|
||||
# 每篇用独立 session 避免并发问题
|
||||
paper_db = make_session()
|
||||
try:
|
||||
paper_in_new_session = paper_db.execute(
|
||||
select(Paper)
|
||||
.where(Paper.arxiv_id == arxiv_id)
|
||||
.options(*PAPER_DEFAULT_LOAD)
|
||||
).unique().scalar_one_or_none()
|
||||
result = await summarize_one(paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode)
|
||||
paper_in_new_session = get_paper_by_arxiv_id(paper_db, arxiv_id)
|
||||
result = await summarize_one(
|
||||
paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode
|
||||
)
|
||||
finally:
|
||||
paper_db.close()
|
||||
|
||||
@@ -656,10 +194,10 @@ async def summarize_batch(
|
||||
try:
|
||||
db.add(lock)
|
||||
db.commit()
|
||||
except Exception:
|
||||
except IntegrityError:
|
||||
db.rollback()
|
||||
logger.warning("Summarize batch already running (lock conflict)")
|
||||
return {"status": "conflict", "error": "summarize batch already running"}
|
||||
raise ConflictError("summarize batch already running")
|
||||
|
||||
# CrawlLog
|
||||
log_entry = CrawlLog(
|
||||
@@ -717,19 +255,18 @@ async def summarize_batch(
|
||||
break
|
||||
paper_db = make_session()
|
||||
try:
|
||||
p = paper_db.execute(
|
||||
select(Paper)
|
||||
.where(Paper.id == paper.id)
|
||||
.options(*PAPER_DEFAULT_LOAD)
|
||||
).unique().scalar_one_or_none()
|
||||
p = get_paper_by_id(paper_db, paper.id)
|
||||
result = await summarize_one(paper_db, p, pdf_mode=pdf_mode)
|
||||
status = result.get("status", "failed")
|
||||
progress[status] = progress.get(status, 0) + 1
|
||||
finished = sum(progress.values())
|
||||
logger.info(
|
||||
"📊 进度: %d/%d (✅%d ❌%d ⏭️%d) — %s",
|
||||
finished, total,
|
||||
progress["done"], progress["failed"], progress["skipped"],
|
||||
finished,
|
||||
total,
|
||||
progress["done"],
|
||||
progress["failed"],
|
||||
progress["skipped"],
|
||||
paper.arxiv_id,
|
||||
)
|
||||
results.append(result)
|
||||
@@ -785,10 +322,10 @@ async def summarize_batch(
|
||||
except Exception as exc:
|
||||
logger.exception("Summarize batch failed")
|
||||
log_entry.status = "failed"
|
||||
log_entry.error = str(exc)[:2000]
|
||||
log_entry.error = truncate_error(exc, limit=2000)
|
||||
log_entry.completed_at = utc_now()
|
||||
db.commit()
|
||||
return {"status": "failed", "error": str(exc)}
|
||||
return {"status": "failed", "error": truncate_error(exc)}
|
||||
|
||||
finally:
|
||||
release_lock(db, lock)
|
||||
|
||||
@@ -0,0 +1,275 @@
|
||||
"""AI 总结生成器 — AI 后端调用、重试循环、JSON 验证、错误分类。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.config import settings
|
||||
from app.services.pdf_downloader import (
|
||||
PdfDownloadError,
|
||||
paper_dir,
|
||||
)
|
||||
from app.services.summary_utils import (
|
||||
JsonNotFoundError,
|
||||
build_prompt,
|
||||
extract_json,
|
||||
extract_pdf_text,
|
||||
)
|
||||
from app.services.pi_client import (
|
||||
PiProcessError,
|
||||
PiTimeoutError,
|
||||
call_pi,
|
||||
)
|
||||
from app.services import claude_backend
|
||||
from app.services.schemas import classify_validation_error
|
||||
from app.utils import truncate_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── 错误分类 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _classify_error(exc: Exception) -> str:
|
||||
"""将异常映射到 error_type 枚举值。"""
|
||||
if isinstance(exc, PdfDownloadError):
|
||||
return "pdf_download_failed"
|
||||
if isinstance(exc, PiTimeoutError):
|
||||
return "timeout"
|
||||
if isinstance(exc, PiProcessError):
|
||||
return "process_error"
|
||||
if isinstance(exc, JsonNotFoundError):
|
||||
return "json_not_found"
|
||||
if isinstance(exc, json.JSONDecodeError):
|
||||
return "json_invalid"
|
||||
if isinstance(exc, ValidationError):
|
||||
return classify_validation_error(exc)
|
||||
return "unknown"
|
||||
|
||||
|
||||
# ── JSON 验证 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
|
||||
"""验证 JSON 数据是否符合要求,返回错误列表(空=通过)。"""
|
||||
errors: list[str] = []
|
||||
|
||||
if not isinstance(json_data, dict):
|
||||
return ["顶层必须是 JSON 对象"]
|
||||
|
||||
# 必填字段
|
||||
for f in ["arxiv_id", "title_zh", "one_line", "tags"]:
|
||||
if f not in json_data or not json_data[f]:
|
||||
errors.append(f"缺少必填字段: {f}")
|
||||
|
||||
# tags 必须是非空数组
|
||||
tags = json_data.get("tags")
|
||||
if not isinstance(tags, list) or len(tags) == 0:
|
||||
errors.append("tags 必须是非空数组")
|
||||
|
||||
# 字符串段落字段(必须是 str 且 ≥50 字)
|
||||
string_fields = [
|
||||
("motivation", "problem"),
|
||||
("motivation", "goal"),
|
||||
("motivation", "gap"),
|
||||
("method", "overview"),
|
||||
("method", "key_idea"),
|
||||
("method", "steps"),
|
||||
("method", "novelty"),
|
||||
("results", "main_findings"),
|
||||
("results", "limitations"),
|
||||
("improvements", "weaknesses"),
|
||||
("improvements", "future_work"),
|
||||
("improvements", "reproducibility"),
|
||||
]
|
||||
for section, field in string_fields:
|
||||
val = json_data.get(section, {}).get(field)
|
||||
if isinstance(val, list):
|
||||
errors.append(f"{section}.{field} 应该是字符串段落,不能是数组")
|
||||
elif not isinstance(val, str) or len(val.strip()) < 50:
|
||||
errors.append(
|
||||
f"{section}.{field} 必须是详细段落(≥50字),"
|
||||
f"当前: {type(val).__name__} ({len(str(val))}字)"
|
||||
)
|
||||
|
||||
# benchmarks 必须是数组
|
||||
benchmarks = json_data.get("results", {}).get("benchmarks")
|
||||
if benchmarks is not None and not isinstance(benchmarks, list):
|
||||
errors.append("results.benchmarks 必须是数组")
|
||||
|
||||
# prerequisites.concepts 必须是对象数组,每个有 term
|
||||
concepts = json_data.get("prerequisites", {}).get("concepts")
|
||||
if concepts is not None:
|
||||
if not isinstance(concepts, list):
|
||||
errors.append("prerequisites.concepts 必须是数组")
|
||||
elif len(concepts) == 0:
|
||||
errors.append("prerequisites.concepts 不能为空")
|
||||
else:
|
||||
for i, c in enumerate(concepts):
|
||||
if isinstance(c, str):
|
||||
errors.append(
|
||||
f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串"
|
||||
)
|
||||
elif isinstance(c, dict) and not c.get("term"):
|
||||
errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段")
|
||||
|
||||
# figures 必须是数组,每个元素应有 id
|
||||
figures = json_data.get("figures")
|
||||
if figures is not None:
|
||||
if not isinstance(figures, list):
|
||||
errors.append("figures 必须是数组")
|
||||
else:
|
||||
for i, fig in enumerate(figures):
|
||||
if isinstance(fig, dict) and not fig.get("id"):
|
||||
errors.append(f"figures[{i}] 缺少 id 字段")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# ── AI 调用 + 重试 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _generate_with_retry(
|
||||
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
|
||||
) -> tuple[dict, str]:
|
||||
"""调用 AI 后端生成总结,最多 4 轮验证循环。
|
||||
|
||||
根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。
|
||||
|
||||
Returns:
|
||||
(json_data, raw_output)
|
||||
Raises:
|
||||
ValueError: 4 轮验证仍未通过
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
backend = settings.SUMMARY_BACKEND
|
||||
validation_errors: list[str] = []
|
||||
json_data: dict | None = None
|
||||
raw_output = ""
|
||||
session_id = None
|
||||
|
||||
summary_file = paper_dir(arxiv_id) / "summary.json"
|
||||
|
||||
# claude 后端需要预构建 prompt(pi 后端在 call_pi 内部构建)
|
||||
claude_prompt: str | None = None
|
||||
if backend == "claude":
|
||||
_t0 = _time.monotonic()
|
||||
txt_path = extract_pdf_text(pdf_path, max_chars=None)
|
||||
body = txt_path.read_text(encoding="utf-8")
|
||||
if len(body) > 80_000:
|
||||
trimmed = body[:80_000].rstrip()
|
||||
txt_path.write_text(trimmed, encoding="utf-8")
|
||||
claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None)
|
||||
logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0)
|
||||
|
||||
for attempt in range(1, settings.SUMMARY_MAX_RETRIES + 1):
|
||||
# 清理上一轮写入的不完整文件
|
||||
if summary_file.exists():
|
||||
summary_file.unlink()
|
||||
|
||||
# 记录 AI 调用开始时间
|
||||
_t_call_start = _time.monotonic()
|
||||
|
||||
if backend == "claude":
|
||||
if attempt == 1:
|
||||
raw_output, session_id = await claude_backend.call_claude(
|
||||
claude_prompt,
|
||||
session_id=None,
|
||||
)
|
||||
else:
|
||||
retry_prompt = build_prompt(
|
||||
arxiv_id,
|
||||
meta_path,
|
||||
extract_pdf_text(pdf_path, max_chars=80000),
|
||||
"inject",
|
||||
fix_errors=validation_errors,
|
||||
)
|
||||
raw_output, session_id = await claude_backend.call_claude(
|
||||
retry_prompt,
|
||||
session_id=session_id,
|
||||
fix_errors=validation_errors,
|
||||
)
|
||||
else:
|
||||
if attempt == 1:
|
||||
raw_output, session_id = await call_pi(
|
||||
meta_path, pdf_path, pdf_mode=pdf_mode
|
||||
)
|
||||
else:
|
||||
raw_output, session_id = await call_pi(
|
||||
meta_path,
|
||||
pdf_path,
|
||||
fix_errors=validation_errors,
|
||||
session_id=session_id,
|
||||
pdf_mode=pdf_mode,
|
||||
)
|
||||
|
||||
_t_call_end = _time.monotonic()
|
||||
|
||||
# 检查 summary.json 是否由 AI 子进程写入
|
||||
file_written_by_ai = summary_file.exists()
|
||||
file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None
|
||||
file_size = summary_file.stat().st_size if file_written_by_ai else 0
|
||||
|
||||
logger.info(
|
||||
" [%s] attempt %d AI调用: %.2fs summary.json=%s%s",
|
||||
arxiv_id,
|
||||
attempt,
|
||||
_t_call_end - _t_call_start,
|
||||
f"已写入({file_size}B)" if file_written_by_ai else "未写入",
|
||||
f" mtime={file_mtime:.2f}" if file_mtime else "",
|
||||
)
|
||||
|
||||
# 提取 JSON
|
||||
_t_json_start = _time.monotonic()
|
||||
try:
|
||||
if file_written_by_ai:
|
||||
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
|
||||
logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id)
|
||||
else:
|
||||
json_data = extract_json(raw_output)
|
||||
except (json.JSONDecodeError, JsonNotFoundError) as exc:
|
||||
_t_json_end = _time.monotonic()
|
||||
logger.warning(
|
||||
" [%s] JSON提取失败: %.2fs %s",
|
||||
arxiv_id,
|
||||
_t_json_end - _t_json_start,
|
||||
str(exc)[:200],
|
||||
)
|
||||
validation_errors = [f"无法提取有效 JSON: {truncate_error(exc)}"]
|
||||
continue
|
||||
_t_json_end = _time.monotonic()
|
||||
|
||||
# 验证
|
||||
_t_val_start = _time.monotonic()
|
||||
validation_errors = _validate_summary(json_data, arxiv_id)
|
||||
_t_val_end = _time.monotonic()
|
||||
|
||||
if not validation_errors:
|
||||
logger.info(
|
||||
" [%s] JSON提取: %.2fs 验证: %.2fs ✅",
|
||||
arxiv_id,
|
||||
_t_json_end - _t_json_start,
|
||||
_t_val_end - _t_val_start,
|
||||
)
|
||||
break
|
||||
logger.warning(
|
||||
" [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s",
|
||||
arxiv_id,
|
||||
_t_json_end - _t_json_start,
|
||||
_t_val_end - _t_val_start,
|
||||
"; ".join(validation_errors)[:200],
|
||||
)
|
||||
|
||||
if validation_errors:
|
||||
exc = ValueError(
|
||||
f"Summary validation failed after {settings.SUMMARY_MAX_RETRIES} attempts: {'; '.join(validation_errors)}"
|
||||
)
|
||||
exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用
|
||||
raise exc
|
||||
|
||||
return json_data, raw_output
|
||||
@@ -0,0 +1,273 @@
|
||||
"""AI 总结持久化 — DB 写入、文件保存、FTS 索引、图片提取、ChromaDB 索引。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import (
|
||||
Paper,
|
||||
PaperSummary,
|
||||
PaperTag,
|
||||
SummaryState,
|
||||
)
|
||||
from app.services.pdf_downloader import paper_dir
|
||||
from app.services.schemas import (
|
||||
SummarySchema,
|
||||
assess_quality,
|
||||
flatten_for_db,
|
||||
)
|
||||
from app.services.summary_generator import _classify_error
|
||||
from app.utils import TMP_DIR, truncate_error, utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── FTS5 文本构建 ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _build_fts_summary_text(schema: SummarySchema) -> str:
|
||||
"""拼接用于 FTS5 索引的总结文本。"""
|
||||
parts = [
|
||||
schema.one_line or "",
|
||||
schema.motivation.problem or "",
|
||||
schema.motivation.goal or "",
|
||||
schema.method.overview or "",
|
||||
schema.method.key_idea or "",
|
||||
schema.results.main_findings or "",
|
||||
]
|
||||
return " ".join(p for p in parts if p)
|
||||
|
||||
|
||||
# ── DB 更新 ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _update_summary_in_db(
|
||||
db: Session,
|
||||
paper: Paper,
|
||||
schema: SummarySchema,
|
||||
quality: str,
|
||||
raw_output: str,
|
||||
) -> None:
|
||||
"""将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。"""
|
||||
# 1. paper_summaries:upsert
|
||||
existing = db.get(PaperSummary, paper.id)
|
||||
flat = flatten_for_db(schema)
|
||||
if existing:
|
||||
for k, v in flat.items():
|
||||
setattr(existing, k, v)
|
||||
else:
|
||||
db.add(PaperSummary(paper_id=paper.id, **flat))
|
||||
|
||||
# 2. papers 表
|
||||
paper.title_zh = schema.title_zh
|
||||
paper.summary_quality = quality
|
||||
p_dir = paper_dir(paper.arxiv_id)
|
||||
paper.summary_path = str(p_dir / "summary.json")
|
||||
paper.raw_output_path = str(p_dir / "raw_output.txt")
|
||||
|
||||
# 3. AI 标签
|
||||
existing_tag_names = {t.tag for t in paper.tags}
|
||||
for tag_name in schema.tags:
|
||||
if tag_name not in existing_tag_names:
|
||||
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
|
||||
existing_tag_names.add(tag_name)
|
||||
|
||||
# 4. FTS5 更新
|
||||
summary_text = _build_fts_summary_text(schema)
|
||||
db.execute(
|
||||
text(
|
||||
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
|
||||
"WHERE rowid=:paper_id"
|
||||
),
|
||||
{
|
||||
"title_zh": schema.title_zh,
|
||||
"summary_text": summary_text,
|
||||
"paper_id": paper.id,
|
||||
},
|
||||
)
|
||||
|
||||
db.commit()
|
||||
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
|
||||
|
||||
|
||||
# ── 文件操作 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) -> None:
|
||||
d = paper_dir(arxiv_id)
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
if schema:
|
||||
(d / "summary.json").write_text(
|
||||
schema.model_dump_json(ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
|
||||
|
||||
|
||||
# ── 失败处理 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _handle_summary_failure(
|
||||
db: Session,
|
||||
paper: Paper,
|
||||
exc: Exception,
|
||||
raw_output: str,
|
||||
) -> dict:
|
||||
"""记录失败:保存 raw_output、重试计数、错误分类。"""
|
||||
from app.config import settings
|
||||
|
||||
error_type = _classify_error(exc)
|
||||
logger.error(
|
||||
"Summarize failed: %s error_type=%s %s",
|
||||
paper.arxiv_id,
|
||||
error_type,
|
||||
truncate_error(exc),
|
||||
)
|
||||
|
||||
status = paper.summary_status
|
||||
if raw_output:
|
||||
_save_files(paper.arxiv_id, None, raw_output)
|
||||
status.raw_output_saved = True
|
||||
|
||||
status.retry_count = (status.retry_count or 0) + 1
|
||||
status.error_type = error_type
|
||||
status.error = truncate_error(exc, limit=2000)
|
||||
|
||||
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
|
||||
status.status = SummaryState.PERMANENT_FAILURE
|
||||
else:
|
||||
status.status = SummaryState.PENDING
|
||||
|
||||
status.completed_at = utc_now()
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"arxiv_id": paper.arxiv_id,
|
||||
"status": "failed",
|
||||
"error_type": error_type,
|
||||
"error": truncate_error(exc),
|
||||
"retry_count": status.retry_count,
|
||||
}
|
||||
|
||||
|
||||
# ── 持久化 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _persist_summary(
|
||||
db: Session, paper: Paper, json_data: dict, raw_output: str
|
||||
) -> str:
|
||||
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
|
||||
import time as _time
|
||||
|
||||
arxiv_id = paper.arxiv_id
|
||||
|
||||
_t0 = _time.monotonic()
|
||||
schema = SummarySchema.model_validate(json_data)
|
||||
quality = assess_quality(schema)
|
||||
_t1 = _time.monotonic()
|
||||
|
||||
_save_files(arxiv_id, schema, raw_output)
|
||||
_t2 = _time.monotonic()
|
||||
|
||||
_update_summary_in_db(db, paper, schema, quality, raw_output)
|
||||
_t3 = _time.monotonic()
|
||||
|
||||
# 状态 → done
|
||||
paper.summary_status.status = SummaryState.DONE
|
||||
paper.summary_status.quality = quality
|
||||
paper.summary_status.completed_at = utc_now()
|
||||
paper.summary_status.raw_output_saved = True
|
||||
db.commit()
|
||||
_t4 = _time.monotonic()
|
||||
|
||||
logger.info(
|
||||
" [%s] persist: pydantic=%.2fs 文件=%.2fs DB写入=%.2fs 状态commit=%.2fs",
|
||||
arxiv_id,
|
||||
_t1 - _t0,
|
||||
_t2 - _t1,
|
||||
_t3 - _t2,
|
||||
_t4 - _t3,
|
||||
)
|
||||
|
||||
# 触发性增强(失败不影响总结)
|
||||
_t5 = _time.monotonic()
|
||||
_maybe_extract_images(arxiv_id, schema)
|
||||
_t6 = _time.monotonic()
|
||||
_maybe_index_chroma(arxiv_id, paper, schema)
|
||||
_t7 = _time.monotonic()
|
||||
|
||||
logger.info(
|
||||
" [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs",
|
||||
arxiv_id,
|
||||
_t6 - _t5,
|
||||
_t7 - _t6,
|
||||
)
|
||||
|
||||
return quality
|
||||
|
||||
|
||||
# ── 清理 ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _cleanup_old_images(db: Session, paper: Paper) -> None:
|
||||
"""清理旧的图片文件和 figures_json,避免重新总结时残留。"""
|
||||
arxiv_id = paper.arxiv_id
|
||||
images_dir = paper_dir(arxiv_id) / "images"
|
||||
if images_dir.exists():
|
||||
for old_file in images_dir.iterdir():
|
||||
if (
|
||||
old_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg")
|
||||
or old_file.name == "manifest.json"
|
||||
):
|
||||
old_file.unlink(missing_ok=True)
|
||||
# 清除数据库中的 figures_json
|
||||
if paper.summary and paper.summary.figures_json:
|
||||
paper.summary.figures_json = None
|
||||
db.commit()
|
||||
|
||||
|
||||
# ── 触发性增强 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
|
||||
"""从 PDF 提取图片和表格(失败不影响总结)。
|
||||
|
||||
两阶段流水线:
|
||||
1. PicoDet 检测 + 渲染截图(通用标签)
|
||||
2. 用 summary 的 figures ID 在 PDF 中搜索定位 → 重命名
|
||||
"""
|
||||
try:
|
||||
from app.services.pdf_image_extractor import (
|
||||
extract_images_from_pdf,
|
||||
label_images_by_summary,
|
||||
)
|
||||
|
||||
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
|
||||
extract_images_from_pdf(arxiv_id, pdf_path)
|
||||
if schema.figures:
|
||||
label_images_by_summary(arxiv_id, schema.figures, pdf_path)
|
||||
except Exception:
|
||||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||||
|
||||
|
||||
def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None:
|
||||
"""写入 ChromaDB 语义索引(失败不影响总结)。"""
|
||||
try:
|
||||
from app.services.embedder import index_paper
|
||||
|
||||
texts_dict = {
|
||||
"arxiv_id": arxiv_id,
|
||||
"title_zh": schema.title_zh or "",
|
||||
"title_en": paper.title_en or "",
|
||||
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
|
||||
"one_line": schema.one_line or "",
|
||||
"motivation_problem": schema.motivation.problem or "",
|
||||
"method_key_idea": schema.method.key_idea or "",
|
||||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
|
||||
}
|
||||
index_paper(arxiv_id, texts_dict)
|
||||
except Exception:
|
||||
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
|
||||
@@ -80,11 +80,16 @@ def _trim_body(text: str, max_chars: int | None = None) -> str:
|
||||
ack_match = re.search(r"(?m)^(?:Acknowledgments?\s*|致谢\s*)$", text)
|
||||
if ack_match:
|
||||
# 只删 Acknowledgments 本身,不删后面的内容
|
||||
next_section = re.search(r"(?m)^(?:A\s|Appendix|Supplementary|附录)\s*$", text[ack_match.start():])
|
||||
next_section = re.search(
|
||||
r"(?m)^(?:A\s|Appendix|Supplementary|附录)\s*$", text[ack_match.start() :]
|
||||
)
|
||||
if next_section:
|
||||
text = text[:ack_match.start()] + text[ack_match.start() + next_section.start():]
|
||||
text = (
|
||||
text[: ack_match.start()]
|
||||
+ text[ack_match.start() + next_section.start() :]
|
||||
)
|
||||
else:
|
||||
text = text[:ack_match.start()].rstrip()
|
||||
text = text[: ack_match.start()].rstrip()
|
||||
|
||||
# 最后:如果指定了上限且超长,从末尾截断(附录在后面,正文在前面,优先保留正文)
|
||||
if max_chars is not None and len(text) > max_chars:
|
||||
@@ -105,10 +110,9 @@ def extract_pdf_text(pdf_path: Path, max_chars: int | None = None) -> Path:
|
||||
# 缓存优先;如果需重新提取(不同 max_chars),先删旧文件
|
||||
return txt_path
|
||||
|
||||
doc = pymupdf.open(str(pdf_path))
|
||||
# sort=True 启用阅读顺序检测,避免双栏论文中跨栏错位
|
||||
raw_text = "\n\n".join(page.get_text(sort=True) for page in doc)
|
||||
doc.close()
|
||||
with pymupdf.open(str(pdf_path)) as doc:
|
||||
# sort=True 启用阅读顺序检测,避免双栏论文中跨栏错位
|
||||
raw_text = "\n\n".join(page.get_text(sort=True) for page in doc)
|
||||
|
||||
body = _trim_body(raw_text, max_chars=max_chars)
|
||||
txt_path.write_text(body, encoding="utf-8")
|
||||
@@ -160,7 +164,8 @@ def build_prompt(
|
||||
'"reproducibility": "详细段落:复现评估(开源情况、数据、算力、难度")}, '
|
||||
'"figures": [{"id":"Figure 1","caption":"原图标题","description":"文字描述图展示了什么","reason":"为什么这张图对理解论文重要","section":"method"},'
|
||||
'{"id":"Table 1","caption":"表格标题","description":"文字描述表格包含的数据和结论","reason":"为什么这个表格对理解论文重要","section":"results"}]'
|
||||
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Table,id 严格使用 \"Figure N\" 或 \"Table N\" 格式。"
|
||||
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Table。"
|
||||
'id 必须严格复用论文原文的写法(原文用 "Fig. 1" 就写 "Fig. 1",用 "Figure A1" 就写 "Figure A1",用 "Table 1" 就写 "Table 1")。'
|
||||
"section 必须是 motivation/method/results/limitations 之一,表示该图最适合展示在哪个章节。"
|
||||
"}"
|
||||
)
|
||||
|
||||
+28
-11
@@ -5,7 +5,15 @@ from __future__ import annotations
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.models import PAPER_FULL_LOAD, Paper, PaperTag, UserBookmark, UserNote, UserReadingStatus
|
||||
from app.exceptions import NotFoundError, ValidationError
|
||||
from app.models import (
|
||||
PAPER_FULL_LOAD,
|
||||
Paper,
|
||||
PaperTag,
|
||||
UserBookmark,
|
||||
UserNote,
|
||||
UserReadingStatus,
|
||||
)
|
||||
from app.utils import utc_now
|
||||
|
||||
# ── 收藏 ──────────────────────────────────────────────────────────────
|
||||
@@ -13,9 +21,11 @@ from app.utils import utc_now
|
||||
|
||||
def toggle_bookmark(db: Session, arxiv_id: str) -> dict:
|
||||
"""切换收藏状态。返回 {"bookmarked": bool, "arxiv_id": str}。"""
|
||||
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
|
||||
paper = db.execute(
|
||||
select(Paper).where(Paper.arxiv_id == arxiv_id)
|
||||
).scalar_one_or_none()
|
||||
if not paper:
|
||||
return {"error": "not_found"}
|
||||
raise NotFoundError(f"Paper not found: {arxiv_id}")
|
||||
|
||||
existing = db.execute(
|
||||
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
|
||||
@@ -42,11 +52,15 @@ VALID_STATUSES = {"unread", "skimmed", "read_summary", "read_full"}
|
||||
def set_reading_status(db: Session, arxiv_id: str, status: str) -> dict:
|
||||
"""设置阅读状态。status 必须是 unread/skimmed/read_summary/read_full。"""
|
||||
if status not in VALID_STATUSES:
|
||||
return {"error": "invalid_status", "valid": sorted(VALID_STATUSES)}
|
||||
raise ValidationError(
|
||||
f"Invalid reading status: {status}. Valid: {', '.join(sorted(VALID_STATUSES))}"
|
||||
)
|
||||
|
||||
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
|
||||
paper = db.execute(
|
||||
select(Paper).where(Paper.arxiv_id == arxiv_id)
|
||||
).scalar_one_or_none()
|
||||
if not paper:
|
||||
return {"error": "not_found"}
|
||||
raise NotFoundError(f"Paper not found: {arxiv_id}")
|
||||
|
||||
now = utc_now()
|
||||
existing = db.execute(
|
||||
@@ -72,7 +86,9 @@ def set_reading_status(db: Session, arxiv_id: str, status: str) -> dict:
|
||||
|
||||
def get_note(db: Session, arxiv_id: str) -> dict | None:
|
||||
"""获取笔记。返回 {"arxiv_id", "content", "updated_at"} 或 None(论文不存在时)。"""
|
||||
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
|
||||
paper = db.execute(
|
||||
select(Paper).where(Paper.arxiv_id == arxiv_id)
|
||||
).scalar_one_or_none()
|
||||
if not paper:
|
||||
return None
|
||||
|
||||
@@ -91,9 +107,11 @@ def get_note(db: Session, arxiv_id: str) -> dict | None:
|
||||
|
||||
def save_note(db: Session, arxiv_id: str, content: str) -> dict:
|
||||
"""创建或更新笔记。返回 {"arxiv_id", "content", "updated_at"}。"""
|
||||
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
|
||||
paper = db.execute(
|
||||
select(Paper).where(Paper.arxiv_id == arxiv_id)
|
||||
).scalar_one_or_none()
|
||||
if not paper:
|
||||
return {"error": "not_found"}
|
||||
raise NotFoundError(f"Paper not found: {arxiv_id}")
|
||||
|
||||
now = utc_now()
|
||||
existing = db.execute(
|
||||
@@ -154,8 +172,7 @@ def query_reading_list(
|
||||
stmt.options(
|
||||
joinedload(Paper.note),
|
||||
*PAPER_FULL_LOAD,
|
||||
)
|
||||
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
|
||||
).order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
|
||||
)
|
||||
.unique()
|
||||
.scalars()
|
||||
|
||||
+42
-6
@@ -137,12 +137,35 @@ def safe_json_loads(text: str | None, default: Any = None) -> Any:
|
||||
|
||||
# AI 生成内容中允许的 HTML 标签和属性
|
||||
_ALLOWED_TAGS = {
|
||||
"p", "br", "strong", "b", "em", "i", "u", "s", "del",
|
||||
"h3", "h4", "h5", "h6",
|
||||
"ul", "ol", "li",
|
||||
"a", "code", "pre", "blockquote",
|
||||
"table", "thead", "tbody", "tr", "th", "td",
|
||||
"sup", "sub", "span",
|
||||
"p",
|
||||
"br",
|
||||
"strong",
|
||||
"b",
|
||||
"em",
|
||||
"i",
|
||||
"u",
|
||||
"s",
|
||||
"del",
|
||||
"h3",
|
||||
"h4",
|
||||
"h5",
|
||||
"h6",
|
||||
"ul",
|
||||
"ol",
|
||||
"li",
|
||||
"a",
|
||||
"code",
|
||||
"pre",
|
||||
"blockquote",
|
||||
"table",
|
||||
"thead",
|
||||
"tbody",
|
||||
"tr",
|
||||
"th",
|
||||
"td",
|
||||
"sup",
|
||||
"sub",
|
||||
"span",
|
||||
}
|
||||
_ALLOWED_ATTRS = {
|
||||
"a": {"href", "title"},
|
||||
@@ -167,3 +190,16 @@ def sanitize_html(text: str | None) -> str:
|
||||
strip=True,
|
||||
)
|
||||
return cleaned
|
||||
|
||||
|
||||
# ── 错误消息截断 ────────────────────────────────────────────────────────
|
||||
|
||||
_ERROR_TRUNCATE_LIMIT = 500
|
||||
|
||||
|
||||
def truncate_error(exc: Exception | str, limit: int = _ERROR_TRUNCATE_LIMIT) -> str:
|
||||
"""将异常或字符串截断到指定长度,保持统一的错误消息格式。"""
|
||||
text = str(exc)
|
||||
if len(text) <= limit:
|
||||
return text
|
||||
return text[:limit] + f"... ({len(text)} chars total)"
|
||||
|
||||
Reference in New Issue
Block a user