feat: refactor summarizer and PDF extraction pipeline

- Split summarizer into summary_generator and summary_persister modules
- Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection
- Add layout_detector service for PicoDet-S_layout_3cls integration
- Add exceptions module with ConflictError and NotFoundError
- Improve admin dashboard with better statistics and task management
- Add design review document with system optimization suggestions
- Add new tests for crawler, pdf_downloader, pipeline, and summary_utils
- Update dependencies and configuration
- Clean up dead code and improve error handling
This commit is contained in:
2026-06-13 13:16:47 +08:00
parent e2f0e1a8be
commit 21f16e6756
43 changed files with 3304 additions and 1494 deletions
+5
View File
@@ -46,3 +46,8 @@ EMBED_API_BASE=https://api.siliconflow.cn/v1/embeddings
EMBED_API_KEY=your_api_key_here
EMBED_MODEL=Qwen/Qwen3-Embedding-4B
EMBED_DIMENSIONS=2560
# ─── 布局检测 ─────────────────────────────
# ONNX 模型路径(首次运行前执行 scripts/export_picodet_onnx.py 导出)
# LAYOUT_MODEL_PATH=data/models/picodet_layout_3cls.onnx
# LAYOUT_THRESHOLD=0.5
+468
View File
@@ -0,0 +1,468 @@
# 项目设计审查与优化建议
本文档汇总对当前项目的系统设计、流程设计和代码结构的审查结论。重点不在局部代码风格,而在后续稳定运行、失败恢复、数据一致性、可维护性和扩展性。
## 总体评价
项目当前结构整体清晰:FastAPI 路由层较薄,主要业务放在 `app/services/`;抓取、总结、搜索、个人数据、管理后台等能力已有模块化拆分;SQLite + FTS5 对本地个人论文导览站是合理选择。
主要风险集中在以下几类:
- 长任务生命周期没有统一抽象。
- Pipeline 不是显式状态机,阶段结果难以恢复和追踪。
- 数据库、文件系统、FTS5、ChromaDB 之间存在多处手工同步。
- Web、CLI、Scheduler 三个入口的业务策略存在分叉风险。
- 失败恢复和可观测性不足。
- 部分同步阻塞任务混入 async Web 流程。
## 高优先级问题
### 1. 缺少统一的任务系统
当前有三套入口都能触发重任务:
- Web 管理后台:`/admin/trigger-pipeline``/admin/summarize`
- APScheduler:每日自动 pipeline
- CLI`app.cli crawl/summarize`
它们共享部分 service,但没有统一的“任务创建、排队、运行、重试、取消、状态查询、失败恢复”模型。
影响:
- Web 请求会长时间挂住。
- CLI、Web、Scheduler 的行为可能逐渐分叉。
- 任务卡死后只能看到粗粒度的 running lock。
- 前端难以展示阶段进度。
- 后续增加重跑、取消、恢复、并发限制会越来越复杂。
建议:
引入统一 Job 模型:
```text
入口层 -> 创建 Job -> Worker 执行 Job -> JobEvent 记录进度 -> 页面/API 查询状态
```
建议任务类型:
- `crawl_daily`
- `summarize_one`
- `summarize_batch`
- `pipeline_daily`
- `refresh_upvotes`
- `delete_range`
- `reindex_fts`
- `reindex_chroma`
- `extract_images`
建议表结构方向:
- `jobs`: 任务主记录,包含 `id/type/status/owner/created_at/started_at/completed_at/error`
- `job_events`: 阶段事件,包含 `job_id/stage/status/message/payload_json/created_at`
- `paper_processing_runs`: 单篇论文处理记录,适合跟踪总结、图片提取、索引状态
### 2. Pipeline 不是显式状态机
当前 pipeline 基本是线性函数:
```text
crawl -> summarize -> cleanup
```
阶段输出没有被建模成可查询、可重放、可补偿的状态。
影响:
- crawl 成功但 summarize 部分失败时,整体状态表达不够精确。
- cleanup 失败是否影响 pipeline 成败语义不清。
- “今天无数据回退昨天”是隐式逻辑,不是可配置策略。
- 图片提取、ChromaDB 索引失败被吞掉,用户无法知道哪些增强能力缺失。
建议:
将 pipeline 改为显式阶段状态:
```text
created
-> crawling
-> crawled
-> summarizing
-> summarized | summarized_partial
-> postprocessing
-> completed | failed | cancelled
```
每个阶段记录:
- 输入参数
- 输出统计
- 错误类型
- 错误摘要
- 开始/结束时间
- 是否可重试
这样可以支持只重跑失败阶段,例如只重跑 ChromaDB 索引、只重跑图片提取、只重跑失败论文总结。
### 3. 数据生命周期没有清晰分层
项目中至少有四类数据:
- 主数据:`papers``authors``tags``summaries`、用户笔记/收藏/阅读状态
- 派生索引:FTS5、ChromaDB
- 长期资产:`data/papers/{arxiv_id}` 下的 summary、raw output、images
- 临时资产:`data/tmp/{arxiv_id}` 下的 PDF、源码、中间文件
现在这些数据由不同流程手动维护。删除、总结、重建索引、清理临时目录都分散在不同模块中。
影响:
- DB 与文件系统可能不一致。
- DB 与 FTS5/ChromaDB 可能不一致。
- 删除流程中一部分资源删除成功、一部分失败时难以恢复。
- 很难判断某篇论文是否“完整可用”。
建议:
明确数据权威源:
- DB 是权威源。
- FTS5、ChromaDB、summary.json、raw_output.txt、images 都应视为可重建派生物。
- 删除主数据优先保证 DB 状态正确,派生物清理可以异步补偿。
- 提供派生数据重建任务。
建议增加命令或任务:
```text
rebuild-derived --fts
rebuild-derived --chroma
rebuild-derived --files
rebuild-derived --images
```
### 4. 数据库与文件系统存在双写一致性风险
总结持久化流程会写文件、写 DB、更新状态、提取图片、写 ChromaDB。任一步失败都可能留下不一致状态。
典型风险:
- 文件已写入,但 DB 未提交。
- DB 标记 done,但 summary.json 缺失。
- DB done,但图片未提取或 ChromaDB 未索引。
- 重新总结时旧图片或旧 figures 残留。
建议:
- 将结构化总结以 DB 为准,JSON 文件作为导出/缓存。
- 文件写入使用临时文件加原子 rename。
- DB 中记录派生物状态,例如:
- `summary_persisted_at`
- `fts_indexed_at`
- `chroma_indexed_at`
- `images_extracted_at`
- `derived_error`
- 提供一致性检查任务,扫描 DB 与文件/索引差异并修复。
### 5. 失败恢复设计偏弱
当前失败主要依赖:
- `summary_status.status`
- `retry_count`
- `task_locks`
- `crawl_logs`
但对以下情况支持不足:
- 进程崩溃后 `processing` 永久残留。
- `task_locks` 中 running lock 永久残留。
- 任务执行到一半,部分文件或索引已经写入。
- 某些后处理失败但主任务显示成功。
建议:
- `task_locks` 增加 `heartbeat_at``expires_at`
- 应用启动时扫描 stale running task,标记为 `failed/stale`
- `summary_status` 增加更细阶段字段,例如:
- `downloading_pdf`
- `generating`
- `validating`
- `persisting`
- `extracting_images`
- `indexing`
- 区分主流程失败和派生增强失败。
### 6. 长任务直接跑在 Web 请求里
管理端触发 pipeline、批量总结、刷新 upvotes、删除数据时,会在请求生命周期内直接执行重任务。
影响:
- HTTP 请求容易超时。
- 用户关闭页面后任务状态不清晰。
- Web 进程被重任务占用。
- 后续很难做取消、进度展示和失败恢复。
建议:
- Web 只创建 job 并返回 `task_id`
- 后台 worker 执行任务。
- 管理页面轮询 `/admin/jobs/{id}` 或使用 SSE 展示进度。
## 中优先级问题
### 7. FTS5 索引手工维护,容易漂移
`papers_fts` 是独立虚拟表。插入论文、更新总结、删除论文时手动同步。
影响:
- 已有论文的 title、abstract、authors、tags 更新时可能不同步。
- AI 标签新增后,如果没有统一重建,搜索结果可能缺字段。
- 删除或回滚失败后索引可能残留。
建议:
- 封装统一的 `reindex_paper(db, paper_id)`
- 所有修改论文、标签、总结的路径最后调用它。
- 提供全量 `reindex_fts` 任务。
- 或用 SQLite trigger 维护基础字段,但总结字段仍建议通过统一函数更新。
### 8. ChromaDB 索引缺少系统级一致性策略
ChromaDB 当前是可选增强能力,索引失败会被日志吞掉,不影响总结主流程。
影响:
- 语义搜索可用性不透明。
- 某些论文 DB done,但未进入 ChromaDB。
- 删除或重建时可能出现残留索引。
建议:
- 将 ChromaDB 明确视为派生索引。
- DB 中记录 `chroma_indexed_at``chroma_error`
- 提供 `reindex_chroma` 任务。
- 搜索页或管理后台展示语义索引健康状态。
### 9. 同步阻塞操作混入 async 主流程
部分 async 函数内部调用同步阻塞操作,例如 PDF 下载使用 `requests`,图片提取和 PDF 解析也都是同步重任务。
影响:
- FastAPI event loop 可能被阻塞。
- 多篇并发总结时吞吐不可控。
- Web 服务响应受后台任务影响。
建议:
- 最佳方案:重任务移到 worker 进程。
- 过渡方案:对同步阻塞段使用 `asyncio.to_thread()`
- PDF 下载统一使用 `httpx.AsyncClient` 或明确放入同步 worker。
### 10. 调度策略和业务流程耦合
Scheduler 固定每日 pipeline 加 30 分钟 upvote refreshpipeline 内部固定“今天无数据回退昨天”。
影响:
- 策略不易测试和调整。
- CLI、Web、Scheduler 可能各自实现不同 fallback。
- 后续支持日期范围、补抓、只总结新增论文会变复杂。
建议将策略配置化:
- `crawl_target_policy`: `today` / `today_then_yesterday` / `date_range`
- `summarize_scope`: `new_only` / `pending_and_failed` / `date_range`
- `cleanup_policy`: `after_success` / `always` / `manual`
- `upvote_refresh_window_days`
- `pipeline_on_partial_failure`: `continue` / `stop`
### 11. 内嵌 APScheduler 对部署形态敏感
当前调度器嵌入 Web 进程,且多 worker 时只是打印警告。
影响:
- 多 worker、多进程、reload 模式下可能重复或漏跑任务。
- Web 进程重启会影响调度可靠性。
- 调度状态和任务执行状态耦合。
建议:
- 本地单机可以保留内嵌调度器。
- 长期运行建议拆为独立 scheduler 进程。
- scheduler 只负责创建 job,不直接执行重任务。
- Web 管理后台只展示 scheduler/job 状态。
### 12. 运行时迁移能力不足
当前 `_migrate()` 只支持补列。
影响:
- 无法可靠处理索引、约束、字段改名、数据回填。
- 数据库结构演进不可追踪。
- 生产或长期本地数据升级风险高。
建议:
- 引入 Alembic。
- 将 FTS5、索引、部分约束和数据回填纳入迁移脚本。
- 启动时不要静默做复杂迁移,改为显式执行迁移命令。
## 低到中优先级问题
### 13. 配置校验不足
当前配置字段主要是裸类型,缺少合法值和组合校验。
风险示例:
- `SUMMARY_BACKEND` 填错。
- `SUMMARY_PDF_MODE` 填错。
- `SCHEDULE_MINUTE + 30` 超过 59。
- `APP_WORKERS > 1``SCHEDULER_ENABLED=true`
- `CHROMA_ENABLED=true` 但 embedding 配置缺失。
建议:
- 使用 Pydantic validator 校验枚举值和组合条件。
- 对危险组合启动时报错,而不是只 warning。
### 14. 私有函数跨模块调用说明模块边界不稳定
例如 `summarizer.py` 调用 `_generate_with_retry``_persist_summary` 等下划线函数。
影响:
- 模块公开边界不清晰。
- 后续重构容易误伤。
- 测试和替换后端不方便。
建议:
- 为 summary 生成、持久化、后处理定义明确公开 API。
- 下划线函数保留为模块内部实现细节。
- 引入更明确的 service 类或函数边界,例如:
- `SummaryGenerator.generate()`
- `SummaryRepository.save()`
- `DerivedAssetBuilder.extract_images()`
### 15. 外部依赖缺少统一 adapter
项目依赖多个外部系统:
- HuggingFace API
- arXiv PDF/source 下载
- `pi` CLI
- `claude` CLI
- Embedding API
- ChromaDB
- ONNX layout detector
建议:
- 为每类外部依赖定义 adapter。
- 统一 timeout、retry、error classification、metrics。
- 测试中替换 adapter,而不是 patch 深层函数。
### 16. 删除流程事务边界不理想
删除流程中同时删除 FTS5、ChromaDB、本地文件、临时文件、ORM 数据。
影响:
- 文件删除成功后 DB rollback,会造成数据仍在但文件丢失。
- DB 删除成功后 ChromaDB 删除失败,会造成残留索引。
- 单篇失败时 rollback 可能影响之前 flush 的状态,逻辑不直观。
建议:
- 删除主数据与删除派生物分两阶段。
- DB 删除成功后创建派生清理 job。
- 派生物清理失败可重试,不影响主数据一致性。
## 建议目标架构
如果项目定位是个人本地工具,可以采用轻量目标架构:
```text
FastAPI Web
- 页面和 API
- 管理后台
- 创建 job
- 查询 job 状态
Scheduler
- 按策略创建 job
- 不直接跑重任务
Worker
- crawl
- summarize
- PDF download
- image extraction
- FTS/Chroma reindex
- cleanup/delete
Storage/Repository
- DB 权威数据
- 文件资产管理
- 派生索引重建
- 一致性检查
```
如果暂时不想引入独立 worker,也可以先在单进程内实现 Job 表和后台任务执行器。这样至少能统一任务状态和恢复机制。
## 推荐实施顺序
### 第一阶段:先解决任务与状态
目标:降低长任务不可控风险。
- 新增 `jobs``job_events` 表。
- Web 管理动作改为创建 job 并返回 task_id。
- Scheduler 改为创建 job。
- CLI 复用同一套 job runner。
- 加 stale running job 恢复逻辑。
### 第二阶段:统一派生数据管理
目标:降低 DB、文件、FTS、Chroma 不一致风险。
- 明确 DB 为权威源。
- 封装 `reindex_paper()`
- 增加 `reindex_fts``reindex_chroma` 任务。
- 增加派生数据状态字段或健康检查。
- 删除流程改为 DB 主删除 + 派生清理 job。
### 第三阶段:改造 pipeline 为状态机
目标:让任务可恢复、可补偿、可观测。
- 拆分 pipeline stages。
- 每个 stage 记录输入/输出/错误/耗时。
- 支持从失败阶段重跑。
- 将 fallback 策略配置化。
### 第四阶段:提升运行可靠性
目标:长期运行更稳。
- 引入 Alembic。
- 配置加 validator。
- 同步阻塞操作移出 Web event loop。
- 外部依赖 adapter 化。
- 管理后台展示任务、派生索引、失败类型、耗时统计。
## 最小可行改造方案
如果只做最小但收益最大的改动,建议优先做这三项:
1. 增加统一 Job 表和 job runner。
2. DB 作为权威源,FTS/Chroma/文件全部按派生物处理。
3. 增加 stale task 恢复和派生数据重建命令。
这三项能显著降低后续复杂度,也不会强迫项目马上拆成多个服务。
+93 -52
View File
@@ -7,7 +7,7 @@
## 功能特性
- **每日抓取**:按日期拉取 HuggingFace Daily Papers,提取元数据并入库,自动去重与重试。
- **AI 中文总结**:下载 PDF调用 `pi` CLI 为每篇论文生成结构化中文解读(动机、方法、结果、局限性等),完成后清理临时文件。
- **AI 中文总结**:下载 PDF通过 `pi` `claude` 后端为每篇论文生成结构化中文解读(动机、方法、结果、局限性等),完成后清理临时文件。
- **浏览与详情**:首页按日期导航、论文详情页展示元数据与总结,提供未总结论文的英文原文回退。
- **搜索**:基于 SQLite FTS5 的关键词搜索(BM25 排序、片段高亮),覆盖标题、摘要、作者、标签与总结正文。
- **语义搜索**(可选):ChromaDB 向量数据库实现相似度搜索,优雅降级至 FTS5。
@@ -15,9 +15,10 @@
- **趋势看板**:Chart.js 驱动的可视化统计(日论文量、Top 标签、投票分布、总结完成率)。
- **个人化**:收藏、阅读状态、个人笔记与阅读列表。
- **RSS 订阅**:最近 7 天论文的 RSS 2.0 输出,支持标签过滤。
- **管理后台**Token 鉴权的手动抓取、总结、清扫、删除与日志查看接口
- **定时调度**APScheduler 内嵌调度,默认每日 08:00 自动抓取与总结(TaskLock 防重)。
- **管理后台**Session 认证的 Web 管理界面(仪表盘、论文管理、日志查看、手动操作)
- **定时调度**APScheduler 内嵌调度,默认每日自动抓取与总结(TaskLock 防重)。
- **LaTeX 图片提取**:下载 arXiv 源码,扫描 `.tex` 文件提取论文图片用于详情页展示。
- **布局检测**(可选):ONNX 模型识别 PDF 页面布局,提升图片提取精度。
- **HTMX 局部更新**:收藏切换等操作无需整页刷新。
- **键盘快捷键**`Ctrl+K``/` 聚焦搜索框。
@@ -31,7 +32,7 @@
| 模板 | Jinja2(服务端渲染) |
| 前端 | HTMX · 原生 JS · Chart.js · 自定义 CSS"kami" 纸质风格) |
| 数据库 | SQLite + SQLAlchemy · SQLite FTS5(全文搜索) |
| AI 总结 | `pi` CLI(外部工具 |
| AI 总结 | `pi` CLI`claude` CLI(可配置后端 |
| 语义搜索 | ChromaDB(可选) |
| 调度 | APScheduler(内嵌单进程) |
| CLI | Typer |
@@ -53,33 +54,39 @@ paper/
│ ├── main.py # FastAPI 入口(lifespan 管理)
│ ├── config.py # pydantic-settings 配置加载
│ ├── database.py # SQLAlchemy 引擎、会话与 FTS5
│ ├── models.py # 11 个 ORM 模型
│ ├── models.py # 11 个 ORM 模型 + 1 枚举
│ ├── utils.py # 共享工具函数
│ ├── exceptions.py # 统一业务异常体系
│ ├── cli.py # Typer CLIcrawl / summarize / init-db
│ │
│ ├── routes/ # 页面与 API 路由
│ │ ├── __init__.py
│ │ ├── pages.py # 首页、日期页、论文详情
│ │ ├── admin.py # Token 鉴权管理接口
│ │ ├── pages.py # 首页、日期页、论文详情、相似推荐
│ │ ├── admin.py # Session 认证管理后台
│ │ ├── search.py # 搜索、阅读列表、RSS
│ │ ├── user.py # 收藏、阅读状态、笔记 API
│ │ ├── trends.py # 趋势看板
│ │ └── compare.py # 论文对比页
│ │
│ ├── services/ # 业务逻辑层
│ │ ├── __init__.py
│ │ ├── crawler.py # HuggingFace API 爬虫
│ │ ├── summarizer.py # AI 总结编排
│ │ ├── summarizer.py # AI 总结编排(调度层)
│ │ ├── summary_generator.py # 总结生成与重试
│ │ ├── summary_persister.py # 总结持久化与文件管理
│ │ ├── summary_utils.py # 总结工具(prompt 构建、PDF 提取)
│ │ ├── pi_client.py # pi CLI 封装 + JSON 提取
│ │ ├── claude_backend.py # claude CLI 后端
│ │ ├── searcher.py # FTS5 + 语义搜索
│ │ ├── schemas.py # Pydantic 总结校验
│ │ ├── cleaner.py # 临时文件清理 + 日期范围删除
│ │ ├── scheduler.py # APScheduler 每日管线
│ │ ├── pipeline.py # 抓取 + 总结流水线编排
│ │ ├── admin.py # 管理后台查询与统计
│ │ ├── user_data.py # 收藏、阅读状态、笔记
│ │ ├── embedder.py # ChromaDB 向量索引
│ │ ├── trends.py # 趋势统计聚合
│ │ ├── pdf_downloader.py # PDF + LaTeX 源码下载
│ │ ├── pi_client.py # pi CLI 封装 + JSON 提取
│ │ └── image_extractor.py # LaTeX 图片提取
│ │ ├── pdf_image_extractor.py # LaTeX 图片提取 + 图表关联
│ │ └── layout_detector.py # ONNX 布局检测(可选)
│ │
│ ├── templates/ # Jinja2 模板
│ │ ├── base.html
@@ -89,24 +96,40 @@ paper/
│ │ ├── reading_list.html
│ │ ├── compare.html
│ │ ├── trends.html
│ │ ├── login.html
│ │ ├── admin_dashboard.html
│ │ ├── admin_papers.html
│ │ ├── admin_logs.html
│ │ └── partials/paper_card.html
│ │ └── partials/
│ │ ├── admin_subnav.html
│ │ ├── paper_card.html
│ │ └── summary_list.html
│ │
│ └── static/
│ ├── css/style.css # 自定义 CSSkami 风格)
└── js/app.js # 键盘快捷键
│ ├── css/
│ ├── style.css # 自定义 CSSkami 风格)
│ │ └── admin.css # 管理后台样式
│ ├── js/
│ │ ├── app.js # 键盘快捷键
│ │ ├── date-picker.js # 日期导航
│ │ └── lightbox.js # 图片灯箱
│ └── favicon.svg
├── data/ # 运行时数据(已 gitignore
│ ├── db/papers.db # SQLite 数据库
│ ├── papers/{arxiv_id}/ # 长期资产(meta.json / summary.json / 图片)
│ ├── tmp/{arxiv_id}/ # 临时下载(流程完成后清理)
── chroma/ # ChromaDB 向量库(可选)
── chroma/ # ChromaDB 向量库(可选)
│ └── models/ # ONNX 模型(布局检测)
├── scripts/
│ ├── init_db.py # 数据库初始化
── manual_crawl.py # 手动抓取脚本
── manual_crawl.py # 手动抓取脚本
│ ├── export_picodet_onnx.py # 导出布局检测 ONNX 模型
│ ├── reextract_images.py # 批量重新提取图片
│ └── validate_summary.py # 校验总结 JSON 结构
├── tests/ # 9 个测试模块
├── tests/ # 13 个测试模块
│ ├── conftest.py # 测试夹具(内存 DB、样本数据)
│ └── test_*.py # 各模块测试
@@ -120,21 +143,23 @@ paper/
### 1. 准备环境
- Python **3.12+**
- 可选:[`pi`](https://www.npmjs.com/package/@mariozechner/pi-coding-agent) CLI(用于 AI 总结)
- [uv](https://docs.astral.sh/uv/) 包管理器
- 可选:[`pi`](https://www.npmjs.com/package/@mariozechner/pi-coding-agent) CLI 或 [`claude`](https://claude.ai/code) CLI(用于 AI 总结)
### 2. 安装依赖
```bash
python -m venv .venv
source .venv/bin/activate
pip install -e ".[dev]"
cp .env.example .env
uv sync
```
### 3. 配置环境变量
```bash
cp .env.example .env
# 编辑 .env,至少修改 ADMIN_TOKEN
# 编辑 .env,至少修改以下三项
ADMIN_USERNAME=admin
ADMIN_PASSWORD=your_secure_password
SECRET_KEY=your_random_secret_key
```
关键配置项:
@@ -145,19 +170,26 @@ cp .env.example .env
| `APP_DEBUG` | `false` | 调试模式(开启 uvicorn reload |
| `BASE_URL` | `http://127.0.0.1:8000` | 站点根 URL(用于 RSS 生成) |
| `APP_TIMEZONE` | `Asia/Shanghai` | 时区 |
| `ADMIN_TOKEN` | `change-me` | **必须修改** — 管理接口鉴权 |
| `ADMIN_USERNAME` | `admin` | 管理后台用户名 |
| `ADMIN_PASSWORD` | — | 管理后台密码 |
| `SECRET_KEY` | `change-me` | Session 签名密钥 |
| `HF_API_BASE` | `https://huggingface.co/api` | HuggingFace API 地址 |
| `HF_PROXY` | — | HTTP 代理 |
| `TOP_N` | `20` | 每日抓取 Top N 论文 |
| `HTTP_TIMEOUT_SECONDS` | `30` | HTTP 请求超时 |
| `HTTP_MAX_RETRIES` | `3` | HTTP 最大重试次数 |
| `SUMMARY_BACKEND` | `pi` | 总结后端:`pi``claude` |
| `PI_BIN` | — | `pi` CLI 路径 |
| `CLAUDE_BIN` | `claude` | `claude` CLI 路径 |
| `SUMMARY_SKILL` | `daily-paper-summary` | pi 总结技能名 |
| `SUMMARY_CONCURRENCY` | `3` | 最大并行总结数 |
| `SUMMARY_TIMEOUT_SECONDS` | `300` | 单篇总结超时 |
| `SUMMARY_MAX_RETRIES` | `1` | 总结最大重试次数 |
| `SUMMARY_TIMEOUT_SECONDS` | `1200` | 单篇总结超时 |
| `SUMMARY_MAX_RETRIES` | `2` | 总结最大重试次数 |
| `SUMMARY_PDF_MODE` | `auto` | PDF 传递方式:`auto` / `inject` / `search` |
| `UPVOTE_REFRESH_DAYS` | `7` | 自动刷新最近 N 天论文的 upvotes |
| `PDF_DOWNLOAD_TIMEOUT` | `120` | PDF 下载超时(秒) |
| `SCHEDULER_ENABLED` | `false` | 启用每日自动抓取 |
| `SCHEDULE_HOUR` / `SCHEDULE_MINUTE` | `8` / `0` | 定时任务时间(APP_TIMEZONE |
| `SCHEDULE_HOUR` / `SCHEDULE_MINUTE` | `4` / `0` | 定时任务时间(APP_TIMEZONE |
| `APP_WORKERS` | `1` | Uvicorn worker 数(必须为 1 |
| `DATABASE_URL` | `sqlite:///data/db/papers.db` | 数据库路径 |
| `CHROMA_ENABLED` | `false` | 启用语义搜索 |
@@ -166,18 +198,19 @@ cp .env.example .env
| `EMBED_API_KEY` | — | Embedding API Key |
| `EMBED_MODEL` | — | Embedding 模型名 |
| `EMBED_DIMENSIONS` | `0` | 向量维度 |
| `LAYOUT_MODEL_PATH` | `data/models/picodet_layout_3cls.onnx` | ONNX 布局检测模型路径(可选) |
| `LAYOUT_THRESHOLD` | `0.5` | 布局检测置信度阈值(可选) |
### 4. 初始化数据库
```bash
python scripts/init_db.py
# 或:python -m app.cli init-db
uv run python -m app.cli init-db
```
### 5. 启动服务
```bash
uvicorn app.main:app --host 127.0.0.1 --port 8000
uv run python -m app.main
```
> 调度器依赖单 worker:不可使用 `--workers > 1`,否则每日任务会被重复触发。
@@ -188,51 +221,59 @@ uvicorn app.main:app --host 127.0.0.1 --port 8000
## 常用命令
### 手动抓取指定日期
### 手动抓取
```bash
python scripts/manual_crawl.py 2025-01-15
# 或
python -m app.cli crawl 2025-01-15 --top 20
# 自动探测今天/昨天
uv run python -m app.cli crawl
# 指定日期
uv run python -m app.cli crawl 2025-01-15 --top 20
# 强制重抓(即使已有数据)
uv run python -m app.cli crawl --force
```
### 手动触发总结
```bash
# 单篇
python -m app.cli summarize 2401.01234
uv run python -m app.cli summarize 2401.01234
# 批量(所有待总结论文)
python -m app.cli summarize
uv run python -m app.cli summarize
# 指定后端和 PDF 模式
uv run python -m app.cli summarize --backend claude --pdf-mode inject
```
### 管理接口(Token 鉴权)
### 管理后台
```bash
# 抓取今日论文
curl -X POST "http://127.0.0.1:8000/admin/crawl" \
-H "Authorization: Bearer $ADMIN_TOKEN"
打开浏览器访问 `http://127.0.0.1:8000/admin/login`,使用 `.env` 中配置的用户名密码登录。
# 批量总结
curl -X POST "http://127.0.0.1:8000/admin/summarize" \
-H "Authorization: Bearer $ADMIN_TOKEN"
# 单篇总结
curl -X POST "http://127.0.0.1:8000/admin/summarize/2401.01234" \
-H "Authorization: Bearer $ADMIN_TOKEN"
```
管理后台包含:
- **仪表盘**:统计卡、调度器控制、存储信息、最近活动
- **论文管理**:搜索、筛选、单篇/批量删除
- **日志**:运行日志、总结状态、失败重试
### 运行测试
```bash
pytest
uv run pytest
```
### 代码检查
```bash
uv run ruff format # 格式化代码
uv run ruff check --fix # 自动修复 lint 问题
```
---
## 安全提示
- `ADMIN_TOKEN` 是管理接口的唯一鉴权凭证,请使用强随机值并妥善保管
- 管理后台使用 Session 认证,请务必在 `.env` 中设置强密码和随机 `SECRET_KEY`
- 默认仅监听 `127.0.0.1`,如需内网访问请配合反向代理与 HTTPS。
- 项目面向本地 / 内网部署,不包含多用户账号体系与公网防护。
+26 -13
View File
@@ -42,9 +42,16 @@ def crawl(
try:
# 检查是否已抓取过(非 force 模式)
if not force and not date_str:
existing = db.scalar(select(func.count(Paper.id)).where(Paper.paper_date == target)) or 0
existing = (
db.scalar(
select(func.count(Paper.id)).where(Paper.paper_date == target)
)
or 0
)
if existing > 0:
typer.echo(f"⏭️ {target} 已有 {existing} 篇论文,跳过(用 --force 强制重抓)")
typer.echo(
f"⏭️ {target} 已有 {existing} 篇论文,跳过(用 --force 强制重抓)"
)
return
typer.echo(f"📡 开始抓取 {target} ...")
@@ -56,7 +63,12 @@ def crawl(
)
if need_fallback:
fallback = yesterday_str()
existing = db.scalar(select(func.count(Paper.id)).where(Paper.paper_date == fallback)) or 0
existing = (
db.scalar(
select(func.count(Paper.id)).where(Paper.paper_date == fallback)
)
or 0
)
if existing > 0:
typer.echo(
f"⏭️ {fallback} 已有 {existing} 篇论文,跳过(用 --force 强制重抓)"
@@ -103,7 +115,9 @@ def summarize(
import os
if pdf_mode not in ("auto", "inject", "search"):
typer.echo(f"❌ 无效的 pdf_mode: {pdf_mode},只支持 auto / inject / search", err=True)
typer.echo(
f"❌ 无效的 pdf_mode: {pdf_mode},只支持 auto / inject / search", err=True
)
raise typer.Exit(code=1)
if backend:
@@ -122,6 +136,8 @@ def summarize(
datefmt="%H:%M:%S",
)
from app.exceptions import ConflictError, NotFoundError
db = SessionLocal()
try:
if arxiv_id:
@@ -131,16 +147,13 @@ def summarize(
typer.echo(f"🤖 开始批量总结 pending 论文 (mode={pdf_mode}) ...")
result = asyncio.run(summarize_batch(db, pdf_mode=pdf_mode))
if result.get("status") in ("success", "done"):
typer.echo(f"✅ 总结完成:{result}")
elif result.get("status") == "conflict":
typer.echo("⚠️ 已有批量总结任务在运行中", err=True)
raise typer.Exit(code=1)
elif result.get("status") == "not_found":
typer.echo(f"❌ 论文未找到:{arxiv_id}", err=True)
raise typer.Exit(code=1)
else:
typer.echo(f"⚠️ 总结结果:{result}", err=True)
except NotFoundError as exc:
typer.echo(f"{exc.message}", err=True)
raise typer.Exit(code=1) from exc
except ConflictError as exc:
typer.echo(f"⚠️ {exc.message}", err=True)
raise typer.Exit(code=1) from exc
finally:
db.close()
+8 -1
View File
@@ -27,6 +27,7 @@ class Settings(BaseSettings):
HTTP_TIMEOUT_SECONDS: int = 30
HTTP_MAX_RETRIES: int = 3
HTTP_USER_AGENT: str = "hf-daily-papers-local/0.1"
PDF_DOWNLOAD_TIMEOUT: int = 120
# AI 总结
SUMMARY_BACKEND: str = "pi" # "pi" | "claude"
@@ -36,7 +37,9 @@ class Settings(BaseSettings):
SUMMARY_CONCURRENCY: int = 3
SUMMARY_TIMEOUT_SECONDS: int = 1200
SUMMARY_MAX_RETRIES: int = 2
SUMMARY_PDF_MODE: str = "auto" # "auto" = ≤80k 用 inject>80k 用 search;也可强制 "inject" / "search"
SUMMARY_PDF_MODE: str = (
"auto" # "auto" = ≤80k 用 inject>80k 用 search;也可强制 "inject" / "search"
)
# 调度
SCHEDULER_ENABLED: bool = False
@@ -56,6 +59,10 @@ class Settings(BaseSettings):
EMBED_MODEL: str = ""
EMBED_DIMENSIONS: int = 0
# 布局检测
LAYOUT_MODEL_PATH: str = "data/models/picodet_layout_3cls.onnx"
LAYOUT_THRESHOLD: float = 0.5
model_config = {
"env_file": str(BASE_DIR / ".env"),
"env_file_encoding": "utf-8",
+2 -5
View File
@@ -82,15 +82,12 @@ def _migrate(engine) -> None:
for table, columns in _MIGRATIONS.items():
# 获取已有列名
existing = {
row[1]
for row in conn.execute(text(f"PRAGMA table_info({table})"))
row[1] for row in conn.execute(text(f"PRAGMA table_info({table})"))
}
for col_name, col_type in columns:
if col_name not in existing:
conn.execute(
text(
f"ALTER TABLE {table} ADD COLUMN {col_name} {col_type}"
)
text(f"ALTER TABLE {table} ADD COLUMN {col_name} {col_type}")
)
logger.info("Migrated: %s.%s added", table, col_name)
conn.commit()
+35
View File
@@ -0,0 +1,35 @@
"""业务异常体系 — 统一错误类型,供路由层和 service 层使用。
路由层通过 main.py 的 @app.exception_handler(AppError) 统一捕获,
转为对应 HTTP 状态码 + JSON 响应。
"""
from __future__ import annotations
class AppError(Exception):
"""所有业务异常的基类。"""
def __init__(self, message: str = "", *, detail: str = ""):
self.message = message or detail or self.__class__.__name__
super().__init__(self.message)
class NotFoundError(AppError):
"""资源不存在(404)。"""
class ValidationError(AppError):
"""请求参数校验失败(400)。"""
class ExternalAPIError(AppError):
"""外部 API 调用失败(502)。"""
class PdfProcessError(AppError):
"""PDF 处理错误(500)。"""
class ConflictError(AppError):
"""资源冲突(409)— 如锁冲突、并发任务冲突。"""
+30 -3
View File
@@ -5,10 +5,12 @@ import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from starlette.middleware.sessions import SessionMiddleware
from app.config import settings
from app.exceptions import AppError, ConflictError, ExternalAPIError, NotFoundError, PdfProcessError, ValidationError
from app.database import engine, init_db
from app.routes.admin import router as admin_router
from app.routes.compare import router as compare_router
@@ -38,8 +40,10 @@ async def lifespan(app: FastAPI):
# ── shutdown ──
from app.services.scheduler import stop_scheduler
from app.services.pdf_downloader import close_http_session
stop_scheduler()
close_http_session()
def create_app() -> FastAPI:
@@ -60,15 +64,38 @@ def create_app() -> FastAPI:
# Session 中间件
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY)
# ── 统一业务异常处理 ──
@app.exception_handler(NotFoundError)
async def _not_found_handler(request, exc):
return JSONResponse(status_code=404, content={"error": exc.message})
@app.exception_handler(ValidationError)
async def _validation_handler(request, exc):
return JSONResponse(status_code=400, content={"error": exc.message})
@app.exception_handler(ExternalAPIError)
async def _external_api_handler(request, exc):
return JSONResponse(status_code=502, content={"error": exc.message})
@app.exception_handler(PdfProcessError)
async def _pdf_process_handler(request, exc):
return JSONResponse(status_code=500, content={"error": exc.message})
@app.exception_handler(ConflictError)
async def _conflict_handler(request, exc):
return JSONResponse(status_code=409, content={"error": exc.message})
@app.exception_handler(AppError)
async def _app_error_handler(request, exc):
return JSONResponse(status_code=500, content={"error": exc.message})
# 安全警告
if settings.SECRET_KEY == "change-me":
logger.warning(
"⚠️ SECRET_KEY is the default value 'change-me'. Please change it in .env!"
)
if not settings.ADMIN_PASSWORD:
logger.warning(
"⚠️ ADMIN_PASSWORD is empty. Please set it in .env!"
)
logger.warning("⚠️ ADMIN_PASSWORD is empty. Please set it in .env!")
# 静态文件
app.mount("/static", StaticFiles(directory="app/static"), name="static")
+17 -4
View File
@@ -12,6 +12,7 @@ from sqlalchemy import (
String,
Text,
UniqueConstraint,
select,
)
from sqlalchemy.orm import joinedload, relationship
@@ -93,7 +94,7 @@ class PaperAuthor(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
paper_id = Column(
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False, index=True
)
name = Column(String, nullable=False)
position = Column(Integer, default=0)
@@ -108,7 +109,7 @@ class PaperTag(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
paper_id = Column(
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False, index=True
)
tag = Column(String, nullable=False)
source = Column(String, default="hf")
@@ -155,7 +156,7 @@ class SummaryStatus(Base):
paper_id = Column(
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
)
status = Column(String, nullable=False, default="pending")
status = Column(String, nullable=False, default="pending", index=True)
quality = Column(String)
error_type = Column(String)
error = Column(Text)
@@ -219,7 +220,7 @@ class UserReadingStatus(Base):
paper_id = Column(
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
)
status = Column(String, nullable=False, default="unread")
status = Column(String, nullable=False, default="unread", index=True)
updated_at = Column(DateTime, nullable=False)
paper = relationship("Paper", back_populates="reading_status")
@@ -271,3 +272,15 @@ PAPER_FULL_LOAD = (
joinedload(Paper.bookmark),
joinedload(Paper.reading_status),
)
def get_paper_by_arxiv_id(db, arxiv_id: str, *, load=PAPER_DEFAULT_LOAD):
"""按 arxiv_id 查询论文(带关联加载),未找到返回 None。"""
stmt = select(Paper).where(Paper.arxiv_id == arxiv_id).options(*load)
return db.execute(stmt).unique().scalar_one_or_none()
def get_paper_by_id(db, paper_id: int, *, load=PAPER_DEFAULT_LOAD):
"""按主键查询论文(带关联加载),未找到返回 None。"""
stmt = select(Paper).where(Paper.id == paper_id).options(*load)
return db.execute(stmt).unique().scalar_one_or_none()
+90 -140
View File
@@ -3,6 +3,7 @@
from __future__ import annotations
import hashlib
import hmac
import json
import logging
from datetime import date
@@ -10,7 +11,7 @@ from datetime import date
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request
from fastapi.responses import RedirectResponse
from pydantic import BaseModel, field_validator
from sqlalchemy import func, select, text
from sqlalchemy import bindparam, func, select, text
from sqlalchemy.orm import Session
from app.config import settings
@@ -22,15 +23,15 @@ from app.models import (
PaperTag,
SummaryState,
SummaryStatus,
TaskLock,
)
from app.services import admin as admin_svc
from app.services.admin import get_admin_stats
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
from app.services.crawler import crawl_daily, refresh_upvotes
from app.services.pipeline import run_pipeline
from app.services.crawler import refresh_upvotes
from app.services.pipeline import run_crawl, run_pipeline
from app.services.scheduler import get_scheduler
from app.services.summarizer import summarize_batch, summarize_single
from app.utils import release_lock, templates, today_str, utc_now
from app.utils import templates, today_str, utc_now
logger = logging.getLogger(__name__)
@@ -41,14 +42,15 @@ router = APIRouter(prefix="/admin", tags=["admin"])
def _check_password(password: str) -> bool:
"""校验密码,支持明文或 sha256 哈希。"""
"""校验密码,支持明文或 sha256 哈希(常量时间比较)"""
stored = settings.ADMIN_PASSWORD
if not stored:
return False
if password == stored:
if hmac.compare_digest(password, stored):
return True
# 也支持存 sha256 哈希
return hashlib.sha256(password.encode()).hexdigest() == stored
hashed = hashlib.sha256(password.encode()).hexdigest()
return hmac.compare_digest(hashed, stored)
async def verify_admin(request: Request) -> None:
@@ -204,32 +206,12 @@ async def admin_crawl(
):
"""手动抓取指定日期,默认今天。"""
target_date = date or today_str()
# TaskLock 防重入
now = utc_now()
lock = TaskLock(
task="crawl",
lock_key=target_date,
status="running",
owner="admin_crawl",
acquired_at=now,
)
try:
db.add(lock)
db.commit()
except Exception:
db.rollback()
raise HTTPException(
status_code=409, detail=f"Crawl already running for {target_date}"
)
try:
result = await crawl_daily(db, target_date)
return result
return await run_crawl(db, target_date, owner="admin_crawl")
except RuntimeError as exc:
raise HTTPException(status_code=409, detail=str(exc))
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
finally:
release_lock(db, lock)
# ── 总结 ──────────────────────────────────────────────────────────────
@@ -241,12 +223,7 @@ async def admin_summarize_batch(
db: Session = Depends(get_db),
):
"""批量总结所有 pending 论文。"""
result = await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE)
if result.get("status") == "conflict":
raise HTTPException(
status_code=409, detail=result.get("error", "batch already running")
)
return result
return await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE)
@router.post("/summarize/{arxiv_id}")
@@ -256,10 +233,9 @@ async def admin_summarize_single(
db: Session = Depends(get_db),
):
"""总结或重跑单篇论文。"""
result = await summarize_single(db, arxiv_id, force=True, pdf_mode=settings.SUMMARY_PDF_MODE)
if result.get("status") == "not_found":
raise HTTPException(status_code=404, detail=f"Paper not found: {arxiv_id}")
return result
return await summarize_single(
db, arxiv_id, force=True, pdf_mode=settings.SUMMARY_PDF_MODE
)
# ── 清理 ──────────────────────────────────────────────────────────────
@@ -284,10 +260,13 @@ async def admin_cleanup(
result = cleanup_tmp()
log_entry.status = "success"
log_entry.completed_at = utc_now()
log_entry.details_json = json.dumps({
log_entry.details_json = json.dumps(
{
"scanned": result.get("scanned", 0),
"removed": result.get("removed", 0),
}, ensure_ascii=False)
},
ensure_ascii=False,
)
if result.get("errors"):
log_entry.error = "; ".join(result["errors"])[:2000]
db.commit()
@@ -358,19 +337,34 @@ async def admin_logs(
# 总结状态统计概要
summary_total = db.scalar(select(func.count(Paper.id))) or 0
summary_done = db.scalar(
select(func.count(SummaryStatus.id)).where(SummaryStatus.status == SummaryState.DONE)
) or 0
summary_pending = db.scalar(
summary_done = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_([SummaryState.PENDING, SummaryState.PROCESSING])
SummaryStatus.status == SummaryState.DONE
)
) or 0
summary_failed = db.scalar(
)
or 0
)
summary_pending = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE])
SummaryStatus.status.in_(
[SummaryState.PENDING, SummaryState.PROCESSING]
)
)
)
or 0
)
summary_failed = (
db.scalar(
select(func.count(SummaryStatus.id)).where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
)
or 0
)
) or 0
return templates.TemplateResponse(
request,
@@ -414,13 +408,8 @@ async def admin_summary_status(
else:
query = query.where(SummaryStatus.status == status)
total = db.scalar(
select(func.count()).select_from(query.subquery())
)
results = (
db.execute(query.offset((page - 1) * per_page).limit(per_page))
.all()
)
total = db.scalar(select(func.count()).select_from(query.subquery()))
results = db.execute(query.offset((page - 1) * per_page).limit(per_page)).all()
# 判断是否 HTMX 请求
is_htmx = request.headers.get("HX-Request") == "true"
@@ -465,7 +454,11 @@ async def admin_summary_retry_failed(
db.execute(
select(Paper.arxiv_id)
.join(SummaryStatus, SummaryStatus.paper_id == Paper.id)
.where(SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]))
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
)
.scalars()
.all()
@@ -477,7 +470,11 @@ async def admin_summary_retry_failed(
# 重置失败任务的状态为 pending
db.execute(
SummaryStatus.__table__.update()
.where(SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]))
.where(
SummaryStatus.status.in_(
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
)
)
.values(status=SummaryState.PENDING, error=None, error_type=None)
)
db.commit()
@@ -492,15 +489,6 @@ async def admin_summary_retry_failed(
# ── 论文管理 ────────────────────────────────────────────────────────
# 排序映射
_SORT_MAP = {
"date_desc": Paper.paper_date.desc(),
"date_asc": Paper.paper_date.asc(),
"upvotes_desc": Paper.upvotes.desc(),
"title_asc": Paper.title_en.asc(),
}
@router.get("/papers")
async def admin_papers(
request: Request,
@@ -516,66 +504,18 @@ async def admin_papers(
per_page: int = Query(20, ge=1, le=100),
):
"""论文管理列表页面。"""
query = select(Paper)
# 搜索
if q.strip():
query = query.where(
Paper.title_en.ilike(f"%{q}%")
| Paper.title_zh.ilike(f"%{q}%")
| Paper.abstract.ilike(f"%{q}%")
papers, total, statuses = admin_svc.query_papers(
db,
q=q,
date_from=date_from,
date_to=date_to,
tag=tag,
summary_status=summary_status,
sort=sort,
page=page,
per_page=per_page,
)
# 日期范围
if date_from:
query = query.where(Paper.paper_date >= date_from)
if date_to:
query = query.where(Paper.paper_date <= date_to)
# 标签筛选
if tag:
query = query.join(PaperTag, PaperTag.paper_id == Paper.id).where(
PaperTag.tag == tag
)
# 总结状态筛选
if summary_status != "all":
if summary_status == "none":
query = query.outerjoin(
SummaryStatus, SummaryStatus.paper_id == Paper.id
).where(SummaryStatus.paper_id == None) # noqa: E711
else:
query = query.join(
SummaryStatus, SummaryStatus.paper_id == Paper.id
).where(SummaryStatus.status == summary_status)
# 排序
order = _SORT_MAP.get(sort, Paper.paper_date.desc())
query = query.order_by(order)
# 计数
total = db.scalar(select(func.count()).select_from(query.subquery()))
# 分页
papers = (
db.execute(query.offset((page - 1) * per_page).limit(per_page))
.scalars()
.all()
)
# 获取每篇论文的总结状态
paper_ids = [p.id for p in papers]
statuses = {}
if paper_ids:
rows = db.execute(
select(SummaryStatus.paper_id, SummaryStatus.status).where(
SummaryStatus.paper_id.in_(paper_ids)
)
).all()
paper_id_to_arxiv = {p.id: p.arxiv_id for p in papers}
for pid, st in rows:
statuses[paper_id_to_arxiv.get(pid, "")] = st
# 构建分页 URL 辅助函数
def pagination_url(p: int) -> str:
params = dict(request.query_params)
@@ -588,7 +528,7 @@ async def admin_papers(
{
"papers": papers,
"paper_summary_statuses": statuses,
"total": total or 0,
"total": total,
"page": page,
"per_page": per_page,
"current_status": summary_status,
@@ -615,7 +555,9 @@ async def admin_paper_delete(
# 清理 FTS 索引
try:
db.execute(text("DELETE FROM papers_fts WHERE arxiv_id = :aid"), {"aid": arxiv_id})
db.execute(
text("DELETE FROM papers_fts WHERE arxiv_id = :aid"), {"aid": arxiv_id}
)
db.commit()
except Exception:
logger.warning("Failed to clean FTS index for %s", arxiv_id, exc_info=True)
@@ -646,9 +588,11 @@ async def admin_papers_batch_action(
raise HTTPException(status_code=400, detail="arxiv_ids 不能为空")
if body.action == "delete":
papers = db.execute(
select(Paper).where(Paper.arxiv_id.in_(body.arxiv_ids))
).scalars().all()
papers = (
db.execute(select(Paper).where(Paper.arxiv_id.in_(body.arxiv_ids)))
.scalars()
.all()
)
count = 0
for paper in papers:
@@ -658,21 +602,27 @@ async def admin_papers_batch_action(
# 清理 FTS 索引
try:
db.execute(
text("DELETE FROM papers_fts WHERE arxiv_id IN :ids"),
{"ids": tuple(body.arxiv_ids)},
stmt = text("DELETE FROM papers_fts WHERE arxiv_id IN :ids").bindparams(
bindparam("ids", expanding=True)
)
db.execute(stmt, {"ids": body.arxiv_ids})
db.commit()
except Exception:
logger.warning("Failed to clean FTS index for batch delete", exc_info=True)
return {"status": "success", "message": f"已删除 {count} 篇论文", "count": count}
return {
"status": "success",
"message": f"已删除 {count} 篇论文",
"count": count,
}
elif body.action == "summarize":
# 将选中论文的总结状态重置为 pending
paper_ids = db.execute(
select(Paper.id).where(Paper.arxiv_id.in_(body.arxiv_ids))
).scalars().all()
paper_ids = (
db.execute(select(Paper.id).where(Paper.arxiv_id.in_(body.arxiv_ids)))
.scalars()
.all()
)
if paper_ids:
# 删除旧的 status 记录让其重新进入 pipeline
+5 -3
View File
@@ -12,6 +12,8 @@ from app.utils import templates
router = APIRouter()
MAX_COMPARE_PAPERS = 5
@router.get("/compare")
def compare_page(
@@ -33,9 +35,9 @@ def compare_page(
arxiv_ids = [i.strip() for i in ids.split(",") if i.strip()]
# 最多 5
if len(arxiv_ids) > 5:
arxiv_ids = arxiv_ids[:5]
# 最多 MAX_COMPARE_PAPERS
if len(arxiv_ids) > MAX_COMPARE_PAPERS:
arxiv_ids = arxiv_ids[:MAX_COMPARE_PAPERS]
if not arxiv_ids:
return templates.TemplateResponse(
+2 -99
View File
@@ -4,7 +4,6 @@ from __future__ import annotations
import json
import logging
import re
from datetime import date, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query, Request
@@ -15,6 +14,7 @@ from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.database import get_db
from app.models import PAPER_FULL_LOAD, Paper
from app.services.pdf_image_extractor import link_figures_with_images
from app.utils import (
PAPERS_DIR,
safe_json_loads,
@@ -120,7 +120,7 @@ def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db))
paper.summary.figures_json if paper.summary else None, default=[]
)
linked_figures = _link_figures_with_images(figures_raw, images, arxiv_id)
linked_figures = link_figures_with_images(figures_raw, images, arxiv_id)
# 拆分图片到对应展示区域:
# table_figures → 实验结果区域(Table 截图,不变)
@@ -279,100 +279,3 @@ def _get_paper_images(arxiv_id: str) -> list[dict]:
}
)
return images
def _link_figures_with_images(
figures: list[dict], images: list[dict], arxiv_id: str
) -> list[dict]:
"""将 summary figures 元数据与提取的图片文件关联。
策略:
1. 优先用 manifest.json 的 label 做 ID 精确匹配
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
"""
if not figures or not images:
return figures
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
# ── 策略 1manifest ID 精确匹配 ──
id_to_url: dict[str, str] = {}
if manifest_path.exists():
try:
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
except (ValueError, TypeError):
manifest = {}
for filename, info in manifest.items():
url = f"/papers/{arxiv_id}/images/{filename}"
# 优先用 label 字段(新格式)
label = info.get("label", "")
if label:
id_to_url[label] = url
# 也兼容 figures/tables 列表(旧格式)
for fig_id in info.get("figures", []) + info.get("tables", []):
if fig_id not in id_to_url:
id_to_url[fig_id] = url
for fig in figures:
raw_id = fig.get("id", "")
normalized = _normalize_figure_id(raw_id)
if normalized in id_to_url:
fig["image_url"] = id_to_url[normalized]
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
unmatched = [f for f in figures if not f.get("image_url")]
if not unmatched:
return figures
# 按类型分流:Figure vs Table
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
table_type_unmatched = [
f for f in unmatched if not _is_figure_type(f.get("id", ""))
]
# 提取的图片按类型分流,按文件名中的编号排序
def _sort_key(name: str) -> tuple[int, int]:
# 新格式:figure_1.jpg, table_1.jpg
m = re.search(r"(?:figure|table)_(\d+)", name)
if m:
return (0, int(m.group(1)))
# 旧格式:page2_img1.png, page5_table1.png, figure_1.png
m2 = re.search(r"page(\d+)_(?:img|table)(\d+)", name)
if m2:
return (int(m2.group(1)), int(m2.group(2)))
return (0, 0)
fig_images = sorted(
[img for img in images if "table" not in img["name"].lower()],
key=lambda img: _sort_key(img["name"]),
)
table_images = sorted(
[img for img in images if "table" in img["name"].lower()],
key=lambda img: _sort_key(img["name"]),
)
for i, fig in enumerate(fig_type_unmatched):
if i < len(fig_images):
fig["image_url"] = fig_images[i]["url"]
for i, fig in enumerate(table_type_unmatched):
if i < len(table_images):
fig["image_url"] = table_images[i]["url"]
return figures
def _normalize_figure_id(raw_id: str) -> str:
"""归一化 Figure/Table ID'Figure 1'/'Fig.1''Figure 1'"""
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
if m:
return f"Figure {m.group(1)}"
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
if m2:
return f"Table {m2.group(1)}"
return raw_id
def _is_figure_type(fig_id: str) -> bool:
"""判断是否为 Figure 类型(非 Table)。"""
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
+5 -23
View File
@@ -2,12 +2,13 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
from app.exceptions import NotFoundError
from app.services.user_data import (
get_note,
save_note,
@@ -37,9 +38,6 @@ def bookmark_toggle(arxiv_id: str, request: Request, db: Session = Depends(get_d
"""切换收藏状态。支持 HTMX 局部刷新和 JSON 响应。"""
result = toggle_bookmark(db, arxiv_id)
if "error" in result:
raise HTTPException(status_code=404, detail=result["error"])
# HTMX 请求 → 返回 HTML 片段
if request.headers.get("HX-Request"):
star = "" if result["bookmarked"] else ""
@@ -66,18 +64,7 @@ def reading_status_update(
db: Session = Depends(get_db),
):
"""更新阅读状态。"""
result = set_reading_status(db, arxiv_id, body.status)
if "error" in result:
if result["error"] == "not_found":
raise HTTPException(status_code=404, detail="Paper not found")
elif result["error"] == "invalid_status":
raise HTTPException(
status_code=422,
detail=f"Invalid status. Valid: {result['valid']}",
)
return result
return set_reading_status(db, arxiv_id, body.status)
# ── 笔记 ──────────────────────────────────────────────────────────────
@@ -88,16 +75,11 @@ def note_get(arxiv_id: str, db: Session = Depends(get_db)):
"""获取笔记。"""
result = get_note(db, arxiv_id)
if result is None:
raise HTTPException(status_code=404, detail="Paper not found")
raise NotFoundError(f"Paper not found: {arxiv_id}")
return result
@router.post("/note/{arxiv_id}")
def note_save(arxiv_id: str, body: NoteRequest, db: Session = Depends(get_db)):
"""保存笔记。"""
result = save_note(db, arxiv_id, body.content)
if "error" in result:
raise HTTPException(status_code=404, detail=result["error"])
return result
return save_note(db, arxiv_id, body.content)
+94 -12
View File
@@ -9,10 +9,18 @@ from sqlalchemy import func, select, text
from sqlalchemy.orm import Session
from app.config import settings
from app.models import CrawlLog, Paper, SummaryState, TaskLock
from app.models import CrawlLog, Paper, PaperTag, SummaryState, SummaryStatus, TaskLock
from app.services.scheduler import get_scheduler
from app.utils import PAPERS_DIR, TMP_DIR
# admin_papers 排序映射
SORT_MAP = {
"date_desc": Paper.paper_date.desc(),
"date_asc": Paper.paper_date.asc(),
"upvotes_desc": Paper.upvotes.desc(),
"title_asc": Paper.title_en.asc(),
}
def _dir_size(path: Path) -> int:
"""递归计算目录总字节数。"""
@@ -52,7 +60,11 @@ def get_admin_stats(db: Session) -> dict:
status_counts = {row[0]: row[1] for row in summary_rows}
# ── 存储概况 ──────────────────────────────────────────────────────
db_size = _fmt_size(settings.db_path.stat().st_size) if settings.db_path.exists() else "0 B"
db_size = (
_fmt_size(settings.db_path.stat().st_size)
if settings.db_path.exists()
else "0 B"
)
papers_size = _fmt_size(_dir_size(PAPERS_DIR))
tmp_size = _fmt_size(_dir_size(TMP_DIR))
@@ -68,22 +80,14 @@ def get_admin_stats(db: Session) -> dict:
# ── 最近日志(5 条) ──────────────────────────────────────────────
recent_logs = (
db.execute(
select(CrawlLog)
.order_by(CrawlLog.started_at.desc())
.limit(5)
)
db.execute(select(CrawlLog).order_by(CrawlLog.started_at.desc()).limit(5))
.scalars()
.all()
)
# ── 活跃锁 ────────────────────────────────────────────────────────
active_locks = (
db.execute(
select(TaskLock).where(TaskLock.status == "running")
)
.scalars()
.all()
db.execute(select(TaskLock).where(TaskLock.status == "running")).scalars().all()
)
return {
@@ -108,3 +112,81 @@ def get_admin_stats(db: Session) -> dict:
"active_locks": active_locks,
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
}
def query_papers(
db: Session,
*,
q: str = "",
date_from: str | None = None,
date_to: str | None = None,
tag: str = "",
summary_status: str = "all",
sort: str = "date_desc",
page: int = 1,
per_page: int = 20,
) -> tuple[list[Paper], int, dict[str, str]]:
"""论文管理查询 — 构建过滤、排序、分页。
Returns:
(papers, total, statuses) — 论文列表、总数、{arxiv_id: summary_status}
"""
query = select(Paper)
# 搜索
if q.strip():
query = query.where(
Paper.title_en.ilike(f"%{q}%")
| Paper.title_zh.ilike(f"%{q}%")
| Paper.abstract.ilike(f"%{q}%")
)
# 日期范围
if date_from:
query = query.where(Paper.paper_date >= date_from)
if date_to:
query = query.where(Paper.paper_date <= date_to)
# 标签筛选
if tag:
query = query.join(PaperTag, PaperTag.paper_id == Paper.id).where(
PaperTag.tag == tag
)
# 总结状态筛选
if summary_status != "all":
if summary_status == "none":
query = query.outerjoin(
SummaryStatus, SummaryStatus.paper_id == Paper.id
).where(SummaryStatus.paper_id == None) # noqa: E711
else:
query = query.join(SummaryStatus, SummaryStatus.paper_id == Paper.id).where(
SummaryStatus.status == summary_status
)
# 排序
order = SORT_MAP.get(sort, Paper.paper_date.desc())
query = query.order_by(order)
# 计数
total = db.scalar(select(func.count()).select_from(query.subquery()))
# 分页
papers = (
db.execute(query.offset((page - 1) * per_page).limit(per_page)).scalars().all()
)
# 每篇论文的总结状态
paper_ids = [p.id for p in papers]
statuses: dict[str, str] = {}
if paper_ids:
rows = db.execute(
select(SummaryStatus.paper_id, SummaryStatus.status).where(
SummaryStatus.paper_id.in_(paper_ids)
)
).all()
paper_id_to_arxiv = {p.id: p.arxiv_id for p in papers}
for pid, st in rows:
statuses[paper_id_to_arxiv.get(pid, "")] = st
return papers, total or 0, statuses
+5 -2
View File
@@ -207,11 +207,14 @@ async def delete_papers_by_date_range(
completed_at=utc_now(),
papers_found=total,
papers_new=deleted,
details_json=json.dumps({
details_json=json.dumps(
{
"total_before": total,
"deleted": deleted,
"failed": len(failed_items),
}, ensure_ascii=False),
},
ensure_ascii=False,
),
error=job_error,
)
db.add(log_entry)
+6 -2
View File
@@ -189,11 +189,15 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
db = SessionLocal()
try:
paper = db.execute(
paper = (
db.execute(
select(Paper)
.where(Paper.arxiv_id == paper_id)
.options(joinedload(Paper.tags), joinedload(Paper.summary))
).unique().scalar_one_or_none()
)
.unique()
.scalar_one_or_none()
)
if not paper:
logger.warning("Paper %s not found for indexing", paper_id)
return False
+174
View File
@@ -0,0 +1,174 @@
"""PicoDet-S_layout_3cls 布局检测 — 纯 ONNX Runtime 推理.
用 onnxruntime 加载导出好的 ONNX 模型,检测 PDF 页面中的 figure / table 区域。
模型自带 NMS + GFL decode,输出即为后处理完毕的检测框。
输入:
image: (1, 3, 480, 480) float32 — ImageNet 标准化后的图片
scale_factor: (1, 2) float32 — [y_scale, x_scale],用于坐标还原
输出:
fetch_name_0: (N, 6) float32 — [xmin, ymin, xmax, ymax, score, class_id]
fetch_name_1: (1,) int32 — 有效框数量 N
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import onnxruntime as ort
import pymupdf
from app.config import settings
logger = logging.getLogger(__name__)
# 模型输入尺寸
_MODEL_SIZE = 480
# ImageNet normalize
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
# PicoDet label → 内部 boxclass
_LABEL_MAP: dict[int, str] = {
0: "picture", # PicoDet "image" → "picture"
1: "table",
# 2: seal — 忽略
}
# 最小 bbox 尺寸(PDF 点)
_MIN_BOX_SIZE = 20
@dataclass
class LayoutBox:
"""检测到的布局区域,兼容现有 _process_page 代码。"""
x0: float
y0: float
x1: float
y1: float
boxclass: str # "picture" | "table"
class _LayoutDetector:
"""单例:管理 ONNX InferenceSession 生命周期。"""
def __init__(self) -> None:
self._session: ort.InferenceSession | None = None
def _init_session(self) -> ort.InferenceSession:
if self._session is not None:
return self._session
model_path = Path(settings.LAYOUT_MODEL_PATH)
if not model_path.exists():
raise FileNotFoundError(
f"Layout model not found: {model_path}. "
"Run scripts/export_picodet_onnx.py first."
)
logger.info("Loading ONNX layout model: %s", model_path)
self._session = ort.InferenceSession(
str(model_path), providers=["CPUExecutionProvider"]
)
logger.info("ONNX layout model loaded")
return self._session
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
"""检测单页 PDF 的 figure / table 区域。
流程:
1. pymupdf 以 480×480 渲染页面
2. ImageNet normalize → NCHW
3. ONNX 推理 → 得到已解码+NMS 的检测框
4. 像素坐标 → PDF 点坐标
5. 过滤 seal 类和低置信度框
Args:
page: pymupdf Page 对象
Returns:
LayoutBox 列表,坐标为 PDF 点
"""
session = self._init_session()
page_w = page.rect.width
page_h = page.rect.height
# 1. 渲染页面到 _MODEL_SIZE × _MODEL_SIZE
zoom_x = _MODEL_SIZE / page_w
zoom_y = _MODEL_SIZE / page_h
mat = pymupdf.Matrix(zoom_x, zoom_y)
pix = page.get_pixmap(matrix=mat)
# 2. 预处理
img = (
np.frombuffer(pix.samples, dtype=np.uint8)
.reshape(pix.height, pix.width, pix.n)
.astype(np.float32)
/ 255.0
)
# 去掉 alpha 通道(如有)
if img.shape[2] == 4:
img = img[:, :, :3]
img = (img - _MEAN) / _STD
img = img.transpose(2, 0, 1)[np.newaxis] # (1, 3, H, W)
# scale_factor 用于坐标还原(模型内部可能用)
scale_factor = np.array([[1.0, 1.0]], dtype=np.float32)
# 3. 推理
input_names = [i.name for i in session.get_inputs()]
feed = {input_names[0]: img}
if len(input_names) > 1:
feed[input_names[1]] = scale_factor
outputs = session.run(None, feed)
boxes_raw = outputs[0] # (N, 6): [class_id, score, xmin, ymin, xmax, ymax]
num_boxes = int(outputs[1][0]) # 有效框数
if num_boxes == 0:
return []
# 4. 像素 → PDF 点坐标
sx = page_w / _MODEL_SIZE
sy = page_h / _MODEL_SIZE
result: list[LayoutBox] = []
for i in range(min(num_boxes, len(boxes_raw))):
cls_id, score, xmin, ymin, xmax, ymax = boxes_raw[i]
cls_id = int(cls_id)
# 跳过 seal 类和低置信度
if cls_id not in _LABEL_MAP:
continue
if score < settings.LAYOUT_THRESHOLD:
continue
x0, y0 = xmin * sx, ymin * sy
x1, y1 = xmax * sx, ymax * sy
# 跳过极小区域
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
continue
result.append(
LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=_LABEL_MAP[cls_id])
)
return result
# 模块级单例
_detector = _LayoutDetector()
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
"""检测 PDF 页面中的 figure / table 区域。
Returns:
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table。
"""
return _detector.detect_page(page)
+16 -1
View File
@@ -9,6 +9,7 @@ from pathlib import Path
import requests
from app.config import settings
from app.utils import PAPERS_DIR, TMP_DIR
logger = logging.getLogger(__name__)
@@ -51,6 +52,14 @@ def _get_session() -> requests.Session:
return _http_session
def close_http_session() -> None:
"""关闭全局 HTTP Session,供应用 shutdown 时调用。"""
global _http_session
if _http_session is not None:
_http_session.close()
_http_session = None
async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
if not pdf_url:
@@ -62,10 +71,16 @@ async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
try:
session = _get_session()
resp = session.get(pdf_url, timeout=120, allow_redirects=True)
resp = session.get(pdf_url, timeout=settings.PDF_DOWNLOAD_TIMEOUT, allow_redirects=True)
resp.raise_for_status()
dest.write_bytes(resp.content)
except Exception as exc:
# 清理残留的部分文件
if dest.exists():
try:
dest.unlink()
except OSError:
pass
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
+363 -442
View File
@@ -1,12 +1,12 @@
"""PDF 图片与表格提取 — 基于 pymupdf4llm layout analysis
"""PDF 图片与表格提取 — 两阶段流水线
用 pymupdf4llm 的 layout analysis 检测 table / picture 区域,
再通过 caption 文字匹配确定 Figure/Table 编号,渲染为 JPEG。
Phase 1: PicoDet-S_layout_3cls 检测 figure/table 区域 → 渲染为 JPEG(通用标签)
Phase 2: 用 LLM summary 的 figures[].id 在 PDF 中搜索定位 → 匹配到 box → 重命名
相比旧方案(caption 正则 + pdfplumber/find_tables/文本块扫描三套策略):
- layout analysis 直接给出区域 bbox,不存在相邻表格互相侵入的问题
- 无需手动调参(最大高度、间隙阈值等)
- 页面级 caption 匹配:每个 caption 只分配给最近的 box,避免上下相邻表格抢夺同一个 caption
相比旧方案(正则匹配 caption):
- 不再依赖正则,用 LLM 输出的 ID 直接搜索 PDF 文本
- page.search_for() 精确搜索 + 空间距离过滤,避免正文引用误匹配
- 通用标签兜底,LLM 没提到的图表不会被丢弃
"""
from __future__ import annotations
@@ -17,44 +17,30 @@ import re
from pathlib import Path
import pymupdf
import pymupdf4llm.helpers.document_layout as dl
from app.services.layout_detector import LayoutBox, detect_page_layout
from app.services.pdf_downloader import paper_dir
from app.utils import TMP_DIR
from app.utils import PAPERS_DIR, TMP_DIR
logger = logging.getLogger(__name__)
# ── Caption 正则 ───────────────────────────────────────────────────────
# 用于从 caption 文字中提取 Figure/Table 编号
_FIGURE_CAPTION_RE = re.compile(
r"^(?:Fig\.?|Figure)\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))",
re.IGNORECASE,
)
_TABLE_CAPTION_RE = re.compile(
r"^Table\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))",
re.IGNORECASE,
)
# caption 与 table/picture 的最大匹配距离(点)
_CAPTION_MATCH_DISTANCE = 100
# 截图区域的外边距
# 截图区域的外边距(单位: pt
_REGION_PADDING = 5
# 3x 渲染,保证清晰度
# 渲染倍率(3x 保证清晰度
_RENDER_ZOOM = 3
# 相邻 box 聚类间距()— 同一 figure/table 的碎片间距通常 < 15pt
# 相邻 box 聚类间距(单位: pt)— 同一 figure/table 的碎片间距通常 < 15pt
_CLUSTER_GAP = 15
# 最小 bbox 面积(单位: pt²)— 过滤 icon/logo 等微小误检
_MIN_BOX_AREA = 2000
# Phase 2: 搜索文本到 box 的最大匹配距离(单位: pt)
_LABEL_MATCH_DISTANCE = 100
# ── Box 聚类 ─────────────────────────────────────────────────────────
class _BoxCluster:
"""合并后的布局区域(由一个或多个相邻 LayoutBox 组成)。
pymupdf4llm 有时将一个大图拆成多个小 picture box(如视频帧网格),
聚类后用整体 bbox 作为渲染区域。
"""
"""合并后的布局区域(由一个或多个相邻 LayoutBox 组成)。"""
__slots__ = ("x0", "y0", "x1", "y1", "boxclass")
@@ -63,17 +49,12 @@ class _BoxCluster:
self.y0 = min(b.y0 for b in boxes)
self.x1 = max(b.x1 for b in boxes)
self.y1 = max(b.y1 for b in boxes)
# table-fallback 归一化为 tablelayout model 检测到表格但无法提取结构)
raw = boxes[0].boxclass
self.boxclass = "table" if raw == "table-fallback" else raw
def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
"""将相邻的同类型 box 合并为聚类。
用 union-find 将间距 ≤ gap 的同类型 box 归为一组,
每组生成一个 _BoxCluster(整体 bbox)。
"""
"""将相邻的同类型 box 合并为聚类。"""
if not boxes:
return []
@@ -111,242 +92,58 @@ def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
return [_BoxCluster(members) for members in groups.values()]
# ── 页面级 Caption 查找与匹配 ──────────────────────────────────────────
# ── Phase 1: 检测 + 渲染 ──────────────────────────────────────────────
def _find_page_captions(page) -> list[dict]:
"""查找页面上所有 Figure/Table caption 文字块。"""
blocks = page.get_text("blocks")
captions = []
for b in blocks:
if len(b) < 5:
continue
bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3]
text = str(b[4]).strip()
first_line = text.split("\n")[0].strip()
cap_type = None
m = _TABLE_CAPTION_RE.match(first_line)
if m:
cap_type = "table"
else:
m = _FIGURE_CAPTION_RE.match(first_line)
if m:
cap_type = "figure"
if m is None:
continue
captions.append(
{
"label": f"{'Table' if cap_type == 'table' else 'Figure'} {m.group(1)}",
"type": cap_type,
"caption_text": text,
"caption_y0": by0,
"caption_y1": by1,
"caption_x0": bx0,
"caption_x1": bx1,
}
)
return captions
def _vertical_distance(cap_y0, cap_y1, box_y0, box_y1) -> float | None:
"""计算 caption 到 box 的垂直距离。不邻接时返回 None。
三种情况:caption 完全在 box 上方、完全在下方、与 box 有垂直重叠。
重叠(含部分溢出)视为 distance=0,确保 caption 延伸到 box 边界外时不会丢失。
"""
# Caption 完全在 box 上方
if cap_y1 <= box_y0:
dist = box_y0 - cap_y1
return dist if dist <= _CAPTION_MATCH_DISTANCE else None
# Caption 完全在 box 下方
if cap_y0 >= box_y1:
dist = cap_y0 - box_y1
return dist if dist <= _CAPTION_MATCH_DISTANCE else None
# Caption 与 box 有垂直重叠(内部、部分溢出都算)→ 距离 0
return 0
def _same_column(cap: dict, box, page_width: float) -> bool:
"""判断 caption 和 box 是否在同一列。
双栏论文中左右栏间距有限,简单的水平重叠检查会跨列匹配。
策略:用中心 X 坐标判断各自在哪半边,只有同半边才算同列。
跨栏图表(caption 或 box 宽度 >65% 页宽)不受此限制。
"""
cap_w = cap["caption_x1"] - cap["caption_x0"]
box_w = box.x1 - box.x0
# 跨栏元素:宽度超过页面的 65%
if cap_w > page_width * 0.65 or box_w > page_width * 0.65:
return True
cap_cx = (cap["caption_x0"] + cap["caption_x1"]) / 2
box_cx = (box.x0 + box.x1) / 2
mid = page_width / 2
# 同在左半边或同在右半边
return (cap_cx < mid) == (box_cx < mid)
def _match_captions_to_boxes(
page_boxes: list, captions: list[dict], page_width: float
) -> list[tuple[list[int], list[dict]]]:
"""将 caption 分配给 box,允许一个 caption 匹配多个同类型 box。
典型场景:
- Figure 由左右两个 picture box 组成,caption 同时靠近两者
- Table 的视觉内容被 layout analysis 误分类为 picture,需要跨类型匹配
Returns:
[(box_indices, captions), ...] 每组是一个独立的渲染任务
"""
# 每个 caption 找到所有距离在阈值内的 box
# 优先匹配同类型;如果找不到,再匹配任意 table/picture box
cap_to_boxes: dict[int, list[tuple[int, float]]] = {}
for ci, cap in enumerate(captions):
same_type: list[tuple[int, float]] = []
any_type: list[tuple[int, float]] = []
expected = "table" if cap["type"] == "table" else "picture"
for bi, box in enumerate(page_boxes):
# 列感知:双栏论文中只匹配同栏的 box
if not _same_column(cap, box, page_width):
continue
# 水平重叠检查(同列内仍需有重叠)
if not (
cap["caption_x1"] > box.x0 - 5 and cap["caption_x0"] < box.x1 + 5
):
continue
dist = _vertical_distance(
cap["caption_y0"], cap["caption_y1"], box.y0, box.y1
)
if dist is None:
continue
entry = (bi, dist)
any_type.append(entry)
if box.boxclass == expected:
same_type.append(entry)
# 优先用同类型匹配;没有时回退到任意类型;都没有则跳过
if same_type:
cap_to_boxes[ci] = same_type
elif any_type:
cap_to_boxes[ci] = any_type
# else: 该 caption 无匹配 box,不加入 cap_to_boxes
# 每个 caption → 最近的 box(用于分组),但记录所有匹配的 box
cap_primary: dict[int, int] = {} # caption → primary box index
cap_all_boxes: dict[int, list[int]] = {} # caption → all matched box indices
for ci, matches in cap_to_boxes.items():
matches.sort(key=lambda x: x[1])
cap_primary[ci] = matches[0][0]
# 所有距离最近的同组 box(距离差 < 20pt 视为同一组)
best_dist = matches[0][1]
cap_all_boxes[ci] = [bi for bi, d in matches if d <= best_dist + 20]
# 按 primary box 分组
box_to_caps: dict[int, list[int]] = {}
for ci, bi in cap_primary.items():
box_to_caps.setdefault(bi, []).append(ci)
# 构建渲染组:每个 caption 独立成组(共享 box 但各自渲染)
# 同类型同 label 的 caption 会合并;不同类型则分开
used_captions: set[int] = set()
groups: list[tuple[list[int], list[dict]]] = []
for bi in sorted(box_to_caps.keys()):
cis = box_to_caps[bi]
for ci in cis:
if ci in used_captions:
continue
used_captions.add(ci)
all_box_indices = set(cap_all_boxes.get(ci, [bi]))
# 只合并同 label 的 caption(同 figure/table 的重复 caption
merged_captions = [captions[ci]]
for other_bi in all_box_indices:
if other_bi in box_to_caps:
for other_ci in box_to_caps[other_bi]:
if other_ci not in used_captions:
other_cap = captions[other_ci]
if other_cap["label"] == captions[ci]["label"]:
used_captions.add(other_ci)
merged_captions.append(other_cap)
groups.append((sorted(all_box_indices), merged_captions))
return groups
# ── 单页处理 ─────────────────────────────────────────────────────────
def _render_and_save(
def _render_box(
page,
clip: pymupdf.Rect,
box: _BoxCluster,
images_dest: Path,
manifest: dict,
label: str,
filename: str,
cap_type: str,
caption_text: str,
page_num_1based: int,
arxiv_id: str,
page_num: int,
) -> bool:
"""渲染页面区域并保存 JPEG写入 manifest。成功返回 True。"""
"""渲染单个 box 区域并保存 JPEG,成功返回 True。"""
page_width = page.rect.width
clip = pymupdf.Rect(
max(0, box.x0 - _REGION_PADDING),
max(0, box.y0 - _REGION_PADDING),
min(page_width, box.x1 + _REGION_PADDING),
box.y1 + _REGION_PADDING,
)
mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM)
try:
pix = page.get_pixmap(matrix=mat, clip=clip)
except Exception:
logger.debug("Failed to render %s for %s", label, arxiv_id)
return False
filename = f"{label.replace(' ', '_').lower()}.jpg"
(images_dest / filename).write_bytes(pix.tobytes("jpeg"))
manifest[filename] = {
"page": page_num_1based,
"type": cap_type,
"label": label,
"caption_text": caption_text[:200] if caption_text else "",
"figures" if cap_type == "figure" else "tables": [label],
}
logger.debug(
"Rendered %s: page %d, region (%.0f,%.0f)-(%.0f,%.0f) → %s",
label,
page_num_1based,
clip.x0,
clip.y0,
clip.x1,
clip.y1,
filename,
)
return True
def _process_page(
doc,
page_idx: int,
page_layout,
page_boxes: list[LayoutBox],
images_dest: Path,
manifest: dict,
seen_labels: set,
arxiv_id: str,
) -> int:
"""处理单页:caption 匹配 + orphan 兜底,返回本页提取数量"""
"""处理单页:检测 → 聚类 → 渲染,全部用通用标签"""
page = doc[page_idx]
page_width = page.rect.width
page_num = page_idx + 1
orphan_fig_counter = 0
orphan_tbl_counter = 0
fig_counter = 0
tbl_counter = 0
# 收集本页的 table/picture box(跳过极小区域)
raw_boxes = []
for box in page_layout.boxes:
for box in page_boxes:
if box.boxclass not in ("table", "table-fallback", "picture"):
continue
if (box.x1 - box.x0) < 20 or (box.y1 - box.y0) < 20:
w = box.x1 - box.x0
h = box.y1 - box.y0
if w < 20 or h < 20 or w * h < _MIN_BOX_AREA:
continue
raw_boxes.append(box)
@@ -354,153 +151,48 @@ def _process_page(
return 0
# 聚类:将同一 figure/table 的碎片 box 合并
page_boxes = _cluster_boxes(raw_boxes)
clusters = _cluster_boxes(raw_boxes)
# 页面级匹配:查找所有 caption,分配给 box
captions = _find_page_captions(page)
groups = _match_captions_to_boxes(page_boxes, captions, page_width)
# 只合并同 label 的 group(同一个 figure/table 的重复 caption
# 不同 label 的 group 即使共享 box 也不合并(如 Figure 7 和 Figure 8),
# 渲染时用 caption 位置切割区域
_merged_groups: set[int] = set()
merged_groups: list[tuple[list[int], list[dict]]] = []
for gi, (box_indices, caps) in enumerate(groups):
if gi in _merged_groups:
continue
this_labels = {c["label"] for c in caps}
all_box_set = set(box_indices)
merge_targets = {gi}
for other_gi, (other_bi, other_caps) in enumerate(groups):
if other_gi <= gi or other_gi in _merged_groups:
continue
other_labels = {c["label"] for c in other_caps}
# 只在 label 有交集时合并(同一个 figure/table
if this_labels & other_labels and all_box_set & set(other_bi):
merge_targets.add(other_gi)
all_box_set |= set(other_bi)
all_caps = []
for mgi in sorted(merge_targets):
_merged_groups.add(mgi)
all_caps.extend(groups[mgi][1])
merged_groups.append((sorted(all_box_set), all_caps))
groups = merged_groups
# ── 阶段 1:渲染有 caption 匹配的图/表 ──
matched_box_indices: set[int] = set()
extracted = 0
for box_indices, caps in groups:
matched_box_indices.update(box_indices)
# 去重同一 label,跳过已处理的
unique_caps = []
for cap in caps:
if cap["label"] not in seen_labels:
seen_labels.add(cap["label"])
unique_caps.append(cap)
if not unique_caps:
continue
# 合并所有关联 box 的 bbox
bx0 = min(page_boxes[i].x0 for i in box_indices)
by0 = min(page_boxes[i].y0 for i in box_indices)
bx1 = max(page_boxes[i].x1 for i in box_indices)
by1 = max(page_boxes[i].y1 for i in box_indices)
# 渲染区域:box + caption
all_cap_y0 = min(c["caption_y0"] for c in unique_caps)
all_cap_y1 = max(c["caption_y1"] for c in unique_caps)
all_cap_x0 = min(c["caption_x0"] for c in unique_caps)
all_cap_x1 = max(c["caption_x1"] for c in unique_caps)
top = max(0, min(by0, all_cap_y0) - _REGION_PADDING)
bottom = max(by1, all_cap_y1) + _REGION_PADDING
rx0 = max(0, min(bx0, all_cap_x0) - _REGION_PADDING)
rx1 = min(page_width, max(bx1, all_cap_x1) + _REGION_PADDING)
clip = pymupdf.Rect(rx0, top, rx1, bottom)
# 多个 caption 可能共享同一区域(如 subfigure),只需渲染一次
jpeg_bytes = None
for cap in unique_caps:
if jpeg_bytes is None:
if not _render_and_save(
page,
clip,
images_dest,
manifest,
cap["label"],
cap["type"],
cap["caption_text"],
page_num,
arxiv_id,
):
break
# 读取刚写入的 bytes 供后续同名 caption 复用
filename = f"{cap['label'].replace(' ', '_').lower()}.jpg"
jpeg_bytes = (images_dest / filename).read_bytes()
extracted += 1
else:
# 同区域的不同 caption(如 subfigure),复用图片
filename = f"{cap['label'].replace(' ', '_').lower()}.jpg"
(images_dest / filename).write_bytes(jpeg_bytes)
cap_preview = cap["caption_text"][:200]
manifest[filename] = {
"page": page_num,
"type": cap["type"],
"label": cap["label"],
"caption_text": cap_preview,
"figures" if cap["type"] == "figure" else "tables": [cap["label"]],
}
extracted += 1
# ── 阶段 2:渲染无 caption 匹配的图/表(orphan boxes ──
orphan_indices = set(range(len(page_boxes))) - matched_box_indices
for bi in sorted(orphan_indices):
box = page_boxes[bi]
cap_type = "figure" if box.boxclass == "picture" else "table"
for cluster in clusters:
cap_type = "figure" if cluster.boxclass == "picture" else "table"
if cap_type == "figure":
orphan_fig_counter += 1
label = f"Figure (p{page_num}-{orphan_fig_counter})"
fig_counter += 1
label = f"Figure (p{page_num}-{fig_counter})"
else:
orphan_tbl_counter += 1
label = f"Table (p{page_num}-{orphan_tbl_counter})"
tbl_counter += 1
label = f"Table (p{page_num}-{tbl_counter})"
if label in seen_labels:
continue
seen_labels.add(label)
clip = pymupdf.Rect(
max(0, box.x0 - _REGION_PADDING),
max(0, box.y0 - _REGION_PADDING),
min(page_width, box.x1 + _REGION_PADDING),
box.y1 + _REGION_PADDING,
)
if _render_and_save(
page,
clip,
images_dest,
manifest,
label,
cap_type,
"",
page_num,
arxiv_id,
):
filename = f"{label.replace(' ', '_').lower()}.jpg"
if not _render_box(page, cluster, images_dest, filename, cap_type, page_num):
continue
manifest[filename] = {
"page": page_num,
"type": cap_type,
"label": label,
"box": [
round(float(cluster.x0), 1),
round(float(cluster.y0), 1),
round(float(cluster.x1), 1),
round(float(cluster.y1), 1),
],
}
extracted += 1
return extracted
# ── 核心提取 ───────────────────────────────────────────────────────────
# ── Phase 1 核心入口 ───────────────────────────────────────────────────
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
"""从 PDF 提取 Figure/Table 截图,生成 manifest。
用 pymupdf4llm layout analysis 检测 table/picture 区域,
再通过 caption 文字确定编号,渲染为 JPEG。
"""Phase 1: 从 PDF 提取 Figure/Table 截图,生成通用标签的 manifest。
Args:
arxiv_id: 论文 ID
@@ -526,30 +218,18 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
if (images_dest / "manifest.json").exists():
(images_dest / "manifest.json").unlink()
doc = pymupdf.open(str(pdf_path))
# layout analysis
try:
parsed = dl.parse_document(
doc, filename=str(pdf_path), use_ocr=dl.OCRMode.NEVER
)
except Exception:
logger.warning(
"pymupdf4llm layout analysis failed for %s", arxiv_id, exc_info=True
)
doc.close()
return 0
with pymupdf.open(str(pdf_path)) as doc:
extracted = 0
manifest: dict[str, dict] = {}
seen_labels: set[str] = set()
for page_idx, page_layout in enumerate(parsed.pages):
for page_idx in range(doc.page_count):
try:
page_boxes = detect_page_layout(doc[page_idx])
extracted += _process_page(
doc,
page_idx,
page_layout,
page_boxes,
images_dest=images_dest,
manifest=manifest,
seen_labels=seen_labels,
@@ -564,8 +244,6 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
)
continue
doc.close()
# 保存 manifest
manifest_path = images_dest / "manifest.json"
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))
@@ -580,78 +258,321 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
return extracted
# ── 按 summary 过滤 ────────────────────────────────────────────────────
# ── Phase 2: 用 summary 的 figures ID 定位并重命名 ─────────────────────
def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
"""根据 summary 中的 figures 字段过滤提取的图片/表格
def _distance_text_to_box(rect: pymupdf.Rect, box: list[float]) -> float | None:
"""计算搜索到的文本 rect 到 box 的距离。超出阈值返回 None
用 manifest.json 中的 label 匹配,保留被 AI 总结引用的图片
判断逻辑:rect 中心与 box 的垂直距离 + 水平重叠检查
"""
rect_cx = (rect.x0 + rect.x1) / 2
rect_cy = (rect.y0 + rect.y1) / 2
bx0, by0, bx1, by1 = box
# 水平重叠:rect 中心在 box 水平范围内(或接近)
if not (bx0 - 20 <= rect_cx <= bx1 + 20):
return None
# 垂直距离
if rect_cy < by0:
dist = by0 - rect_cy
elif rect_cy > by1:
dist = rect_cy - by1
else:
dist = 0
return dist if dist <= _LABEL_MATCH_DISTANCE else None
def _search_variants(fig_id: str) -> list[str]:
"""为 figure/table ID 生成搜索变体。
"Figure 1" → ["Figure 1", "Fig. 1", "Fig 1"]
"Fig. 1" → ["Fig. 1", "Figure 1", "Fig 1"]
"Table A1" → ["Table A1"]
"""
variants = [fig_id]
m = re.match(r"(Fig\.?|Figure)\s+(\d+.*)", fig_id, re.IGNORECASE)
if m:
num_part = m.group(2)
variants.extend(
[
f"Figure {num_part}",
f"Fig. {num_part}",
f"Fig {num_part}",
]
)
# 去重保序
seen = set()
result = []
for v in variants:
if v not in seen:
seen.add(v)
result.append(v)
return result
def label_images_by_summary(
arxiv_id: str,
figures: list[dict],
pdf_path: Path | None = None,
) -> int:
"""Phase 2: 用 summary 的 figures ID 在 PDF 中搜索定位,重命名图片。
对 summary 中的每个 figure/table ID
1. page.search_for(id) 在所有页面搜索文本位置
2. 计算搜索位置与 manifest 中 box 坐标的距离
3. 最近匹配 → 重命名文件、更新 manifest
Args:
arxiv_id: 论文 ID
figures: summary 的 figures 列表,每项含 id/caption/description 等
pdf_path: PDF 路径
Returns:
成功重命名的图片数量
"""
if not figures:
return 0
images_dir = paper_dir(arxiv_id) / "images"
manifest_path = images_dir / "manifest.json"
if not images_dir.exists() or not manifest_path.exists():
if pdf_path is None:
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
if not pdf_path.exists():
return 0
all_files = [
f for f in images_dir.iterdir() if f.suffix.lower() in (".png", ".jpg", ".jpeg")
]
if not all_files:
images_dest = paper_dir(arxiv_id) / "images"
manifest_path = images_dest / "manifest.json"
if not manifest_path.exists():
return 0
manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8"))
manifest: dict[str, dict] = json.loads(manifest_path.read_text(encoding="utf-8"))
if not manifest:
return 0
# 收集 summary 中引用的所有 Figure/Table ID(归一化)
referenced_ids: set[str] = set()
for fig in figures:
# 构建候选列表:只对通用标签的条目做匹配
candidates: dict[str, dict] = {} # filename → {page, box, ...}
for fname, info in manifest.items():
if "(p" in info.get("label", ""):
candidates[fname] = info
if not candidates:
return 0
with pymupdf.open(str(pdf_path)) as doc:
# 收集所有匹配候选:(fig_id, fig_index, filename, distance)
matches: list[tuple[str, int, str, float]] = []
for fig_idx, fig in enumerate(figures):
fig_id = fig.get("id", "")
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", fig_id, re.IGNORECASE)
if m:
referenced_ids.add(f"Figure {m.group(1)}")
m2 = re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
if m2:
referenced_ids.add(f"Table {m2.group(1)}")
if not referenced_ids:
logger.warning("No valid figure/table IDs in summary for %s", arxiv_id)
return len(all_files)
# 根据 manifest 的 label 字段匹配
keep_filenames: set[str] = set()
for filename, info in manifest.items():
label = info.get("label", "")
if label in referenced_ids:
keep_filenames.add(filename)
if not fig_id:
continue
for ref in info.get("figures", []) + info.get("tables", []):
if ref in referenced_ids:
keep_filenames.add(filename)
# 生成搜索变体:Figure 1 / Fig. 1 / Fig 1 等
search_terms = _search_variants(fig_id)
# 在所有页面搜索该文本(含变体)
search_hits: list[tuple[int, pymupdf.Rect]] = [] # (page_num_1based, Rect)
for page_idx in range(doc.page_count):
page = doc[page_idx]
seen_rects: set[tuple[float, float]] = set()
for term in search_terms:
for r in page.search_for(term):
key = (round(r.x0, 1), round(r.y0, 1))
if key not in seen_rects:
seen_rects.add(key)
search_hits.append((page_idx + 1, r))
if not search_hits:
continue
# 对每个候选 manifest 条目,找最近的搜索命中
for fname, info in candidates.items():
box = info.get("box")
if not box:
continue
manifest_page = info.get("page", 0)
best_dist: float | None = None
for hit_page, rect in search_hits:
# 只匹配同页面
if hit_page != manifest_page:
continue
dist = _distance_text_to_box(rect, box)
if dist is not None and (best_dist is None or dist < best_dist):
best_dist = dist
if best_dist is not None:
matches.append((fig_id, fig_idx, fname, best_dist))
if not matches:
logger.info("No label matches for %s", arxiv_id)
return 0
# 去冲突:按距离排序,每个 fig_id 和每个 filename 只匹配一次
matches.sort(key=lambda x: x[3])
used_fig_ids: set[int] = set()
used_filenames: set[str] = set()
renames: list[tuple[str, str, str]] = [] # (old_fname, new_fname, fig_id)
for fig_id, fig_idx, fname, dist in matches:
if fig_idx in used_fig_ids or fname in used_filenames:
continue
used_fig_ids.add(fig_idx)
used_filenames.add(fname)
new_fname = f"{fig_id.replace(' ', '_').lower()}.jpg"
renames.append((fname, new_fname, fig_id))
# 执行重命名
labeled = 0
new_manifest: dict[str, dict] = {}
for fname, info in manifest.items():
if fname in used_filenames:
continue
# 未匹配的保持原样
new_manifest[fname] = info
for old_fname, new_fname, fig_id in renames:
old_path = images_dest / old_fname
new_path = images_dest / new_fname
if not old_path.exists():
continue
# 搬运 manifest 信息
info = manifest[old_fname].copy()
cap_type = info.get("type", "figure")
# 读取 caption 文本(从 figures 列表)
caption_text = ""
for fig in figures:
if fig.get("id") == fig_id:
caption_text = fig.get("caption", "")
break
if not keep_filenames:
logger.warning(
"No manifest matches for %s (refs=%s), keeping all",
arxiv_id,
referenced_ids,
info["label"] = fig_id
info["caption_text"] = caption_text[:200] if caption_text else ""
info.setdefault("figures" if cap_type == "figure" else "tables", []).append(
fig_id
)
return len(all_files)
removed = 0
for f in all_files:
if f.name not in keep_filenames:
f.unlink()
removed += 1
# 重命名文件
if new_fname != old_fname:
old_path.rename(new_path)
new_manifest[new_fname] = info
labeled += 1
# 写回 manifest
manifest_path.write_text(json.dumps(new_manifest, ensure_ascii=False, indent=2))
kept = len(all_files) - removed
logger.info(
"Filtered images for %s: kept %d, removed %d (refs=%s)",
"Labeled %d/%d images for %s using summary figures",
labeled,
len(manifest),
arxiv_id,
kept,
removed,
referenced_ids,
)
return kept
return labeled
# ── Figure ↔ Image 关联 ────────────────────────────────────────────────
def _normalize_figure_id(raw_id: str) -> str:
"""归一化 Figure/Table ID'Figure 1'/'Fig.1''Figure 1'"""
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
if m:
return f"Figure {m.group(1)}"
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
if m2:
return f"Table {m2.group(1)}"
return raw_id
def _is_figure_type(fig_id: str) -> bool:
"""判断是否为 Figure 类型(非 Table)。"""
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
def _image_sort_key(name: str) -> tuple[int, int]:
"""按文件名中的编号排序提取的图片。"""
# 新格式:figure_1.jpg, table_1.jpg
m = re.search(r"(?:figure|table)_(\d+)", name)
if m:
return (0, int(m.group(1)))
# 旧格式:page2_img1.png, page5_table1.png, figure_1.png
m2 = re.search(r"page(\d+)_(?:img|table)(\d+)", name)
if m2:
return (int(m2.group(1)), int(m2.group(2)))
return (0, 0)
def link_figures_with_images(
figures: list[dict], images: list[dict], arxiv_id: str
) -> list[dict]:
"""将 summary figures 元数据与提取的图片文件关联。
策略:
1. 优先用 manifest.json 的 label 做 ID 精确匹配
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
"""
if not figures or not images:
return figures
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
# ── 策略 1manifest ID 精确匹配 ──
id_to_url: dict[str, str] = {}
if manifest_path.exists():
try:
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
except (ValueError, TypeError):
manifest = {}
for filename, info in manifest.items():
url = f"/papers/{arxiv_id}/images/{filename}"
# 优先用 label 字段(新格式)
label = info.get("label", "")
if label:
id_to_url[label] = url
# 也兼容 figures/tables 列表(旧格式)
for fig_id in info.get("figures", []) + info.get("tables", []):
if fig_id not in id_to_url:
id_to_url[fig_id] = url
for fig in figures:
raw_id = fig.get("id", "")
normalized = _normalize_figure_id(raw_id)
if normalized in id_to_url:
fig["image_url"] = id_to_url[normalized]
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
unmatched = [f for f in figures if not f.get("image_url")]
if not unmatched:
return figures
# 按类型分流:Figure vs Table
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
table_type_unmatched = [
f for f in unmatched if not _is_figure_type(f.get("id", ""))
]
# 提取的图片按类型分流,按文件名中的编号排序
fig_images = sorted(
[img for img in images if "table" not in img["name"].lower()],
key=lambda img: _image_sort_key(img["name"]),
)
table_images = sorted(
[img for img in images if "table" in img["name"].lower()],
key=lambda img: _image_sort_key(img["name"]),
)
for i, fig in enumerate(fig_type_unmatched):
if i < len(fig_images):
fig["image_url"] = fig_images[i]["url"]
for i, fig in enumerate(table_type_unmatched):
if i < len(table_images):
fig["image_url"] = table_images[i]["url"]
return figures
+25 -10
View File
@@ -11,6 +11,7 @@ import uuid
from pathlib import Path
from app.config import settings
from app.utils import truncate_error
from app.services.summary_utils import (
JsonNotFoundError,
build_prompt,
@@ -21,6 +22,9 @@ from app.services.summary_utils import (
logger = logging.getLogger(__name__)
# PDF 全文注入模式的字符上限 — 超过此阈值自动切换到 search 模式
_PDF_MAX_CHARS = 80_000
# 重新导出,保持向后兼容
__all__ = [
"PiTimeoutError",
@@ -45,7 +49,7 @@ class PiProcessError(Exception):
def __init__(self, returncode: int, stderr: str):
self.returncode = returncode
self.stderr = stderr
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
super().__init__(f"pi exited with code {returncode}: {truncate_error(stderr)}")
# ── pi CLI 调用 ────────────────────────────────────────────────────────
@@ -72,23 +76,27 @@ async def call_pi(
actual_mode = pdf_mode
if pdf_mode == "auto":
if txt_size > 80_000:
if txt_size > _PDF_MAX_CHARS:
actual_mode = "search"
logger.info(
"Auto mode: %s text=%d chars > 80k → search", arxiv_id, txt_size
"Auto mode: %s text=%d chars > %dk → search",
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
)
else:
actual_mode = "inject"
logger.info(
"Auto mode: %s text=%d chars ≤ 80k → inject", arxiv_id, txt_size
"Auto mode: %s text=%d chars ≤ %dk → inject",
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
)
# inject 模式需要截断过长的文本(避免撑爆 context)
if actual_mode == "inject" and txt_size > 80_000:
if actual_mode == "inject" and txt_size > _PDF_MAX_CHARS:
body = txt_path.read_text(encoding="utf-8")
trimmed = body[:80_000].rstrip()
trimmed = body[:_PDF_MAX_CHARS].rstrip()
txt_path.write_text(trimmed, encoding="utf-8")
logger.info("Truncated %s for inject: %d%d chars", arxiv_id, txt_size, len(trimmed))
logger.info(
"Truncated %s for inject: %d%d chars", arxiv_id, txt_size, len(trimmed)
)
prompt_text = build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
@@ -101,7 +109,8 @@ async def call_pi(
cmd = [
settings.PI_BIN,
"-p",
"--tools", tools,
"--tools",
tools,
]
if fix_errors:
cmd += ["--session", session_id, "--continue"]
@@ -118,10 +127,14 @@ async def call_pi(
logger.info(
"Calling pi for %s (fix=%s, session=%s, mode=%s)",
arxiv_id, bool(fix_errors), session_id, actual_mode,
arxiv_id,
bool(fix_errors),
session_id,
actual_mode,
)
import time as _time
_t_sub_start = _time.monotonic()
proc = await asyncio.create_subprocess_exec(
@@ -151,7 +164,9 @@ async def call_pi(
logger.info(
"pi subprocess for %s: %.2fs%s",
arxiv_id, _t_sub_end - _t_sub_start, _file_info,
arxiv_id,
_t_sub_end - _t_sub_start,
_file_info,
)
if proc.returncode != 0:
+56 -11
View File
@@ -8,6 +8,7 @@ from __future__ import annotations
import logging
from datetime import date as date_type
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.config import settings
@@ -15,11 +16,50 @@ from app.models import CrawlLog, TaskLock
from app.services.cleaner import cleanup_tmp
from app.services.crawler import crawl_daily
from app.services.summarizer import summarize_batch
from app.utils import utc_now, yesterday_str
from app.utils import release_lock, truncate_error, utc_now, yesterday_str
logger = logging.getLogger(__name__)
def acquire_lock(db: Session, task: str, lock_key: str, owner: str) -> TaskLock:
"""获取 TaskLock,锁冲突时抛出 RuntimeError。
供需要防重入的操作(crawl、pipeline 等)统一调用。
"""
lock = TaskLock(
task=task,
lock_key=lock_key,
status="running",
owner=owner,
acquired_at=utc_now(),
)
try:
db.add(lock)
db.commit()
except IntegrityError:
db.rollback()
raise RuntimeError(f"{task} already running for {lock_key}")
return lock
async def run_crawl(db: Session, target_date: str, owner: str = "admin_crawl") -> dict:
"""执行单次抓取(带防重入锁)。
Args:
db: 数据库 session
target_date: 目标日期 YYYY-MM-DD
owner: 调用者标识
Returns:
crawl_daily() 的原始返回值
"""
lock = acquire_lock(db, "crawl", target_date, owner)
try:
return await crawl_daily(db, target_date)
finally:
release_lock(db, lock)
async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
"""执行完整流水线:crawl → summarize → cleanup。
@@ -47,7 +87,7 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
try:
db.add(lock)
db.commit()
except Exception:
except IntegrityError:
db.rollback()
raise RuntimeError(f"Pipeline already running for {target_date}")
@@ -66,9 +106,13 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
try:
# Step 1: 抓取(先试今天,无数据则回退昨天)
crawl_result = await crawl_daily(db, target_date)
logger.info("Pipeline [%s]: crawl %s, found=%d new=%d",
owner, target_date,
crawl_result.get("found", 0), crawl_result.get("new", 0))
logger.info(
"Pipeline [%s]: crawl %s, found=%d new=%d",
owner,
target_date,
crawl_result.get("found", 0),
crawl_result.get("new", 0),
)
if crawl_result.get("status") == "success" and crawl_result.get("found") == 0:
yesterday = yesterday_str()
@@ -81,8 +125,11 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
# Step 3: 清理
cleanup_result = cleanup_tmp()
logger.info("Pipeline [%s]: cleanup done, removed=%d",
owner, cleanup_result.get("removed", 0))
logger.info(
"Pipeline [%s]: cleanup done, removed=%d",
owner,
cleanup_result.get("removed", 0),
)
log_entry.status = "success"
log_entry.papers_found = crawl_result.get("found", 0)
@@ -91,7 +138,7 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
except Exception as exc:
logger.exception("Pipeline [%s] failed", owner)
log_entry.status = "failed"
error_msg = str(exc)[:2000]
error_msg = truncate_error(exc, limit=2000)
finally:
log_entry.completed_at = utc_now()
@@ -99,9 +146,7 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
log_entry.error = error_msg
db.commit()
lock.status = "finished"
lock.released_at = utc_now()
db.commit()
release_lock(db, lock)
if error_msg:
return {"status": "failed", "error": error_msg}
+1
View File
@@ -90,6 +90,7 @@ class SummarySchema(BaseModel):
# ── 质量评估 ────────────────────────────────────────────────────────────
def assess_quality(schema: SummarySchema) -> str:
"""评估总结质量:normal / degraded / low。"""
# low:内容空洞的启发式判断
+2 -8
View File
@@ -213,11 +213,7 @@ def _search_semantic(
arxiv_ids = [c["arxiv_id"] for c in candidates]
distance_map = {c["arxiv_id"]: c["distance"] for c in candidates}
stmt = (
select(Paper)
.where(Paper.arxiv_id.in_(arxiv_ids))
.options(*PAPER_FULL_LOAD)
)
stmt = select(Paper).where(Paper.arxiv_id.in_(arxiv_ids)).options(*PAPER_FULL_LOAD)
if tag:
stmt = stmt.where(Paper.tags.any(tag=tag))
@@ -298,9 +294,7 @@ def _load_papers_by_ids(
papers = (
db.execute(
select(Paper)
.where(Paper.id.in_(paper_ids))
.options(*PAPER_FULL_LOAD)
select(Paper).where(Paper.id.in_(paper_ids)).options(*PAPER_FULL_LOAD)
)
.unique()
.scalars()
+39 -502
View File
@@ -1,233 +1,42 @@
"""AI 总结编排服务 — 协调 PDF 下载、pi CLI 调用、JSON 校验、DB 写入、语义索引"""
"""AI 总结编排服务 — 协调生成器、持久化、批量处理的顶层入口"""
from __future__ import annotations
import asyncio
import json
import logging
from pathlib import Path
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.config import settings
from app.database import SessionLocal
from app.exceptions import ConflictError, NotFoundError
from app.models import (
PAPER_DEFAULT_LOAD,
CrawlLog,
Paper,
PaperSummary,
PaperTag,
SummaryState,
SummaryStatus,
TaskLock,
get_paper_by_arxiv_id,
get_paper_by_id,
)
from app.services.pdf_downloader import (
PdfDownloadError,
cleanup_tmp,
download_pdf,
paper_dir,
from app.services.pdf_downloader import download_pdf
from app.services.summary_utils import write_meta_json
from app.services.summary_generator import (
_generate_with_retry,
)
from app.services.summary_utils import (
JsonNotFoundError,
build_prompt,
extract_json,
write_meta_json,
extract_pdf_text,
from app.services.summary_persister import (
_cleanup_old_images,
_handle_summary_failure,
_persist_summary,
)
from app.services.pi_client import (
PiProcessError,
PiTimeoutError,
call_pi,
)
from app.services import claude_backend
from app.services.schemas import (
SummarySchema,
assess_quality,
classify_validation_error,
flatten_for_db,
)
from app.utils import TMP_DIR, release_lock, utc_now
from app.utils import TMP_DIR, release_lock, truncate_error, utc_now
logger = logging.getLogger(__name__)
# ── 错误分类 ────────────────────────────────────────────────────────────
def _classify_error(exc: Exception) -> str:
"""将异常映射到 error_type 枚举值。"""
if isinstance(exc, PdfDownloadError):
return "pdf_download_failed"
if isinstance(exc, PiTimeoutError):
return "timeout"
if isinstance(exc, PiProcessError):
return "process_error"
if isinstance(exc, JsonNotFoundError):
return "json_not_found"
if isinstance(exc, json.JSONDecodeError):
return "json_invalid"
if isinstance(exc, ValidationError):
return classify_validation_error(exc)
return "unknown"
# ── FTS5 文本构建 ───────────────────────────────────────────────────────
def _build_fts_summary_text(schema: SummarySchema) -> str:
"""拼接用于 FTS5 索引的总结文本。"""
parts = [
schema.one_line or "",
schema.motivation.problem or "",
schema.motivation.goal or "",
schema.method.overview or "",
schema.method.key_idea or "",
schema.results.main_findings or "",
]
return " ".join(p for p in parts if p)
# ── DB 更新 ─────────────────────────────────────────────────────────────
def _update_summary_in_db(
db: Session,
paper: Paper,
schema: SummarySchema,
quality: str,
raw_output: str,
) -> None:
"""将校验后的总结写入 DBpaper_summaries + papers + paper_tags + FTS5。"""
from sqlalchemy import text
# 1. paper_summariesupsert
existing = db.get(PaperSummary, paper.id)
flat = flatten_for_db(schema)
if existing:
for k, v in flat.items():
setattr(existing, k, v)
else:
db.add(PaperSummary(paper_id=paper.id, **flat))
# 2. papers 表
paper.title_zh = schema.title_zh
paper.summary_quality = quality
p_dir = paper_dir(paper.arxiv_id)
paper.summary_path = str(p_dir / "summary.json")
paper.raw_output_path = str(p_dir / "raw_output.txt")
# 3. AI 标签
existing_tag_names = {t.tag for t in paper.tags}
for tag_name in schema.tags:
if tag_name not in existing_tag_names:
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
existing_tag_names.add(tag_name)
# 4. FTS5 更新
summary_text = _build_fts_summary_text(schema)
db.execute(
text(
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
"WHERE rowid=:paper_id"
),
{
"title_zh": schema.title_zh,
"summary_text": summary_text,
"paper_id": paper.id,
},
)
db.commit()
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
# ── JSON 验证 ──────────────────────────────────────────────────────────
def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
"""验证 JSON 数据是否符合要求,返回错误列表(空=通过)。"""
errors: list[str] = []
if not isinstance(json_data, dict):
return ["顶层必须是 JSON 对象"]
# 必填字段
for f in ["arxiv_id", "title_zh", "one_line", "tags"]:
if f not in json_data or not json_data[f]:
errors.append(f"缺少必填字段: {f}")
# tags 必须是非空数组
tags = json_data.get("tags")
if not isinstance(tags, list) or len(tags) == 0:
errors.append("tags 必须是非空数组")
# 字符串段落字段(必须是 str 且 ≥50 字)
string_fields = [
("motivation", "problem"), ("motivation", "goal"), ("motivation", "gap"),
("method", "overview"), ("method", "key_idea"), ("method", "steps"),
("method", "novelty"),
("results", "main_findings"), ("results", "limitations"),
("improvements", "weaknesses"), ("improvements", "future_work"),
("improvements", "reproducibility"),
]
for section, field in string_fields:
val = json_data.get(section, {}).get(field)
if isinstance(val, list):
errors.append(f"{section}.{field} 应该是字符串段落,不能是数组")
elif not isinstance(val, str) or len(val.strip()) < 50:
errors.append(
f"{section}.{field} 必须是详细段落(≥50字),"
f"当前: {type(val).__name__} ({len(str(val))}字)"
)
# benchmarks 必须是数组
benchmarks = json_data.get("results", {}).get("benchmarks")
if benchmarks is not None and not isinstance(benchmarks, list):
errors.append("results.benchmarks 必须是数组")
# prerequisites.concepts 必须是对象数组,每个有 term
concepts = json_data.get("prerequisites", {}).get("concepts")
if concepts is not None:
if not isinstance(concepts, list):
errors.append("prerequisites.concepts 必须是数组")
elif len(concepts) == 0:
errors.append("prerequisites.concepts 不能为空")
else:
for i, c in enumerate(concepts):
if isinstance(c, str):
errors.append(f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串")
elif isinstance(c, dict) and not c.get("term"):
errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段")
# figures 必须是数组,每个元素应有 id
figures = json_data.get("figures")
if figures is not None:
if not isinstance(figures, list):
errors.append("figures 必须是数组")
else:
for i, fig in enumerate(figures):
if isinstance(fig, dict) and not fig.get("id"):
errors.append(f"figures[{i}] 缺少 id 字段")
return errors
# ── 文件操作 ────────────────────────────────────────────────────────────
def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) -> None:
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
if schema:
(d / "summary.json").write_text(
schema.model_dump_json(ensure_ascii=False, indent=2),
encoding="utf-8",
)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
# ── 单篇总结 ────────────────────────────────────────────────────────────
@@ -264,277 +73,7 @@ async def summarize_one(
return await _do_summarize_one(db, paper, pdf_mode=pdf_mode)
async def _generate_with_retry(
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
) -> tuple[dict, str]:
"""调用 AI 后端生成总结,最多 4 轮验证循环。
根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。
Returns:
(json_data, raw_output)
Raises:
ValueError: 4 轮验证仍未通过
"""
import time as _time
backend = settings.SUMMARY_BACKEND
validation_errors: list[str] = []
json_data: dict | None = None
raw_output = ""
session_id = None
summary_file = paper_dir(arxiv_id) / "summary.json"
# claude 后端需要预构建 promptpi 后端在 call_pi 内部构建)
claude_prompt: str | None = None
if backend == "claude":
_t0 = _time.monotonic()
txt_path = extract_pdf_text(pdf_path, max_chars=None)
body = txt_path.read_text(encoding="utf-8")
if len(body) > 80_000:
trimmed = body[:80_000].rstrip()
txt_path.write_text(trimmed, encoding="utf-8")
claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None)
logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0)
for attempt in range(1, 5):
# 清理上一轮写入的不完整文件
if summary_file.exists():
summary_file.unlink()
# 记录 AI 调用开始时间
_t_call_start = _time.monotonic()
if backend == "claude":
if attempt == 1:
raw_output, session_id = await claude_backend.call_claude(
claude_prompt, session_id=None,
)
else:
retry_prompt = build_prompt(
arxiv_id, meta_path,
extract_pdf_text(pdf_path, max_chars=80000),
"inject", fix_errors=validation_errors,
)
raw_output, session_id = await claude_backend.call_claude(
retry_prompt, session_id=session_id, fix_errors=validation_errors,
)
else:
if attempt == 1:
raw_output, session_id = await call_pi(meta_path, pdf_path, pdf_mode=pdf_mode)
else:
raw_output, session_id = await call_pi(
meta_path, pdf_path,
fix_errors=validation_errors,
session_id=session_id,
pdf_mode=pdf_mode,
)
_t_call_end = _time.monotonic()
# 检查 summary.json 是否由 AI 子进程写入
file_written_by_ai = summary_file.exists()
file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None
file_size = summary_file.stat().st_size if file_written_by_ai else 0
logger.info(
" [%s] attempt %d AI调用: %.2fs summary.json=%s%s",
arxiv_id, attempt,
_t_call_end - _t_call_start,
f"已写入({file_size}B)" if file_written_by_ai else "未写入",
f" mtime={file_mtime:.2f}" if file_mtime else "",
)
# 提取 JSON
_t_json_start = _time.monotonic()
try:
if file_written_by_ai:
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id)
else:
json_data = extract_json(raw_output)
except (json.JSONDecodeError, JsonNotFoundError) as exc:
_t_json_end = _time.monotonic()
logger.warning(
" [%s] JSON提取失败: %.2fs %s",
arxiv_id, _t_json_end - _t_json_start, str(exc)[:200],
)
validation_errors = [f"无法提取有效 JSON: {str(exc)[:100]}"]
continue
_t_json_end = _time.monotonic()
# 验证
_t_val_start = _time.monotonic()
validation_errors = _validate_summary(json_data, arxiv_id)
_t_val_end = _time.monotonic()
if not validation_errors:
logger.info(
" [%s] JSON提取: %.2fs 验证: %.2fs ✅",
arxiv_id,
_t_json_end - _t_json_start,
_t_val_end - _t_val_start,
)
break
logger.warning(
" [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s",
arxiv_id,
_t_json_end - _t_json_start,
_t_val_end - _t_val_start,
"; ".join(validation_errors)[:200],
)
if validation_errors:
exc = ValueError(
f"Summary validation failed after 4 attempts: {'; '.join(validation_errors)}"
)
exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用
raise exc
return json_data, raw_output
def _persist_summary(
db: Session, paper: Paper, json_data: dict, raw_output: str
) -> str:
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
import time as _time
arxiv_id = paper.arxiv_id
_t0 = _time.monotonic()
schema = SummarySchema.model_validate(json_data)
quality = assess_quality(schema)
_t1 = _time.monotonic()
_save_files(arxiv_id, schema, raw_output)
_t2 = _time.monotonic()
_update_summary_in_db(db, paper, schema, quality, raw_output)
_t3 = _time.monotonic()
# 状态 → done
paper.summary_status.status = SummaryState.DONE
paper.summary_status.quality = quality
paper.summary_status.completed_at = utc_now()
paper.summary_status.raw_output_saved = True
db.commit()
_t4 = _time.monotonic()
logger.info(
" [%s] persist: pydantic=%.2fs 文件=%.2fs DB写入=%.2fs 状态commit=%.2fs",
arxiv_id,
_t1 - _t0,
_t2 - _t1,
_t3 - _t2,
_t4 - _t3,
)
# 触发性增强(失败不影响总结)
_t5 = _time.monotonic()
_maybe_extract_images(arxiv_id, schema)
_t6 = _time.monotonic()
_maybe_index_chroma(arxiv_id, paper, schema)
_t7 = _time.monotonic()
logger.info(
" [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs",
arxiv_id,
_t6 - _t5,
_t7 - _t6,
)
return quality
def _handle_summary_failure(
db: Session, paper: Paper, exc: Exception, raw_output: str,
) -> dict:
"""记录失败:保存 raw_output、重试计数、错误分类。"""
error_type = _classify_error(exc)
logger.error(
"Summarize failed: %s error_type=%s %s",
paper.arxiv_id, error_type, str(exc)[:200],
)
status = paper.summary_status
if raw_output:
_save_files(paper.arxiv_id, None, raw_output)
status.raw_output_saved = True
status.retry_count = (status.retry_count or 0) + 1
status.error_type = error_type
status.error = str(exc)[:2000]
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
status.status = SummaryState.PERMANENT_FAILURE
else:
status.status = SummaryState.PENDING
status.completed_at = utc_now()
db.commit()
return {
"arxiv_id": paper.arxiv_id,
"status": "failed",
"error_type": error_type,
"error": str(exc)[:200],
"retry_count": status.retry_count,
}
def _cleanup_old_images(db: Session, paper: Paper) -> None:
"""清理旧的图片文件和 figures_json,避免重新总结时残留。"""
arxiv_id = paper.arxiv_id
images_dir = paper_dir(arxiv_id) / "images"
if images_dir.exists():
for old_file in images_dir.iterdir():
if old_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg") or old_file.name == "manifest.json":
old_file.unlink(missing_ok=True)
# 清除数据库中的 figures_json
if paper.summary and paper.summary.figures_json:
paper.summary.figures_json = None
db.commit()
def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
"""从 PDF 提取图片和表格(失败不影响总结)。"""
try:
from app.services.pdf_image_extractor import (
extract_images_from_pdf,
filter_images_by_summary,
)
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
extract_images_from_pdf(arxiv_id, pdf_path)
if schema.figures:
filter_images_by_summary(arxiv_id, schema.figures)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None:
"""写入 ChromaDB 语义索引(失败不影响总结)。"""
try:
from app.services.embedder import index_paper
texts_dict = {
"arxiv_id": arxiv_id,
"title_zh": schema.title_zh or "",
"title_en": paper.title_en or "",
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
"one_line": schema.one_line or "",
"motivation_problem": schema.motivation.problem or "",
"method_key_idea": schema.method.key_idea or "",
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
}
index_paper(arxiv_id, texts_dict)
except Exception:
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
async def _do_summarize_one(
db: Session, paper: Paper, pdf_mode: str = "auto"
) -> dict:
async def _do_summarize_one(db: Session, paper: Paper, pdf_mode: str = "auto") -> dict:
"""实际的单篇总结执行(在 semaphore 保护下)。"""
arxiv_id = paper.arxiv_id
title_short = (paper.title_en or "")[:50]
@@ -548,6 +87,7 @@ async def _do_summarize_one(
# 清理旧的图片文件和 figures_json,避免重新总结时残留
import time as _time
_t_cleanup_start = _time.monotonic()
_cleanup_old_images(db, paper)
_t_cleanup_end = _time.monotonic()
@@ -567,7 +107,9 @@ async def _do_summarize_one(
logger.info(" [%s] 调用 pi 生成总结...", arxiv_id)
json_data, raw_output = await _generate_with_retry(
arxiv_id, meta_path, TMP_DIR / arxiv_id / "paper.pdf",
arxiv_id,
meta_path,
TMP_DIR / arxiv_id / "paper.pdf",
pdf_mode=pdf_mode,
)
_t3 = _time.monotonic()
@@ -577,7 +119,9 @@ async def _do_summarize_one(
_t4 = _time.monotonic()
logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3)
logger.info("✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0)
logger.info(
"✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0
)
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
except Exception as exc:
@@ -586,7 +130,7 @@ async def _do_summarize_one(
return _handle_summary_failure(db, paper, exc, fail_output)
finally:
cleanup_tmp(arxiv_id)
pass # cleanup_tmp(arxiv_id) # 暂时禁用,保留 PDF 用于调试图片提取
# ── 单篇入口 ────────────────────────────────────────────────────────────
@@ -604,25 +148,19 @@ async def summarize_single(
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
paper = db.execute(
select(Paper)
.where(Paper.arxiv_id == arxiv_id)
.options(*PAPER_DEFAULT_LOAD)
).unique().scalar_one_or_none()
paper = get_paper_by_arxiv_id(db, arxiv_id)
if not paper:
return {"status": "not_found", "arxiv_id": arxiv_id}
raise NotFoundError(f"Paper not found: {arxiv_id}")
make_session = _session_factory or SessionLocal
# 每篇用独立 session 避免并发问题
paper_db = make_session()
try:
paper_in_new_session = paper_db.execute(
select(Paper)
.where(Paper.arxiv_id == arxiv_id)
.options(*PAPER_DEFAULT_LOAD)
).unique().scalar_one_or_none()
result = await summarize_one(paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode)
paper_in_new_session = get_paper_by_arxiv_id(paper_db, arxiv_id)
result = await summarize_one(
paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode
)
finally:
paper_db.close()
@@ -656,10 +194,10 @@ async def summarize_batch(
try:
db.add(lock)
db.commit()
except Exception:
except IntegrityError:
db.rollback()
logger.warning("Summarize batch already running (lock conflict)")
return {"status": "conflict", "error": "summarize batch already running"}
raise ConflictError("summarize batch already running")
# CrawlLog
log_entry = CrawlLog(
@@ -717,19 +255,18 @@ async def summarize_batch(
break
paper_db = make_session()
try:
p = paper_db.execute(
select(Paper)
.where(Paper.id == paper.id)
.options(*PAPER_DEFAULT_LOAD)
).unique().scalar_one_or_none()
p = get_paper_by_id(paper_db, paper.id)
result = await summarize_one(paper_db, p, pdf_mode=pdf_mode)
status = result.get("status", "failed")
progress[status] = progress.get(status, 0) + 1
finished = sum(progress.values())
logger.info(
"📊 进度: %d/%d (✅%d%d ⏭️%d) — %s",
finished, total,
progress["done"], progress["failed"], progress["skipped"],
finished,
total,
progress["done"],
progress["failed"],
progress["skipped"],
paper.arxiv_id,
)
results.append(result)
@@ -785,10 +322,10 @@ async def summarize_batch(
except Exception as exc:
logger.exception("Summarize batch failed")
log_entry.status = "failed"
log_entry.error = str(exc)[:2000]
log_entry.error = truncate_error(exc, limit=2000)
log_entry.completed_at = utc_now()
db.commit()
return {"status": "failed", "error": str(exc)}
return {"status": "failed", "error": truncate_error(exc)}
finally:
release_lock(db, lock)
+275
View File
@@ -0,0 +1,275 @@
"""AI 总结生成器 — AI 后端调用、重试循环、JSON 验证、错误分类。"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from pydantic import ValidationError
from app.config import settings
from app.services.pdf_downloader import (
PdfDownloadError,
paper_dir,
)
from app.services.summary_utils import (
JsonNotFoundError,
build_prompt,
extract_json,
extract_pdf_text,
)
from app.services.pi_client import (
PiProcessError,
PiTimeoutError,
call_pi,
)
from app.services import claude_backend
from app.services.schemas import classify_validation_error
from app.utils import truncate_error
logger = logging.getLogger(__name__)
# ── 错误分类 ────────────────────────────────────────────────────────────
def _classify_error(exc: Exception) -> str:
"""将异常映射到 error_type 枚举值。"""
if isinstance(exc, PdfDownloadError):
return "pdf_download_failed"
if isinstance(exc, PiTimeoutError):
return "timeout"
if isinstance(exc, PiProcessError):
return "process_error"
if isinstance(exc, JsonNotFoundError):
return "json_not_found"
if isinstance(exc, json.JSONDecodeError):
return "json_invalid"
if isinstance(exc, ValidationError):
return classify_validation_error(exc)
return "unknown"
# ── JSON 验证 ──────────────────────────────────────────────────────────
def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
"""验证 JSON 数据是否符合要求,返回错误列表(空=通过)。"""
errors: list[str] = []
if not isinstance(json_data, dict):
return ["顶层必须是 JSON 对象"]
# 必填字段
for f in ["arxiv_id", "title_zh", "one_line", "tags"]:
if f not in json_data or not json_data[f]:
errors.append(f"缺少必填字段: {f}")
# tags 必须是非空数组
tags = json_data.get("tags")
if not isinstance(tags, list) or len(tags) == 0:
errors.append("tags 必须是非空数组")
# 字符串段落字段(必须是 str 且 ≥50 字)
string_fields = [
("motivation", "problem"),
("motivation", "goal"),
("motivation", "gap"),
("method", "overview"),
("method", "key_idea"),
("method", "steps"),
("method", "novelty"),
("results", "main_findings"),
("results", "limitations"),
("improvements", "weaknesses"),
("improvements", "future_work"),
("improvements", "reproducibility"),
]
for section, field in string_fields:
val = json_data.get(section, {}).get(field)
if isinstance(val, list):
errors.append(f"{section}.{field} 应该是字符串段落,不能是数组")
elif not isinstance(val, str) or len(val.strip()) < 50:
errors.append(
f"{section}.{field} 必须是详细段落(≥50字),"
f"当前: {type(val).__name__} ({len(str(val))}字)"
)
# benchmarks 必须是数组
benchmarks = json_data.get("results", {}).get("benchmarks")
if benchmarks is not None and not isinstance(benchmarks, list):
errors.append("results.benchmarks 必须是数组")
# prerequisites.concepts 必须是对象数组,每个有 term
concepts = json_data.get("prerequisites", {}).get("concepts")
if concepts is not None:
if not isinstance(concepts, list):
errors.append("prerequisites.concepts 必须是数组")
elif len(concepts) == 0:
errors.append("prerequisites.concepts 不能为空")
else:
for i, c in enumerate(concepts):
if isinstance(c, str):
errors.append(
f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串"
)
elif isinstance(c, dict) and not c.get("term"):
errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段")
# figures 必须是数组,每个元素应有 id
figures = json_data.get("figures")
if figures is not None:
if not isinstance(figures, list):
errors.append("figures 必须是数组")
else:
for i, fig in enumerate(figures):
if isinstance(fig, dict) and not fig.get("id"):
errors.append(f"figures[{i}] 缺少 id 字段")
return errors
# ── AI 调用 + 重试 ──────────────────────────────────────────────────────
async def _generate_with_retry(
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
) -> tuple[dict, str]:
"""调用 AI 后端生成总结,最多 4 轮验证循环。
根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。
Returns:
(json_data, raw_output)
Raises:
ValueError: 4 轮验证仍未通过
"""
import time as _time
backend = settings.SUMMARY_BACKEND
validation_errors: list[str] = []
json_data: dict | None = None
raw_output = ""
session_id = None
summary_file = paper_dir(arxiv_id) / "summary.json"
# claude 后端需要预构建 promptpi 后端在 call_pi 内部构建)
claude_prompt: str | None = None
if backend == "claude":
_t0 = _time.monotonic()
txt_path = extract_pdf_text(pdf_path, max_chars=None)
body = txt_path.read_text(encoding="utf-8")
if len(body) > 80_000:
trimmed = body[:80_000].rstrip()
txt_path.write_text(trimmed, encoding="utf-8")
claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None)
logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0)
for attempt in range(1, settings.SUMMARY_MAX_RETRIES + 1):
# 清理上一轮写入的不完整文件
if summary_file.exists():
summary_file.unlink()
# 记录 AI 调用开始时间
_t_call_start = _time.monotonic()
if backend == "claude":
if attempt == 1:
raw_output, session_id = await claude_backend.call_claude(
claude_prompt,
session_id=None,
)
else:
retry_prompt = build_prompt(
arxiv_id,
meta_path,
extract_pdf_text(pdf_path, max_chars=80000),
"inject",
fix_errors=validation_errors,
)
raw_output, session_id = await claude_backend.call_claude(
retry_prompt,
session_id=session_id,
fix_errors=validation_errors,
)
else:
if attempt == 1:
raw_output, session_id = await call_pi(
meta_path, pdf_path, pdf_mode=pdf_mode
)
else:
raw_output, session_id = await call_pi(
meta_path,
pdf_path,
fix_errors=validation_errors,
session_id=session_id,
pdf_mode=pdf_mode,
)
_t_call_end = _time.monotonic()
# 检查 summary.json 是否由 AI 子进程写入
file_written_by_ai = summary_file.exists()
file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None
file_size = summary_file.stat().st_size if file_written_by_ai else 0
logger.info(
" [%s] attempt %d AI调用: %.2fs summary.json=%s%s",
arxiv_id,
attempt,
_t_call_end - _t_call_start,
f"已写入({file_size}B)" if file_written_by_ai else "未写入",
f" mtime={file_mtime:.2f}" if file_mtime else "",
)
# 提取 JSON
_t_json_start = _time.monotonic()
try:
if file_written_by_ai:
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id)
else:
json_data = extract_json(raw_output)
except (json.JSONDecodeError, JsonNotFoundError) as exc:
_t_json_end = _time.monotonic()
logger.warning(
" [%s] JSON提取失败: %.2fs %s",
arxiv_id,
_t_json_end - _t_json_start,
str(exc)[:200],
)
validation_errors = [f"无法提取有效 JSON: {truncate_error(exc)}"]
continue
_t_json_end = _time.monotonic()
# 验证
_t_val_start = _time.monotonic()
validation_errors = _validate_summary(json_data, arxiv_id)
_t_val_end = _time.monotonic()
if not validation_errors:
logger.info(
" [%s] JSON提取: %.2fs 验证: %.2fs ✅",
arxiv_id,
_t_json_end - _t_json_start,
_t_val_end - _t_val_start,
)
break
logger.warning(
" [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s",
arxiv_id,
_t_json_end - _t_json_start,
_t_val_end - _t_val_start,
"; ".join(validation_errors)[:200],
)
if validation_errors:
exc = ValueError(
f"Summary validation failed after {settings.SUMMARY_MAX_RETRIES} attempts: {'; '.join(validation_errors)}"
)
exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用
raise exc
return json_data, raw_output
+273
View File
@@ -0,0 +1,273 @@
"""AI 总结持久化 — DB 写入、文件保存、FTS 索引、图片提取、ChromaDB 索引。"""
from __future__ import annotations
import logging
from sqlalchemy import text
from sqlalchemy.orm import Session
from app.models import (
Paper,
PaperSummary,
PaperTag,
SummaryState,
)
from app.services.pdf_downloader import paper_dir
from app.services.schemas import (
SummarySchema,
assess_quality,
flatten_for_db,
)
from app.services.summary_generator import _classify_error
from app.utils import TMP_DIR, truncate_error, utc_now
logger = logging.getLogger(__name__)
# ── FTS5 文本构建 ───────────────────────────────────────────────────────
def _build_fts_summary_text(schema: SummarySchema) -> str:
"""拼接用于 FTS5 索引的总结文本。"""
parts = [
schema.one_line or "",
schema.motivation.problem or "",
schema.motivation.goal or "",
schema.method.overview or "",
schema.method.key_idea or "",
schema.results.main_findings or "",
]
return " ".join(p for p in parts if p)
# ── DB 更新 ─────────────────────────────────────────────────────────────
def _update_summary_in_db(
db: Session,
paper: Paper,
schema: SummarySchema,
quality: str,
raw_output: str,
) -> None:
"""将校验后的总结写入 DBpaper_summaries + papers + paper_tags + FTS5。"""
# 1. paper_summariesupsert
existing = db.get(PaperSummary, paper.id)
flat = flatten_for_db(schema)
if existing:
for k, v in flat.items():
setattr(existing, k, v)
else:
db.add(PaperSummary(paper_id=paper.id, **flat))
# 2. papers 表
paper.title_zh = schema.title_zh
paper.summary_quality = quality
p_dir = paper_dir(paper.arxiv_id)
paper.summary_path = str(p_dir / "summary.json")
paper.raw_output_path = str(p_dir / "raw_output.txt")
# 3. AI 标签
existing_tag_names = {t.tag for t in paper.tags}
for tag_name in schema.tags:
if tag_name not in existing_tag_names:
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
existing_tag_names.add(tag_name)
# 4. FTS5 更新
summary_text = _build_fts_summary_text(schema)
db.execute(
text(
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
"WHERE rowid=:paper_id"
),
{
"title_zh": schema.title_zh,
"summary_text": summary_text,
"paper_id": paper.id,
},
)
db.commit()
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
# ── 文件操作 ────────────────────────────────────────────────────────────
def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) -> None:
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
if schema:
(d / "summary.json").write_text(
schema.model_dump_json(ensure_ascii=False, indent=2),
encoding="utf-8",
)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
# ── 失败处理 ────────────────────────────────────────────────────────────
def _handle_summary_failure(
db: Session,
paper: Paper,
exc: Exception,
raw_output: str,
) -> dict:
"""记录失败:保存 raw_output、重试计数、错误分类。"""
from app.config import settings
error_type = _classify_error(exc)
logger.error(
"Summarize failed: %s error_type=%s %s",
paper.arxiv_id,
error_type,
truncate_error(exc),
)
status = paper.summary_status
if raw_output:
_save_files(paper.arxiv_id, None, raw_output)
status.raw_output_saved = True
status.retry_count = (status.retry_count or 0) + 1
status.error_type = error_type
status.error = truncate_error(exc, limit=2000)
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
status.status = SummaryState.PERMANENT_FAILURE
else:
status.status = SummaryState.PENDING
status.completed_at = utc_now()
db.commit()
return {
"arxiv_id": paper.arxiv_id,
"status": "failed",
"error_type": error_type,
"error": truncate_error(exc),
"retry_count": status.retry_count,
}
# ── 持久化 ──────────────────────────────────────────────────────────────
def _persist_summary(
db: Session, paper: Paper, json_data: dict, raw_output: str
) -> str:
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
import time as _time
arxiv_id = paper.arxiv_id
_t0 = _time.monotonic()
schema = SummarySchema.model_validate(json_data)
quality = assess_quality(schema)
_t1 = _time.monotonic()
_save_files(arxiv_id, schema, raw_output)
_t2 = _time.monotonic()
_update_summary_in_db(db, paper, schema, quality, raw_output)
_t3 = _time.monotonic()
# 状态 → done
paper.summary_status.status = SummaryState.DONE
paper.summary_status.quality = quality
paper.summary_status.completed_at = utc_now()
paper.summary_status.raw_output_saved = True
db.commit()
_t4 = _time.monotonic()
logger.info(
" [%s] persist: pydantic=%.2fs 文件=%.2fs DB写入=%.2fs 状态commit=%.2fs",
arxiv_id,
_t1 - _t0,
_t2 - _t1,
_t3 - _t2,
_t4 - _t3,
)
# 触发性增强(失败不影响总结)
_t5 = _time.monotonic()
_maybe_extract_images(arxiv_id, schema)
_t6 = _time.monotonic()
_maybe_index_chroma(arxiv_id, paper, schema)
_t7 = _time.monotonic()
logger.info(
" [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs",
arxiv_id,
_t6 - _t5,
_t7 - _t6,
)
return quality
# ── 清理 ────────────────────────────────────────────────────────────────
def _cleanup_old_images(db: Session, paper: Paper) -> None:
"""清理旧的图片文件和 figures_json,避免重新总结时残留。"""
arxiv_id = paper.arxiv_id
images_dir = paper_dir(arxiv_id) / "images"
if images_dir.exists():
for old_file in images_dir.iterdir():
if (
old_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg")
or old_file.name == "manifest.json"
):
old_file.unlink(missing_ok=True)
# 清除数据库中的 figures_json
if paper.summary and paper.summary.figures_json:
paper.summary.figures_json = None
db.commit()
# ── 触发性增强 ──────────────────────────────────────────────────────────
def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
"""从 PDF 提取图片和表格(失败不影响总结)。
两阶段流水线:
1. PicoDet 检测 + 渲染截图(通用标签)
2. 用 summary 的 figures ID 在 PDF 中搜索定位 → 重命名
"""
try:
from app.services.pdf_image_extractor import (
extract_images_from_pdf,
label_images_by_summary,
)
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
extract_images_from_pdf(arxiv_id, pdf_path)
if schema.figures:
label_images_by_summary(arxiv_id, schema.figures, pdf_path)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None:
"""写入 ChromaDB 语义索引(失败不影响总结)。"""
try:
from app.services.embedder import index_paper
texts_dict = {
"arxiv_id": arxiv_id,
"title_zh": schema.title_zh or "",
"title_en": paper.title_en or "",
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
"one_line": schema.one_line or "",
"motivation_problem": schema.motivation.problem or "",
"method_key_idea": schema.method.key_idea or "",
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
}
index_paper(arxiv_id, texts_dict)
except Exception:
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
+10 -5
View File
@@ -80,9 +80,14 @@ def _trim_body(text: str, max_chars: int | None = None) -> str:
ack_match = re.search(r"(?m)^(?:Acknowledgments?\s*|致谢\s*)$", text)
if ack_match:
# 只删 Acknowledgments 本身,不删后面的内容
next_section = re.search(r"(?m)^(?:A\s|Appendix|Supplementary|附录)\s*$", text[ack_match.start():])
next_section = re.search(
r"(?m)^(?:A\s|Appendix|Supplementary|附录)\s*$", text[ack_match.start() :]
)
if next_section:
text = text[:ack_match.start()] + text[ack_match.start() + next_section.start():]
text = (
text[: ack_match.start()]
+ text[ack_match.start() + next_section.start() :]
)
else:
text = text[: ack_match.start()].rstrip()
@@ -105,10 +110,9 @@ def extract_pdf_text(pdf_path: Path, max_chars: int | None = None) -> Path:
# 缓存优先;如果需重新提取(不同 max_chars),先删旧文件
return txt_path
doc = pymupdf.open(str(pdf_path))
with pymupdf.open(str(pdf_path)) as doc:
# sort=True 启用阅读顺序检测,避免双栏论文中跨栏错位
raw_text = "\n\n".join(page.get_text(sort=True) for page in doc)
doc.close()
body = _trim_body(raw_text, max_chars=max_chars)
txt_path.write_text(body, encoding="utf-8")
@@ -160,7 +164,8 @@ def build_prompt(
'"reproducibility": "详细段落:复现评估(开源情况、数据、算力、难度")}, '
'"figures": [{"id":"Figure 1","caption":"原图标题","description":"文字描述图展示了什么","reason":"为什么这张图对理解论文重要","section":"method"},'
'{"id":"Table 1","caption":"表格标题","description":"文字描述表格包含的数据和结论","reason":"为什么这个表格对理解论文重要","section":"results"}]'
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Tableid 严格使用 \"Figure N\"\"Table N\" 格式"
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Table。"
'id 必须严格复用论文原文的写法(原文用 "Fig. 1" 就写 "Fig. 1",用 "Figure A1" 就写 "Figure A1",用 "Table 1" 就写 "Table 1")。'
"section 必须是 motivation/method/results/limitations 之一,表示该图最适合展示在哪个章节。"
"}"
)
+28 -11
View File
@@ -5,7 +5,15 @@ from __future__ import annotations
from sqlalchemy import or_, select
from sqlalchemy.orm import Session, joinedload
from app.models import PAPER_FULL_LOAD, Paper, PaperTag, UserBookmark, UserNote, UserReadingStatus
from app.exceptions import NotFoundError, ValidationError
from app.models import (
PAPER_FULL_LOAD,
Paper,
PaperTag,
UserBookmark,
UserNote,
UserReadingStatus,
)
from app.utils import utc_now
# ── 收藏 ──────────────────────────────────────────────────────────────
@@ -13,9 +21,11 @@ from app.utils import utc_now
def toggle_bookmark(db: Session, arxiv_id: str) -> dict:
"""切换收藏状态。返回 {"bookmarked": bool, "arxiv_id": str}。"""
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
paper = db.execute(
select(Paper).where(Paper.arxiv_id == arxiv_id)
).scalar_one_or_none()
if not paper:
return {"error": "not_found"}
raise NotFoundError(f"Paper not found: {arxiv_id}")
existing = db.execute(
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
@@ -42,11 +52,15 @@ VALID_STATUSES = {"unread", "skimmed", "read_summary", "read_full"}
def set_reading_status(db: Session, arxiv_id: str, status: str) -> dict:
"""设置阅读状态。status 必须是 unread/skimmed/read_summary/read_full。"""
if status not in VALID_STATUSES:
return {"error": "invalid_status", "valid": sorted(VALID_STATUSES)}
raise ValidationError(
f"Invalid reading status: {status}. Valid: {', '.join(sorted(VALID_STATUSES))}"
)
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
paper = db.execute(
select(Paper).where(Paper.arxiv_id == arxiv_id)
).scalar_one_or_none()
if not paper:
return {"error": "not_found"}
raise NotFoundError(f"Paper not found: {arxiv_id}")
now = utc_now()
existing = db.execute(
@@ -72,7 +86,9 @@ def set_reading_status(db: Session, arxiv_id: str, status: str) -> dict:
def get_note(db: Session, arxiv_id: str) -> dict | None:
"""获取笔记。返回 {"arxiv_id", "content", "updated_at"} 或 None(论文不存在时)。"""
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
paper = db.execute(
select(Paper).where(Paper.arxiv_id == arxiv_id)
).scalar_one_or_none()
if not paper:
return None
@@ -91,9 +107,11 @@ def get_note(db: Session, arxiv_id: str) -> dict | None:
def save_note(db: Session, arxiv_id: str, content: str) -> dict:
"""创建或更新笔记。返回 {"arxiv_id", "content", "updated_at"}。"""
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
paper = db.execute(
select(Paper).where(Paper.arxiv_id == arxiv_id)
).scalar_one_or_none()
if not paper:
return {"error": "not_found"}
raise NotFoundError(f"Paper not found: {arxiv_id}")
now = utc_now()
existing = db.execute(
@@ -154,8 +172,7 @@ def query_reading_list(
stmt.options(
joinedload(Paper.note),
*PAPER_FULL_LOAD,
)
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
).order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
)
.unique()
.scalars()
+42 -6
View File
@@ -137,12 +137,35 @@ def safe_json_loads(text: str | None, default: Any = None) -> Any:
# AI 生成内容中允许的 HTML 标签和属性
_ALLOWED_TAGS = {
"p", "br", "strong", "b", "em", "i", "u", "s", "del",
"h3", "h4", "h5", "h6",
"ul", "ol", "li",
"a", "code", "pre", "blockquote",
"table", "thead", "tbody", "tr", "th", "td",
"sup", "sub", "span",
"p",
"br",
"strong",
"b",
"em",
"i",
"u",
"s",
"del",
"h3",
"h4",
"h5",
"h6",
"ul",
"ol",
"li",
"a",
"code",
"pre",
"blockquote",
"table",
"thead",
"tbody",
"tr",
"th",
"td",
"sup",
"sub",
"span",
}
_ALLOWED_ATTRS = {
"a": {"href", "title"},
@@ -167,3 +190,16 @@ def sanitize_html(text: str | None) -> str:
strip=True,
)
return cleaned
# ── 错误消息截断 ────────────────────────────────────────────────────────
_ERROR_TRUNCATE_LIMIT = 500
def truncate_error(exc: Exception | str, limit: int = _ERROR_TRUNCATE_LIMIT) -> str:
"""将异常或字符串截断到指定长度,保持统一的错误消息格式。"""
text = str(exc)
if len(text) <= limit:
return text
return text[:limit] + f"... ({len(text)} chars total)"
+106
View File
@@ -0,0 +1,106 @@
import json
data = {
"arxiv_id": "2602.21760",
"title_zh": "基于条件引导调度的混合数据-流水线并行加速扩散模型",
"one_line": "提出混合并行框架,通过条件划分与自适应流水线切换加速扩散推理,实现2.31倍提速。",
"tags": ["Diffusion Models", "Distributed Inference", "Parallel Computing", "Image Generation"],
"difficulty": "进阶",
"prerequisites": {
"concepts": [
{
"term": "Diffusion Models",
"explanation": "扩散模型是一类基于去噪过程的生成模型。在正向过程中,它逐渐向数据添加高斯噪声直到变成纯噪声;在反向过程中,模型学习逐步去噪以恢复原始数据。这种迭代特性虽然能生成高质量的样本,但也导致了高昂的推理计算成本。",
"why_matters": "理解扩散模型的迭代去噪机制是理解本文如何通过并行化减少推理延迟的基础。"
},
{
"term": "Classifier-Free Guidance (CFG)",
"explanation": "无分类器引导是一种在推理时提升生成样本与文本条件一致性的技术。模型同时预测有条件噪声(给定文本提示)和无条件噪声(不给定提示),最终通过加权组合两者来获得最终预测。公式为 $\epsilon_{cfg} = \epsilon_\theta(x_t, c, t) + w (\epsilon_\theta(x_t, c, t) - \epsilon_\theta(x_t, t))$,其中 $w$ 是引导强度。",
"why_matters": "本文的核心创新点在于利用CFG中存在的有条件和无条件双路径作为数据划分的基础。"
},
{
"term": "Distributed Inference",
"explanation": "分布式推理利用多个GPU并行处理计算任务以减少延迟。主要分为数据并行(如将图像切片处理)和流水线并行(如将模型层切分)。然而,现有的分布式方法在扩散模型中往往面临通信开销大或生成图像出现拼接伪影的问题。",
"why_matters": "本文提出的混合并行框架正是为了解决现有分布式推理方法中的这些痛点。"
}
]
},
"motivation": {
"problem": "现有的扩散模型加速方法,无论是单卡优化(如减少采样步数、模型剪枝)还是多卡分布式并行(如DistriFusion和AsyncDiff),都存在明显的局限性。单卡优化受限于硬件算力上限,而现有多卡并行方法通常只能实现次线性的加速比。例如,DistriFusion将图像切片并行处理,容易在拼接处产生明显的伪影;AsyncDiff采用异步流水线,虽然加速了但会引入估计误差,且通信开销巨大(在SDXL上高达9.83GB)。",
"goal": "本文旨在提出一种新颖的混合并行框架,在仅使用两张GPU的情况下,不仅能实现超过线性的加速比(即 $>2\times$),还要严格保持甚至提升生成图像的质量,同时将通信开销降到最低。",
"gap": "与以往将图像空间切片(Patch-based)的思路不同,本文独辟蹊径,利用无分类器引导(CFG)中天然存在的“有条件”和“无条件”两条路径作为新的数据划分维度(Condition-based Partitioning)。同时,作者发现这两条路径的预测误差差异在整个去噪过程中呈现出先大后小再变大的U型曲线,因此引入了自适应的并行切换策略,只在误差差异最小时才进行并行流水线处理。"
},
"method": {
"overview": "该框架的核心思想是将扩散推理过程划分为三个阶段:预热阶段(Warm-Up)、并行阶段(Parallelism)和完全连接阶段(Fully-Connecting)。在预热和完全连接阶段,使用“基于条件的划分”策略,即一张GPU处理有条件的预测,另一张处理无条件的预测。而在中间的并行阶段,由于两个预测结果非常接近,框架切换到“自适应流水线并行”,利用两张GPU交替执行推理步骤,从而大幅压缩时间。",
"key_idea": "核心创新在于不再将图片在空间上切片,而是沿“条件”维度切分数据。这保证了每个GPU都能看到整张图片的全局信息,从而避免了拼接伪影。此外,引入了“去噪差异度”(Denoising Discrepancy,即 rel-MAE)这一指标来动态评估两条路径的相似性,并以此自动决定何时开启和关闭流水线并行,实现了最优的加速-质量平衡。",
"steps": "1. 数据划分:输入潜变量同时送入GPU 1(有条件预测 $\epsilon_\theta(x_t, c, t)$)和GPU 2(无条件预测 $\epsilon_\theta(x_t, t)$)。2. 阶段判断:根据实时计算的“去噪差异度” $G_t$ 与阈值 $g_{slope}$ 的关系,确定切换点 $\tau_1$ 和 $\tau_2$。3. 混合执行:在 $[T, \tau_1]$ 阶段同步运行;在 $[\tau_1, \tau_2]$ 阶段启用流水线并行(如GPU 1处理 $t-1$ 步时GPU 2处理 $t$ 步);在 $[\tau_2, 0]$ 阶段重新恢复同步以精细调整细节。",
"novelty": "该方法的另一大新颖之处在于其“安全性”设计:通过设置 $\tau_{cap}$ 作为安全上限,确保即使自动算法失效,也不会在错误的时间点引入并行,从而保证了算法的鲁棒性。此外,该框架对U-Net(如SDXL)和DiT(如SD3)架构均具有良好的泛化性。"
},
"results": {
"main_findings": "实验在SDXL和SD3模型上进行,使用MS-COCO 2014验证集。结果显示,在SDXL上,该方法实现了2.31倍加速,延迟从16.49秒降至7.12秒,且FID指标与原始单卡模型持平(甚至略优)。相比此前最强的DistriFusion1.22倍)和AsyncDiff(1.31倍),提速效果显著。在通信开销方面,本方法仅为0.516GB,比AsyncDiff的9.83GB降低了19.6倍。在SD3模型上,同样实现了2.07倍的加速。",
"benchmarks": [
{
"task": "Text-to-Image (SDXL)",
"metric": "Speed-Up",
"this_work": "2.31x",
"baseline": "1.31x (AsyncDiff)",
"improvement": "1.0x (Extra speed)"
},
{
"task": "Text-to-Image (SDXL)",
"metric": "Comm. (GB)",
"this_work": "0.516",
"baseline": "9.830 (AsyncDiff)",
"improvement": "Reduced by 19.6x"
},
{
"task": "Text-to-Image (SD3)",
"metric": "Speed-Up",
"this_work": "2.07x",
"baseline": "1.97x (AsyncDiff)",
"improvement": "0.1x (Extra speed)"
}
],
"limitations": "尽管该方法在通用性上表现出色,但在处理极高分辨率(如4K以上)时,加速比会随分辨率提升而有所下降(从2.72x降至1.62x)。此外,目前的实现仅针对两张GPU进行了深度优化,虽然文中提出了多卡扩展策略,但在单个样本推理场景下,如何高效地扩展到四卡或更多卡仍是一个挑战。最后,参数 $k$ 的选取目前仍需人工根据经验设定。"
},
"improvements": {
"weaknesses": "主要弱点在于自适应切换参数(如 $k$ 和 $\tau_{cap}$)的确定目前仍偏向经验性,缺乏完全自动化的端到端学习机制。此外,虽然避免了图像切片,但条件分支的“信息量”并不总是完全对等的,特别是在极早期的噪声阶段,可能导致其中一张GPU负载不均衡。改进方向可以是结合动态负载均衡算法,根据当前步骤的预测难度动态分配计算资源。",
"future_work": "未来的研究方向包括:1. 将该混合并行策略扩展到视频生成模型(Video Diffusion)中,利用时间轴上的相关性进行更细粒度的流水线调度。2. 结合模型量化(Quantization)和蒸馏技术,在多卡并行的基础上进一步压缩单步推理时间。3. 探索在“去噪差异度”指标指导下自动学习最优的 $k$ 值和切换点。",
"reproducibility": "代码已在GitHub开源(https://github.com/kaist-dmlab/Hybridiff)。实验环境基于PyTorch,使用的GPU为NVIDIA GeForce 3090,硬件门槛相对较低。文中详细列出了关键超参数(如SDXL上的 $L=12, k=5, \tau_{cap}=15$),使得复现结果的难度较低。"
},
"figures": [
{
"id": "Figure 1",
"caption": "Summary of the proposed hybrid data-pipeline parallelism",
"description": "五维雷达图展示了该方法在速度、图像质量、通用性、高分辨率能力和通信开销五个方面均优于现有分布式框架。",
"reason": "直观概括了本文的核心优势,即全方位的性能提升。",
"section": "results"
},
{
"id": "Figure 2",
"caption": "Comparison of parallel strategies",
"description": "对比了三种并行策略:(a)基于切片的数据并行容易产生伪影,(b)流水线并行通信开销大,(c)本文提出的混合并行既保留全局一致性又实现了高效并行。",
"reason": "通过对比展示了本文方法设计的合理性和必要性。",
"section": "method"
},
{
"id": "Figure 3",
"caption": "Overview of the hybrid parallel framework",
"description": "详细展示了三个阶段(Warm-Up, Parallelism, Fully-Connecting)的数据流和通信模式,清晰地说明了自适应切换的动态过程。",
"reason": "这是理解整个算法执行流程的关键示意图。",
"section": "method"
},
{
"id": "Table 1",
"caption": "Quantitative comparison on SDXL and SD3",
"description": "表格列出了该方法与基线方法在延迟、加速比、通信开销及生成质量指标(FID, LPIPS, PSNR)上的详细对比数据。",
"reason": "提供了最核心的定量证据,证明了该方法的有效性。",
"section": "results"
}
]
}
with open("data/papers/2602.21760/summary.json", "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print("summary.json created successfully.")
+1 -1
View File
@@ -19,7 +19,7 @@ dependencies = [
"pymupdf>=1.25",
"itsdangerous>=2.2.0",
"bleach>=6.4.0",
"pymupdf4llm>=1.27.2.3",
"onnxruntime>=1.17",
]
[project.optional-dependencies]
+172
View File
@@ -0,0 +1,172 @@
"""导出 PicoDet-S_layout_3cls 为 ONNX 格式.
一次性脚本,在独立 venv 中运行:
python -m venv .venv-export && source .venv-export/bin/activate
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple paddlepaddle paddleocr paddle2onnx onnxruntime opencv-python-headless
HF_ENDPOINT=https://hf-mirror.com python scripts/export_picodet_onnx.py
输出:
data/models/picodet_layout_3cls.onnx
"""
from __future__ import annotations
import os
import shutil
import subprocess
import sys
from pathlib import Path
# hf-mirror
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
PROJECT_ROOT = Path(__file__).resolve().parent.parent
MODEL_DIR = PROJECT_ROOT / "data" / "models"
OUTPUT_PATH = MODEL_DIR / "picodet_layout_3cls.onnx"
MODEL_NAME = "PicoDet-S_layout_3cls"
def main() -> None:
MODEL_DIR.mkdir(parents=True, exist_ok=True)
# ── Step 1: 用 PaddleOCR paddle_static 引擎加载模型,触发下载 ──
print(f"[1/4] Loading model '{MODEL_NAME}' (paddle_static engine, triggers download) ...")
from paddleocr import LayoutDetection
model = LayoutDetection(
model_name=MODEL_NAME,
engine="paddle_static",
device="cpu",
)
print(" ✓ Model loaded and cached")
# ── Step 2: 找到 PaddleX 缓存的 Paddle 模型文件 ────────────────
paddlex_cache = Path.home() / ".paddlex"
print(f"\n[2/4] Searching Paddle model cache in {paddlex_cache} ...")
# 搜索 layout 相关的缓存目录
candidates = []
for d in paddlex_cache.rglob("*"):
if d.is_dir() and (d / "inference.pdiparams").exists():
# 检查是否是 layout 模型
marker = d.name
parent_name = d.parent.name
if "layout" in marker.lower() or "layout" in parent_name.lower() or "picodet" in marker.lower():
candidates.append(d)
elif "PicoDet" in str(d):
candidates.append(d)
if not candidates:
# 如果没找到明确的 layout 目录,列出所有含 inference.pdiparams 的目录
all_model_dirs = [d for d in paddlex_cache.rglob("*") if d.is_dir() and (d / "inference.pdiparams").exists()]
print(" No layout-specific dir found. All model dirs with inference.pdiparams:")
for d in all_model_dirs:
files = [f.name for f in d.iterdir()]
print(f" {d} ({', '.join(files)})")
if all_model_dirs:
# 取最新的(刚下载的)
candidates = sorted(all_model_dirs, key=lambda d: (d / "inference.pdiparams").stat().st_mtime, reverse=True)[:1]
if not candidates:
print(" ✗ No cached model found")
sys.exit(1)
model_cache_dir = candidates[0]
files_in_dir = list(model_cache_dir.iterdir())
print(f" Using: {model_cache_dir}")
for f in files_in_dir:
print(f" {f.name} ({f.stat().st_size / 1024:.1f} KB)")
# ── Step 3: 用 paddle2onnx 转换 ─────────────────────────────────
print("\n[3/4] Converting to ONNX with paddle2onnx ...")
tmp_onnx = OUTPUT_PATH.with_suffix(".tmp.onnx")
# 确定 model_filename
pdmodel = model_cache_dir / "inference.pdmodel"
has_pdmodel = pdmodel.exists()
cmd = [
sys.executable, "-m", "paddle2onnx",
"--model_dir", str(model_cache_dir),
"--save_file", str(tmp_onnx),
"--opset_version", "11",
"--enable_onnx_checker", "True",
]
if has_pdmodel:
cmd.extend(["--model_filename", "inference.pdmodel"])
cmd.extend(["--params_filename", "inference.pdiparams"])
print(f" Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.stdout:
print(f" stdout: {result.stdout[:500]}")
if result.returncode != 0:
print(f" ✗ paddle2onnx failed (exit {result.returncode})")
print(f" stderr: {result.stderr[:500]}")
# 尝试不带 model_filenamecombined format
if has_pdmodel:
print(" Retrying without explicit model_filename ...")
cmd2 = [
sys.executable, "-m", "paddle2onnx",
"--model_dir", str(model_cache_dir),
"--params_filename", "inference.pdiparams",
"--save_file", str(tmp_onnx),
"--opset_version", "11",
]
result2 = subprocess.run(cmd2, capture_output=True, text=True)
if result2.returncode != 0:
print(f" ✗ Retry also failed: {result2.stderr[:500]}")
sys.exit(1)
if not tmp_onnx.exists() or tmp_onnx.stat().st_size < 1000:
print(" ✗ ONNX file not created or too small")
sys.exit(1)
shutil.move(str(tmp_onnx), str(OUTPUT_PATH))
print(f" ✓ ONNX saved ({OUTPUT_PATH.stat().st_size / 1024 / 1024:.2f} MB)")
# ── Step 4: 用 onnxruntime 验证 ─────────────────────────────────
print("\n[4/4] Verifying with onnxruntime ...")
_inspect_onnx(OUTPUT_PATH)
print(f"\n✓ Done! ONNX model saved to {OUTPUT_PATH}")
def _inspect_onnx(onnx_path: Path) -> None:
"""用 onnxruntime 加载模型,打印输入输出信息."""
import numpy as np
import onnxruntime as ort
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
print(" Inputs:")
for inp in session.get_inputs():
print(f" {inp.name}: shape={inp.shape}, dtype={inp.type}")
print(" Outputs:")
for out in session.get_outputs():
print(f" {out.name}: shape={out.shape}, dtype={out.type}")
# 试推理
input_info = session.get_inputs()[0]
input_name = input_info.name
batch_size = input_info.shape[0] if isinstance(input_info.shape[0], int) else 1
channels = input_info.shape[1] if isinstance(input_info.shape[1], int) else 3
height = input_info.shape[2] if isinstance(input_info.shape[2], int) else 480
width = input_info.shape[3] if isinstance(input_info.shape[3], int) else 480
dummy_input = np.random.rand(batch_size, channels, height, width).astype(np.float32)
outputs = session.run(None, {input_name: dummy_input})
print(" Inference test outputs:")
for i, (out_info, out_val) in enumerate(zip(session.get_outputs(), outputs)):
print(f" output[{i}] '{out_info.name}': shape={out_val.shape}, dtype={out_val.dtype}")
if out_val.size <= 20:
print(f" values: {out_val}")
print(" ✓ Inference OK")
if __name__ == "__main__":
main()
+212
View File
@@ -0,0 +1,212 @@
"""批量重新提取所有论文的图片 — 下载 PDF + PicoDet 检测 + caption 匹配.
用法:
PROXY_SERVER=http://... uv run python scripts/reextract_images.py
uv run python scripts/reextract_images.py --limit 10 # 只处理前 10 篇
uv run python scripts/reextract_images.py --id 2512.24880 # 只处理指定论文
"""
from __future__ import annotations
import json
import logging
import os
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import requests
# 让脚本可以从项目根目录直接运行
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from dotenv import load_dotenv
load_dotenv()
from app.database import SessionLocal, init_db, engine # noqa: E402
from app.models import Paper # noqa: E402
from app.services.pdf_image_extractor import extract_images_from_pdf # noqa: E402
from app.utils import TMP_DIR # noqa: E402
from sqlalchemy import select # noqa: E402
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-5s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
# 下载并发数
MAX_WORKERS = 3
# 下载超时(秒)
DOWNLOAD_TIMEOUT = 120
def _get_session() -> requests.Session:
"""创建带代理的 HTTP session。"""
sess = requests.Session()
sess.headers.update({"User-Agent": "hf-daily-papers/1.0"})
proxy = os.environ.get("PROXY_SERVER") or os.environ.get("HTTPS_PROXY")
if proxy:
sess.proxies = {"http": proxy, "https": proxy}
logger.info("使用代理: %s", proxy)
else:
logger.warning("未设置代理 (PROXY_SERVER / HTTPS_PROXY),直连 arxiv.org")
return sess
def download_pdf(session: requests.Session, arxiv_id: str, pdf_url: str) -> Path | None:
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf,返回路径或 None。"""
dest_dir = TMP_DIR / arxiv_id
dest = dest_dir / "paper.pdf"
if dest.exists() and dest.stat().st_size > 1000:
return dest
dest_dir.mkdir(parents=True, exist_ok=True)
try:
resp = session.get(pdf_url, timeout=DOWNLOAD_TIMEOUT, allow_redirects=True)
resp.raise_for_status()
dest.write_bytes(resp.content)
return dest
except Exception as exc:
logger.warning("下载失败 %s: %s", arxiv_id, exc)
return None
def process_one(session: requests.Session, arxiv_id: str, pdf_url: str) -> dict:
"""处理单篇论文:下载 → 提取图片 → 返回统计。"""
result = {"arxiv_id": arxiv_id, "downloaded": False, "extracted": 0, "error": None}
# 下载 PDF
pdf_path = download_pdf(session, arxiv_id, pdf_url)
if pdf_path is None:
result["error"] = "download_failed"
return result
result["downloaded"] = True
# 提取图片
try:
n = extract_images_from_pdf(arxiv_id, pdf_path)
result["extracted"] = n
except Exception as exc:
logger.warning("提取失败 %s: %s", arxiv_id, exc, exc_info=True)
result["error"] = f"extract_failed: {exc}"
return result
# 统计 matched / orphan
mf = Path(f"data/papers/{arxiv_id}/images/manifest.json")
if mf.exists():
m = json.loads(mf.read_text(encoding="utf-8"))
result["matched"] = sum(1 for v in m.values() if "(p" not in v.get("label", ""))
result["orphan"] = sum(1 for v in m.values() if "(p" in v.get("label", ""))
return result
def main():
import argparse
parser = argparse.ArgumentParser(description="批量重新提取论文图片")
parser.add_argument("--limit", type=int, default=0, help="只处理前 N 篇")
parser.add_argument("--id", dest="arxiv_id", help="只处理指定 arxiv_id")
parser.add_argument("--workers", type=int, default=MAX_WORKERS, help="并发数")
args = parser.parse_args()
# 初始化数据库
os.makedirs("data/db", exist_ok=True)
init_db(engine)
# 读取论文列表
db = SessionLocal()
try:
if args.arxiv_id:
papers = (
db.execute(select(Paper).where(Paper.arxiv_id == args.arxiv_id))
.scalars()
.all()
)
else:
papers = db.execute(select(Paper)).scalars().all()
finally:
db.close()
if args.limit > 0:
papers = papers[: args.limit]
total = len(papers)
logger.info("待处理论文: %d", total)
if total == 0:
return
session = _get_session()
# 统计
done = 0
failed = 0
total_extracted = 0
total_matched = 0
total_orphan = 0
t0 = time.time()
with ThreadPoolExecutor(max_workers=args.workers) as pool:
futures = {}
for p in papers:
f = pool.submit(process_one, session, p.arxiv_id, p.pdf_url)
futures[f] = p.arxiv_id
for f in as_completed(futures):
arxiv_id = futures[f]
try:
r = f.result()
except Exception as exc:
logger.error("异常 %s: %s", arxiv_id, exc)
failed += 1
done += 1
continue
done += 1
if r["error"]:
failed += 1
logger.info("[%d/%d] ✗ %s%s", done, total, arxiv_id, r["error"])
else:
total_extracted += r["extracted"]
total_matched += r.get("matched", 0)
total_orphan += r.get("orphan", 0)
matched = r.get("matched", 0)
orphan = r.get("orphan", 0)
elapsed = time.time() - t0
rate = done / elapsed if elapsed > 0 else 0
eta = (total - done) / rate if rate > 0 else 0
logger.info(
"[%d/%d] ✓ %s%d 张 (matched=%d, orphan=%d) ETA %.0fs",
done,
total,
arxiv_id,
r["extracted"],
matched,
orphan,
eta,
)
elapsed = time.time() - t0
logger.info("=" * 60)
logger.info(
"完成: %d/%d 成功, %d 失败, 耗时 %.1fs",
done - failed,
total,
failed,
elapsed,
)
logger.info(
"图片: %d 总计, %d matched, %d orphan (%.1f%%)",
total_extracted,
total_matched,
total_orphan,
total_orphan / total_extracted * 100 if total_extracted else 0,
)
if __name__ == "__main__":
main()
+7 -1
View File
@@ -161,7 +161,13 @@ def sample_summary_dict() -> dict:
"results": {
"main_findings": "在长文本基准 LongBench 上取得了 SOTA 结果,平均得分提升 3.2 个百分点。推理速度相比全注意力提升了 2 倍,显存占用降低 60%。在 32k 序列长度下仍保持与全注意力相当的生成质量。",
"benchmarks": [
{"task": "长文本摘要", "metric": "ROUGE-L", "this_work": "42.1", "baseline": "38.9", "improvement": "+3.2"},
{
"task": "长文本摘要",
"metric": "ROUGE-L",
"this_work": "42.1",
"baseline": "38.9",
"improvement": "+3.2",
},
],
"limitations": "在超长文本(>100k tokens)上效果有所下降,主要原因是全局采样点数量不足以覆盖所有关键信息。此外,在小规模数据集上的优势不如大规模数据集明显。",
},
+9 -13
View File
@@ -67,7 +67,7 @@ class TestAdminAuth:
def test_correct_session_accepted(self, auth_client):
"""已登录 session 应被接受(crawl 可能会失败但不是 303)。"""
with patch(
"app.routes.admin.crawl_daily", new_callable=AsyncMock
"app.routes.admin.run_crawl", new_callable=AsyncMock
) as mock_crawl:
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
resp = auth_client.post("/admin/crawl")
@@ -83,9 +83,7 @@ class TestAdminAuth:
def test_correct_session_batch_summarize(self, auth_client):
"""已登录调用 batch summarizemock 掉服务层。"""
with patch(
"app.routes.admin.summarize_batch", new_callable=AsyncMock
) as mock:
with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock:
mock.return_value = {
"status": "success",
"done": 0,
@@ -98,10 +96,12 @@ class TestAdminAuth:
def test_single_paper_not_found(self, auth_client):
"""单篇总结不存在的论文返回 404。"""
from app.exceptions import NotFoundError
with patch(
"app.routes.admin.summarize_single",
new_callable=AsyncMock,
return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"},
side_effect=NotFoundError("Paper not found: nonexistent.99999"),
):
resp = auth_client.post("/admin/summarize/nonexistent.99999")
assert resp.status_code == 404
@@ -118,7 +118,7 @@ class TestAdminCrawl:
def test_crawl_default_today(self, auth_client):
"""不指定日期时默认抓取今天。"""
with patch(
"app.routes.admin.crawl_daily", new_callable=AsyncMock
"app.routes.admin.run_crawl", new_callable=AsyncMock
) as mock_crawl:
mock_crawl.return_value = {"found": 5, "new": 3, "status": "success"}
resp = auth_client.post("/admin/crawl")
@@ -130,7 +130,7 @@ class TestAdminCrawl:
def test_crawl_specific_date(self, auth_client):
"""指定日期抓取。"""
with patch(
"app.routes.admin.crawl_daily", new_callable=AsyncMock
"app.routes.admin.run_crawl", new_callable=AsyncMock
) as mock_crawl:
mock_crawl.return_value = {"found": 2, "new": 1, "status": "success"}
resp = auth_client.post("/admin/crawl?date=2024-01-15")
@@ -194,9 +194,7 @@ class TestAdminDelete:
)
assert resp.status_code == 422
def test_delete_with_confirm(
self, auth_client, db_session, sample_papers_range
):
def test_delete_with_confirm(self, auth_client, db_session, sample_papers_range):
"""confirm='DELETE' 时应执行删除。"""
resp = auth_client.post(
"/admin/delete",
@@ -255,9 +253,7 @@ class TestAdminLogs:
resp = client.get("/admin/logs", follow_redirects=False)
assert resp.status_code == 303
def test_logs_contains_data(
self, auth_client, db_session, sample_papers_range
):
def test_logs_contains_data(self, auth_client, db_session, sample_papers_range):
"""日志页面应包含日志数据。"""
# 先创建一条日志
now = utc_now()
+189
View File
@@ -0,0 +1,189 @@
"""爬虫服务测试 — _parse_paper、fetch_daily、upsert_papers、crawl_daily。"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.crawler import (
_parse_paper,
crawl_daily,
fetch_daily,
upsert_papers,
)
# ═══════════════════════════════════════════════════════════════════════
# _parse_paper
# ═══════════════════════════════════════════════════════════════════════
class TestParsePaper:
def test_normal_item(self):
item = {
"paper": {
"id": "2401.12345",
"title": "Test Paper",
"abstract": "Abstract text",
"publishedAt": "2024-01-15T00:00:00",
"authors": [{"name": "Alice"}, {"name": "Bob"}],
"tags": [{"name": "NLP"}, {"name": "LLM"}],
"upvotes": 42,
}
}
result = _parse_paper(item)
assert result["arxiv_id"] == "2401.12345"
assert result["title_en"] == "Test Paper"
assert len(result["authors"]) == 2
assert result["authors"] == ["Alice", "Bob"]
assert result["tags"] == ["NLP", "LLM"]
assert result["upvotes"] == 42
assert "huggingface.co" in result["hf_url"]
def test_empty_id(self):
item = {"paper": {"id": "", "authors": [], "tags": []}}
result = _parse_paper(item)
assert result["arxiv_id"] == ""
assert result["hf_url"] == ""
def test_missing_published_at(self):
item = {"paper": {"id": "2401.00001", "title": "T", "authors": [], "tags": []}}
result = _parse_paper(item)
assert result["published_at"] is None
def test_flat_structure_fallback(self):
"""无 paper 包装时直接从顶层取字段。"""
item = {"id": "2401.99999", "title": "Flat", "authors": [], "tags": []}
result = _parse_paper(item)
assert result["arxiv_id"] == "2401.99999"
assert result["title_en"] == "Flat"
# ═══════════════════════════════════════════════════════════════════════
# fetch_daily
# ═══════════════════════════════════════════════════════════════════════
class TestFetchDaily:
@pytest.mark.asyncio
async def test_returns_papers(self, monkeypatch):
fake_data = [{"paper": {"id": "2401.00001"}}]
mock_resp = MagicMock()
mock_resp.json.return_value = fake_data
mock_resp.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("app.services.crawler.make_http_client", return_value=mock_client):
result = await fetch_daily("2024-01-15")
assert len(result) == 1
assert result[0]["paper"]["id"] == "2401.00001"
@pytest.mark.asyncio
async def test_respects_top_n(self, monkeypatch):
fake_data = [{"paper": {"id": f"2401.{i:05d}"}} for i in range(10)]
mock_resp = MagicMock()
mock_resp.json.return_value = fake_data
mock_resp.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("app.services.crawler.make_http_client", return_value=mock_client):
result = await fetch_daily("2024-01-15", top_n=3)
assert len(result) == 3
# ═══════════════════════════════════════════════════════════════════════
# upsert_papers
# ═══════════════════════════════════════════════════════════════════════
class TestUpsertPapers:
def test_inserts_new_paper(self, db_session):
papers_raw = [
{
"paper": {
"id": "2401.00001",
"title": "New Paper",
"abstract": "Abstract",
"authors": [{"name": "Alice"}],
"tags": [{"name": "CV"}],
"upvotes": 5,
}
}
]
new = upsert_papers(db_session, papers_raw, "2024-01-15")
assert len(new) == 1
assert new[0].arxiv_id == "2401.00001"
assert new[0].title_en == "New Paper"
def test_updates_existing_upvotes(self, db_session, sample_paper):
papers_raw = [
{
"paper": {
"id": sample_paper.arxiv_id,
"title": sample_paper.title_en,
"upvotes": 999,
"authors": [],
"tags": [],
}
}
]
new = upsert_papers(db_session, papers_raw, "2024-01-15")
assert len(new) == 0 # 不新增
db_session.refresh(sample_paper)
assert sample_paper.upvotes == 999
def test_skips_empty_id(self, db_session):
papers_raw = [{"paper": {"id": "", "title": "Nope", "authors": [], "tags": []}}]
new = upsert_papers(db_session, papers_raw, "2024-01-15")
assert len(new) == 0
# ═══════════════════════════════════════════════════════════════════════
# crawl_daily
# ═══════════════════════════════════════════════════════════════════════
class TestCrawlDaily:
@pytest.mark.asyncio
async def test_success_flow(self, db_session):
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
) as mock_fetch:
mock_fetch.return_value = [
{
"paper": {
"id": "2401.00001",
"title": "T",
"authors": [],
"tags": [],
"upvotes": 0,
}
}
]
result = await crawl_daily(db_session, "2024-01-15")
assert result["status"] == "success"
assert result["new"] == 1
assert result["found"] == 1
@pytest.mark.asyncio
async def test_failure_returns_failed(self, db_session):
with patch(
"app.services.crawler.fetch_daily",
new_callable=AsyncMock,
side_effect=ConnectionError("network error"),
):
result = await crawl_daily(db_session, "2024-01-15")
assert result["status"] == "failed"
assert "network error" in result["error"]
+77
View File
@@ -0,0 +1,77 @@
"""PDF 下载测试 — download_pdf、路径工具、错误处理。"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from app.services.pdf_downloader import (
PdfDownloadError,
download_pdf,
paper_dir,
tmp_dir,
)
from app.utils import PAPERS_DIR, TMP_DIR
# ═══════════════════════════════════════════════════════════════════════
# 路径工具
# ═══════════════════════════════════════════════════════════════════════
class TestPathHelpers:
def test_paper_dir(self):
assert paper_dir("2401.12345") == PAPERS_DIR / "2401.12345"
def test_tmp_dir(self):
assert tmp_dir("2401.12345") == TMP_DIR / "2401.12345"
# ═══════════════════════════════════════════════════════════════════════
# download_pdf
# ═══════════════════════════════════════════════════════════════════════
class TestDownloadPdf:
@pytest.mark.asyncio
async def test_success_download(self, tmp_path):
mock_resp = MagicMock()
mock_resp.content = b"%PDF-1.4 fake"
mock_resp.raise_for_status = MagicMock()
mock_session = MagicMock()
mock_session.get.return_value = mock_resp
with (
patch("app.services.pdf_downloader.TMP_DIR", tmp_path),
patch(
"app.services.pdf_downloader._get_session", return_value=mock_session
),
):
result = await download_pdf(
"2401.12345", "https://arxiv.org/pdf/2401.12345.pdf"
)
assert result.exists()
assert result.name == "paper.pdf"
assert result.read_bytes() == b"%PDF-1.4 fake"
@pytest.mark.asyncio
async def test_empty_pdf_url_raises(self):
with pytest.raises(PdfDownloadError, match="no pdf_url"):
await download_pdf("2401.12345", "")
@pytest.mark.asyncio
async def test_http_failure_raises(self, tmp_path):
mock_session = MagicMock()
mock_session.get.side_effect = ConnectionError("refused")
with (
patch("app.services.pdf_downloader.TMP_DIR", tmp_path),
patch(
"app.services.pdf_downloader._get_session", return_value=mock_session
),
):
with pytest.raises(PdfDownloadError, match="failed to download"):
await download_pdf("2401.12345", "https://bad.url/pdf.pdf")
+77
View File
@@ -0,0 +1,77 @@
"""流水线编排测试 — run_pipeline (crawl → summarize → cleanup)。"""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.models import TaskLock
from app.services.pipeline import run_pipeline
from app.utils import utc_now
class TestRunPipeline:
@pytest.mark.asyncio
async def test_full_pipeline_success(self, db_session):
with (
patch(
"app.services.pipeline.crawl_daily", new_callable=AsyncMock
) as mock_crawl,
patch(
"app.services.pipeline.summarize_batch", new_callable=AsyncMock
) as mock_summ,
patch("app.services.pipeline.cleanup_tmp") as mock_clean,
):
mock_crawl.return_value = {"status": "success", "found": 5, "new": 2}
mock_summ.return_value = {"status": "success", "done": 2, "failed": 0}
mock_clean.return_value = {"removed": 0}
result = await run_pipeline(db_session, "2024-01-15", "test")
assert result["status"] == "success"
mock_crawl.assert_called_once()
mock_summ.assert_called_once()
mock_clean.assert_called_once()
@pytest.mark.asyncio
async def test_pipeline_lock_prevents_reentry(self, db_session):
"""已有 running 锁时抛出 RuntimeError。"""
now = utc_now()
db_session.add(
TaskLock(
task="scheduler",
lock_key="pipeline-2024-01-15",
status="running",
owner="other",
acquired_at=now,
)
)
db_session.commit()
with pytest.raises(RuntimeError, match="already running"):
await run_pipeline(db_session, "2024-01-15", "test")
@pytest.mark.asyncio
async def test_crawl_failure_still_runs_summarize_and_cleanup(self, db_session):
"""crawl 失败时 pipeline 继续执行 summarize 和 cleanup。"""
with (
patch(
"app.services.pipeline.crawl_daily", new_callable=AsyncMock
) as mock_crawl,
patch(
"app.services.pipeline.summarize_batch", new_callable=AsyncMock
) as mock_summ,
patch("app.services.pipeline.cleanup_tmp") as mock_clean,
):
mock_crawl.side_effect = ConnectionError("timeout")
mock_summ.return_value = {"status": "success", "done": 0}
mock_clean.return_value = {"removed": 0}
result = await run_pipeline(db_session, "2024-01-15", "test")
# pipeline 捕获异常,返回 failed
assert result["status"] == "failed"
assert "timeout" in result["error"]
# summarize 和 cleanup 不会被调用(exception 跳出 try 块)
mock_summ.assert_not_called()
+1 -1
View File
@@ -20,7 +20,7 @@ from app.services.schemas import (
classify_validation_error,
flatten_for_db,
)
from app.services.summarizer import _classify_error
from app.services.summary_generator import _classify_error
# ═══════════════════════════════════════════════════════════════════════
+35 -22
View File
@@ -23,12 +23,8 @@ from app.services.pdf_downloader import (
)
from app.services.pi_client import PiTimeoutError
from app.services.schemas import SummarySchema
from app.services.summarizer import (
_save_files,
_update_summary_in_db,
summarize_batch,
summarize_one,
)
from app.services.summarizer import summarize_batch, summarize_one
from app.services.summary_persister import _save_files, _update_summary_in_db
from app.utils import utc_now
@@ -39,7 +35,14 @@ from app.utils import utc_now
def _summarize_tmp_paths(tmp_path):
"""将 data 目录重定向到 tmp_path(供 summarizer 测试使用)。"""
with (
patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
patch(
"app.services.summary_persister.paper_dir",
lambda aid: tmp_path / "papers" / aid,
),
patch(
"app.services.summary_generator.paper_dir",
lambda aid: tmp_path / "papers" / aid,
),
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
@@ -134,7 +137,9 @@ class TestFileOperations:
def test_save_files(self, tmp_path, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
with patch(
"app.services.summary_persister.paper_dir", lambda aid: tmp_path / aid
):
_save_files("2401.12345", schema, "raw output text")
paper_dir = tmp_path / "2401.12345"
@@ -144,7 +149,9 @@ class TestFileOperations:
assert saved["title_zh"] == "测试论文中文标题"
def test_save_raw_output_only(self, tmp_path):
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
with patch(
"app.services.summary_persister.paper_dir", lambda aid: tmp_path / aid
):
_save_files("2401.12345", None, "raw output")
paper_dir = tmp_path / "2401.12345"
assert (paper_dir / "raw_output.txt").exists()
@@ -180,7 +187,7 @@ class TestSummarizeOneFlow:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=(mock_pi_output, "test-session-id"),
),
@@ -209,7 +216,9 @@ class TestSummarizeOneFlow:
assert fts_row[0] == "测试论文中文标题"
@pytest.mark.asyncio
async def test_pdf_download_failure(self, db_session, sample_paper, _summarize_tmp_paths):
async def test_pdf_download_failure(
self, db_session, sample_paper, _summarize_tmp_paths
):
"""PDF 下载失败 → error_type=pdf_download_failedtmp 被清理。"""
with (
patch(
@@ -233,7 +242,7 @@ class TestSummarizeOneFlow:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
side_effect=PiTimeoutError("timeout after 300s"),
),
@@ -250,7 +259,7 @@ class TestSummarizeOneFlow:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=("No JSON in this output at all.", "test-session-id"),
),
@@ -281,7 +290,7 @@ class TestSummarizeOneFlow:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=(bad_output, "test-session-id"),
),
@@ -300,7 +309,7 @@ class TestSummarizeOneFlow:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=("Some output without JSON", "test-session-id"),
),
@@ -319,7 +328,7 @@ class TestSummarizeOneFlow:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=(mock_pi_output, "test-session-id"),
),
@@ -347,7 +356,9 @@ class TestSummarizeOneFlow:
assert not tmp_paper.exists()
@pytest.mark.asyncio
async def test_skips_done_paper(self, db_session, sample_paper, _summarize_tmp_paths):
async def test_skips_done_paper(
self, db_session, sample_paper, _summarize_tmp_paths
):
"""已完成的论文跳过。"""
sample_paper.summary_status.status = "done"
db_session.commit()
@@ -393,7 +404,7 @@ class TestBatchSummarize:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=(mock_pi_output, "test-session-id"),
),
@@ -446,7 +457,7 @@ class TestBatchSummarize:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.call_pi", side_effect=_mock_call_pi),
patch("app.services.summary_generator.call_pi", side_effect=_mock_call_pi),
):
result = await summarize_batch(db_session, _session_factory=_TestSession)
@@ -456,6 +467,8 @@ class TestBatchSummarize:
@pytest.mark.asyncio
async def test_task_lock_conflict(self, db_session, _summarize_tmp_paths):
"""TaskLock 防止并发 batch。"""
from app.exceptions import ConflictError
# 先插入一个 running 锁
db_session.add(
TaskLock(
@@ -467,8 +480,8 @@ class TestBatchSummarize:
)
db_session.commit()
result = await summarize_batch(db_session)
assert result["status"] == "conflict"
with pytest.raises(ConflictError):
await summarize_batch(db_session)
@pytest.mark.asyncio
async def test_task_lock_released(
@@ -482,7 +495,7 @@ class TestBatchSummarize:
with (
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer.call_pi",
"app.services.summary_generator.call_pi",
new_callable=AsyncMock,
return_value=(mock_pi_output, "test-session-id"),
),
+174
View File
@@ -0,0 +1,174 @@
"""summary_utils 测试 — PDF 文本提取、正文裁剪、JSON 提取、meta.json 写入、prompt 构建。"""
from __future__ import annotations
import json
from unittest.mock import MagicMock, patch
import pytest
from app.services.summary_utils import (
JsonNotFoundError,
_trim_body,
build_prompt,
extract_json,
extract_pdf_text,
write_meta_json,
)
# ═══════════════════════════════════════════════════════════════════════
# _trim_body 正文裁剪
# ═══════════════════════════════════════════════════════════════════════
class TestTrimBody:
def test_removes_references_section(self):
text = "Intro\n\nSome content here.\n\nReferences\n[1] Smith et al."
result = _trim_body(text)
assert "References" not in result
assert "Intro" in result
def test_removes_bibliography(self):
text = "Body\n\nBibliography\n[1] Smith"
result = _trim_body(text)
assert "Bibliography" not in result
def test_keeps_appendix_after_references(self):
text = "Body\n\nReferences\n[1] X\n\nAppendix\nExtra content"
result = _trim_body(text)
assert "Appendix" in result
assert "Extra content" in result
assert "References" not in result
def test_removes_acknowledgments(self):
text = "Body\n\nAcknowledgments\nThanks to everyone."
result = _trim_body(text)
assert "Acknowledgments" not in result
def test_max_chars_truncation(self):
text = "A" * 1000
result = _trim_body(text, max_chars=100)
assert len(result) <= 100
def test_no_truncation_when_none(self):
text = "A" * 500
result = _trim_body(text, max_chars=None)
assert len(result) == 500
# ═══════════════════════════════════════════════════════════════════════
# extract_pdf_text
# ═══════════════════════════════════════════════════════════════════════
class TestExtractPdfText:
def test_extracts_text_and_saves(self, tmp_path):
pdf_path = tmp_path / "test.pdf"
pdf_path.write_bytes(b"%PDF-fake")
mock_page = MagicMock()
mock_page.get_text.return_value = "Page 1 text"
mock_doc = MagicMock()
mock_doc.__iter__ = MagicMock(return_value=iter([mock_page]))
mock_doc.__enter__ = MagicMock(return_value=mock_doc)
mock_doc.__exit__ = MagicMock(return_value=False)
with (
patch("pymupdf.open", return_value=mock_doc),
patch(
"app.services.summary_utils._trim_body", side_effect=lambda t, **kw: t
),
):
result_path = extract_pdf_text(pdf_path)
assert result_path.suffix == ".txt"
assert result_path.exists()
assert "Page 1 text" in result_path.read_text()
def test_uses_cached_txt(self, tmp_path):
pdf_path = tmp_path / "test.pdf"
pdf_path.write_bytes(b"%PDF-fake")
txt_path = tmp_path / "test.txt"
txt_path.write_text("cached", encoding="utf-8")
with patch("pymupdf.open") as mock_open:
result = extract_pdf_text(pdf_path)
mock_open.assert_not_called()
assert result == txt_path
# ═══════════════════════════════════════════════════════════════════════
# write_meta_json
# ═══════════════════════════════════════════════════════════════════════
class TestWriteMetaJson:
def test_writes_meta_json(self, tmp_path, sample_paper):
with patch("app.services.pdf_downloader.paper_dir", lambda aid: tmp_path / aid):
result = write_meta_json(sample_paper)
assert result.exists()
assert result.name == "meta.json"
data = json.loads(result.read_text(encoding="utf-8"))
assert data["arxiv_id"] == "2401.12345"
assert data["title_en"] == "Test Paper Title"
# ═══════════════════════════════════════════════════════════════════════
# build_prompt
# ═══════════════════════════════════════════════════════════════════════
class TestBuildPrompt:
def test_inject_mode_contains_schema(self, tmp_path):
prompt = build_prompt(
"2401.12345", tmp_path / "meta", tmp_path / "txt", "inject"
)
assert "title_zh" in prompt
assert "必须包含以下字段" in prompt
def test_search_mode_contains_read_instruction(self, tmp_path):
prompt = build_prompt(
"2401.12345", tmp_path / "meta", tmp_path / "txt", "search"
)
assert "read" in prompt.lower()
assert "title_zh" in prompt
def test_fix_errors_mode(self, tmp_path):
prompt = build_prompt(
"2401.12345",
tmp_path / "meta",
tmp_path / "txt",
"inject",
fix_errors=["字段缺失"],
)
assert "字段缺失" in prompt
assert "修正" in prompt
# ═══════════════════════════════════════════════════════════════════════
# extract_json
# ═══════════════════════════════════════════════════════════════════════
class TestExtractJson:
def test_direct_json(self, sample_summary_json):
result = extract_json(sample_summary_json)
assert result["title_zh"] == "测试论文中文标题"
def test_fenced_code_block(self, sample_summary_json):
raw = f"some text\n```json\n{sample_summary_json}\n```\nmore text"
result = extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_brace_matching_fallback(self, sample_summary_dict):
json_str = json.dumps(sample_summary_dict, ensure_ascii=False)
raw = f"Here is the result: {json_str} end."
result = extract_json(raw)
assert result["title_zh"] == "测试论文中文标题"
def test_no_json_raises(self):
with pytest.raises(JsonNotFoundError):
extract_json("plain text no json here at all")
+13 -14
View File
@@ -2,6 +2,9 @@
from __future__ import annotations
import pytest
from app.exceptions import NotFoundError, ValidationError
from app.services.user_data import (
get_note,
save_note,
@@ -27,9 +30,8 @@ class TestBookmarkService:
assert result["bookmarked"] is False
def test_toggle_bookmark_not_found(self, db_session):
result = toggle_bookmark(db_session, "nonexistent")
assert "error" in result
assert result["error"] == "not_found"
with pytest.raises(NotFoundError):
toggle_bookmark(db_session, "nonexistent")
# ═══════════════════════════════════════════════════════════════════════
@@ -44,9 +46,8 @@ class TestReadingStatusService:
assert result["arxiv_id"] == "2401.12345"
def test_set_reading_status_invalid(self, db_session, sample_paper):
result = set_reading_status(db_session, "2401.12345", "invalid_status")
assert "error" in result
assert result["error"] == "invalid_status"
with pytest.raises(ValidationError):
set_reading_status(db_session, "2401.12345", "invalid_status")
def test_update_existing_status(self, db_session, sample_paper):
set_reading_status(db_session, "2401.12345", "skimmed")
@@ -54,9 +55,8 @@ class TestReadingStatusService:
assert result["status"] == "read_full"
def test_set_reading_status_not_found(self, db_session):
result = set_reading_status(db_session, "nonexistent", "unread")
assert "error" in result
assert result["error"] == "not_found"
with pytest.raises(NotFoundError):
set_reading_status(db_session, "nonexistent", "unread")
def test_all_valid_statuses(self, db_session, sample_paper):
for status in ("unread", "skimmed", "read_summary", "read_full"):
@@ -93,9 +93,8 @@ class TestNoteService:
assert result is None
def test_save_note_paper_not_found(self, db_session):
result = save_note(db_session, "nonexistent", "内容")
assert "error" in result
assert result["error"] == "not_found"
with pytest.raises(NotFoundError):
save_note(db_session, "nonexistent", "内容")
# ═══════════════════════════════════════════════════════════════════════
@@ -143,12 +142,12 @@ class TestUserDataRoutes:
assert data["status"] == "read_summary"
def test_reading_status_invalid(self, client, sample_paper):
"""无效状态返回 422"""
"""无效状态返回 400 (ValidationError)"""
resp = client.post(
"/api/reading-status/2401.12345",
json={"status": "invalid"},
)
assert resp.status_code == 422
assert resp.status_code == 400
def test_reading_status_not_found(self, client):
"""不存在的论文返回 404。"""
Generated
+2 -53
View File
@@ -709,10 +709,10 @@ dependencies = [
{ name = "httpx", extra = ["http2"] },
{ name = "itsdangerous" },
{ name = "jinja2" },
{ name = "onnxruntime" },
{ name = "pydantic" },
{ name = "pydantic-settings" },
{ name = "pymupdf" },
{ name = "pymupdf4llm" },
{ name = "python-dotenv" },
{ name = "python-multipart" },
{ name = "sqlalchemy" },
@@ -741,10 +741,10 @@ requires-dist = [
{ name = "httpx", extras = ["http2"], specifier = ">=0.28" },
{ name = "itsdangerous", specifier = ">=2.2.0" },
{ name = "jinja2", specifier = ">=3.1" },
{ name = "onnxruntime", specifier = ">=1.17" },
{ name = "pydantic", specifier = ">=2.0" },
{ name = "pydantic-settings", specifier = ">=2.0" },
{ name = "pymupdf", specifier = ">=1.25" },
{ name = "pymupdf4llm", specifier = ">=1.27.2.3" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0" },
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24" },
{ name = "python-dotenv", specifier = ">=1.0" },
@@ -1261,15 +1261,6 @@ wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/81/08/7036c080d7117f28a4af526d794aab6a84463126db031b007717c1a6676e/multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56", size = 12319, upload-time = "2026-01-26T02:46:44.004Z" },
]
[[package]]
name = "networkx"
version = "3.6.1"
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" }
wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" },
]
[[package]]
name = "numpy"
version = "2.4.6"
@@ -1889,39 +1880,6 @@ wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/53/a4/b9e91aac82293f9c954654c85581ee8212b5b05efadc534b581141241e6f/pymupdf-1.27.2.3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:77691604c5d1d0233827139bbcdea61fd57879c84712b8e49b1f45520f7ab9c2", size = 25000393, upload-time = "2026-04-24T14:11:01.669Z" },
]
[[package]]
name = "pymupdf-layout"
version = "1.27.2.3"
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
dependencies = [
{ name = "networkx" },
{ name = "numpy" },
{ name = "onnxruntime" },
{ name = "pymupdf" },
{ name = "pyyaml" },
]
wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/bc/ee/067726c3ee5574ad5c605d00d7419e264ef509d626a726f99388111f8216/pymupdf_layout-1.27.2.3-cp310-abi3-macosx_10_9_x86_64.whl", hash = "sha256:75c2ab3c0e8830ac2bc50cfd32d375a30768a2610dac72a02f08265336e0834f", size = 15799844, upload-time = "2026-04-24T14:11:13.177Z" },
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/0a/ba/46a7a36474722f9280d885f6eec878561a257d9378e52590b43d32ffb96c/pymupdf_layout-1.27.2.3-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:5656b09669dcd7c51f539afb6fdaf853602bab4cbc20479ee5ee1a85a4e32b60", size = 15795220, upload-time = "2026-04-24T14:11:23.17Z" },
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/84/87/bfdcca67346052943a4549814f2009b38f4d15ec025798cdf7dfa5f57c84/pymupdf_layout-1.27.2.3-cp310-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:fcf03aa815cbceebdb3263dd6a190de4547c46b1d168928836ec38738afe127d", size = 15805240, upload-time = "2026-04-24T14:11:33.465Z" },
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/32/e9/7ce6eaf97cebd46c3808593282e9eb99a60cddd6183e25a636980d5c7986/pymupdf_layout-1.27.2.3-cp310-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:303b9414216dfaf711ec7d807b6f1e4c3e0a92bbb4569340fcedd9d5593d16ca", size = 15806269, upload-time = "2026-04-24T14:11:43.481Z" },
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/bf/61/3b2417d8f2cdfaa0f4749cd9dafa3379cb5cdaddf4233165f1ff81953c30/pymupdf_layout-1.27.2.3-cp310-abi3-win_amd64.whl", hash = "sha256:503b64d9b6b31ea3af79ef85cf7d36950c5048af468cb297684d2953553c62ad", size = 15809163, upload-time = "2026-04-24T14:11:53.956Z" },
]
[[package]]
name = "pymupdf4llm"
version = "1.27.2.3"
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
dependencies = [
{ name = "pymupdf" },
{ name = "pymupdf-layout" },
{ name = "tabulate" },
]
sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/87/c0/e3830452d82032c3d82a9879616c05bf0c51e0dea03c1d80d57b3a6ec0d1/pymupdf4llm-1.27.2.3.tar.gz", hash = "sha256:42ec1a47ddc62be3f4f40c116d27618611c6f9fa366719016d9ddc3f3a3dc22b", size = 1406297, upload-time = "2026-04-24T14:13:18.843Z" }
wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e6/38/84bf29f4dd72e6c450546df6ca8f53021f764fd945ba67dcc235d39bc20e/pymupdf4llm-1.27.2.3-py3-none-any.whl", hash = "sha256:bd724b79fa3f06a5b28d7a65f7acfa8de56e04bdb603ac2d6dff315e0d151aaa", size = 77348, upload-time = "2026-04-24T14:11:04.305Z" },
]
[[package]]
name = "pypika"
version = "0.51.1"
@@ -2282,15 +2240,6 @@ wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1c/54/196d0c1db10af76baa4f64894448505d60d3cdf70ef92cbb35f46a4e4c71/starlette-1.2.1-py3-none-any.whl", hash = "sha256:4de0082d08c8f6764a85a54cf1120d6939507a19905c7768acad2a9f875d2b89", size = 73350, upload-time = "2026-05-31T01:07:50.09Z" },
]
[[package]]
name = "tabulate"
version = "0.10.0"
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/46/58/8c37dea7bbf769b20d58e7ace7e5edfe65b849442b00ffcdd56be88697c6/tabulate-0.10.0.tar.gz", hash = "sha256:e2cfde8f79420f6deeffdeda9aaec3b6bc5abce947655d17ac662b126e48a60d", size = 91754, upload-time = "2026-03-04T18:55:34.402Z" }
wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/99/55/db07de81b5c630da5cbf5c7df646580ca26dfaefa593667fc6f2fe016d2e/tabulate-0.10.0-py3-none-any.whl", hash = "sha256:f0b0622e567335c8fabaaa659f1b33bcb6ddfe2e496071b743aa113f8774f2d3", size = 39814, upload-time = "2026-03-04T18:55:31.284Z" },
]
[[package]]
name = "tenacity"
version = "9.1.4"