feat: refactor summarizer and PDF extraction pipeline
- Split summarizer into summary_generator and summary_persister modules - Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection - Add layout_detector service for PicoDet-S_layout_3cls integration - Add exceptions module with ConflictError and NotFoundError - Improve admin dashboard with better statistics and task management - Add design review document with system optimization suggestions - Add new tests for crawler, pdf_downloader, pipeline, and summary_utils - Update dependencies and configuration - Clean up dead code and improve error handling
This commit is contained in:
@@ -46,3 +46,8 @@ EMBED_API_BASE=https://api.siliconflow.cn/v1/embeddings
|
|||||||
EMBED_API_KEY=your_api_key_here
|
EMBED_API_KEY=your_api_key_here
|
||||||
EMBED_MODEL=Qwen/Qwen3-Embedding-4B
|
EMBED_MODEL=Qwen/Qwen3-Embedding-4B
|
||||||
EMBED_DIMENSIONS=2560
|
EMBED_DIMENSIONS=2560
|
||||||
|
|
||||||
|
# ─── 布局检测 ─────────────────────────────
|
||||||
|
# ONNX 模型路径(首次运行前执行 scripts/export_picodet_onnx.py 导出)
|
||||||
|
# LAYOUT_MODEL_PATH=data/models/picodet_layout_3cls.onnx
|
||||||
|
# LAYOUT_THRESHOLD=0.5
|
||||||
|
|||||||
@@ -0,0 +1,468 @@
|
|||||||
|
# 项目设计审查与优化建议
|
||||||
|
|
||||||
|
本文档汇总对当前项目的系统设计、流程设计和代码结构的审查结论。重点不在局部代码风格,而在后续稳定运行、失败恢复、数据一致性、可维护性和扩展性。
|
||||||
|
|
||||||
|
## 总体评价
|
||||||
|
|
||||||
|
项目当前结构整体清晰:FastAPI 路由层较薄,主要业务放在 `app/services/`;抓取、总结、搜索、个人数据、管理后台等能力已有模块化拆分;SQLite + FTS5 对本地个人论文导览站是合理选择。
|
||||||
|
|
||||||
|
主要风险集中在以下几类:
|
||||||
|
|
||||||
|
- 长任务生命周期没有统一抽象。
|
||||||
|
- Pipeline 不是显式状态机,阶段结果难以恢复和追踪。
|
||||||
|
- 数据库、文件系统、FTS5、ChromaDB 之间存在多处手工同步。
|
||||||
|
- Web、CLI、Scheduler 三个入口的业务策略存在分叉风险。
|
||||||
|
- 失败恢复和可观测性不足。
|
||||||
|
- 部分同步阻塞任务混入 async Web 流程。
|
||||||
|
|
||||||
|
## 高优先级问题
|
||||||
|
|
||||||
|
### 1. 缺少统一的任务系统
|
||||||
|
|
||||||
|
当前有三套入口都能触发重任务:
|
||||||
|
|
||||||
|
- Web 管理后台:`/admin/trigger-pipeline`、`/admin/summarize`
|
||||||
|
- APScheduler:每日自动 pipeline
|
||||||
|
- CLI:`app.cli crawl/summarize`
|
||||||
|
|
||||||
|
它们共享部分 service,但没有统一的“任务创建、排队、运行、重试、取消、状态查询、失败恢复”模型。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
- Web 请求会长时间挂住。
|
||||||
|
- CLI、Web、Scheduler 的行为可能逐渐分叉。
|
||||||
|
- 任务卡死后只能看到粗粒度的 running lock。
|
||||||
|
- 前端难以展示阶段进度。
|
||||||
|
- 后续增加重跑、取消、恢复、并发限制会越来越复杂。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
引入统一 Job 模型:
|
||||||
|
|
||||||
|
```text
|
||||||
|
入口层 -> 创建 Job -> Worker 执行 Job -> JobEvent 记录进度 -> 页面/API 查询状态
|
||||||
|
```
|
||||||
|
|
||||||
|
建议任务类型:
|
||||||
|
|
||||||
|
- `crawl_daily`
|
||||||
|
- `summarize_one`
|
||||||
|
- `summarize_batch`
|
||||||
|
- `pipeline_daily`
|
||||||
|
- `refresh_upvotes`
|
||||||
|
- `delete_range`
|
||||||
|
- `reindex_fts`
|
||||||
|
- `reindex_chroma`
|
||||||
|
- `extract_images`
|
||||||
|
|
||||||
|
建议表结构方向:
|
||||||
|
|
||||||
|
- `jobs`: 任务主记录,包含 `id/type/status/owner/created_at/started_at/completed_at/error`
|
||||||
|
- `job_events`: 阶段事件,包含 `job_id/stage/status/message/payload_json/created_at`
|
||||||
|
- `paper_processing_runs`: 单篇论文处理记录,适合跟踪总结、图片提取、索引状态
|
||||||
|
|
||||||
|
### 2. Pipeline 不是显式状态机
|
||||||
|
|
||||||
|
当前 pipeline 基本是线性函数:
|
||||||
|
|
||||||
|
```text
|
||||||
|
crawl -> summarize -> cleanup
|
||||||
|
```
|
||||||
|
|
||||||
|
阶段输出没有被建模成可查询、可重放、可补偿的状态。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
- crawl 成功但 summarize 部分失败时,整体状态表达不够精确。
|
||||||
|
- cleanup 失败是否影响 pipeline 成败语义不清。
|
||||||
|
- “今天无数据回退昨天”是隐式逻辑,不是可配置策略。
|
||||||
|
- 图片提取、ChromaDB 索引失败被吞掉,用户无法知道哪些增强能力缺失。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
将 pipeline 改为显式阶段状态:
|
||||||
|
|
||||||
|
```text
|
||||||
|
created
|
||||||
|
-> crawling
|
||||||
|
-> crawled
|
||||||
|
-> summarizing
|
||||||
|
-> summarized | summarized_partial
|
||||||
|
-> postprocessing
|
||||||
|
-> completed | failed | cancelled
|
||||||
|
```
|
||||||
|
|
||||||
|
每个阶段记录:
|
||||||
|
|
||||||
|
- 输入参数
|
||||||
|
- 输出统计
|
||||||
|
- 错误类型
|
||||||
|
- 错误摘要
|
||||||
|
- 开始/结束时间
|
||||||
|
- 是否可重试
|
||||||
|
|
||||||
|
这样可以支持只重跑失败阶段,例如只重跑 ChromaDB 索引、只重跑图片提取、只重跑失败论文总结。
|
||||||
|
|
||||||
|
### 3. 数据生命周期没有清晰分层
|
||||||
|
|
||||||
|
项目中至少有四类数据:
|
||||||
|
|
||||||
|
- 主数据:`papers`、`authors`、`tags`、`summaries`、用户笔记/收藏/阅读状态
|
||||||
|
- 派生索引:FTS5、ChromaDB
|
||||||
|
- 长期资产:`data/papers/{arxiv_id}` 下的 summary、raw output、images
|
||||||
|
- 临时资产:`data/tmp/{arxiv_id}` 下的 PDF、源码、中间文件
|
||||||
|
|
||||||
|
现在这些数据由不同流程手动维护。删除、总结、重建索引、清理临时目录都分散在不同模块中。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
- DB 与文件系统可能不一致。
|
||||||
|
- DB 与 FTS5/ChromaDB 可能不一致。
|
||||||
|
- 删除流程中一部分资源删除成功、一部分失败时难以恢复。
|
||||||
|
- 很难判断某篇论文是否“完整可用”。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
明确数据权威源:
|
||||||
|
|
||||||
|
- DB 是权威源。
|
||||||
|
- FTS5、ChromaDB、summary.json、raw_output.txt、images 都应视为可重建派生物。
|
||||||
|
- 删除主数据优先保证 DB 状态正确,派生物清理可以异步补偿。
|
||||||
|
- 提供派生数据重建任务。
|
||||||
|
|
||||||
|
建议增加命令或任务:
|
||||||
|
|
||||||
|
```text
|
||||||
|
rebuild-derived --fts
|
||||||
|
rebuild-derived --chroma
|
||||||
|
rebuild-derived --files
|
||||||
|
rebuild-derived --images
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 数据库与文件系统存在双写一致性风险
|
||||||
|
|
||||||
|
总结持久化流程会写文件、写 DB、更新状态、提取图片、写 ChromaDB。任一步失败都可能留下不一致状态。
|
||||||
|
|
||||||
|
典型风险:
|
||||||
|
|
||||||
|
- 文件已写入,但 DB 未提交。
|
||||||
|
- DB 标记 done,但 summary.json 缺失。
|
||||||
|
- DB done,但图片未提取或 ChromaDB 未索引。
|
||||||
|
- 重新总结时旧图片或旧 figures 残留。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 将结构化总结以 DB 为准,JSON 文件作为导出/缓存。
|
||||||
|
- 文件写入使用临时文件加原子 rename。
|
||||||
|
- DB 中记录派生物状态,例如:
|
||||||
|
- `summary_persisted_at`
|
||||||
|
- `fts_indexed_at`
|
||||||
|
- `chroma_indexed_at`
|
||||||
|
- `images_extracted_at`
|
||||||
|
- `derived_error`
|
||||||
|
- 提供一致性检查任务,扫描 DB 与文件/索引差异并修复。
|
||||||
|
|
||||||
|
### 5. 失败恢复设计偏弱
|
||||||
|
|
||||||
|
当前失败主要依赖:
|
||||||
|
|
||||||
|
- `summary_status.status`
|
||||||
|
- `retry_count`
|
||||||
|
- `task_locks`
|
||||||
|
- `crawl_logs`
|
||||||
|
|
||||||
|
但对以下情况支持不足:
|
||||||
|
|
||||||
|
- 进程崩溃后 `processing` 永久残留。
|
||||||
|
- `task_locks` 中 running lock 永久残留。
|
||||||
|
- 任务执行到一半,部分文件或索引已经写入。
|
||||||
|
- 某些后处理失败但主任务显示成功。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- `task_locks` 增加 `heartbeat_at` 或 `expires_at`。
|
||||||
|
- 应用启动时扫描 stale running task,标记为 `failed/stale`。
|
||||||
|
- `summary_status` 增加更细阶段字段,例如:
|
||||||
|
- `downloading_pdf`
|
||||||
|
- `generating`
|
||||||
|
- `validating`
|
||||||
|
- `persisting`
|
||||||
|
- `extracting_images`
|
||||||
|
- `indexing`
|
||||||
|
- 区分主流程失败和派生增强失败。
|
||||||
|
|
||||||
|
### 6. 长任务直接跑在 Web 请求里
|
||||||
|
|
||||||
|
管理端触发 pipeline、批量总结、刷新 upvotes、删除数据时,会在请求生命周期内直接执行重任务。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
- HTTP 请求容易超时。
|
||||||
|
- 用户关闭页面后任务状态不清晰。
|
||||||
|
- Web 进程被重任务占用。
|
||||||
|
- 后续很难做取消、进度展示和失败恢复。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- Web 只创建 job 并返回 `task_id`。
|
||||||
|
- 后台 worker 执行任务。
|
||||||
|
- 管理页面轮询 `/admin/jobs/{id}` 或使用 SSE 展示进度。
|
||||||
|
|
||||||
|
## 中优先级问题
|
||||||
|
|
||||||
|
### 7. FTS5 索引手工维护,容易漂移
|
||||||
|
|
||||||
|
`papers_fts` 是独立虚拟表。插入论文、更新总结、删除论文时手动同步。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
- 已有论文的 title、abstract、authors、tags 更新时可能不同步。
|
||||||
|
- AI 标签新增后,如果没有统一重建,搜索结果可能缺字段。
|
||||||
|
- 删除或回滚失败后索引可能残留。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 封装统一的 `reindex_paper(db, paper_id)`。
|
||||||
|
- 所有修改论文、标签、总结的路径最后调用它。
|
||||||
|
- 提供全量 `reindex_fts` 任务。
|
||||||
|
- 或用 SQLite trigger 维护基础字段,但总结字段仍建议通过统一函数更新。
|
||||||
|
|
||||||
|
### 8. ChromaDB 索引缺少系统级一致性策略
|
||||||
|
|
||||||
|
ChromaDB 当前是可选增强能力,索引失败会被日志吞掉,不影响总结主流程。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
- 语义搜索可用性不透明。
|
||||||
|
- 某些论文 DB done,但未进入 ChromaDB。
|
||||||
|
- 删除或重建时可能出现残留索引。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 将 ChromaDB 明确视为派生索引。
|
||||||
|
- DB 中记录 `chroma_indexed_at` 和 `chroma_error`。
|
||||||
|
- 提供 `reindex_chroma` 任务。
|
||||||
|
- 搜索页或管理后台展示语义索引健康状态。
|
||||||
|
|
||||||
|
### 9. 同步阻塞操作混入 async 主流程
|
||||||
|
|
||||||
|
部分 async 函数内部调用同步阻塞操作,例如 PDF 下载使用 `requests`,图片提取和 PDF 解析也都是同步重任务。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
- FastAPI event loop 可能被阻塞。
|
||||||
|
- 多篇并发总结时吞吐不可控。
|
||||||
|
- Web 服务响应受后台任务影响。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 最佳方案:重任务移到 worker 进程。
|
||||||
|
- 过渡方案:对同步阻塞段使用 `asyncio.to_thread()`。
|
||||||
|
- PDF 下载统一使用 `httpx.AsyncClient` 或明确放入同步 worker。
|
||||||
|
|
||||||
|
### 10. 调度策略和业务流程耦合
|
||||||
|
|
||||||
|
Scheduler 固定每日 pipeline 加 30 分钟 upvote refresh;pipeline 内部固定“今天无数据回退昨天”。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
- 策略不易测试和调整。
|
||||||
|
- CLI、Web、Scheduler 可能各自实现不同 fallback。
|
||||||
|
- 后续支持日期范围、补抓、只总结新增论文会变复杂。
|
||||||
|
|
||||||
|
建议将策略配置化:
|
||||||
|
|
||||||
|
- `crawl_target_policy`: `today` / `today_then_yesterday` / `date_range`
|
||||||
|
- `summarize_scope`: `new_only` / `pending_and_failed` / `date_range`
|
||||||
|
- `cleanup_policy`: `after_success` / `always` / `manual`
|
||||||
|
- `upvote_refresh_window_days`
|
||||||
|
- `pipeline_on_partial_failure`: `continue` / `stop`
|
||||||
|
|
||||||
|
### 11. 内嵌 APScheduler 对部署形态敏感
|
||||||
|
|
||||||
|
当前调度器嵌入 Web 进程,且多 worker 时只是打印警告。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
- 多 worker、多进程、reload 模式下可能重复或漏跑任务。
|
||||||
|
- Web 进程重启会影响调度可靠性。
|
||||||
|
- 调度状态和任务执行状态耦合。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 本地单机可以保留内嵌调度器。
|
||||||
|
- 长期运行建议拆为独立 scheduler 进程。
|
||||||
|
- scheduler 只负责创建 job,不直接执行重任务。
|
||||||
|
- Web 管理后台只展示 scheduler/job 状态。
|
||||||
|
|
||||||
|
### 12. 运行时迁移能力不足
|
||||||
|
|
||||||
|
当前 `_migrate()` 只支持补列。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
- 无法可靠处理索引、约束、字段改名、数据回填。
|
||||||
|
- 数据库结构演进不可追踪。
|
||||||
|
- 生产或长期本地数据升级风险高。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 引入 Alembic。
|
||||||
|
- 将 FTS5、索引、部分约束和数据回填纳入迁移脚本。
|
||||||
|
- 启动时不要静默做复杂迁移,改为显式执行迁移命令。
|
||||||
|
|
||||||
|
## 低到中优先级问题
|
||||||
|
|
||||||
|
### 13. 配置校验不足
|
||||||
|
|
||||||
|
当前配置字段主要是裸类型,缺少合法值和组合校验。
|
||||||
|
|
||||||
|
风险示例:
|
||||||
|
|
||||||
|
- `SUMMARY_BACKEND` 填错。
|
||||||
|
- `SUMMARY_PDF_MODE` 填错。
|
||||||
|
- `SCHEDULE_MINUTE + 30` 超过 59。
|
||||||
|
- `APP_WORKERS > 1` 且 `SCHEDULER_ENABLED=true`。
|
||||||
|
- `CHROMA_ENABLED=true` 但 embedding 配置缺失。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 使用 Pydantic validator 校验枚举值和组合条件。
|
||||||
|
- 对危险组合启动时报错,而不是只 warning。
|
||||||
|
|
||||||
|
### 14. 私有函数跨模块调用说明模块边界不稳定
|
||||||
|
|
||||||
|
例如 `summarizer.py` 调用 `_generate_with_retry`、`_persist_summary` 等下划线函数。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
- 模块公开边界不清晰。
|
||||||
|
- 后续重构容易误伤。
|
||||||
|
- 测试和替换后端不方便。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 为 summary 生成、持久化、后处理定义明确公开 API。
|
||||||
|
- 下划线函数保留为模块内部实现细节。
|
||||||
|
- 引入更明确的 service 类或函数边界,例如:
|
||||||
|
- `SummaryGenerator.generate()`
|
||||||
|
- `SummaryRepository.save()`
|
||||||
|
- `DerivedAssetBuilder.extract_images()`
|
||||||
|
|
||||||
|
### 15. 外部依赖缺少统一 adapter
|
||||||
|
|
||||||
|
项目依赖多个外部系统:
|
||||||
|
|
||||||
|
- HuggingFace API
|
||||||
|
- arXiv PDF/source 下载
|
||||||
|
- `pi` CLI
|
||||||
|
- `claude` CLI
|
||||||
|
- Embedding API
|
||||||
|
- ChromaDB
|
||||||
|
- ONNX layout detector
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 为每类外部依赖定义 adapter。
|
||||||
|
- 统一 timeout、retry、error classification、metrics。
|
||||||
|
- 测试中替换 adapter,而不是 patch 深层函数。
|
||||||
|
|
||||||
|
### 16. 删除流程事务边界不理想
|
||||||
|
|
||||||
|
删除流程中同时删除 FTS5、ChromaDB、本地文件、临时文件、ORM 数据。
|
||||||
|
|
||||||
|
影响:
|
||||||
|
|
||||||
|
- 文件删除成功后 DB rollback,会造成数据仍在但文件丢失。
|
||||||
|
- DB 删除成功后 ChromaDB 删除失败,会造成残留索引。
|
||||||
|
- 单篇失败时 rollback 可能影响之前 flush 的状态,逻辑不直观。
|
||||||
|
|
||||||
|
建议:
|
||||||
|
|
||||||
|
- 删除主数据与删除派生物分两阶段。
|
||||||
|
- DB 删除成功后创建派生清理 job。
|
||||||
|
- 派生物清理失败可重试,不影响主数据一致性。
|
||||||
|
|
||||||
|
## 建议目标架构
|
||||||
|
|
||||||
|
如果项目定位是个人本地工具,可以采用轻量目标架构:
|
||||||
|
|
||||||
|
```text
|
||||||
|
FastAPI Web
|
||||||
|
- 页面和 API
|
||||||
|
- 管理后台
|
||||||
|
- 创建 job
|
||||||
|
- 查询 job 状态
|
||||||
|
|
||||||
|
Scheduler
|
||||||
|
- 按策略创建 job
|
||||||
|
- 不直接跑重任务
|
||||||
|
|
||||||
|
Worker
|
||||||
|
- crawl
|
||||||
|
- summarize
|
||||||
|
- PDF download
|
||||||
|
- image extraction
|
||||||
|
- FTS/Chroma reindex
|
||||||
|
- cleanup/delete
|
||||||
|
|
||||||
|
Storage/Repository
|
||||||
|
- DB 权威数据
|
||||||
|
- 文件资产管理
|
||||||
|
- 派生索引重建
|
||||||
|
- 一致性检查
|
||||||
|
```
|
||||||
|
|
||||||
|
如果暂时不想引入独立 worker,也可以先在单进程内实现 Job 表和后台任务执行器。这样至少能统一任务状态和恢复机制。
|
||||||
|
|
||||||
|
## 推荐实施顺序
|
||||||
|
|
||||||
|
### 第一阶段:先解决任务与状态
|
||||||
|
|
||||||
|
目标:降低长任务不可控风险。
|
||||||
|
|
||||||
|
- 新增 `jobs` 和 `job_events` 表。
|
||||||
|
- Web 管理动作改为创建 job 并返回 task_id。
|
||||||
|
- Scheduler 改为创建 job。
|
||||||
|
- CLI 复用同一套 job runner。
|
||||||
|
- 加 stale running job 恢复逻辑。
|
||||||
|
|
||||||
|
### 第二阶段:统一派生数据管理
|
||||||
|
|
||||||
|
目标:降低 DB、文件、FTS、Chroma 不一致风险。
|
||||||
|
|
||||||
|
- 明确 DB 为权威源。
|
||||||
|
- 封装 `reindex_paper()`。
|
||||||
|
- 增加 `reindex_fts`、`reindex_chroma` 任务。
|
||||||
|
- 增加派生数据状态字段或健康检查。
|
||||||
|
- 删除流程改为 DB 主删除 + 派生清理 job。
|
||||||
|
|
||||||
|
### 第三阶段:改造 pipeline 为状态机
|
||||||
|
|
||||||
|
目标:让任务可恢复、可补偿、可观测。
|
||||||
|
|
||||||
|
- 拆分 pipeline stages。
|
||||||
|
- 每个 stage 记录输入/输出/错误/耗时。
|
||||||
|
- 支持从失败阶段重跑。
|
||||||
|
- 将 fallback 策略配置化。
|
||||||
|
|
||||||
|
### 第四阶段:提升运行可靠性
|
||||||
|
|
||||||
|
目标:长期运行更稳。
|
||||||
|
|
||||||
|
- 引入 Alembic。
|
||||||
|
- 配置加 validator。
|
||||||
|
- 同步阻塞操作移出 Web event loop。
|
||||||
|
- 外部依赖 adapter 化。
|
||||||
|
- 管理后台展示任务、派生索引、失败类型、耗时统计。
|
||||||
|
|
||||||
|
## 最小可行改造方案
|
||||||
|
|
||||||
|
如果只做最小但收益最大的改动,建议优先做这三项:
|
||||||
|
|
||||||
|
1. 增加统一 Job 表和 job runner。
|
||||||
|
2. DB 作为权威源,FTS/Chroma/文件全部按派生物处理。
|
||||||
|
3. 增加 stale task 恢复和派生数据重建命令。
|
||||||
|
|
||||||
|
这三项能显著降低后续复杂度,也不会强迫项目马上拆成多个服务。
|
||||||
|
|
||||||
@@ -7,7 +7,7 @@
|
|||||||
## 功能特性
|
## 功能特性
|
||||||
|
|
||||||
- **每日抓取**:按日期拉取 HuggingFace Daily Papers,提取元数据并入库,自动去重与重试。
|
- **每日抓取**:按日期拉取 HuggingFace Daily Papers,提取元数据并入库,自动去重与重试。
|
||||||
- **AI 中文总结**:下载 PDF,调用 `pi` CLI 为每篇论文生成结构化中文解读(动机、方法、结果、局限性等),完成后清理临时文件。
|
- **AI 中文总结**:下载 PDF,通过 `pi` 或 `claude` 后端为每篇论文生成结构化中文解读(动机、方法、结果、局限性等),完成后清理临时文件。
|
||||||
- **浏览与详情**:首页按日期导航、论文详情页展示元数据与总结,提供未总结论文的英文原文回退。
|
- **浏览与详情**:首页按日期导航、论文详情页展示元数据与总结,提供未总结论文的英文原文回退。
|
||||||
- **搜索**:基于 SQLite FTS5 的关键词搜索(BM25 排序、片段高亮),覆盖标题、摘要、作者、标签与总结正文。
|
- **搜索**:基于 SQLite FTS5 的关键词搜索(BM25 排序、片段高亮),覆盖标题、摘要、作者、标签与总结正文。
|
||||||
- **语义搜索**(可选):ChromaDB 向量数据库实现相似度搜索,优雅降级至 FTS5。
|
- **语义搜索**(可选):ChromaDB 向量数据库实现相似度搜索,优雅降级至 FTS5。
|
||||||
@@ -15,9 +15,10 @@
|
|||||||
- **趋势看板**:Chart.js 驱动的可视化统计(日论文量、Top 标签、投票分布、总结完成率)。
|
- **趋势看板**:Chart.js 驱动的可视化统计(日论文量、Top 标签、投票分布、总结完成率)。
|
||||||
- **个人化**:收藏、阅读状态、个人笔记与阅读列表。
|
- **个人化**:收藏、阅读状态、个人笔记与阅读列表。
|
||||||
- **RSS 订阅**:最近 7 天论文的 RSS 2.0 输出,支持标签过滤。
|
- **RSS 订阅**:最近 7 天论文的 RSS 2.0 输出,支持标签过滤。
|
||||||
- **管理后台**:Token 鉴权的手动抓取、总结、清扫、删除与日志查看接口。
|
- **管理后台**:Session 认证的 Web 管理界面(仪表盘、论文管理、日志查看、手动操作)。
|
||||||
- **定时调度**:APScheduler 内嵌调度,默认每日 08:00 自动抓取与总结(TaskLock 防重)。
|
- **定时调度**:APScheduler 内嵌调度,默认每日自动抓取与总结(TaskLock 防重)。
|
||||||
- **LaTeX 图片提取**:下载 arXiv 源码,扫描 `.tex` 文件提取论文图片用于详情页展示。
|
- **LaTeX 图片提取**:下载 arXiv 源码,扫描 `.tex` 文件提取论文图片用于详情页展示。
|
||||||
|
- **布局检测**(可选):ONNX 模型识别 PDF 页面布局,提升图片提取精度。
|
||||||
- **HTMX 局部更新**:收藏切换等操作无需整页刷新。
|
- **HTMX 局部更新**:收藏切换等操作无需整页刷新。
|
||||||
- **键盘快捷键**:`Ctrl+K` 或 `/` 聚焦搜索框。
|
- **键盘快捷键**:`Ctrl+K` 或 `/` 聚焦搜索框。
|
||||||
|
|
||||||
@@ -31,7 +32,7 @@
|
|||||||
| 模板 | Jinja2(服务端渲染) |
|
| 模板 | Jinja2(服务端渲染) |
|
||||||
| 前端 | HTMX · 原生 JS · Chart.js · 自定义 CSS("kami" 纸质风格) |
|
| 前端 | HTMX · 原生 JS · Chart.js · 自定义 CSS("kami" 纸质风格) |
|
||||||
| 数据库 | SQLite + SQLAlchemy · SQLite FTS5(全文搜索) |
|
| 数据库 | SQLite + SQLAlchemy · SQLite FTS5(全文搜索) |
|
||||||
| AI 总结 | `pi` CLI(外部工具) |
|
| AI 总结 | `pi` CLI 或 `claude` CLI(可配置后端) |
|
||||||
| 语义搜索 | ChromaDB(可选) |
|
| 语义搜索 | ChromaDB(可选) |
|
||||||
| 调度 | APScheduler(内嵌单进程) |
|
| 调度 | APScheduler(内嵌单进程) |
|
||||||
| CLI | Typer |
|
| CLI | Typer |
|
||||||
@@ -53,33 +54,39 @@ paper/
|
|||||||
│ ├── main.py # FastAPI 入口(lifespan 管理)
|
│ ├── main.py # FastAPI 入口(lifespan 管理)
|
||||||
│ ├── config.py # pydantic-settings 配置加载
|
│ ├── config.py # pydantic-settings 配置加载
|
||||||
│ ├── database.py # SQLAlchemy 引擎、会话与 FTS5
|
│ ├── database.py # SQLAlchemy 引擎、会话与 FTS5
|
||||||
│ ├── models.py # 11 个 ORM 模型
|
│ ├── models.py # 11 个 ORM 模型 + 1 枚举
|
||||||
│ ├── utils.py # 共享工具函数
|
│ ├── utils.py # 共享工具函数
|
||||||
|
│ ├── exceptions.py # 统一业务异常体系
|
||||||
│ ├── cli.py # Typer CLI(crawl / summarize / init-db)
|
│ ├── cli.py # Typer CLI(crawl / summarize / init-db)
|
||||||
│ │
|
│ │
|
||||||
│ ├── routes/ # 页面与 API 路由
|
│ ├── routes/ # 页面与 API 路由
|
||||||
│ │ ├── __init__.py
|
│ │ ├── pages.py # 首页、日期页、论文详情、相似推荐
|
||||||
│ │ ├── pages.py # 首页、日期页、论文详情
|
│ │ ├── admin.py # Session 认证管理后台
|
||||||
│ │ ├── admin.py # Token 鉴权管理接口
|
|
||||||
│ │ ├── search.py # 搜索、阅读列表、RSS
|
│ │ ├── search.py # 搜索、阅读列表、RSS
|
||||||
│ │ ├── user.py # 收藏、阅读状态、笔记 API
|
│ │ ├── user.py # 收藏、阅读状态、笔记 API
|
||||||
│ │ ├── trends.py # 趋势看板
|
│ │ ├── trends.py # 趋势看板
|
||||||
│ │ └── compare.py # 论文对比页
|
│ │ └── compare.py # 论文对比页
|
||||||
│ │
|
│ │
|
||||||
│ ├── services/ # 业务逻辑层
|
│ ├── services/ # 业务逻辑层
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── crawler.py # HuggingFace API 爬虫
|
│ │ ├── crawler.py # HuggingFace API 爬虫
|
||||||
│ │ ├── summarizer.py # AI 总结编排
|
│ │ ├── summarizer.py # AI 总结编排(调度层)
|
||||||
|
│ │ ├── summary_generator.py # 总结生成与重试
|
||||||
|
│ │ ├── summary_persister.py # 总结持久化与文件管理
|
||||||
|
│ │ ├── summary_utils.py # 总结工具(prompt 构建、PDF 提取)
|
||||||
|
│ │ ├── pi_client.py # pi CLI 封装 + JSON 提取
|
||||||
|
│ │ ├── claude_backend.py # claude CLI 后端
|
||||||
│ │ ├── searcher.py # FTS5 + 语义搜索
|
│ │ ├── searcher.py # FTS5 + 语义搜索
|
||||||
│ │ ├── schemas.py # Pydantic 总结校验
|
│ │ ├── schemas.py # Pydantic 总结校验
|
||||||
│ │ ├── cleaner.py # 临时文件清理 + 日期范围删除
|
│ │ ├── cleaner.py # 临时文件清理 + 日期范围删除
|
||||||
│ │ ├── scheduler.py # APScheduler 每日管线
|
│ │ ├── scheduler.py # APScheduler 每日管线
|
||||||
|
│ │ ├── pipeline.py # 抓取 + 总结流水线编排
|
||||||
|
│ │ ├── admin.py # 管理后台查询与统计
|
||||||
│ │ ├── user_data.py # 收藏、阅读状态、笔记
|
│ │ ├── user_data.py # 收藏、阅读状态、笔记
|
||||||
│ │ ├── embedder.py # ChromaDB 向量索引
|
│ │ ├── embedder.py # ChromaDB 向量索引
|
||||||
│ │ ├── trends.py # 趋势统计聚合
|
│ │ ├── trends.py # 趋势统计聚合
|
||||||
│ │ ├── pdf_downloader.py # PDF + LaTeX 源码下载
|
│ │ ├── pdf_downloader.py # PDF + LaTeX 源码下载
|
||||||
│ │ ├── pi_client.py # pi CLI 封装 + JSON 提取
|
│ │ ├── pdf_image_extractor.py # LaTeX 图片提取 + 图表关联
|
||||||
│ │ └── image_extractor.py # LaTeX 图片提取
|
│ │ └── layout_detector.py # ONNX 布局检测(可选)
|
||||||
│ │
|
│ │
|
||||||
│ ├── templates/ # Jinja2 模板
|
│ ├── templates/ # Jinja2 模板
|
||||||
│ │ ├── base.html
|
│ │ ├── base.html
|
||||||
@@ -89,24 +96,40 @@ paper/
|
|||||||
│ │ ├── reading_list.html
|
│ │ ├── reading_list.html
|
||||||
│ │ ├── compare.html
|
│ │ ├── compare.html
|
||||||
│ │ ├── trends.html
|
│ │ ├── trends.html
|
||||||
|
│ │ ├── login.html
|
||||||
|
│ │ ├── admin_dashboard.html
|
||||||
|
│ │ ├── admin_papers.html
|
||||||
│ │ ├── admin_logs.html
|
│ │ ├── admin_logs.html
|
||||||
│ │ └── partials/paper_card.html
|
│ │ └── partials/
|
||||||
|
│ │ ├── admin_subnav.html
|
||||||
|
│ │ ├── paper_card.html
|
||||||
|
│ │ └── summary_list.html
|
||||||
│ │
|
│ │
|
||||||
│ └── static/
|
│ └── static/
|
||||||
│ ├── css/style.css # 自定义 CSS(kami 风格)
|
│ ├── css/
|
||||||
│ └── js/app.js # 键盘快捷键
|
│ │ ├── style.css # 自定义 CSS(kami 风格)
|
||||||
|
│ │ └── admin.css # 管理后台样式
|
||||||
|
│ ├── js/
|
||||||
|
│ │ ├── app.js # 键盘快捷键
|
||||||
|
│ │ ├── date-picker.js # 日期导航
|
||||||
|
│ │ └── lightbox.js # 图片灯箱
|
||||||
|
│ └── favicon.svg
|
||||||
│
|
│
|
||||||
├── data/ # 运行时数据(已 gitignore)
|
├── data/ # 运行时数据(已 gitignore)
|
||||||
│ ├── db/papers.db # SQLite 数据库
|
│ ├── db/papers.db # SQLite 数据库
|
||||||
│ ├── papers/{arxiv_id}/ # 长期资产(meta.json / summary.json / 图片)
|
│ ├── papers/{arxiv_id}/ # 长期资产(meta.json / summary.json / 图片)
|
||||||
│ ├── tmp/{arxiv_id}/ # 临时下载(流程完成后清理)
|
│ ├── tmp/{arxiv_id}/ # 临时下载(流程完成后清理)
|
||||||
│ └── chroma/ # ChromaDB 向量库(可选)
|
│ ├── chroma/ # ChromaDB 向量库(可选)
|
||||||
|
│ └── models/ # ONNX 模型(布局检测)
|
||||||
│
|
│
|
||||||
├── scripts/
|
├── scripts/
|
||||||
│ ├── init_db.py # 数据库初始化
|
│ ├── init_db.py # 数据库初始化
|
||||||
│ └── manual_crawl.py # 手动抓取脚本
|
│ ├── manual_crawl.py # 手动抓取脚本
|
||||||
|
│ ├── export_picodet_onnx.py # 导出布局检测 ONNX 模型
|
||||||
|
│ ├── reextract_images.py # 批量重新提取图片
|
||||||
|
│ └── validate_summary.py # 校验总结 JSON 结构
|
||||||
│
|
│
|
||||||
├── tests/ # 9 个测试模块
|
├── tests/ # 13 个测试模块
|
||||||
│ ├── conftest.py # 测试夹具(内存 DB、样本数据)
|
│ ├── conftest.py # 测试夹具(内存 DB、样本数据)
|
||||||
│ └── test_*.py # 各模块测试
|
│ └── test_*.py # 各模块测试
|
||||||
│
|
│
|
||||||
@@ -120,21 +143,23 @@ paper/
|
|||||||
### 1. 准备环境
|
### 1. 准备环境
|
||||||
|
|
||||||
- Python **3.12+**
|
- Python **3.12+**
|
||||||
- 可选:[`pi`](https://www.npmjs.com/package/@mariozechner/pi-coding-agent) CLI(用于 AI 总结)
|
- [uv](https://docs.astral.sh/uv/) 包管理器
|
||||||
|
- 可选:[`pi`](https://www.npmjs.com/package/@mariozechner/pi-coding-agent) CLI 或 [`claude`](https://claude.ai/code) CLI(用于 AI 总结)
|
||||||
|
|
||||||
### 2. 安装依赖
|
### 2. 安装依赖
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m venv .venv
|
cp .env.example .env
|
||||||
source .venv/bin/activate
|
uv sync
|
||||||
pip install -e ".[dev]"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. 配置环境变量
|
### 3. 配置环境变量
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cp .env.example .env
|
# 编辑 .env,至少修改以下三项
|
||||||
# 编辑 .env,至少修改 ADMIN_TOKEN
|
ADMIN_USERNAME=admin
|
||||||
|
ADMIN_PASSWORD=your_secure_password
|
||||||
|
SECRET_KEY=your_random_secret_key
|
||||||
```
|
```
|
||||||
|
|
||||||
关键配置项:
|
关键配置项:
|
||||||
@@ -145,19 +170,26 @@ cp .env.example .env
|
|||||||
| `APP_DEBUG` | `false` | 调试模式(开启 uvicorn reload) |
|
| `APP_DEBUG` | `false` | 调试模式(开启 uvicorn reload) |
|
||||||
| `BASE_URL` | `http://127.0.0.1:8000` | 站点根 URL(用于 RSS 生成) |
|
| `BASE_URL` | `http://127.0.0.1:8000` | 站点根 URL(用于 RSS 生成) |
|
||||||
| `APP_TIMEZONE` | `Asia/Shanghai` | 时区 |
|
| `APP_TIMEZONE` | `Asia/Shanghai` | 时区 |
|
||||||
| `ADMIN_TOKEN` | `change-me` | **必须修改** — 管理接口鉴权 |
|
| `ADMIN_USERNAME` | `admin` | 管理后台用户名 |
|
||||||
|
| `ADMIN_PASSWORD` | — | 管理后台密码 |
|
||||||
|
| `SECRET_KEY` | `change-me` | Session 签名密钥 |
|
||||||
| `HF_API_BASE` | `https://huggingface.co/api` | HuggingFace API 地址 |
|
| `HF_API_BASE` | `https://huggingface.co/api` | HuggingFace API 地址 |
|
||||||
| `HF_PROXY` | — | HTTP 代理 |
|
| `HF_PROXY` | — | HTTP 代理 |
|
||||||
| `TOP_N` | `20` | 每日抓取 Top N 论文 |
|
| `TOP_N` | `20` | 每日抓取 Top N 论文 |
|
||||||
| `HTTP_TIMEOUT_SECONDS` | `30` | HTTP 请求超时 |
|
| `HTTP_TIMEOUT_SECONDS` | `30` | HTTP 请求超时 |
|
||||||
| `HTTP_MAX_RETRIES` | `3` | HTTP 最大重试次数 |
|
| `HTTP_MAX_RETRIES` | `3` | HTTP 最大重试次数 |
|
||||||
|
| `SUMMARY_BACKEND` | `pi` | 总结后端:`pi` 或 `claude` |
|
||||||
| `PI_BIN` | — | `pi` CLI 路径 |
|
| `PI_BIN` | — | `pi` CLI 路径 |
|
||||||
|
| `CLAUDE_BIN` | `claude` | `claude` CLI 路径 |
|
||||||
| `SUMMARY_SKILL` | `daily-paper-summary` | pi 总结技能名 |
|
| `SUMMARY_SKILL` | `daily-paper-summary` | pi 总结技能名 |
|
||||||
| `SUMMARY_CONCURRENCY` | `3` | 最大并行总结数 |
|
| `SUMMARY_CONCURRENCY` | `3` | 最大并行总结数 |
|
||||||
| `SUMMARY_TIMEOUT_SECONDS` | `300` | 单篇总结超时 |
|
| `SUMMARY_TIMEOUT_SECONDS` | `1200` | 单篇总结超时 |
|
||||||
| `SUMMARY_MAX_RETRIES` | `1` | 总结最大重试次数 |
|
| `SUMMARY_MAX_RETRIES` | `2` | 总结最大重试次数 |
|
||||||
|
| `SUMMARY_PDF_MODE` | `auto` | PDF 传递方式:`auto` / `inject` / `search` |
|
||||||
|
| `UPVOTE_REFRESH_DAYS` | `7` | 自动刷新最近 N 天论文的 upvotes |
|
||||||
|
| `PDF_DOWNLOAD_TIMEOUT` | `120` | PDF 下载超时(秒) |
|
||||||
| `SCHEDULER_ENABLED` | `false` | 启用每日自动抓取 |
|
| `SCHEDULER_ENABLED` | `false` | 启用每日自动抓取 |
|
||||||
| `SCHEDULE_HOUR` / `SCHEDULE_MINUTE` | `8` / `0` | 定时任务时间(APP_TIMEZONE) |
|
| `SCHEDULE_HOUR` / `SCHEDULE_MINUTE` | `4` / `0` | 定时任务时间(APP_TIMEZONE) |
|
||||||
| `APP_WORKERS` | `1` | Uvicorn worker 数(必须为 1) |
|
| `APP_WORKERS` | `1` | Uvicorn worker 数(必须为 1) |
|
||||||
| `DATABASE_URL` | `sqlite:///data/db/papers.db` | 数据库路径 |
|
| `DATABASE_URL` | `sqlite:///data/db/papers.db` | 数据库路径 |
|
||||||
| `CHROMA_ENABLED` | `false` | 启用语义搜索 |
|
| `CHROMA_ENABLED` | `false` | 启用语义搜索 |
|
||||||
@@ -166,18 +198,19 @@ cp .env.example .env
|
|||||||
| `EMBED_API_KEY` | — | Embedding API Key |
|
| `EMBED_API_KEY` | — | Embedding API Key |
|
||||||
| `EMBED_MODEL` | — | Embedding 模型名 |
|
| `EMBED_MODEL` | — | Embedding 模型名 |
|
||||||
| `EMBED_DIMENSIONS` | `0` | 向量维度 |
|
| `EMBED_DIMENSIONS` | `0` | 向量维度 |
|
||||||
|
| `LAYOUT_MODEL_PATH` | `data/models/picodet_layout_3cls.onnx` | ONNX 布局检测模型路径(可选) |
|
||||||
|
| `LAYOUT_THRESHOLD` | `0.5` | 布局检测置信度阈值(可选) |
|
||||||
|
|
||||||
### 4. 初始化数据库
|
### 4. 初始化数据库
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/init_db.py
|
uv run python -m app.cli init-db
|
||||||
# 或:python -m app.cli init-db
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 5. 启动服务
|
### 5. 启动服务
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uvicorn app.main:app --host 127.0.0.1 --port 8000
|
uv run python -m app.main
|
||||||
```
|
```
|
||||||
|
|
||||||
> 调度器依赖单 worker:不可使用 `--workers > 1`,否则每日任务会被重复触发。
|
> 调度器依赖单 worker:不可使用 `--workers > 1`,否则每日任务会被重复触发。
|
||||||
@@ -188,51 +221,59 @@ uvicorn app.main:app --host 127.0.0.1 --port 8000
|
|||||||
|
|
||||||
## 常用命令
|
## 常用命令
|
||||||
|
|
||||||
### 手动抓取指定日期
|
### 手动抓取
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/manual_crawl.py 2025-01-15
|
# 自动探测今天/昨天
|
||||||
# 或
|
uv run python -m app.cli crawl
|
||||||
python -m app.cli crawl 2025-01-15 --top 20
|
|
||||||
|
# 指定日期
|
||||||
|
uv run python -m app.cli crawl 2025-01-15 --top 20
|
||||||
|
|
||||||
|
# 强制重抓(即使已有数据)
|
||||||
|
uv run python -m app.cli crawl --force
|
||||||
```
|
```
|
||||||
|
|
||||||
### 手动触发总结
|
### 手动触发总结
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 单篇
|
# 单篇
|
||||||
python -m app.cli summarize 2401.01234
|
uv run python -m app.cli summarize 2401.01234
|
||||||
|
|
||||||
# 批量(所有待总结论文)
|
# 批量(所有待总结论文)
|
||||||
python -m app.cli summarize
|
uv run python -m app.cli summarize
|
||||||
|
|
||||||
|
# 指定后端和 PDF 模式
|
||||||
|
uv run python -m app.cli summarize --backend claude --pdf-mode inject
|
||||||
```
|
```
|
||||||
|
|
||||||
### 管理接口(Token 鉴权)
|
### 管理后台
|
||||||
|
|
||||||
```bash
|
打开浏览器访问 `http://127.0.0.1:8000/admin/login`,使用 `.env` 中配置的用户名密码登录。
|
||||||
# 抓取今日论文
|
|
||||||
curl -X POST "http://127.0.0.1:8000/admin/crawl" \
|
|
||||||
-H "Authorization: Bearer $ADMIN_TOKEN"
|
|
||||||
|
|
||||||
# 批量总结
|
管理后台包含:
|
||||||
curl -X POST "http://127.0.0.1:8000/admin/summarize" \
|
- **仪表盘**:统计卡、调度器控制、存储信息、最近活动
|
||||||
-H "Authorization: Bearer $ADMIN_TOKEN"
|
- **论文管理**:搜索、筛选、单篇/批量删除
|
||||||
|
- **日志**:运行日志、总结状态、失败重试
|
||||||
# 单篇总结
|
|
||||||
curl -X POST "http://127.0.0.1:8000/admin/summarize/2401.01234" \
|
|
||||||
-H "Authorization: Bearer $ADMIN_TOKEN"
|
|
||||||
```
|
|
||||||
|
|
||||||
### 运行测试
|
### 运行测试
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pytest
|
uv run pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
### 代码检查
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run ruff format # 格式化代码
|
||||||
|
uv run ruff check --fix # 自动修复 lint 问题
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 安全提示
|
## 安全提示
|
||||||
|
|
||||||
- `ADMIN_TOKEN` 是管理接口的唯一鉴权凭证,请使用强随机值并妥善保管。
|
- 管理后台使用 Session 认证,请务必在 `.env` 中设置强密码和随机 `SECRET_KEY`。
|
||||||
- 默认仅监听 `127.0.0.1`,如需内网访问请配合反向代理与 HTTPS。
|
- 默认仅监听 `127.0.0.1`,如需内网访问请配合反向代理与 HTTPS。
|
||||||
- 项目面向本地 / 内网部署,不包含多用户账号体系与公网防护。
|
- 项目面向本地 / 内网部署,不包含多用户账号体系与公网防护。
|
||||||
|
|
||||||
|
|||||||
+27
-14
@@ -42,9 +42,16 @@ def crawl(
|
|||||||
try:
|
try:
|
||||||
# 检查是否已抓取过(非 force 模式)
|
# 检查是否已抓取过(非 force 模式)
|
||||||
if not force and not date_str:
|
if not force and not date_str:
|
||||||
existing = db.scalar(select(func.count(Paper.id)).where(Paper.paper_date == target)) or 0
|
existing = (
|
||||||
|
db.scalar(
|
||||||
|
select(func.count(Paper.id)).where(Paper.paper_date == target)
|
||||||
|
)
|
||||||
|
or 0
|
||||||
|
)
|
||||||
if existing > 0:
|
if existing > 0:
|
||||||
typer.echo(f"⏭️ {target} 已有 {existing} 篇论文,跳过(用 --force 强制重抓)")
|
typer.echo(
|
||||||
|
f"⏭️ {target} 已有 {existing} 篇论文,跳过(用 --force 强制重抓)"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo(f"📡 开始抓取 {target} ...")
|
typer.echo(f"📡 开始抓取 {target} ...")
|
||||||
@@ -56,7 +63,12 @@ def crawl(
|
|||||||
)
|
)
|
||||||
if need_fallback:
|
if need_fallback:
|
||||||
fallback = yesterday_str()
|
fallback = yesterday_str()
|
||||||
existing = db.scalar(select(func.count(Paper.id)).where(Paper.paper_date == fallback)) or 0
|
existing = (
|
||||||
|
db.scalar(
|
||||||
|
select(func.count(Paper.id)).where(Paper.paper_date == fallback)
|
||||||
|
)
|
||||||
|
or 0
|
||||||
|
)
|
||||||
if existing > 0:
|
if existing > 0:
|
||||||
typer.echo(
|
typer.echo(
|
||||||
f"⏭️ {fallback} 已有 {existing} 篇论文,跳过(用 --force 强制重抓)"
|
f"⏭️ {fallback} 已有 {existing} 篇论文,跳过(用 --force 强制重抓)"
|
||||||
@@ -103,7 +115,9 @@ def summarize(
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
if pdf_mode not in ("auto", "inject", "search"):
|
if pdf_mode not in ("auto", "inject", "search"):
|
||||||
typer.echo(f"❌ 无效的 pdf_mode: {pdf_mode},只支持 auto / inject / search", err=True)
|
typer.echo(
|
||||||
|
f"❌ 无效的 pdf_mode: {pdf_mode},只支持 auto / inject / search", err=True
|
||||||
|
)
|
||||||
raise typer.Exit(code=1)
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
if backend:
|
if backend:
|
||||||
@@ -122,6 +136,8 @@ def summarize(
|
|||||||
datefmt="%H:%M:%S",
|
datefmt="%H:%M:%S",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from app.exceptions import ConflictError, NotFoundError
|
||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
if arxiv_id:
|
if arxiv_id:
|
||||||
@@ -131,16 +147,13 @@ def summarize(
|
|||||||
typer.echo(f"🤖 开始批量总结 pending 论文 (mode={pdf_mode}) ...")
|
typer.echo(f"🤖 开始批量总结 pending 论文 (mode={pdf_mode}) ...")
|
||||||
result = asyncio.run(summarize_batch(db, pdf_mode=pdf_mode))
|
result = asyncio.run(summarize_batch(db, pdf_mode=pdf_mode))
|
||||||
|
|
||||||
if result.get("status") in ("success", "done"):
|
typer.echo(f"✅ 总结完成:{result}")
|
||||||
typer.echo(f"✅ 总结完成:{result}")
|
except NotFoundError as exc:
|
||||||
elif result.get("status") == "conflict":
|
typer.echo(f"❌ {exc.message}", err=True)
|
||||||
typer.echo("⚠️ 已有批量总结任务在运行中", err=True)
|
raise typer.Exit(code=1) from exc
|
||||||
raise typer.Exit(code=1)
|
except ConflictError as exc:
|
||||||
elif result.get("status") == "not_found":
|
typer.echo(f"⚠️ {exc.message}", err=True)
|
||||||
typer.echo(f"❌ 论文未找到:{arxiv_id}", err=True)
|
raise typer.Exit(code=1) from exc
|
||||||
raise typer.Exit(code=1)
|
|
||||||
else:
|
|
||||||
typer.echo(f"⚠️ 总结结果:{result}", err=True)
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|||||||
+8
-1
@@ -27,6 +27,7 @@ class Settings(BaseSettings):
|
|||||||
HTTP_TIMEOUT_SECONDS: int = 30
|
HTTP_TIMEOUT_SECONDS: int = 30
|
||||||
HTTP_MAX_RETRIES: int = 3
|
HTTP_MAX_RETRIES: int = 3
|
||||||
HTTP_USER_AGENT: str = "hf-daily-papers-local/0.1"
|
HTTP_USER_AGENT: str = "hf-daily-papers-local/0.1"
|
||||||
|
PDF_DOWNLOAD_TIMEOUT: int = 120
|
||||||
|
|
||||||
# AI 总结
|
# AI 总结
|
||||||
SUMMARY_BACKEND: str = "pi" # "pi" | "claude"
|
SUMMARY_BACKEND: str = "pi" # "pi" | "claude"
|
||||||
@@ -36,7 +37,9 @@ class Settings(BaseSettings):
|
|||||||
SUMMARY_CONCURRENCY: int = 3
|
SUMMARY_CONCURRENCY: int = 3
|
||||||
SUMMARY_TIMEOUT_SECONDS: int = 1200
|
SUMMARY_TIMEOUT_SECONDS: int = 1200
|
||||||
SUMMARY_MAX_RETRIES: int = 2
|
SUMMARY_MAX_RETRIES: int = 2
|
||||||
SUMMARY_PDF_MODE: str = "auto" # "auto" = ≤80k 用 inject,>80k 用 search;也可强制 "inject" / "search"
|
SUMMARY_PDF_MODE: str = (
|
||||||
|
"auto" # "auto" = ≤80k 用 inject,>80k 用 search;也可强制 "inject" / "search"
|
||||||
|
)
|
||||||
|
|
||||||
# 调度
|
# 调度
|
||||||
SCHEDULER_ENABLED: bool = False
|
SCHEDULER_ENABLED: bool = False
|
||||||
@@ -56,6 +59,10 @@ class Settings(BaseSettings):
|
|||||||
EMBED_MODEL: str = ""
|
EMBED_MODEL: str = ""
|
||||||
EMBED_DIMENSIONS: int = 0
|
EMBED_DIMENSIONS: int = 0
|
||||||
|
|
||||||
|
# 布局检测
|
||||||
|
LAYOUT_MODEL_PATH: str = "data/models/picodet_layout_3cls.onnx"
|
||||||
|
LAYOUT_THRESHOLD: float = 0.5
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"env_file": str(BASE_DIR / ".env"),
|
"env_file": str(BASE_DIR / ".env"),
|
||||||
"env_file_encoding": "utf-8",
|
"env_file_encoding": "utf-8",
|
||||||
|
|||||||
+2
-5
@@ -82,15 +82,12 @@ def _migrate(engine) -> None:
|
|||||||
for table, columns in _MIGRATIONS.items():
|
for table, columns in _MIGRATIONS.items():
|
||||||
# 获取已有列名
|
# 获取已有列名
|
||||||
existing = {
|
existing = {
|
||||||
row[1]
|
row[1] for row in conn.execute(text(f"PRAGMA table_info({table})"))
|
||||||
for row in conn.execute(text(f"PRAGMA table_info({table})"))
|
|
||||||
}
|
}
|
||||||
for col_name, col_type in columns:
|
for col_name, col_type in columns:
|
||||||
if col_name not in existing:
|
if col_name not in existing:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
text(
|
text(f"ALTER TABLE {table} ADD COLUMN {col_name} {col_type}")
|
||||||
f"ALTER TABLE {table} ADD COLUMN {col_name} {col_type}"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
logger.info("Migrated: %s.%s added", table, col_name)
|
logger.info("Migrated: %s.%s added", table, col_name)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|||||||
@@ -0,0 +1,35 @@
|
|||||||
|
"""业务异常体系 — 统一错误类型,供路由层和 service 层使用。
|
||||||
|
|
||||||
|
路由层通过 main.py 的 @app.exception_handler(AppError) 统一捕获,
|
||||||
|
转为对应 HTTP 状态码 + JSON 响应。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
class AppError(Exception):
|
||||||
|
"""所有业务异常的基类。"""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "", *, detail: str = ""):
|
||||||
|
self.message = message or detail or self.__class__.__name__
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class NotFoundError(AppError):
|
||||||
|
"""资源不存在(404)。"""
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationError(AppError):
|
||||||
|
"""请求参数校验失败(400)。"""
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalAPIError(AppError):
|
||||||
|
"""外部 API 调用失败(502)。"""
|
||||||
|
|
||||||
|
|
||||||
|
class PdfProcessError(AppError):
|
||||||
|
"""PDF 处理错误(500)。"""
|
||||||
|
|
||||||
|
|
||||||
|
class ConflictError(AppError):
|
||||||
|
"""资源冲突(409)— 如锁冲突、并发任务冲突。"""
|
||||||
+30
-3
@@ -5,10 +5,12 @@ import os
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
from app.exceptions import AppError, ConflictError, ExternalAPIError, NotFoundError, PdfProcessError, ValidationError
|
||||||
from app.database import engine, init_db
|
from app.database import engine, init_db
|
||||||
from app.routes.admin import router as admin_router
|
from app.routes.admin import router as admin_router
|
||||||
from app.routes.compare import router as compare_router
|
from app.routes.compare import router as compare_router
|
||||||
@@ -38,8 +40,10 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
# ── shutdown ──
|
# ── shutdown ──
|
||||||
from app.services.scheduler import stop_scheduler
|
from app.services.scheduler import stop_scheduler
|
||||||
|
from app.services.pdf_downloader import close_http_session
|
||||||
|
|
||||||
stop_scheduler()
|
stop_scheduler()
|
||||||
|
close_http_session()
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
@@ -60,15 +64,38 @@ def create_app() -> FastAPI:
|
|||||||
# Session 中间件
|
# Session 中间件
|
||||||
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY)
|
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY)
|
||||||
|
|
||||||
|
# ── 统一业务异常处理 ──
|
||||||
|
@app.exception_handler(NotFoundError)
|
||||||
|
async def _not_found_handler(request, exc):
|
||||||
|
return JSONResponse(status_code=404, content={"error": exc.message})
|
||||||
|
|
||||||
|
@app.exception_handler(ValidationError)
|
||||||
|
async def _validation_handler(request, exc):
|
||||||
|
return JSONResponse(status_code=400, content={"error": exc.message})
|
||||||
|
|
||||||
|
@app.exception_handler(ExternalAPIError)
|
||||||
|
async def _external_api_handler(request, exc):
|
||||||
|
return JSONResponse(status_code=502, content={"error": exc.message})
|
||||||
|
|
||||||
|
@app.exception_handler(PdfProcessError)
|
||||||
|
async def _pdf_process_handler(request, exc):
|
||||||
|
return JSONResponse(status_code=500, content={"error": exc.message})
|
||||||
|
|
||||||
|
@app.exception_handler(ConflictError)
|
||||||
|
async def _conflict_handler(request, exc):
|
||||||
|
return JSONResponse(status_code=409, content={"error": exc.message})
|
||||||
|
|
||||||
|
@app.exception_handler(AppError)
|
||||||
|
async def _app_error_handler(request, exc):
|
||||||
|
return JSONResponse(status_code=500, content={"error": exc.message})
|
||||||
|
|
||||||
# 安全警告
|
# 安全警告
|
||||||
if settings.SECRET_KEY == "change-me":
|
if settings.SECRET_KEY == "change-me":
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"⚠️ SECRET_KEY is the default value 'change-me'. Please change it in .env!"
|
"⚠️ SECRET_KEY is the default value 'change-me'. Please change it in .env!"
|
||||||
)
|
)
|
||||||
if not settings.ADMIN_PASSWORD:
|
if not settings.ADMIN_PASSWORD:
|
||||||
logger.warning(
|
logger.warning("⚠️ ADMIN_PASSWORD is empty. Please set it in .env!")
|
||||||
"⚠️ ADMIN_PASSWORD is empty. Please set it in .env!"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 静态文件
|
# 静态文件
|
||||||
app.mount("/static", StaticFiles(directory="app/static"), name="static")
|
app.mount("/static", StaticFiles(directory="app/static"), name="static")
|
||||||
|
|||||||
+17
-4
@@ -12,6 +12,7 @@ from sqlalchemy import (
|
|||||||
String,
|
String,
|
||||||
Text,
|
Text,
|
||||||
UniqueConstraint,
|
UniqueConstraint,
|
||||||
|
select,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import joinedload, relationship
|
from sqlalchemy.orm import joinedload, relationship
|
||||||
|
|
||||||
@@ -93,7 +94,7 @@ class PaperAuthor(Base):
|
|||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
paper_id = Column(
|
paper_id = Column(
|
||||||
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
|
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
)
|
)
|
||||||
name = Column(String, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
position = Column(Integer, default=0)
|
position = Column(Integer, default=0)
|
||||||
@@ -108,7 +109,7 @@ class PaperTag(Base):
|
|||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
paper_id = Column(
|
paper_id = Column(
|
||||||
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
|
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
)
|
)
|
||||||
tag = Column(String, nullable=False)
|
tag = Column(String, nullable=False)
|
||||||
source = Column(String, default="hf")
|
source = Column(String, default="hf")
|
||||||
@@ -155,7 +156,7 @@ class SummaryStatus(Base):
|
|||||||
paper_id = Column(
|
paper_id = Column(
|
||||||
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
|
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
|
||||||
)
|
)
|
||||||
status = Column(String, nullable=False, default="pending")
|
status = Column(String, nullable=False, default="pending", index=True)
|
||||||
quality = Column(String)
|
quality = Column(String)
|
||||||
error_type = Column(String)
|
error_type = Column(String)
|
||||||
error = Column(Text)
|
error = Column(Text)
|
||||||
@@ -219,7 +220,7 @@ class UserReadingStatus(Base):
|
|||||||
paper_id = Column(
|
paper_id = Column(
|
||||||
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
|
Integer, ForeignKey("papers.id", ondelete="CASCADE"), nullable=False
|
||||||
)
|
)
|
||||||
status = Column(String, nullable=False, default="unread")
|
status = Column(String, nullable=False, default="unread", index=True)
|
||||||
updated_at = Column(DateTime, nullable=False)
|
updated_at = Column(DateTime, nullable=False)
|
||||||
|
|
||||||
paper = relationship("Paper", back_populates="reading_status")
|
paper = relationship("Paper", back_populates="reading_status")
|
||||||
@@ -271,3 +272,15 @@ PAPER_FULL_LOAD = (
|
|||||||
joinedload(Paper.bookmark),
|
joinedload(Paper.bookmark),
|
||||||
joinedload(Paper.reading_status),
|
joinedload(Paper.reading_status),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_paper_by_arxiv_id(db, arxiv_id: str, *, load=PAPER_DEFAULT_LOAD):
|
||||||
|
"""按 arxiv_id 查询论文(带关联加载),未找到返回 None。"""
|
||||||
|
stmt = select(Paper).where(Paper.arxiv_id == arxiv_id).options(*load)
|
||||||
|
return db.execute(stmt).unique().scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
def get_paper_by_id(db, paper_id: int, *, load=PAPER_DEFAULT_LOAD):
|
||||||
|
"""按主键查询论文(带关联加载),未找到返回 None。"""
|
||||||
|
stmt = select(Paper).where(Paper.id == paper_id).options(*load)
|
||||||
|
return db.execute(stmt).unique().scalar_one_or_none()
|
||||||
|
|||||||
+94
-144
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import hmac
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import date
|
from datetime import date
|
||||||
@@ -10,7 +11,7 @@ from datetime import date
|
|||||||
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
from sqlalchemy import func, select, text
|
from sqlalchemy import bindparam, func, select, text
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
@@ -22,15 +23,15 @@ from app.models import (
|
|||||||
PaperTag,
|
PaperTag,
|
||||||
SummaryState,
|
SummaryState,
|
||||||
SummaryStatus,
|
SummaryStatus,
|
||||||
TaskLock,
|
|
||||||
)
|
)
|
||||||
|
from app.services import admin as admin_svc
|
||||||
from app.services.admin import get_admin_stats
|
from app.services.admin import get_admin_stats
|
||||||
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
|
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
|
||||||
from app.services.crawler import crawl_daily, refresh_upvotes
|
from app.services.crawler import refresh_upvotes
|
||||||
from app.services.pipeline import run_pipeline
|
from app.services.pipeline import run_crawl, run_pipeline
|
||||||
from app.services.scheduler import get_scheduler
|
from app.services.scheduler import get_scheduler
|
||||||
from app.services.summarizer import summarize_batch, summarize_single
|
from app.services.summarizer import summarize_batch, summarize_single
|
||||||
from app.utils import release_lock, templates, today_str, utc_now
|
from app.utils import templates, today_str, utc_now
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -41,14 +42,15 @@ router = APIRouter(prefix="/admin", tags=["admin"])
|
|||||||
|
|
||||||
|
|
||||||
def _check_password(password: str) -> bool:
|
def _check_password(password: str) -> bool:
|
||||||
"""校验密码,支持明文或 sha256 哈希。"""
|
"""校验密码,支持明文或 sha256 哈希(常量时间比较)。"""
|
||||||
stored = settings.ADMIN_PASSWORD
|
stored = settings.ADMIN_PASSWORD
|
||||||
if not stored:
|
if not stored:
|
||||||
return False
|
return False
|
||||||
if password == stored:
|
if hmac.compare_digest(password, stored):
|
||||||
return True
|
return True
|
||||||
# 也支持存 sha256 哈希
|
# 也支持存 sha256 哈希
|
||||||
return hashlib.sha256(password.encode()).hexdigest() == stored
|
hashed = hashlib.sha256(password.encode()).hexdigest()
|
||||||
|
return hmac.compare_digest(hashed, stored)
|
||||||
|
|
||||||
|
|
||||||
async def verify_admin(request: Request) -> None:
|
async def verify_admin(request: Request) -> None:
|
||||||
@@ -204,32 +206,12 @@ async def admin_crawl(
|
|||||||
):
|
):
|
||||||
"""手动抓取指定日期,默认今天。"""
|
"""手动抓取指定日期,默认今天。"""
|
||||||
target_date = date or today_str()
|
target_date = date or today_str()
|
||||||
|
|
||||||
# TaskLock 防重入
|
|
||||||
now = utc_now()
|
|
||||||
lock = TaskLock(
|
|
||||||
task="crawl",
|
|
||||||
lock_key=target_date,
|
|
||||||
status="running",
|
|
||||||
owner="admin_crawl",
|
|
||||||
acquired_at=now,
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
db.add(lock)
|
return await run_crawl(db, target_date, owner="admin_crawl")
|
||||||
db.commit()
|
except RuntimeError as exc:
|
||||||
except Exception:
|
raise HTTPException(status_code=409, detail=str(exc))
|
||||||
db.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=409, detail=f"Crawl already running for {target_date}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await crawl_daily(db, target_date)
|
|
||||||
return result
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise HTTPException(status_code=500, detail=str(exc))
|
raise HTTPException(status_code=500, detail=str(exc))
|
||||||
finally:
|
|
||||||
release_lock(db, lock)
|
|
||||||
|
|
||||||
|
|
||||||
# ── 总结 ──────────────────────────────────────────────────────────────
|
# ── 总结 ──────────────────────────────────────────────────────────────
|
||||||
@@ -241,12 +223,7 @@ async def admin_summarize_batch(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""批量总结所有 pending 论文。"""
|
"""批量总结所有 pending 论文。"""
|
||||||
result = await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE)
|
return await summarize_batch(db, pdf_mode=settings.SUMMARY_PDF_MODE)
|
||||||
if result.get("status") == "conflict":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=409, detail=result.get("error", "batch already running")
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/summarize/{arxiv_id}")
|
@router.post("/summarize/{arxiv_id}")
|
||||||
@@ -256,10 +233,9 @@ async def admin_summarize_single(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""总结或重跑单篇论文。"""
|
"""总结或重跑单篇论文。"""
|
||||||
result = await summarize_single(db, arxiv_id, force=True, pdf_mode=settings.SUMMARY_PDF_MODE)
|
return await summarize_single(
|
||||||
if result.get("status") == "not_found":
|
db, arxiv_id, force=True, pdf_mode=settings.SUMMARY_PDF_MODE
|
||||||
raise HTTPException(status_code=404, detail=f"Paper not found: {arxiv_id}")
|
)
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# ── 清理 ──────────────────────────────────────────────────────────────
|
# ── 清理 ──────────────────────────────────────────────────────────────
|
||||||
@@ -284,10 +260,13 @@ async def admin_cleanup(
|
|||||||
result = cleanup_tmp()
|
result = cleanup_tmp()
|
||||||
log_entry.status = "success"
|
log_entry.status = "success"
|
||||||
log_entry.completed_at = utc_now()
|
log_entry.completed_at = utc_now()
|
||||||
log_entry.details_json = json.dumps({
|
log_entry.details_json = json.dumps(
|
||||||
"scanned": result.get("scanned", 0),
|
{
|
||||||
"removed": result.get("removed", 0),
|
"scanned": result.get("scanned", 0),
|
||||||
}, ensure_ascii=False)
|
"removed": result.get("removed", 0),
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
if result.get("errors"):
|
if result.get("errors"):
|
||||||
log_entry.error = "; ".join(result["errors"])[:2000]
|
log_entry.error = "; ".join(result["errors"])[:2000]
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -358,19 +337,34 @@ async def admin_logs(
|
|||||||
|
|
||||||
# 总结状态统计概要
|
# 总结状态统计概要
|
||||||
summary_total = db.scalar(select(func.count(Paper.id))) or 0
|
summary_total = db.scalar(select(func.count(Paper.id))) or 0
|
||||||
summary_done = db.scalar(
|
summary_done = (
|
||||||
select(func.count(SummaryStatus.id)).where(SummaryStatus.status == SummaryState.DONE)
|
db.scalar(
|
||||||
) or 0
|
select(func.count(SummaryStatus.id)).where(
|
||||||
summary_pending = db.scalar(
|
SummaryStatus.status == SummaryState.DONE
|
||||||
select(func.count(SummaryStatus.id)).where(
|
)
|
||||||
SummaryStatus.status.in_([SummaryState.PENDING, SummaryState.PROCESSING])
|
|
||||||
)
|
)
|
||||||
) or 0
|
or 0
|
||||||
summary_failed = db.scalar(
|
)
|
||||||
select(func.count(SummaryStatus.id)).where(
|
summary_pending = (
|
||||||
SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE])
|
db.scalar(
|
||||||
|
select(func.count(SummaryStatus.id)).where(
|
||||||
|
SummaryStatus.status.in_(
|
||||||
|
[SummaryState.PENDING, SummaryState.PROCESSING]
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
) or 0
|
or 0
|
||||||
|
)
|
||||||
|
summary_failed = (
|
||||||
|
db.scalar(
|
||||||
|
select(func.count(SummaryStatus.id)).where(
|
||||||
|
SummaryStatus.status.in_(
|
||||||
|
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
@@ -414,13 +408,8 @@ async def admin_summary_status(
|
|||||||
else:
|
else:
|
||||||
query = query.where(SummaryStatus.status == status)
|
query = query.where(SummaryStatus.status == status)
|
||||||
|
|
||||||
total = db.scalar(
|
total = db.scalar(select(func.count()).select_from(query.subquery()))
|
||||||
select(func.count()).select_from(query.subquery())
|
results = db.execute(query.offset((page - 1) * per_page).limit(per_page)).all()
|
||||||
)
|
|
||||||
results = (
|
|
||||||
db.execute(query.offset((page - 1) * per_page).limit(per_page))
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 判断是否 HTMX 请求
|
# 判断是否 HTMX 请求
|
||||||
is_htmx = request.headers.get("HX-Request") == "true"
|
is_htmx = request.headers.get("HX-Request") == "true"
|
||||||
@@ -465,7 +454,11 @@ async def admin_summary_retry_failed(
|
|||||||
db.execute(
|
db.execute(
|
||||||
select(Paper.arxiv_id)
|
select(Paper.arxiv_id)
|
||||||
.join(SummaryStatus, SummaryStatus.paper_id == Paper.id)
|
.join(SummaryStatus, SummaryStatus.paper_id == Paper.id)
|
||||||
.where(SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]))
|
.where(
|
||||||
|
SummaryStatus.status.in_(
|
||||||
|
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
.scalars()
|
.scalars()
|
||||||
.all()
|
.all()
|
||||||
@@ -477,7 +470,11 @@ async def admin_summary_retry_failed(
|
|||||||
# 重置失败任务的状态为 pending
|
# 重置失败任务的状态为 pending
|
||||||
db.execute(
|
db.execute(
|
||||||
SummaryStatus.__table__.update()
|
SummaryStatus.__table__.update()
|
||||||
.where(SummaryStatus.status.in_([SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]))
|
.where(
|
||||||
|
SummaryStatus.status.in_(
|
||||||
|
[SummaryState.FAILED, SummaryState.PERMANENT_FAILURE]
|
||||||
|
)
|
||||||
|
)
|
||||||
.values(status=SummaryState.PENDING, error=None, error_type=None)
|
.values(status=SummaryState.PENDING, error=None, error_type=None)
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -492,15 +489,6 @@ async def admin_summary_retry_failed(
|
|||||||
# ── 论文管理 ────────────────────────────────────────────────────────
|
# ── 论文管理 ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
# 排序映射
|
|
||||||
_SORT_MAP = {
|
|
||||||
"date_desc": Paper.paper_date.desc(),
|
|
||||||
"date_asc": Paper.paper_date.asc(),
|
|
||||||
"upvotes_desc": Paper.upvotes.desc(),
|
|
||||||
"title_asc": Paper.title_en.asc(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/papers")
|
@router.get("/papers")
|
||||||
async def admin_papers(
|
async def admin_papers(
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -516,66 +504,18 @@ async def admin_papers(
|
|||||||
per_page: int = Query(20, ge=1, le=100),
|
per_page: int = Query(20, ge=1, le=100),
|
||||||
):
|
):
|
||||||
"""论文管理列表页面。"""
|
"""论文管理列表页面。"""
|
||||||
query = select(Paper)
|
papers, total, statuses = admin_svc.query_papers(
|
||||||
|
db,
|
||||||
# 搜索
|
q=q,
|
||||||
if q.strip():
|
date_from=date_from,
|
||||||
query = query.where(
|
date_to=date_to,
|
||||||
Paper.title_en.ilike(f"%{q}%")
|
tag=tag,
|
||||||
| Paper.title_zh.ilike(f"%{q}%")
|
summary_status=summary_status,
|
||||||
| Paper.abstract.ilike(f"%{q}%")
|
sort=sort,
|
||||||
)
|
page=page,
|
||||||
|
per_page=per_page,
|
||||||
# 日期范围
|
|
||||||
if date_from:
|
|
||||||
query = query.where(Paper.paper_date >= date_from)
|
|
||||||
if date_to:
|
|
||||||
query = query.where(Paper.paper_date <= date_to)
|
|
||||||
|
|
||||||
# 标签筛选
|
|
||||||
if tag:
|
|
||||||
query = query.join(PaperTag, PaperTag.paper_id == Paper.id).where(
|
|
||||||
PaperTag.tag == tag
|
|
||||||
)
|
|
||||||
|
|
||||||
# 总结状态筛选
|
|
||||||
if summary_status != "all":
|
|
||||||
if summary_status == "none":
|
|
||||||
query = query.outerjoin(
|
|
||||||
SummaryStatus, SummaryStatus.paper_id == Paper.id
|
|
||||||
).where(SummaryStatus.paper_id == None) # noqa: E711
|
|
||||||
else:
|
|
||||||
query = query.join(
|
|
||||||
SummaryStatus, SummaryStatus.paper_id == Paper.id
|
|
||||||
).where(SummaryStatus.status == summary_status)
|
|
||||||
|
|
||||||
# 排序
|
|
||||||
order = _SORT_MAP.get(sort, Paper.paper_date.desc())
|
|
||||||
query = query.order_by(order)
|
|
||||||
|
|
||||||
# 计数
|
|
||||||
total = db.scalar(select(func.count()).select_from(query.subquery()))
|
|
||||||
|
|
||||||
# 分页
|
|
||||||
papers = (
|
|
||||||
db.execute(query.offset((page - 1) * per_page).limit(per_page))
|
|
||||||
.scalars()
|
|
||||||
.all()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取每篇论文的总结状态
|
|
||||||
paper_ids = [p.id for p in papers]
|
|
||||||
statuses = {}
|
|
||||||
if paper_ids:
|
|
||||||
rows = db.execute(
|
|
||||||
select(SummaryStatus.paper_id, SummaryStatus.status).where(
|
|
||||||
SummaryStatus.paper_id.in_(paper_ids)
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
paper_id_to_arxiv = {p.id: p.arxiv_id for p in papers}
|
|
||||||
for pid, st in rows:
|
|
||||||
statuses[paper_id_to_arxiv.get(pid, "")] = st
|
|
||||||
|
|
||||||
# 构建分页 URL 辅助函数
|
# 构建分页 URL 辅助函数
|
||||||
def pagination_url(p: int) -> str:
|
def pagination_url(p: int) -> str:
|
||||||
params = dict(request.query_params)
|
params = dict(request.query_params)
|
||||||
@@ -588,7 +528,7 @@ async def admin_papers(
|
|||||||
{
|
{
|
||||||
"papers": papers,
|
"papers": papers,
|
||||||
"paper_summary_statuses": statuses,
|
"paper_summary_statuses": statuses,
|
||||||
"total": total or 0,
|
"total": total,
|
||||||
"page": page,
|
"page": page,
|
||||||
"per_page": per_page,
|
"per_page": per_page,
|
||||||
"current_status": summary_status,
|
"current_status": summary_status,
|
||||||
@@ -615,7 +555,9 @@ async def admin_paper_delete(
|
|||||||
|
|
||||||
# 清理 FTS 索引
|
# 清理 FTS 索引
|
||||||
try:
|
try:
|
||||||
db.execute(text("DELETE FROM papers_fts WHERE arxiv_id = :aid"), {"aid": arxiv_id})
|
db.execute(
|
||||||
|
text("DELETE FROM papers_fts WHERE arxiv_id = :aid"), {"aid": arxiv_id}
|
||||||
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to clean FTS index for %s", arxiv_id, exc_info=True)
|
logger.warning("Failed to clean FTS index for %s", arxiv_id, exc_info=True)
|
||||||
@@ -646,9 +588,11 @@ async def admin_papers_batch_action(
|
|||||||
raise HTTPException(status_code=400, detail="arxiv_ids 不能为空")
|
raise HTTPException(status_code=400, detail="arxiv_ids 不能为空")
|
||||||
|
|
||||||
if body.action == "delete":
|
if body.action == "delete":
|
||||||
papers = db.execute(
|
papers = (
|
||||||
select(Paper).where(Paper.arxiv_id.in_(body.arxiv_ids))
|
db.execute(select(Paper).where(Paper.arxiv_id.in_(body.arxiv_ids)))
|
||||||
).scalars().all()
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
for paper in papers:
|
for paper in papers:
|
||||||
@@ -658,21 +602,27 @@ async def admin_papers_batch_action(
|
|||||||
|
|
||||||
# 清理 FTS 索引
|
# 清理 FTS 索引
|
||||||
try:
|
try:
|
||||||
db.execute(
|
stmt = text("DELETE FROM papers_fts WHERE arxiv_id IN :ids").bindparams(
|
||||||
text("DELETE FROM papers_fts WHERE arxiv_id IN :ids"),
|
bindparam("ids", expanding=True)
|
||||||
{"ids": tuple(body.arxiv_ids)},
|
|
||||||
)
|
)
|
||||||
|
db.execute(stmt, {"ids": body.arxiv_ids})
|
||||||
db.commit()
|
db.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to clean FTS index for batch delete", exc_info=True)
|
logger.warning("Failed to clean FTS index for batch delete", exc_info=True)
|
||||||
|
|
||||||
return {"status": "success", "message": f"已删除 {count} 篇论文", "count": count}
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message": f"已删除 {count} 篇论文",
|
||||||
|
"count": count,
|
||||||
|
}
|
||||||
|
|
||||||
elif body.action == "summarize":
|
elif body.action == "summarize":
|
||||||
# 将选中论文的总结状态重置为 pending
|
# 将选中论文的总结状态重置为 pending
|
||||||
paper_ids = db.execute(
|
paper_ids = (
|
||||||
select(Paper.id).where(Paper.arxiv_id.in_(body.arxiv_ids))
|
db.execute(select(Paper.id).where(Paper.arxiv_id.in_(body.arxiv_ids)))
|
||||||
).scalars().all()
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
if paper_ids:
|
if paper_ids:
|
||||||
# 删除旧的 status 记录让其重新进入 pipeline
|
# 删除旧的 status 记录让其重新进入 pipeline
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ from app.utils import templates
|
|||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
MAX_COMPARE_PAPERS = 5
|
||||||
|
|
||||||
|
|
||||||
@router.get("/compare")
|
@router.get("/compare")
|
||||||
def compare_page(
|
def compare_page(
|
||||||
@@ -33,9 +35,9 @@ def compare_page(
|
|||||||
|
|
||||||
arxiv_ids = [i.strip() for i in ids.split(",") if i.strip()]
|
arxiv_ids = [i.strip() for i in ids.split(",") if i.strip()]
|
||||||
|
|
||||||
# 最多 5 篇
|
# 最多 MAX_COMPARE_PAPERS 篇
|
||||||
if len(arxiv_ids) > 5:
|
if len(arxiv_ids) > MAX_COMPARE_PAPERS:
|
||||||
arxiv_ids = arxiv_ids[:5]
|
arxiv_ids = arxiv_ids[:MAX_COMPARE_PAPERS]
|
||||||
|
|
||||||
if not arxiv_ids:
|
if not arxiv_ids:
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
|
|||||||
+2
-99
@@ -4,7 +4,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
from datetime import date, timedelta
|
from datetime import date, timedelta
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
@@ -15,6 +14,7 @@ from sqlalchemy.orm import Session, joinedload
|
|||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.models import PAPER_FULL_LOAD, Paper
|
from app.models import PAPER_FULL_LOAD, Paper
|
||||||
|
from app.services.pdf_image_extractor import link_figures_with_images
|
||||||
from app.utils import (
|
from app.utils import (
|
||||||
PAPERS_DIR,
|
PAPERS_DIR,
|
||||||
safe_json_loads,
|
safe_json_loads,
|
||||||
@@ -120,7 +120,7 @@ def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db))
|
|||||||
paper.summary.figures_json if paper.summary else None, default=[]
|
paper.summary.figures_json if paper.summary else None, default=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
linked_figures = _link_figures_with_images(figures_raw, images, arxiv_id)
|
linked_figures = link_figures_with_images(figures_raw, images, arxiv_id)
|
||||||
|
|
||||||
# 拆分图片到对应展示区域:
|
# 拆分图片到对应展示区域:
|
||||||
# table_figures → 实验结果区域(Table 截图,不变)
|
# table_figures → 实验结果区域(Table 截图,不变)
|
||||||
@@ -279,100 +279,3 @@ def _get_paper_images(arxiv_id: str) -> list[dict]:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
def _link_figures_with_images(
|
|
||||||
figures: list[dict], images: list[dict], arxiv_id: str
|
|
||||||
) -> list[dict]:
|
|
||||||
"""将 summary figures 元数据与提取的图片文件关联。
|
|
||||||
|
|
||||||
策略:
|
|
||||||
1. 优先用 manifest.json 的 label 做 ID 精确匹配
|
|
||||||
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
|
|
||||||
"""
|
|
||||||
if not figures or not images:
|
|
||||||
return figures
|
|
||||||
|
|
||||||
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
|
|
||||||
|
|
||||||
# ── 策略 1:manifest ID 精确匹配 ──
|
|
||||||
id_to_url: dict[str, str] = {}
|
|
||||||
if manifest_path.exists():
|
|
||||||
try:
|
|
||||||
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
manifest = {}
|
|
||||||
for filename, info in manifest.items():
|
|
||||||
url = f"/papers/{arxiv_id}/images/{filename}"
|
|
||||||
# 优先用 label 字段(新格式)
|
|
||||||
label = info.get("label", "")
|
|
||||||
if label:
|
|
||||||
id_to_url[label] = url
|
|
||||||
# 也兼容 figures/tables 列表(旧格式)
|
|
||||||
for fig_id in info.get("figures", []) + info.get("tables", []):
|
|
||||||
if fig_id not in id_to_url:
|
|
||||||
id_to_url[fig_id] = url
|
|
||||||
|
|
||||||
for fig in figures:
|
|
||||||
raw_id = fig.get("id", "")
|
|
||||||
normalized = _normalize_figure_id(raw_id)
|
|
||||||
if normalized in id_to_url:
|
|
||||||
fig["image_url"] = id_to_url[normalized]
|
|
||||||
|
|
||||||
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
|
|
||||||
unmatched = [f for f in figures if not f.get("image_url")]
|
|
||||||
if not unmatched:
|
|
||||||
return figures
|
|
||||||
|
|
||||||
# 按类型分流:Figure vs Table
|
|
||||||
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
|
|
||||||
table_type_unmatched = [
|
|
||||||
f for f in unmatched if not _is_figure_type(f.get("id", ""))
|
|
||||||
]
|
|
||||||
|
|
||||||
# 提取的图片按类型分流,按文件名中的编号排序
|
|
||||||
def _sort_key(name: str) -> tuple[int, int]:
|
|
||||||
# 新格式:figure_1.jpg, table_1.jpg
|
|
||||||
m = re.search(r"(?:figure|table)_(\d+)", name)
|
|
||||||
if m:
|
|
||||||
return (0, int(m.group(1)))
|
|
||||||
# 旧格式:page2_img1.png, page5_table1.png, figure_1.png
|
|
||||||
m2 = re.search(r"page(\d+)_(?:img|table)(\d+)", name)
|
|
||||||
if m2:
|
|
||||||
return (int(m2.group(1)), int(m2.group(2)))
|
|
||||||
return (0, 0)
|
|
||||||
|
|
||||||
fig_images = sorted(
|
|
||||||
[img for img in images if "table" not in img["name"].lower()],
|
|
||||||
key=lambda img: _sort_key(img["name"]),
|
|
||||||
)
|
|
||||||
table_images = sorted(
|
|
||||||
[img for img in images if "table" in img["name"].lower()],
|
|
||||||
key=lambda img: _sort_key(img["name"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, fig in enumerate(fig_type_unmatched):
|
|
||||||
if i < len(fig_images):
|
|
||||||
fig["image_url"] = fig_images[i]["url"]
|
|
||||||
|
|
||||||
for i, fig in enumerate(table_type_unmatched):
|
|
||||||
if i < len(table_images):
|
|
||||||
fig["image_url"] = table_images[i]["url"]
|
|
||||||
|
|
||||||
return figures
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_figure_id(raw_id: str) -> str:
|
|
||||||
"""归一化 Figure/Table ID:'Figure 1'/'Fig.1' → 'Figure 1'。"""
|
|
||||||
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
|
|
||||||
if m:
|
|
||||||
return f"Figure {m.group(1)}"
|
|
||||||
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
|
|
||||||
if m2:
|
|
||||||
return f"Table {m2.group(1)}"
|
|
||||||
return raw_id
|
|
||||||
|
|
||||||
|
|
||||||
def _is_figure_type(fig_id: str) -> bool:
|
|
||||||
"""判断是否为 Figure 类型(非 Table)。"""
|
|
||||||
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
|
|
||||||
|
|||||||
+5
-23
@@ -2,12 +2,13 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, Request
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
from app.exceptions import NotFoundError
|
||||||
from app.services.user_data import (
|
from app.services.user_data import (
|
||||||
get_note,
|
get_note,
|
||||||
save_note,
|
save_note,
|
||||||
@@ -37,9 +38,6 @@ def bookmark_toggle(arxiv_id: str, request: Request, db: Session = Depends(get_d
|
|||||||
"""切换收藏状态。支持 HTMX 局部刷新和 JSON 响应。"""
|
"""切换收藏状态。支持 HTMX 局部刷新和 JSON 响应。"""
|
||||||
result = toggle_bookmark(db, arxiv_id)
|
result = toggle_bookmark(db, arxiv_id)
|
||||||
|
|
||||||
if "error" in result:
|
|
||||||
raise HTTPException(status_code=404, detail=result["error"])
|
|
||||||
|
|
||||||
# HTMX 请求 → 返回 HTML 片段
|
# HTMX 请求 → 返回 HTML 片段
|
||||||
if request.headers.get("HX-Request"):
|
if request.headers.get("HX-Request"):
|
||||||
star = "★" if result["bookmarked"] else "☆"
|
star = "★" if result["bookmarked"] else "☆"
|
||||||
@@ -66,18 +64,7 @@ def reading_status_update(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""更新阅读状态。"""
|
"""更新阅读状态。"""
|
||||||
result = set_reading_status(db, arxiv_id, body.status)
|
return set_reading_status(db, arxiv_id, body.status)
|
||||||
|
|
||||||
if "error" in result:
|
|
||||||
if result["error"] == "not_found":
|
|
||||||
raise HTTPException(status_code=404, detail="Paper not found")
|
|
||||||
elif result["error"] == "invalid_status":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=422,
|
|
||||||
detail=f"Invalid status. Valid: {result['valid']}",
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# ── 笔记 ──────────────────────────────────────────────────────────────
|
# ── 笔记 ──────────────────────────────────────────────────────────────
|
||||||
@@ -88,16 +75,11 @@ def note_get(arxiv_id: str, db: Session = Depends(get_db)):
|
|||||||
"""获取笔记。"""
|
"""获取笔记。"""
|
||||||
result = get_note(db, arxiv_id)
|
result = get_note(db, arxiv_id)
|
||||||
if result is None:
|
if result is None:
|
||||||
raise HTTPException(status_code=404, detail="Paper not found")
|
raise NotFoundError(f"Paper not found: {arxiv_id}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.post("/note/{arxiv_id}")
|
@router.post("/note/{arxiv_id}")
|
||||||
def note_save(arxiv_id: str, body: NoteRequest, db: Session = Depends(get_db)):
|
def note_save(arxiv_id: str, body: NoteRequest, db: Session = Depends(get_db)):
|
||||||
"""保存笔记。"""
|
"""保存笔记。"""
|
||||||
result = save_note(db, arxiv_id, body.content)
|
return save_note(db, arxiv_id, body.content)
|
||||||
|
|
||||||
if "error" in result:
|
|
||||||
raise HTTPException(status_code=404, detail=result["error"])
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|||||||
+94
-12
@@ -9,10 +9,18 @@ from sqlalchemy import func, select, text
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.models import CrawlLog, Paper, SummaryState, TaskLock
|
from app.models import CrawlLog, Paper, PaperTag, SummaryState, SummaryStatus, TaskLock
|
||||||
from app.services.scheduler import get_scheduler
|
from app.services.scheduler import get_scheduler
|
||||||
from app.utils import PAPERS_DIR, TMP_DIR
|
from app.utils import PAPERS_DIR, TMP_DIR
|
||||||
|
|
||||||
|
# admin_papers 排序映射
|
||||||
|
SORT_MAP = {
|
||||||
|
"date_desc": Paper.paper_date.desc(),
|
||||||
|
"date_asc": Paper.paper_date.asc(),
|
||||||
|
"upvotes_desc": Paper.upvotes.desc(),
|
||||||
|
"title_asc": Paper.title_en.asc(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _dir_size(path: Path) -> int:
|
def _dir_size(path: Path) -> int:
|
||||||
"""递归计算目录总字节数。"""
|
"""递归计算目录总字节数。"""
|
||||||
@@ -52,7 +60,11 @@ def get_admin_stats(db: Session) -> dict:
|
|||||||
status_counts = {row[0]: row[1] for row in summary_rows}
|
status_counts = {row[0]: row[1] for row in summary_rows}
|
||||||
|
|
||||||
# ── 存储概况 ──────────────────────────────────────────────────────
|
# ── 存储概况 ──────────────────────────────────────────────────────
|
||||||
db_size = _fmt_size(settings.db_path.stat().st_size) if settings.db_path.exists() else "0 B"
|
db_size = (
|
||||||
|
_fmt_size(settings.db_path.stat().st_size)
|
||||||
|
if settings.db_path.exists()
|
||||||
|
else "0 B"
|
||||||
|
)
|
||||||
papers_size = _fmt_size(_dir_size(PAPERS_DIR))
|
papers_size = _fmt_size(_dir_size(PAPERS_DIR))
|
||||||
tmp_size = _fmt_size(_dir_size(TMP_DIR))
|
tmp_size = _fmt_size(_dir_size(TMP_DIR))
|
||||||
|
|
||||||
@@ -68,22 +80,14 @@ def get_admin_stats(db: Session) -> dict:
|
|||||||
|
|
||||||
# ── 最近日志(5 条) ──────────────────────────────────────────────
|
# ── 最近日志(5 条) ──────────────────────────────────────────────
|
||||||
recent_logs = (
|
recent_logs = (
|
||||||
db.execute(
|
db.execute(select(CrawlLog).order_by(CrawlLog.started_at.desc()).limit(5))
|
||||||
select(CrawlLog)
|
|
||||||
.order_by(CrawlLog.started_at.desc())
|
|
||||||
.limit(5)
|
|
||||||
)
|
|
||||||
.scalars()
|
.scalars()
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── 活跃锁 ────────────────────────────────────────────────────────
|
# ── 活跃锁 ────────────────────────────────────────────────────────
|
||||||
active_locks = (
|
active_locks = (
|
||||||
db.execute(
|
db.execute(select(TaskLock).where(TaskLock.status == "running")).scalars().all()
|
||||||
select(TaskLock).where(TaskLock.status == "running")
|
|
||||||
)
|
|
||||||
.scalars()
|
|
||||||
.all()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -108,3 +112,81 @@ def get_admin_stats(db: Session) -> dict:
|
|||||||
"active_locks": active_locks,
|
"active_locks": active_locks,
|
||||||
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
|
"upvote_refresh_days": settings.UPVOTE_REFRESH_DAYS,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def query_papers(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
q: str = "",
|
||||||
|
date_from: str | None = None,
|
||||||
|
date_to: str | None = None,
|
||||||
|
tag: str = "",
|
||||||
|
summary_status: str = "all",
|
||||||
|
sort: str = "date_desc",
|
||||||
|
page: int = 1,
|
||||||
|
per_page: int = 20,
|
||||||
|
) -> tuple[list[Paper], int, dict[str, str]]:
|
||||||
|
"""论文管理查询 — 构建过滤、排序、分页。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(papers, total, statuses) — 论文列表、总数、{arxiv_id: summary_status}
|
||||||
|
"""
|
||||||
|
query = select(Paper)
|
||||||
|
|
||||||
|
# 搜索
|
||||||
|
if q.strip():
|
||||||
|
query = query.where(
|
||||||
|
Paper.title_en.ilike(f"%{q}%")
|
||||||
|
| Paper.title_zh.ilike(f"%{q}%")
|
||||||
|
| Paper.abstract.ilike(f"%{q}%")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 日期范围
|
||||||
|
if date_from:
|
||||||
|
query = query.where(Paper.paper_date >= date_from)
|
||||||
|
if date_to:
|
||||||
|
query = query.where(Paper.paper_date <= date_to)
|
||||||
|
|
||||||
|
# 标签筛选
|
||||||
|
if tag:
|
||||||
|
query = query.join(PaperTag, PaperTag.paper_id == Paper.id).where(
|
||||||
|
PaperTag.tag == tag
|
||||||
|
)
|
||||||
|
|
||||||
|
# 总结状态筛选
|
||||||
|
if summary_status != "all":
|
||||||
|
if summary_status == "none":
|
||||||
|
query = query.outerjoin(
|
||||||
|
SummaryStatus, SummaryStatus.paper_id == Paper.id
|
||||||
|
).where(SummaryStatus.paper_id == None) # noqa: E711
|
||||||
|
else:
|
||||||
|
query = query.join(SummaryStatus, SummaryStatus.paper_id == Paper.id).where(
|
||||||
|
SummaryStatus.status == summary_status
|
||||||
|
)
|
||||||
|
|
||||||
|
# 排序
|
||||||
|
order = SORT_MAP.get(sort, Paper.paper_date.desc())
|
||||||
|
query = query.order_by(order)
|
||||||
|
|
||||||
|
# 计数
|
||||||
|
total = db.scalar(select(func.count()).select_from(query.subquery()))
|
||||||
|
|
||||||
|
# 分页
|
||||||
|
papers = (
|
||||||
|
db.execute(query.offset((page - 1) * per_page).limit(per_page)).scalars().all()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 每篇论文的总结状态
|
||||||
|
paper_ids = [p.id for p in papers]
|
||||||
|
statuses: dict[str, str] = {}
|
||||||
|
if paper_ids:
|
||||||
|
rows = db.execute(
|
||||||
|
select(SummaryStatus.paper_id, SummaryStatus.status).where(
|
||||||
|
SummaryStatus.paper_id.in_(paper_ids)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
paper_id_to_arxiv = {p.id: p.arxiv_id for p in papers}
|
||||||
|
for pid, st in rows:
|
||||||
|
statuses[paper_id_to_arxiv.get(pid, "")] = st
|
||||||
|
|
||||||
|
return papers, total or 0, statuses
|
||||||
|
|||||||
@@ -207,11 +207,14 @@ async def delete_papers_by_date_range(
|
|||||||
completed_at=utc_now(),
|
completed_at=utc_now(),
|
||||||
papers_found=total,
|
papers_found=total,
|
||||||
papers_new=deleted,
|
papers_new=deleted,
|
||||||
details_json=json.dumps({
|
details_json=json.dumps(
|
||||||
"total_before": total,
|
{
|
||||||
"deleted": deleted,
|
"total_before": total,
|
||||||
"failed": len(failed_items),
|
"deleted": deleted,
|
||||||
}, ensure_ascii=False),
|
"failed": len(failed_items),
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
),
|
||||||
error=job_error,
|
error=job_error,
|
||||||
)
|
)
|
||||||
db.add(log_entry)
|
db.add(log_entry)
|
||||||
|
|||||||
@@ -189,11 +189,15 @@ def index_paper(paper_id: str, texts_dict: dict | None = None) -> bool:
|
|||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
paper = db.execute(
|
paper = (
|
||||||
select(Paper)
|
db.execute(
|
||||||
.where(Paper.arxiv_id == paper_id)
|
select(Paper)
|
||||||
.options(joinedload(Paper.tags), joinedload(Paper.summary))
|
.where(Paper.arxiv_id == paper_id)
|
||||||
).unique().scalar_one_or_none()
|
.options(joinedload(Paper.tags), joinedload(Paper.summary))
|
||||||
|
)
|
||||||
|
.unique()
|
||||||
|
.scalar_one_or_none()
|
||||||
|
)
|
||||||
if not paper:
|
if not paper:
|
||||||
logger.warning("Paper %s not found for indexing", paper_id)
|
logger.warning("Paper %s not found for indexing", paper_id)
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -0,0 +1,174 @@
|
|||||||
|
"""PicoDet-S_layout_3cls 布局检测 — 纯 ONNX Runtime 推理.
|
||||||
|
|
||||||
|
用 onnxruntime 加载导出好的 ONNX 模型,检测 PDF 页面中的 figure / table 区域。
|
||||||
|
模型自带 NMS + GFL decode,输出即为后处理完毕的检测框。
|
||||||
|
|
||||||
|
输入:
|
||||||
|
image: (1, 3, 480, 480) float32 — ImageNet 标准化后的图片
|
||||||
|
scale_factor: (1, 2) float32 — [y_scale, x_scale],用于坐标还原
|
||||||
|
|
||||||
|
输出:
|
||||||
|
fetch_name_0: (N, 6) float32 — [xmin, ymin, xmax, ymax, score, class_id]
|
||||||
|
fetch_name_1: (1,) int32 — 有效框数量 N
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import pymupdf
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 模型输入尺寸
|
||||||
|
_MODEL_SIZE = 480
|
||||||
|
# ImageNet normalize
|
||||||
|
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||||||
|
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||||
|
# PicoDet label → 内部 boxclass
|
||||||
|
_LABEL_MAP: dict[int, str] = {
|
||||||
|
0: "picture", # PicoDet "image" → "picture"
|
||||||
|
1: "table",
|
||||||
|
# 2: seal — 忽略
|
||||||
|
}
|
||||||
|
# 最小 bbox 尺寸(PDF 点)
|
||||||
|
_MIN_BOX_SIZE = 20
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LayoutBox:
|
||||||
|
"""检测到的布局区域,兼容现有 _process_page 代码。"""
|
||||||
|
|
||||||
|
x0: float
|
||||||
|
y0: float
|
||||||
|
x1: float
|
||||||
|
y1: float
|
||||||
|
boxclass: str # "picture" | "table"
|
||||||
|
|
||||||
|
|
||||||
|
class _LayoutDetector:
|
||||||
|
"""单例:管理 ONNX InferenceSession 生命周期。"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._session: ort.InferenceSession | None = None
|
||||||
|
|
||||||
|
def _init_session(self) -> ort.InferenceSession:
|
||||||
|
if self._session is not None:
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
model_path = Path(settings.LAYOUT_MODEL_PATH)
|
||||||
|
if not model_path.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Layout model not found: {model_path}. "
|
||||||
|
"Run scripts/export_picodet_onnx.py first."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Loading ONNX layout model: %s", model_path)
|
||||||
|
self._session = ort.InferenceSession(
|
||||||
|
str(model_path), providers=["CPUExecutionProvider"]
|
||||||
|
)
|
||||||
|
logger.info("ONNX layout model loaded")
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
def detect_page(self, page: pymupdf.Page) -> list[LayoutBox]:
|
||||||
|
"""检测单页 PDF 的 figure / table 区域。
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. pymupdf 以 480×480 渲染页面
|
||||||
|
2. ImageNet normalize → NCHW
|
||||||
|
3. ONNX 推理 → 得到已解码+NMS 的检测框
|
||||||
|
4. 像素坐标 → PDF 点坐标
|
||||||
|
5. 过滤 seal 类和低置信度框
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: pymupdf Page 对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LayoutBox 列表,坐标为 PDF 点
|
||||||
|
"""
|
||||||
|
session = self._init_session()
|
||||||
|
|
||||||
|
page_w = page.rect.width
|
||||||
|
page_h = page.rect.height
|
||||||
|
|
||||||
|
# 1. 渲染页面到 _MODEL_SIZE × _MODEL_SIZE
|
||||||
|
zoom_x = _MODEL_SIZE / page_w
|
||||||
|
zoom_y = _MODEL_SIZE / page_h
|
||||||
|
mat = pymupdf.Matrix(zoom_x, zoom_y)
|
||||||
|
pix = page.get_pixmap(matrix=mat)
|
||||||
|
|
||||||
|
# 2. 预处理
|
||||||
|
img = (
|
||||||
|
np.frombuffer(pix.samples, dtype=np.uint8)
|
||||||
|
.reshape(pix.height, pix.width, pix.n)
|
||||||
|
.astype(np.float32)
|
||||||
|
/ 255.0
|
||||||
|
)
|
||||||
|
# 去掉 alpha 通道(如有)
|
||||||
|
if img.shape[2] == 4:
|
||||||
|
img = img[:, :, :3]
|
||||||
|
img = (img - _MEAN) / _STD
|
||||||
|
img = img.transpose(2, 0, 1)[np.newaxis] # (1, 3, H, W)
|
||||||
|
|
||||||
|
# scale_factor 用于坐标还原(模型内部可能用)
|
||||||
|
scale_factor = np.array([[1.0, 1.0]], dtype=np.float32)
|
||||||
|
|
||||||
|
# 3. 推理
|
||||||
|
input_names = [i.name for i in session.get_inputs()]
|
||||||
|
feed = {input_names[0]: img}
|
||||||
|
if len(input_names) > 1:
|
||||||
|
feed[input_names[1]] = scale_factor
|
||||||
|
|
||||||
|
outputs = session.run(None, feed)
|
||||||
|
boxes_raw = outputs[0] # (N, 6): [class_id, score, xmin, ymin, xmax, ymax]
|
||||||
|
num_boxes = int(outputs[1][0]) # 有效框数
|
||||||
|
|
||||||
|
if num_boxes == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 4. 像素 → PDF 点坐标
|
||||||
|
sx = page_w / _MODEL_SIZE
|
||||||
|
sy = page_h / _MODEL_SIZE
|
||||||
|
|
||||||
|
result: list[LayoutBox] = []
|
||||||
|
for i in range(min(num_boxes, len(boxes_raw))):
|
||||||
|
cls_id, score, xmin, ymin, xmax, ymax = boxes_raw[i]
|
||||||
|
cls_id = int(cls_id)
|
||||||
|
|
||||||
|
# 跳过 seal 类和低置信度
|
||||||
|
if cls_id not in _LABEL_MAP:
|
||||||
|
continue
|
||||||
|
if score < settings.LAYOUT_THRESHOLD:
|
||||||
|
continue
|
||||||
|
|
||||||
|
x0, y0 = xmin * sx, ymin * sy
|
||||||
|
x1, y1 = xmax * sx, ymax * sy
|
||||||
|
|
||||||
|
# 跳过极小区域
|
||||||
|
if (x1 - x0) < _MIN_BOX_SIZE or (y1 - y0) < _MIN_BOX_SIZE:
|
||||||
|
continue
|
||||||
|
|
||||||
|
result.append(
|
||||||
|
LayoutBox(x0=x0, y0=y0, x1=x1, y1=y1, boxclass=_LABEL_MAP[cls_id])
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# 模块级单例
|
||||||
|
_detector = _LayoutDetector()
|
||||||
|
|
||||||
|
|
||||||
|
def detect_page_layout(page: pymupdf.Page) -> list[LayoutBox]:
|
||||||
|
"""检测 PDF 页面中的 figure / table 区域。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LayoutBox 列表,坐标为 PDF 点,仅含 picture/table。
|
||||||
|
"""
|
||||||
|
return _detector.detect_page(page)
|
||||||
@@ -9,6 +9,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
from app.utils import PAPERS_DIR, TMP_DIR
|
from app.utils import PAPERS_DIR, TMP_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -51,6 +52,14 @@ def _get_session() -> requests.Session:
|
|||||||
return _http_session
|
return _http_session
|
||||||
|
|
||||||
|
|
||||||
|
def close_http_session() -> None:
|
||||||
|
"""关闭全局 HTTP Session,供应用 shutdown 时调用。"""
|
||||||
|
global _http_session
|
||||||
|
if _http_session is not None:
|
||||||
|
_http_session.close()
|
||||||
|
_http_session = None
|
||||||
|
|
||||||
|
|
||||||
async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
|
async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
|
||||||
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
|
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
|
||||||
if not pdf_url:
|
if not pdf_url:
|
||||||
@@ -62,10 +71,16 @@ async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
session = _get_session()
|
session = _get_session()
|
||||||
resp = session.get(pdf_url, timeout=120, allow_redirects=True)
|
resp = session.get(pdf_url, timeout=settings.PDF_DOWNLOAD_TIMEOUT, allow_redirects=True)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
dest.write_bytes(resp.content)
|
dest.write_bytes(resp.content)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
# 清理残留的部分文件
|
||||||
|
if dest.exists():
|
||||||
|
try:
|
||||||
|
dest.unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
|
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
|
||||||
|
|
||||||
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
|
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
|
||||||
|
|||||||
+383
-462
@@ -1,12 +1,12 @@
|
|||||||
"""PDF 图片与表格提取 — 基于 pymupdf4llm layout analysis。
|
"""PDF 图片与表格提取 — 两阶段流水线。
|
||||||
|
|
||||||
用 pymupdf4llm 的 layout analysis 检测 table / picture 区域,
|
Phase 1: PicoDet-S_layout_3cls 检测 figure/table 区域 → 渲染为 JPEG(通用标签)
|
||||||
再通过 caption 文字匹配确定 Figure/Table 编号,渲染为 JPEG。
|
Phase 2: 用 LLM summary 的 figures[].id 在 PDF 中搜索定位 → 匹配到 box → 重命名
|
||||||
|
|
||||||
相比旧方案(caption 正则 + pdfplumber/find_tables/文本块扫描三套策略):
|
相比旧方案(正则匹配 caption):
|
||||||
- layout analysis 直接给出区域 bbox,不存在相邻表格互相侵入的问题
|
- 不再依赖正则,用 LLM 输出的 ID 直接搜索 PDF 文本
|
||||||
- 无需手动调参(最大高度、间隙阈值等)
|
- page.search_for() 精确搜索 + 空间距离过滤,避免正文引用误匹配
|
||||||
- 页面级 caption 匹配:每个 caption 只分配给最近的 box,避免上下相邻表格抢夺同一个 caption
|
- 通用标签兜底,LLM 没提到的图表不会被丢弃
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -17,44 +17,30 @@ import re
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pymupdf
|
import pymupdf
|
||||||
import pymupdf4llm.helpers.document_layout as dl
|
|
||||||
|
|
||||||
|
from app.services.layout_detector import LayoutBox, detect_page_layout
|
||||||
from app.services.pdf_downloader import paper_dir
|
from app.services.pdf_downloader import paper_dir
|
||||||
from app.utils import TMP_DIR
|
from app.utils import PAPERS_DIR, TMP_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ── Caption 正则 ───────────────────────────────────────────────────────
|
# 截图区域的外边距(单位: pt)
|
||||||
|
|
||||||
# 用于从 caption 文字中提取 Figure/Table 编号
|
|
||||||
_FIGURE_CAPTION_RE = re.compile(
|
|
||||||
r"^(?:Fig\.?|Figure)\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))",
|
|
||||||
re.IGNORECASE,
|
|
||||||
)
|
|
||||||
_TABLE_CAPTION_RE = re.compile(
|
|
||||||
r"^Table\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))",
|
|
||||||
re.IGNORECASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# caption 与 table/picture 的最大匹配距离(点)
|
|
||||||
_CAPTION_MATCH_DISTANCE = 100
|
|
||||||
# 截图区域的外边距
|
|
||||||
_REGION_PADDING = 5
|
_REGION_PADDING = 5
|
||||||
# 3x 渲染,保证清晰度
|
# 渲染倍率(3x 保证清晰度)
|
||||||
_RENDER_ZOOM = 3
|
_RENDER_ZOOM = 3
|
||||||
# 相邻 box 聚类间距(点)— 同一 figure/table 的碎片间距通常 < 15pt
|
# 相邻 box 聚类间距(单位: pt)— 同一 figure/table 的碎片间距通常 < 15pt
|
||||||
_CLUSTER_GAP = 15
|
_CLUSTER_GAP = 15
|
||||||
|
# 最小 bbox 面积(单位: pt²)— 过滤 icon/logo 等微小误检
|
||||||
|
_MIN_BOX_AREA = 2000
|
||||||
|
# Phase 2: 搜索文本到 box 的最大匹配距离(单位: pt)
|
||||||
|
_LABEL_MATCH_DISTANCE = 100
|
||||||
|
|
||||||
|
|
||||||
# ── Box 聚类 ─────────────────────────────────────────────────────────
|
# ── Box 聚类 ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class _BoxCluster:
|
class _BoxCluster:
|
||||||
"""合并后的布局区域(由一个或多个相邻 LayoutBox 组成)。
|
"""合并后的布局区域(由一个或多个相邻 LayoutBox 组成)。"""
|
||||||
|
|
||||||
pymupdf4llm 有时将一个大图拆成多个小 picture box(如视频帧网格),
|
|
||||||
聚类后用整体 bbox 作为渲染区域。
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ("x0", "y0", "x1", "y1", "boxclass")
|
__slots__ = ("x0", "y0", "x1", "y1", "boxclass")
|
||||||
|
|
||||||
@@ -63,17 +49,12 @@ class _BoxCluster:
|
|||||||
self.y0 = min(b.y0 for b in boxes)
|
self.y0 = min(b.y0 for b in boxes)
|
||||||
self.x1 = max(b.x1 for b in boxes)
|
self.x1 = max(b.x1 for b in boxes)
|
||||||
self.y1 = max(b.y1 for b in boxes)
|
self.y1 = max(b.y1 for b in boxes)
|
||||||
# table-fallback 归一化为 table(layout model 检测到表格但无法提取结构)
|
|
||||||
raw = boxes[0].boxclass
|
raw = boxes[0].boxclass
|
||||||
self.boxclass = "table" if raw == "table-fallback" else raw
|
self.boxclass = "table" if raw == "table-fallback" else raw
|
||||||
|
|
||||||
|
|
||||||
def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
|
def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
|
||||||
"""将相邻的同类型 box 合并为聚类。
|
"""将相邻的同类型 box 合并为聚类。"""
|
||||||
|
|
||||||
用 union-find 将间距 ≤ gap 的同类型 box 归为一组,
|
|
||||||
每组生成一个 _BoxCluster(整体 bbox)。
|
|
||||||
"""
|
|
||||||
if not boxes:
|
if not boxes:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -111,242 +92,58 @@ def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
|
|||||||
return [_BoxCluster(members) for members in groups.values()]
|
return [_BoxCluster(members) for members in groups.values()]
|
||||||
|
|
||||||
|
|
||||||
# ── 页面级 Caption 查找与匹配 ──────────────────────────────────────────
|
# ── Phase 1: 检测 + 渲染 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _find_page_captions(page) -> list[dict]:
|
def _render_box(
|
||||||
"""查找页面上所有 Figure/Table caption 文字块。"""
|
|
||||||
blocks = page.get_text("blocks")
|
|
||||||
captions = []
|
|
||||||
for b in blocks:
|
|
||||||
if len(b) < 5:
|
|
||||||
continue
|
|
||||||
bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3]
|
|
||||||
text = str(b[4]).strip()
|
|
||||||
first_line = text.split("\n")[0].strip()
|
|
||||||
|
|
||||||
cap_type = None
|
|
||||||
m = _TABLE_CAPTION_RE.match(first_line)
|
|
||||||
if m:
|
|
||||||
cap_type = "table"
|
|
||||||
else:
|
|
||||||
m = _FIGURE_CAPTION_RE.match(first_line)
|
|
||||||
if m:
|
|
||||||
cap_type = "figure"
|
|
||||||
if m is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
captions.append(
|
|
||||||
{
|
|
||||||
"label": f"{'Table' if cap_type == 'table' else 'Figure'} {m.group(1)}",
|
|
||||||
"type": cap_type,
|
|
||||||
"caption_text": text,
|
|
||||||
"caption_y0": by0,
|
|
||||||
"caption_y1": by1,
|
|
||||||
"caption_x0": bx0,
|
|
||||||
"caption_x1": bx1,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return captions
|
|
||||||
|
|
||||||
|
|
||||||
def _vertical_distance(cap_y0, cap_y1, box_y0, box_y1) -> float | None:
|
|
||||||
"""计算 caption 到 box 的垂直距离。不邻接时返回 None。
|
|
||||||
|
|
||||||
三种情况:caption 完全在 box 上方、完全在下方、与 box 有垂直重叠。
|
|
||||||
重叠(含部分溢出)视为 distance=0,确保 caption 延伸到 box 边界外时不会丢失。
|
|
||||||
"""
|
|
||||||
# Caption 完全在 box 上方
|
|
||||||
if cap_y1 <= box_y0:
|
|
||||||
dist = box_y0 - cap_y1
|
|
||||||
return dist if dist <= _CAPTION_MATCH_DISTANCE else None
|
|
||||||
# Caption 完全在 box 下方
|
|
||||||
if cap_y0 >= box_y1:
|
|
||||||
dist = cap_y0 - box_y1
|
|
||||||
return dist if dist <= _CAPTION_MATCH_DISTANCE else None
|
|
||||||
# Caption 与 box 有垂直重叠(内部、部分溢出都算)→ 距离 0
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def _same_column(cap: dict, box, page_width: float) -> bool:
|
|
||||||
"""判断 caption 和 box 是否在同一列。
|
|
||||||
|
|
||||||
双栏论文中左右栏间距有限,简单的水平重叠检查会跨列匹配。
|
|
||||||
策略:用中心 X 坐标判断各自在哪半边,只有同半边才算同列。
|
|
||||||
跨栏图表(caption 或 box 宽度 >65% 页宽)不受此限制。
|
|
||||||
"""
|
|
||||||
cap_w = cap["caption_x1"] - cap["caption_x0"]
|
|
||||||
box_w = box.x1 - box.x0
|
|
||||||
|
|
||||||
# 跨栏元素:宽度超过页面的 65%
|
|
||||||
if cap_w > page_width * 0.65 or box_w > page_width * 0.65:
|
|
||||||
return True
|
|
||||||
|
|
||||||
cap_cx = (cap["caption_x0"] + cap["caption_x1"]) / 2
|
|
||||||
box_cx = (box.x0 + box.x1) / 2
|
|
||||||
mid = page_width / 2
|
|
||||||
|
|
||||||
# 同在左半边或同在右半边
|
|
||||||
return (cap_cx < mid) == (box_cx < mid)
|
|
||||||
|
|
||||||
|
|
||||||
def _match_captions_to_boxes(
|
|
||||||
page_boxes: list, captions: list[dict], page_width: float
|
|
||||||
) -> list[tuple[list[int], list[dict]]]:
|
|
||||||
"""将 caption 分配给 box,允许一个 caption 匹配多个同类型 box。
|
|
||||||
|
|
||||||
典型场景:
|
|
||||||
- Figure 由左右两个 picture box 组成,caption 同时靠近两者
|
|
||||||
- Table 的视觉内容被 layout analysis 误分类为 picture,需要跨类型匹配
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[(box_indices, captions), ...] 每组是一个独立的渲染任务
|
|
||||||
"""
|
|
||||||
# 每个 caption 找到所有距离在阈值内的 box
|
|
||||||
# 优先匹配同类型;如果找不到,再匹配任意 table/picture box
|
|
||||||
cap_to_boxes: dict[int, list[tuple[int, float]]] = {}
|
|
||||||
|
|
||||||
for ci, cap in enumerate(captions):
|
|
||||||
same_type: list[tuple[int, float]] = []
|
|
||||||
any_type: list[tuple[int, float]] = []
|
|
||||||
expected = "table" if cap["type"] == "table" else "picture"
|
|
||||||
|
|
||||||
for bi, box in enumerate(page_boxes):
|
|
||||||
# 列感知:双栏论文中只匹配同栏的 box
|
|
||||||
if not _same_column(cap, box, page_width):
|
|
||||||
continue
|
|
||||||
# 水平重叠检查(同列内仍需有重叠)
|
|
||||||
if not (
|
|
||||||
cap["caption_x1"] > box.x0 - 5 and cap["caption_x0"] < box.x1 + 5
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
dist = _vertical_distance(
|
|
||||||
cap["caption_y0"], cap["caption_y1"], box.y0, box.y1
|
|
||||||
)
|
|
||||||
if dist is None:
|
|
||||||
continue
|
|
||||||
entry = (bi, dist)
|
|
||||||
any_type.append(entry)
|
|
||||||
if box.boxclass == expected:
|
|
||||||
same_type.append(entry)
|
|
||||||
|
|
||||||
# 优先用同类型匹配;没有时回退到任意类型;都没有则跳过
|
|
||||||
if same_type:
|
|
||||||
cap_to_boxes[ci] = same_type
|
|
||||||
elif any_type:
|
|
||||||
cap_to_boxes[ci] = any_type
|
|
||||||
# else: 该 caption 无匹配 box,不加入 cap_to_boxes
|
|
||||||
|
|
||||||
# 每个 caption → 最近的 box(用于分组),但记录所有匹配的 box
|
|
||||||
cap_primary: dict[int, int] = {} # caption → primary box index
|
|
||||||
cap_all_boxes: dict[int, list[int]] = {} # caption → all matched box indices
|
|
||||||
for ci, matches in cap_to_boxes.items():
|
|
||||||
matches.sort(key=lambda x: x[1])
|
|
||||||
cap_primary[ci] = matches[0][0]
|
|
||||||
# 所有距离最近的同组 box(距离差 < 20pt 视为同一组)
|
|
||||||
best_dist = matches[0][1]
|
|
||||||
cap_all_boxes[ci] = [bi for bi, d in matches if d <= best_dist + 20]
|
|
||||||
|
|
||||||
# 按 primary box 分组
|
|
||||||
box_to_caps: dict[int, list[int]] = {}
|
|
||||||
for ci, bi in cap_primary.items():
|
|
||||||
box_to_caps.setdefault(bi, []).append(ci)
|
|
||||||
|
|
||||||
# 构建渲染组:每个 caption 独立成组(共享 box 但各自渲染)
|
|
||||||
# 同类型同 label 的 caption 会合并;不同类型则分开
|
|
||||||
used_captions: set[int] = set()
|
|
||||||
groups: list[tuple[list[int], list[dict]]] = []
|
|
||||||
|
|
||||||
for bi in sorted(box_to_caps.keys()):
|
|
||||||
cis = box_to_caps[bi]
|
|
||||||
for ci in cis:
|
|
||||||
if ci in used_captions:
|
|
||||||
continue
|
|
||||||
used_captions.add(ci)
|
|
||||||
|
|
||||||
all_box_indices = set(cap_all_boxes.get(ci, [bi]))
|
|
||||||
# 只合并同 label 的 caption(同 figure/table 的重复 caption)
|
|
||||||
merged_captions = [captions[ci]]
|
|
||||||
for other_bi in all_box_indices:
|
|
||||||
if other_bi in box_to_caps:
|
|
||||||
for other_ci in box_to_caps[other_bi]:
|
|
||||||
if other_ci not in used_captions:
|
|
||||||
other_cap = captions[other_ci]
|
|
||||||
if other_cap["label"] == captions[ci]["label"]:
|
|
||||||
used_captions.add(other_ci)
|
|
||||||
merged_captions.append(other_cap)
|
|
||||||
groups.append((sorted(all_box_indices), merged_captions))
|
|
||||||
|
|
||||||
return groups
|
|
||||||
|
|
||||||
|
|
||||||
# ── 单页处理 ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _render_and_save(
|
|
||||||
page,
|
page,
|
||||||
clip: pymupdf.Rect,
|
box: _BoxCluster,
|
||||||
images_dest: Path,
|
images_dest: Path,
|
||||||
manifest: dict,
|
filename: str,
|
||||||
label: str,
|
|
||||||
cap_type: str,
|
cap_type: str,
|
||||||
caption_text: str,
|
page_num: int,
|
||||||
page_num_1based: int,
|
|
||||||
arxiv_id: str,
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""渲染页面区域并保存 JPEG,写入 manifest。成功返回 True。"""
|
"""渲染单个 box 区域并保存 JPEG,成功返回 True。"""
|
||||||
|
page_width = page.rect.width
|
||||||
|
clip = pymupdf.Rect(
|
||||||
|
max(0, box.x0 - _REGION_PADDING),
|
||||||
|
max(0, box.y0 - _REGION_PADDING),
|
||||||
|
min(page_width, box.x1 + _REGION_PADDING),
|
||||||
|
box.y1 + _REGION_PADDING,
|
||||||
|
)
|
||||||
mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM)
|
mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM)
|
||||||
try:
|
try:
|
||||||
pix = page.get_pixmap(matrix=mat, clip=clip)
|
pix = page.get_pixmap(matrix=mat, clip=clip)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Failed to render %s for %s", label, arxiv_id)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
filename = f"{label.replace(' ', '_').lower()}.jpg"
|
|
||||||
(images_dest / filename).write_bytes(pix.tobytes("jpeg"))
|
(images_dest / filename).write_bytes(pix.tobytes("jpeg"))
|
||||||
|
|
||||||
manifest[filename] = {
|
|
||||||
"page": page_num_1based,
|
|
||||||
"type": cap_type,
|
|
||||||
"label": label,
|
|
||||||
"caption_text": caption_text[:200] if caption_text else "",
|
|
||||||
"figures" if cap_type == "figure" else "tables": [label],
|
|
||||||
}
|
|
||||||
logger.debug(
|
|
||||||
"Rendered %s: page %d, region (%.0f,%.0f)-(%.0f,%.0f) → %s",
|
|
||||||
label,
|
|
||||||
page_num_1based,
|
|
||||||
clip.x0,
|
|
||||||
clip.y0,
|
|
||||||
clip.x1,
|
|
||||||
clip.y1,
|
|
||||||
filename,
|
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _process_page(
|
def _process_page(
|
||||||
doc,
|
doc,
|
||||||
page_idx: int,
|
page_idx: int,
|
||||||
page_layout,
|
page_boxes: list[LayoutBox],
|
||||||
images_dest: Path,
|
images_dest: Path,
|
||||||
manifest: dict,
|
manifest: dict,
|
||||||
seen_labels: set,
|
seen_labels: set,
|
||||||
arxiv_id: str,
|
arxiv_id: str,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""处理单页:caption 匹配 + orphan 兜底,返回本页提取数量。"""
|
"""处理单页:检测 → 聚类 → 渲染,全部用通用标签。"""
|
||||||
page = doc[page_idx]
|
page = doc[page_idx]
|
||||||
page_width = page.rect.width
|
|
||||||
page_num = page_idx + 1
|
page_num = page_idx + 1
|
||||||
orphan_fig_counter = 0
|
fig_counter = 0
|
||||||
orphan_tbl_counter = 0
|
tbl_counter = 0
|
||||||
|
|
||||||
# 收集本页的 table/picture box(跳过极小区域)
|
# 收集本页的 table/picture box(跳过极小区域)
|
||||||
raw_boxes = []
|
raw_boxes = []
|
||||||
for box in page_layout.boxes:
|
for box in page_boxes:
|
||||||
if box.boxclass not in ("table", "table-fallback", "picture"):
|
if box.boxclass not in ("table", "table-fallback", "picture"):
|
||||||
continue
|
continue
|
||||||
if (box.x1 - box.x0) < 20 or (box.y1 - box.y0) < 20:
|
w = box.x1 - box.x0
|
||||||
|
h = box.y1 - box.y0
|
||||||
|
if w < 20 or h < 20 or w * h < _MIN_BOX_AREA:
|
||||||
continue
|
continue
|
||||||
raw_boxes.append(box)
|
raw_boxes.append(box)
|
||||||
|
|
||||||
@@ -354,153 +151,48 @@ def _process_page(
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
# 聚类:将同一 figure/table 的碎片 box 合并
|
# 聚类:将同一 figure/table 的碎片 box 合并
|
||||||
page_boxes = _cluster_boxes(raw_boxes)
|
clusters = _cluster_boxes(raw_boxes)
|
||||||
|
|
||||||
# 页面级匹配:查找所有 caption,分配给 box
|
|
||||||
captions = _find_page_captions(page)
|
|
||||||
groups = _match_captions_to_boxes(page_boxes, captions, page_width)
|
|
||||||
|
|
||||||
# 只合并同 label 的 group(同一个 figure/table 的重复 caption)
|
|
||||||
# 不同 label 的 group 即使共享 box 也不合并(如 Figure 7 和 Figure 8),
|
|
||||||
# 渲染时用 caption 位置切割区域
|
|
||||||
_merged_groups: set[int] = set()
|
|
||||||
merged_groups: list[tuple[list[int], list[dict]]] = []
|
|
||||||
for gi, (box_indices, caps) in enumerate(groups):
|
|
||||||
if gi in _merged_groups:
|
|
||||||
continue
|
|
||||||
this_labels = {c["label"] for c in caps}
|
|
||||||
all_box_set = set(box_indices)
|
|
||||||
merge_targets = {gi}
|
|
||||||
for other_gi, (other_bi, other_caps) in enumerate(groups):
|
|
||||||
if other_gi <= gi or other_gi in _merged_groups:
|
|
||||||
continue
|
|
||||||
other_labels = {c["label"] for c in other_caps}
|
|
||||||
# 只在 label 有交集时合并(同一个 figure/table)
|
|
||||||
if this_labels & other_labels and all_box_set & set(other_bi):
|
|
||||||
merge_targets.add(other_gi)
|
|
||||||
all_box_set |= set(other_bi)
|
|
||||||
all_caps = []
|
|
||||||
for mgi in sorted(merge_targets):
|
|
||||||
_merged_groups.add(mgi)
|
|
||||||
all_caps.extend(groups[mgi][1])
|
|
||||||
merged_groups.append((sorted(all_box_set), all_caps))
|
|
||||||
groups = merged_groups
|
|
||||||
|
|
||||||
# ── 阶段 1:渲染有 caption 匹配的图/表 ──
|
|
||||||
matched_box_indices: set[int] = set()
|
|
||||||
extracted = 0
|
extracted = 0
|
||||||
|
for cluster in clusters:
|
||||||
for box_indices, caps in groups:
|
cap_type = "figure" if cluster.boxclass == "picture" else "table"
|
||||||
matched_box_indices.update(box_indices)
|
|
||||||
|
|
||||||
# 去重同一 label,跳过已处理的
|
|
||||||
unique_caps = []
|
|
||||||
for cap in caps:
|
|
||||||
if cap["label"] not in seen_labels:
|
|
||||||
seen_labels.add(cap["label"])
|
|
||||||
unique_caps.append(cap)
|
|
||||||
if not unique_caps:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 合并所有关联 box 的 bbox
|
|
||||||
bx0 = min(page_boxes[i].x0 for i in box_indices)
|
|
||||||
by0 = min(page_boxes[i].y0 for i in box_indices)
|
|
||||||
bx1 = max(page_boxes[i].x1 for i in box_indices)
|
|
||||||
by1 = max(page_boxes[i].y1 for i in box_indices)
|
|
||||||
|
|
||||||
# 渲染区域:box + caption
|
|
||||||
all_cap_y0 = min(c["caption_y0"] for c in unique_caps)
|
|
||||||
all_cap_y1 = max(c["caption_y1"] for c in unique_caps)
|
|
||||||
all_cap_x0 = min(c["caption_x0"] for c in unique_caps)
|
|
||||||
all_cap_x1 = max(c["caption_x1"] for c in unique_caps)
|
|
||||||
|
|
||||||
top = max(0, min(by0, all_cap_y0) - _REGION_PADDING)
|
|
||||||
bottom = max(by1, all_cap_y1) + _REGION_PADDING
|
|
||||||
rx0 = max(0, min(bx0, all_cap_x0) - _REGION_PADDING)
|
|
||||||
rx1 = min(page_width, max(bx1, all_cap_x1) + _REGION_PADDING)
|
|
||||||
|
|
||||||
clip = pymupdf.Rect(rx0, top, rx1, bottom)
|
|
||||||
# 多个 caption 可能共享同一区域(如 subfigure),只需渲染一次
|
|
||||||
jpeg_bytes = None
|
|
||||||
for cap in unique_caps:
|
|
||||||
if jpeg_bytes is None:
|
|
||||||
if not _render_and_save(
|
|
||||||
page,
|
|
||||||
clip,
|
|
||||||
images_dest,
|
|
||||||
manifest,
|
|
||||||
cap["label"],
|
|
||||||
cap["type"],
|
|
||||||
cap["caption_text"],
|
|
||||||
page_num,
|
|
||||||
arxiv_id,
|
|
||||||
):
|
|
||||||
break
|
|
||||||
# 读取刚写入的 bytes 供后续同名 caption 复用
|
|
||||||
filename = f"{cap['label'].replace(' ', '_').lower()}.jpg"
|
|
||||||
jpeg_bytes = (images_dest / filename).read_bytes()
|
|
||||||
extracted += 1
|
|
||||||
else:
|
|
||||||
# 同区域的不同 caption(如 subfigure),复用图片
|
|
||||||
filename = f"{cap['label'].replace(' ', '_').lower()}.jpg"
|
|
||||||
(images_dest / filename).write_bytes(jpeg_bytes)
|
|
||||||
cap_preview = cap["caption_text"][:200]
|
|
||||||
manifest[filename] = {
|
|
||||||
"page": page_num,
|
|
||||||
"type": cap["type"],
|
|
||||||
"label": cap["label"],
|
|
||||||
"caption_text": cap_preview,
|
|
||||||
"figures" if cap["type"] == "figure" else "tables": [cap["label"]],
|
|
||||||
}
|
|
||||||
extracted += 1
|
|
||||||
|
|
||||||
# ── 阶段 2:渲染无 caption 匹配的图/表(orphan boxes) ──
|
|
||||||
orphan_indices = set(range(len(page_boxes))) - matched_box_indices
|
|
||||||
for bi in sorted(orphan_indices):
|
|
||||||
box = page_boxes[bi]
|
|
||||||
cap_type = "figure" if box.boxclass == "picture" else "table"
|
|
||||||
|
|
||||||
if cap_type == "figure":
|
if cap_type == "figure":
|
||||||
orphan_fig_counter += 1
|
fig_counter += 1
|
||||||
label = f"Figure (p{page_num}-{orphan_fig_counter})"
|
label = f"Figure (p{page_num}-{fig_counter})"
|
||||||
else:
|
else:
|
||||||
orphan_tbl_counter += 1
|
tbl_counter += 1
|
||||||
label = f"Table (p{page_num}-{orphan_tbl_counter})"
|
label = f"Table (p{page_num}-{tbl_counter})"
|
||||||
|
|
||||||
if label in seen_labels:
|
if label in seen_labels:
|
||||||
continue
|
continue
|
||||||
seen_labels.add(label)
|
seen_labels.add(label)
|
||||||
|
|
||||||
clip = pymupdf.Rect(
|
filename = f"{label.replace(' ', '_').lower()}.jpg"
|
||||||
max(0, box.x0 - _REGION_PADDING),
|
if not _render_box(page, cluster, images_dest, filename, cap_type, page_num):
|
||||||
max(0, box.y0 - _REGION_PADDING),
|
continue
|
||||||
min(page_width, box.x1 + _REGION_PADDING),
|
|
||||||
box.y1 + _REGION_PADDING,
|
manifest[filename] = {
|
||||||
)
|
"page": page_num,
|
||||||
if _render_and_save(
|
"type": cap_type,
|
||||||
page,
|
"label": label,
|
||||||
clip,
|
"box": [
|
||||||
images_dest,
|
round(float(cluster.x0), 1),
|
||||||
manifest,
|
round(float(cluster.y0), 1),
|
||||||
label,
|
round(float(cluster.x1), 1),
|
||||||
cap_type,
|
round(float(cluster.y1), 1),
|
||||||
"",
|
],
|
||||||
page_num,
|
}
|
||||||
arxiv_id,
|
extracted += 1
|
||||||
):
|
|
||||||
extracted += 1
|
|
||||||
|
|
||||||
return extracted
|
return extracted
|
||||||
|
|
||||||
|
|
||||||
# ── 核心提取 ───────────────────────────────────────────────────────────
|
# ── Phase 1 核心入口 ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
|
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
|
||||||
"""从 PDF 提取 Figure/Table 截图,生成 manifest。
|
"""Phase 1: 从 PDF 提取 Figure/Table 截图,生成通用标签的 manifest。
|
||||||
|
|
||||||
用 pymupdf4llm layout analysis 检测 table/picture 区域,
|
|
||||||
再通过 caption 文字确定编号,渲染为 JPEG。
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
arxiv_id: 论文 ID
|
arxiv_id: 论文 ID
|
||||||
@@ -526,45 +218,31 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
|
|||||||
if (images_dest / "manifest.json").exists():
|
if (images_dest / "manifest.json").exists():
|
||||||
(images_dest / "manifest.json").unlink()
|
(images_dest / "manifest.json").unlink()
|
||||||
|
|
||||||
doc = pymupdf.open(str(pdf_path))
|
with pymupdf.open(str(pdf_path)) as doc:
|
||||||
|
extracted = 0
|
||||||
|
manifest: dict[str, dict] = {}
|
||||||
|
seen_labels: set[str] = set()
|
||||||
|
|
||||||
# layout analysis
|
for page_idx in range(doc.page_count):
|
||||||
try:
|
try:
|
||||||
parsed = dl.parse_document(
|
page_boxes = detect_page_layout(doc[page_idx])
|
||||||
doc, filename=str(pdf_path), use_ocr=dl.OCRMode.NEVER
|
extracted += _process_page(
|
||||||
)
|
doc,
|
||||||
except Exception:
|
page_idx,
|
||||||
logger.warning(
|
page_boxes,
|
||||||
"pymupdf4llm layout analysis failed for %s", arxiv_id, exc_info=True
|
images_dest=images_dest,
|
||||||
)
|
manifest=manifest,
|
||||||
doc.close()
|
seen_labels=seen_labels,
|
||||||
return 0
|
arxiv_id=arxiv_id,
|
||||||
|
)
|
||||||
extracted = 0
|
except Exception:
|
||||||
manifest: dict[str, dict] = {}
|
logger.warning(
|
||||||
seen_labels: set[str] = set()
|
"Failed to process page %d for %s",
|
||||||
|
page_idx + 1,
|
||||||
for page_idx, page_layout in enumerate(parsed.pages):
|
arxiv_id,
|
||||||
try:
|
exc_info=True,
|
||||||
extracted += _process_page(
|
)
|
||||||
doc,
|
continue
|
||||||
page_idx,
|
|
||||||
page_layout,
|
|
||||||
images_dest=images_dest,
|
|
||||||
manifest=manifest,
|
|
||||||
seen_labels=seen_labels,
|
|
||||||
arxiv_id=arxiv_id,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to process page %d for %s",
|
|
||||||
page_idx + 1,
|
|
||||||
arxiv_id,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
doc.close()
|
|
||||||
|
|
||||||
# 保存 manifest
|
# 保存 manifest
|
||||||
manifest_path = images_dest / "manifest.json"
|
manifest_path = images_dest / "manifest.json"
|
||||||
@@ -580,78 +258,321 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
|
|||||||
return extracted
|
return extracted
|
||||||
|
|
||||||
|
|
||||||
# ── 按 summary 过滤 ────────────────────────────────────────────────────
|
# ── Phase 2: 用 summary 的 figures ID 定位并重命名 ─────────────────────
|
||||||
|
|
||||||
|
|
||||||
def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
|
def _distance_text_to_box(rect: pymupdf.Rect, box: list[float]) -> float | None:
|
||||||
"""根据 summary 中的 figures 字段过滤提取的图片/表格。
|
"""计算搜索到的文本 rect 到 box 的距离。超出阈值返回 None。
|
||||||
|
|
||||||
用 manifest.json 中的 label 匹配,保留被 AI 总结引用的图片。
|
判断逻辑:rect 中心与 box 的垂直距离 + 水平重叠检查。
|
||||||
|
"""
|
||||||
|
rect_cx = (rect.x0 + rect.x1) / 2
|
||||||
|
rect_cy = (rect.y0 + rect.y1) / 2
|
||||||
|
bx0, by0, bx1, by1 = box
|
||||||
|
|
||||||
|
# 水平重叠:rect 中心在 box 水平范围内(或接近)
|
||||||
|
if not (bx0 - 20 <= rect_cx <= bx1 + 20):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 垂直距离
|
||||||
|
if rect_cy < by0:
|
||||||
|
dist = by0 - rect_cy
|
||||||
|
elif rect_cy > by1:
|
||||||
|
dist = rect_cy - by1
|
||||||
|
else:
|
||||||
|
dist = 0
|
||||||
|
|
||||||
|
return dist if dist <= _LABEL_MATCH_DISTANCE else None
|
||||||
|
|
||||||
|
|
||||||
|
def _search_variants(fig_id: str) -> list[str]:
|
||||||
|
"""为 figure/table ID 生成搜索变体。
|
||||||
|
|
||||||
|
"Figure 1" → ["Figure 1", "Fig. 1", "Fig 1"]
|
||||||
|
"Fig. 1" → ["Fig. 1", "Figure 1", "Fig 1"]
|
||||||
|
"Table A1" → ["Table A1"]
|
||||||
|
"""
|
||||||
|
variants = [fig_id]
|
||||||
|
|
||||||
|
m = re.match(r"(Fig\.?|Figure)\s+(\d+.*)", fig_id, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
num_part = m.group(2)
|
||||||
|
variants.extend(
|
||||||
|
[
|
||||||
|
f"Figure {num_part}",
|
||||||
|
f"Fig. {num_part}",
|
||||||
|
f"Fig {num_part}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 去重保序
|
||||||
|
seen = set()
|
||||||
|
result = []
|
||||||
|
for v in variants:
|
||||||
|
if v not in seen:
|
||||||
|
seen.add(v)
|
||||||
|
result.append(v)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def label_images_by_summary(
|
||||||
|
arxiv_id: str,
|
||||||
|
figures: list[dict],
|
||||||
|
pdf_path: Path | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Phase 2: 用 summary 的 figures ID 在 PDF 中搜索定位,重命名图片。
|
||||||
|
|
||||||
|
对 summary 中的每个 figure/table ID:
|
||||||
|
1. page.search_for(id) 在所有页面搜索文本位置
|
||||||
|
2. 计算搜索位置与 manifest 中 box 坐标的距离
|
||||||
|
3. 最近匹配 → 重命名文件、更新 manifest
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arxiv_id: 论文 ID
|
||||||
|
figures: summary 的 figures 列表,每项含 id/caption/description 等
|
||||||
|
pdf_path: PDF 路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
成功重命名的图片数量
|
||||||
"""
|
"""
|
||||||
if not figures:
|
if not figures:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
images_dir = paper_dir(arxiv_id) / "images"
|
if pdf_path is None:
|
||||||
manifest_path = images_dir / "manifest.json"
|
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
|
||||||
|
if not pdf_path.exists():
|
||||||
if not images_dir.exists() or not manifest_path.exists():
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
all_files = [
|
images_dest = paper_dir(arxiv_id) / "images"
|
||||||
f for f in images_dir.iterdir() if f.suffix.lower() in (".png", ".jpg", ".jpeg")
|
manifest_path = images_dest / "manifest.json"
|
||||||
]
|
if not manifest_path.exists():
|
||||||
if not all_files:
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8"))
|
manifest: dict[str, dict] = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||||||
|
if not manifest:
|
||||||
|
return 0
|
||||||
|
|
||||||
# 收集 summary 中引用的所有 Figure/Table ID(归一化)
|
# 构建候选列表:只对通用标签的条目做匹配
|
||||||
referenced_ids: set[str] = set()
|
candidates: dict[str, dict] = {} # filename → {page, box, ...}
|
||||||
for fig in figures:
|
for fname, info in manifest.items():
|
||||||
fig_id = fig.get("id", "")
|
if "(p" in info.get("label", ""):
|
||||||
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", fig_id, re.IGNORECASE)
|
candidates[fname] = info
|
||||||
if m:
|
|
||||||
referenced_ids.add(f"Figure {m.group(1)}")
|
|
||||||
m2 = re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
|
|
||||||
if m2:
|
|
||||||
referenced_ids.add(f"Table {m2.group(1)}")
|
|
||||||
|
|
||||||
if not referenced_ids:
|
if not candidates:
|
||||||
logger.warning("No valid figure/table IDs in summary for %s", arxiv_id)
|
return 0
|
||||||
return len(all_files)
|
|
||||||
|
|
||||||
# 根据 manifest 的 label 字段匹配
|
with pymupdf.open(str(pdf_path)) as doc:
|
||||||
keep_filenames: set[str] = set()
|
# 收集所有匹配候选:(fig_id, fig_index, filename, distance)
|
||||||
for filename, info in manifest.items():
|
matches: list[tuple[str, int, str, float]] = []
|
||||||
label = info.get("label", "")
|
|
||||||
if label in referenced_ids:
|
for fig_idx, fig in enumerate(figures):
|
||||||
keep_filenames.add(filename)
|
fig_id = fig.get("id", "")
|
||||||
|
if not fig_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 生成搜索变体:Figure 1 / Fig. 1 / Fig 1 等
|
||||||
|
search_terms = _search_variants(fig_id)
|
||||||
|
|
||||||
|
# 在所有页面搜索该文本(含变体)
|
||||||
|
search_hits: list[tuple[int, pymupdf.Rect]] = [] # (page_num_1based, Rect)
|
||||||
|
for page_idx in range(doc.page_count):
|
||||||
|
page = doc[page_idx]
|
||||||
|
seen_rects: set[tuple[float, float]] = set()
|
||||||
|
for term in search_terms:
|
||||||
|
for r in page.search_for(term):
|
||||||
|
key = (round(r.x0, 1), round(r.y0, 1))
|
||||||
|
if key not in seen_rects:
|
||||||
|
seen_rects.add(key)
|
||||||
|
search_hits.append((page_idx + 1, r))
|
||||||
|
|
||||||
|
if not search_hits:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 对每个候选 manifest 条目,找最近的搜索命中
|
||||||
|
for fname, info in candidates.items():
|
||||||
|
box = info.get("box")
|
||||||
|
if not box:
|
||||||
|
continue
|
||||||
|
manifest_page = info.get("page", 0)
|
||||||
|
|
||||||
|
best_dist: float | None = None
|
||||||
|
for hit_page, rect in search_hits:
|
||||||
|
# 只匹配同页面
|
||||||
|
if hit_page != manifest_page:
|
||||||
|
continue
|
||||||
|
dist = _distance_text_to_box(rect, box)
|
||||||
|
if dist is not None and (best_dist is None or dist < best_dist):
|
||||||
|
best_dist = dist
|
||||||
|
|
||||||
|
if best_dist is not None:
|
||||||
|
matches.append((fig_id, fig_idx, fname, best_dist))
|
||||||
|
|
||||||
|
if not matches:
|
||||||
|
logger.info("No label matches for %s", arxiv_id)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 去冲突:按距离排序,每个 fig_id 和每个 filename 只匹配一次
|
||||||
|
matches.sort(key=lambda x: x[3])
|
||||||
|
used_fig_ids: set[int] = set()
|
||||||
|
used_filenames: set[str] = set()
|
||||||
|
renames: list[tuple[str, str, str]] = [] # (old_fname, new_fname, fig_id)
|
||||||
|
|
||||||
|
for fig_id, fig_idx, fname, dist in matches:
|
||||||
|
if fig_idx in used_fig_ids or fname in used_filenames:
|
||||||
continue
|
continue
|
||||||
for ref in info.get("figures", []) + info.get("tables", []):
|
used_fig_ids.add(fig_idx)
|
||||||
if ref in referenced_ids:
|
used_filenames.add(fname)
|
||||||
keep_filenames.add(filename)
|
new_fname = f"{fig_id.replace(' ', '_').lower()}.jpg"
|
||||||
|
renames.append((fname, new_fname, fig_id))
|
||||||
|
|
||||||
|
# 执行重命名
|
||||||
|
labeled = 0
|
||||||
|
new_manifest: dict[str, dict] = {}
|
||||||
|
|
||||||
|
for fname, info in manifest.items():
|
||||||
|
if fname in used_filenames:
|
||||||
|
continue
|
||||||
|
# 未匹配的保持原样
|
||||||
|
new_manifest[fname] = info
|
||||||
|
|
||||||
|
for old_fname, new_fname, fig_id in renames:
|
||||||
|
old_path = images_dest / old_fname
|
||||||
|
new_path = images_dest / new_fname
|
||||||
|
if not old_path.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 搬运 manifest 信息
|
||||||
|
info = manifest[old_fname].copy()
|
||||||
|
cap_type = info.get("type", "figure")
|
||||||
|
|
||||||
|
# 读取 caption 文本(从 figures 列表)
|
||||||
|
caption_text = ""
|
||||||
|
for fig in figures:
|
||||||
|
if fig.get("id") == fig_id:
|
||||||
|
caption_text = fig.get("caption", "")
|
||||||
break
|
break
|
||||||
|
|
||||||
if not keep_filenames:
|
info["label"] = fig_id
|
||||||
logger.warning(
|
info["caption_text"] = caption_text[:200] if caption_text else ""
|
||||||
"No manifest matches for %s (refs=%s), keeping all",
|
info.setdefault("figures" if cap_type == "figure" else "tables", []).append(
|
||||||
arxiv_id,
|
fig_id
|
||||||
referenced_ids,
|
|
||||||
)
|
)
|
||||||
return len(all_files)
|
|
||||||
|
|
||||||
removed = 0
|
# 重命名文件
|
||||||
for f in all_files:
|
if new_fname != old_fname:
|
||||||
if f.name not in keep_filenames:
|
old_path.rename(new_path)
|
||||||
f.unlink()
|
new_manifest[new_fname] = info
|
||||||
removed += 1
|
labeled += 1
|
||||||
|
|
||||||
|
# 写回 manifest
|
||||||
|
manifest_path.write_text(json.dumps(new_manifest, ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
kept = len(all_files) - removed
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Filtered images for %s: kept %d, removed %d (refs=%s)",
|
"Labeled %d/%d images for %s using summary figures",
|
||||||
|
labeled,
|
||||||
|
len(manifest),
|
||||||
arxiv_id,
|
arxiv_id,
|
||||||
kept,
|
|
||||||
removed,
|
|
||||||
referenced_ids,
|
|
||||||
)
|
)
|
||||||
return kept
|
return labeled
|
||||||
|
|
||||||
|
|
||||||
|
# ── Figure ↔ Image 关联 ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_figure_id(raw_id: str) -> str:
|
||||||
|
"""归一化 Figure/Table ID:'Figure 1'/'Fig.1' → 'Figure 1'。"""
|
||||||
|
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
return f"Figure {m.group(1)}"
|
||||||
|
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
|
||||||
|
if m2:
|
||||||
|
return f"Table {m2.group(1)}"
|
||||||
|
return raw_id
|
||||||
|
|
||||||
|
|
||||||
|
def _is_figure_type(fig_id: str) -> bool:
|
||||||
|
"""判断是否为 Figure 类型(非 Table)。"""
|
||||||
|
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def _image_sort_key(name: str) -> tuple[int, int]:
|
||||||
|
"""按文件名中的编号排序提取的图片。"""
|
||||||
|
# 新格式:figure_1.jpg, table_1.jpg
|
||||||
|
m = re.search(r"(?:figure|table)_(\d+)", name)
|
||||||
|
if m:
|
||||||
|
return (0, int(m.group(1)))
|
||||||
|
# 旧格式:page2_img1.png, page5_table1.png, figure_1.png
|
||||||
|
m2 = re.search(r"page(\d+)_(?:img|table)(\d+)", name)
|
||||||
|
if m2:
|
||||||
|
return (int(m2.group(1)), int(m2.group(2)))
|
||||||
|
return (0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def link_figures_with_images(
|
||||||
|
figures: list[dict], images: list[dict], arxiv_id: str
|
||||||
|
) -> list[dict]:
|
||||||
|
"""将 summary figures 元数据与提取的图片文件关联。
|
||||||
|
|
||||||
|
策略:
|
||||||
|
1. 优先用 manifest.json 的 label 做 ID 精确匹配
|
||||||
|
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
|
||||||
|
"""
|
||||||
|
if not figures or not images:
|
||||||
|
return figures
|
||||||
|
|
||||||
|
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
|
||||||
|
|
||||||
|
# ── 策略 1:manifest ID 精确匹配 ──
|
||||||
|
id_to_url: dict[str, str] = {}
|
||||||
|
if manifest_path.exists():
|
||||||
|
try:
|
||||||
|
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
manifest = {}
|
||||||
|
for filename, info in manifest.items():
|
||||||
|
url = f"/papers/{arxiv_id}/images/{filename}"
|
||||||
|
# 优先用 label 字段(新格式)
|
||||||
|
label = info.get("label", "")
|
||||||
|
if label:
|
||||||
|
id_to_url[label] = url
|
||||||
|
# 也兼容 figures/tables 列表(旧格式)
|
||||||
|
for fig_id in info.get("figures", []) + info.get("tables", []):
|
||||||
|
if fig_id not in id_to_url:
|
||||||
|
id_to_url[fig_id] = url
|
||||||
|
|
||||||
|
for fig in figures:
|
||||||
|
raw_id = fig.get("id", "")
|
||||||
|
normalized = _normalize_figure_id(raw_id)
|
||||||
|
if normalized in id_to_url:
|
||||||
|
fig["image_url"] = id_to_url[normalized]
|
||||||
|
|
||||||
|
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
|
||||||
|
unmatched = [f for f in figures if not f.get("image_url")]
|
||||||
|
if not unmatched:
|
||||||
|
return figures
|
||||||
|
|
||||||
|
# 按类型分流:Figure vs Table
|
||||||
|
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
|
||||||
|
table_type_unmatched = [
|
||||||
|
f for f in unmatched if not _is_figure_type(f.get("id", ""))
|
||||||
|
]
|
||||||
|
|
||||||
|
# 提取的图片按类型分流,按文件名中的编号排序
|
||||||
|
fig_images = sorted(
|
||||||
|
[img for img in images if "table" not in img["name"].lower()],
|
||||||
|
key=lambda img: _image_sort_key(img["name"]),
|
||||||
|
)
|
||||||
|
table_images = sorted(
|
||||||
|
[img for img in images if "table" in img["name"].lower()],
|
||||||
|
key=lambda img: _image_sort_key(img["name"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, fig in enumerate(fig_type_unmatched):
|
||||||
|
if i < len(fig_images):
|
||||||
|
fig["image_url"] = fig_images[i]["url"]
|
||||||
|
|
||||||
|
for i, fig in enumerate(table_type_unmatched):
|
||||||
|
if i < len(table_images):
|
||||||
|
fig["image_url"] = table_images[i]["url"]
|
||||||
|
|
||||||
|
return figures
|
||||||
|
|||||||
+25
-10
@@ -11,6 +11,7 @@ import uuid
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
from app.utils import truncate_error
|
||||||
from app.services.summary_utils import (
|
from app.services.summary_utils import (
|
||||||
JsonNotFoundError,
|
JsonNotFoundError,
|
||||||
build_prompt,
|
build_prompt,
|
||||||
@@ -21,6 +22,9 @@ from app.services.summary_utils import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# PDF 全文注入模式的字符上限 — 超过此阈值自动切换到 search 模式
|
||||||
|
_PDF_MAX_CHARS = 80_000
|
||||||
|
|
||||||
# 重新导出,保持向后兼容
|
# 重新导出,保持向后兼容
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PiTimeoutError",
|
"PiTimeoutError",
|
||||||
@@ -45,7 +49,7 @@ class PiProcessError(Exception):
|
|||||||
def __init__(self, returncode: int, stderr: str):
|
def __init__(self, returncode: int, stderr: str):
|
||||||
self.returncode = returncode
|
self.returncode = returncode
|
||||||
self.stderr = stderr
|
self.stderr = stderr
|
||||||
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
|
super().__init__(f"pi exited with code {returncode}: {truncate_error(stderr)}")
|
||||||
|
|
||||||
|
|
||||||
# ── pi CLI 调用 ────────────────────────────────────────────────────────
|
# ── pi CLI 调用 ────────────────────────────────────────────────────────
|
||||||
@@ -72,23 +76,27 @@ async def call_pi(
|
|||||||
|
|
||||||
actual_mode = pdf_mode
|
actual_mode = pdf_mode
|
||||||
if pdf_mode == "auto":
|
if pdf_mode == "auto":
|
||||||
if txt_size > 80_000:
|
if txt_size > _PDF_MAX_CHARS:
|
||||||
actual_mode = "search"
|
actual_mode = "search"
|
||||||
logger.info(
|
logger.info(
|
||||||
"Auto mode: %s text=%d chars > 80k → search", arxiv_id, txt_size
|
"Auto mode: %s text=%d chars > %dk → search",
|
||||||
|
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
actual_mode = "inject"
|
actual_mode = "inject"
|
||||||
logger.info(
|
logger.info(
|
||||||
"Auto mode: %s text=%d chars ≤ 80k → inject", arxiv_id, txt_size
|
"Auto mode: %s text=%d chars ≤ %dk → inject",
|
||||||
|
arxiv_id, txt_size, _PDF_MAX_CHARS // 1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
# inject 模式需要截断过长的文本(避免撑爆 context)
|
# inject 模式需要截断过长的文本(避免撑爆 context)
|
||||||
if actual_mode == "inject" and txt_size > 80_000:
|
if actual_mode == "inject" and txt_size > _PDF_MAX_CHARS:
|
||||||
body = txt_path.read_text(encoding="utf-8")
|
body = txt_path.read_text(encoding="utf-8")
|
||||||
trimmed = body[:80_000].rstrip()
|
trimmed = body[:_PDF_MAX_CHARS].rstrip()
|
||||||
txt_path.write_text(trimmed, encoding="utf-8")
|
txt_path.write_text(trimmed, encoding="utf-8")
|
||||||
logger.info("Truncated %s for inject: %d → %d chars", arxiv_id, txt_size, len(trimmed))
|
logger.info(
|
||||||
|
"Truncated %s for inject: %d → %d chars", arxiv_id, txt_size, len(trimmed)
|
||||||
|
)
|
||||||
|
|
||||||
prompt_text = build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
|
prompt_text = build_prompt(arxiv_id, meta_path, txt_path, actual_mode, fix_errors)
|
||||||
|
|
||||||
@@ -101,7 +109,8 @@ async def call_pi(
|
|||||||
cmd = [
|
cmd = [
|
||||||
settings.PI_BIN,
|
settings.PI_BIN,
|
||||||
"-p",
|
"-p",
|
||||||
"--tools", tools,
|
"--tools",
|
||||||
|
tools,
|
||||||
]
|
]
|
||||||
if fix_errors:
|
if fix_errors:
|
||||||
cmd += ["--session", session_id, "--continue"]
|
cmd += ["--session", session_id, "--continue"]
|
||||||
@@ -118,10 +127,14 @@ async def call_pi(
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Calling pi for %s (fix=%s, session=%s, mode=%s)",
|
"Calling pi for %s (fix=%s, session=%s, mode=%s)",
|
||||||
arxiv_id, bool(fix_errors), session_id, actual_mode,
|
arxiv_id,
|
||||||
|
bool(fix_errors),
|
||||||
|
session_id,
|
||||||
|
actual_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
import time as _time
|
import time as _time
|
||||||
|
|
||||||
_t_sub_start = _time.monotonic()
|
_t_sub_start = _time.monotonic()
|
||||||
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
@@ -151,7 +164,9 @@ async def call_pi(
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"pi subprocess for %s: %.2fs%s",
|
"pi subprocess for %s: %.2fs%s",
|
||||||
arxiv_id, _t_sub_end - _t_sub_start, _file_info,
|
arxiv_id,
|
||||||
|
_t_sub_end - _t_sub_start,
|
||||||
|
_file_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
if proc.returncode != 0:
|
if proc.returncode != 0:
|
||||||
|
|||||||
+56
-11
@@ -8,6 +8,7 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from datetime import date as date_type
|
from datetime import date as date_type
|
||||||
|
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
@@ -15,11 +16,50 @@ from app.models import CrawlLog, TaskLock
|
|||||||
from app.services.cleaner import cleanup_tmp
|
from app.services.cleaner import cleanup_tmp
|
||||||
from app.services.crawler import crawl_daily
|
from app.services.crawler import crawl_daily
|
||||||
from app.services.summarizer import summarize_batch
|
from app.services.summarizer import summarize_batch
|
||||||
from app.utils import utc_now, yesterday_str
|
from app.utils import release_lock, truncate_error, utc_now, yesterday_str
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def acquire_lock(db: Session, task: str, lock_key: str, owner: str) -> TaskLock:
|
||||||
|
"""获取 TaskLock,锁冲突时抛出 RuntimeError。
|
||||||
|
|
||||||
|
供需要防重入的操作(crawl、pipeline 等)统一调用。
|
||||||
|
"""
|
||||||
|
lock = TaskLock(
|
||||||
|
task=task,
|
||||||
|
lock_key=lock_key,
|
||||||
|
status="running",
|
||||||
|
owner=owner,
|
||||||
|
acquired_at=utc_now(),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
db.add(lock)
|
||||||
|
db.commit()
|
||||||
|
except IntegrityError:
|
||||||
|
db.rollback()
|
||||||
|
raise RuntimeError(f"{task} already running for {lock_key}")
|
||||||
|
return lock
|
||||||
|
|
||||||
|
|
||||||
|
async def run_crawl(db: Session, target_date: str, owner: str = "admin_crawl") -> dict:
|
||||||
|
"""执行单次抓取(带防重入锁)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库 session
|
||||||
|
target_date: 目标日期 YYYY-MM-DD
|
||||||
|
owner: 调用者标识
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
crawl_daily() 的原始返回值
|
||||||
|
"""
|
||||||
|
lock = acquire_lock(db, "crawl", target_date, owner)
|
||||||
|
try:
|
||||||
|
return await crawl_daily(db, target_date)
|
||||||
|
finally:
|
||||||
|
release_lock(db, lock)
|
||||||
|
|
||||||
|
|
||||||
async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
|
async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
|
||||||
"""执行完整流水线:crawl → summarize → cleanup。
|
"""执行完整流水线:crawl → summarize → cleanup。
|
||||||
|
|
||||||
@@ -47,7 +87,7 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
|
|||||||
try:
|
try:
|
||||||
db.add(lock)
|
db.add(lock)
|
||||||
db.commit()
|
db.commit()
|
||||||
except Exception:
|
except IntegrityError:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
raise RuntimeError(f"Pipeline already running for {target_date}")
|
raise RuntimeError(f"Pipeline already running for {target_date}")
|
||||||
|
|
||||||
@@ -66,9 +106,13 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
|
|||||||
try:
|
try:
|
||||||
# Step 1: 抓取(先试今天,无数据则回退昨天)
|
# Step 1: 抓取(先试今天,无数据则回退昨天)
|
||||||
crawl_result = await crawl_daily(db, target_date)
|
crawl_result = await crawl_daily(db, target_date)
|
||||||
logger.info("Pipeline [%s]: crawl %s, found=%d new=%d",
|
logger.info(
|
||||||
owner, target_date,
|
"Pipeline [%s]: crawl %s, found=%d new=%d",
|
||||||
crawl_result.get("found", 0), crawl_result.get("new", 0))
|
owner,
|
||||||
|
target_date,
|
||||||
|
crawl_result.get("found", 0),
|
||||||
|
crawl_result.get("new", 0),
|
||||||
|
)
|
||||||
|
|
||||||
if crawl_result.get("status") == "success" and crawl_result.get("found") == 0:
|
if crawl_result.get("status") == "success" and crawl_result.get("found") == 0:
|
||||||
yesterday = yesterday_str()
|
yesterday = yesterday_str()
|
||||||
@@ -81,8 +125,11 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
|
|||||||
|
|
||||||
# Step 3: 清理
|
# Step 3: 清理
|
||||||
cleanup_result = cleanup_tmp()
|
cleanup_result = cleanup_tmp()
|
||||||
logger.info("Pipeline [%s]: cleanup done, removed=%d",
|
logger.info(
|
||||||
owner, cleanup_result.get("removed", 0))
|
"Pipeline [%s]: cleanup done, removed=%d",
|
||||||
|
owner,
|
||||||
|
cleanup_result.get("removed", 0),
|
||||||
|
)
|
||||||
|
|
||||||
log_entry.status = "success"
|
log_entry.status = "success"
|
||||||
log_entry.papers_found = crawl_result.get("found", 0)
|
log_entry.papers_found = crawl_result.get("found", 0)
|
||||||
@@ -91,7 +138,7 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Pipeline [%s] failed", owner)
|
logger.exception("Pipeline [%s] failed", owner)
|
||||||
log_entry.status = "failed"
|
log_entry.status = "failed"
|
||||||
error_msg = str(exc)[:2000]
|
error_msg = truncate_error(exc, limit=2000)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
log_entry.completed_at = utc_now()
|
log_entry.completed_at = utc_now()
|
||||||
@@ -99,9 +146,7 @@ async def run_pipeline(db: Session, target_date: str, owner: str) -> dict:
|
|||||||
log_entry.error = error_msg
|
log_entry.error = error_msg
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
lock.status = "finished"
|
release_lock(db, lock)
|
||||||
lock.released_at = utc_now()
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
if error_msg:
|
if error_msg:
|
||||||
return {"status": "failed", "error": error_msg}
|
return {"status": "failed", "error": error_msg}
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ class SummarySchema(BaseModel):
|
|||||||
|
|
||||||
# ── 质量评估 ────────────────────────────────────────────────────────────
|
# ── 质量评估 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def assess_quality(schema: SummarySchema) -> str:
|
def assess_quality(schema: SummarySchema) -> str:
|
||||||
"""评估总结质量:normal / degraded / low。"""
|
"""评估总结质量:normal / degraded / low。"""
|
||||||
# low:内容空洞的启发式判断
|
# low:内容空洞的启发式判断
|
||||||
|
|||||||
@@ -213,11 +213,7 @@ def _search_semantic(
|
|||||||
arxiv_ids = [c["arxiv_id"] for c in candidates]
|
arxiv_ids = [c["arxiv_id"] for c in candidates]
|
||||||
distance_map = {c["arxiv_id"]: c["distance"] for c in candidates}
|
distance_map = {c["arxiv_id"]: c["distance"] for c in candidates}
|
||||||
|
|
||||||
stmt = (
|
stmt = select(Paper).where(Paper.arxiv_id.in_(arxiv_ids)).options(*PAPER_FULL_LOAD)
|
||||||
select(Paper)
|
|
||||||
.where(Paper.arxiv_id.in_(arxiv_ids))
|
|
||||||
.options(*PAPER_FULL_LOAD)
|
|
||||||
)
|
|
||||||
if tag:
|
if tag:
|
||||||
stmt = stmt.where(Paper.tags.any(tag=tag))
|
stmt = stmt.where(Paper.tags.any(tag=tag))
|
||||||
|
|
||||||
@@ -298,9 +294,7 @@ def _load_papers_by_ids(
|
|||||||
|
|
||||||
papers = (
|
papers = (
|
||||||
db.execute(
|
db.execute(
|
||||||
select(Paper)
|
select(Paper).where(Paper.id.in_(paper_ids)).options(*PAPER_FULL_LOAD)
|
||||||
.where(Paper.id.in_(paper_ids))
|
|
||||||
.options(*PAPER_FULL_LOAD)
|
|
||||||
)
|
)
|
||||||
.unique()
|
.unique()
|
||||||
.scalars()
|
.scalars()
|
||||||
|
|||||||
+39
-502
@@ -1,233 +1,42 @@
|
|||||||
"""AI 总结编排服务 — 协调 PDF 下载、pi CLI 调用、JSON 校验、DB 写入、语义索引。"""
|
"""AI 总结编排服务 — 协调生成器、持久化、批量处理的顶层入口。"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from pydantic import ValidationError
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database import SessionLocal
|
from app.database import SessionLocal
|
||||||
|
from app.exceptions import ConflictError, NotFoundError
|
||||||
from app.models import (
|
from app.models import (
|
||||||
PAPER_DEFAULT_LOAD,
|
PAPER_DEFAULT_LOAD,
|
||||||
CrawlLog,
|
CrawlLog,
|
||||||
Paper,
|
Paper,
|
||||||
PaperSummary,
|
|
||||||
PaperTag,
|
|
||||||
SummaryState,
|
SummaryState,
|
||||||
SummaryStatus,
|
SummaryStatus,
|
||||||
TaskLock,
|
TaskLock,
|
||||||
|
get_paper_by_arxiv_id,
|
||||||
|
get_paper_by_id,
|
||||||
)
|
)
|
||||||
from app.services.pdf_downloader import (
|
from app.services.pdf_downloader import download_pdf
|
||||||
PdfDownloadError,
|
from app.services.summary_utils import write_meta_json
|
||||||
cleanup_tmp,
|
from app.services.summary_generator import (
|
||||||
download_pdf,
|
_generate_with_retry,
|
||||||
paper_dir,
|
|
||||||
)
|
)
|
||||||
from app.services.summary_utils import (
|
from app.services.summary_persister import (
|
||||||
JsonNotFoundError,
|
_cleanup_old_images,
|
||||||
build_prompt,
|
_handle_summary_failure,
|
||||||
extract_json,
|
_persist_summary,
|
||||||
write_meta_json,
|
|
||||||
extract_pdf_text,
|
|
||||||
)
|
)
|
||||||
from app.services.pi_client import (
|
from app.utils import TMP_DIR, release_lock, truncate_error, utc_now
|
||||||
PiProcessError,
|
|
||||||
PiTimeoutError,
|
|
||||||
call_pi,
|
|
||||||
)
|
|
||||||
from app.services import claude_backend
|
|
||||||
from app.services.schemas import (
|
|
||||||
SummarySchema,
|
|
||||||
assess_quality,
|
|
||||||
classify_validation_error,
|
|
||||||
flatten_for_db,
|
|
||||||
)
|
|
||||||
from app.utils import TMP_DIR, release_lock, utc_now
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# ── 错误分类 ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _classify_error(exc: Exception) -> str:
|
|
||||||
"""将异常映射到 error_type 枚举值。"""
|
|
||||||
if isinstance(exc, PdfDownloadError):
|
|
||||||
return "pdf_download_failed"
|
|
||||||
if isinstance(exc, PiTimeoutError):
|
|
||||||
return "timeout"
|
|
||||||
if isinstance(exc, PiProcessError):
|
|
||||||
return "process_error"
|
|
||||||
if isinstance(exc, JsonNotFoundError):
|
|
||||||
return "json_not_found"
|
|
||||||
if isinstance(exc, json.JSONDecodeError):
|
|
||||||
return "json_invalid"
|
|
||||||
if isinstance(exc, ValidationError):
|
|
||||||
return classify_validation_error(exc)
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
# ── FTS5 文本构建 ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _build_fts_summary_text(schema: SummarySchema) -> str:
|
|
||||||
"""拼接用于 FTS5 索引的总结文本。"""
|
|
||||||
parts = [
|
|
||||||
schema.one_line or "",
|
|
||||||
schema.motivation.problem or "",
|
|
||||||
schema.motivation.goal or "",
|
|
||||||
schema.method.overview or "",
|
|
||||||
schema.method.key_idea or "",
|
|
||||||
schema.results.main_findings or "",
|
|
||||||
]
|
|
||||||
return " ".join(p for p in parts if p)
|
|
||||||
|
|
||||||
|
|
||||||
# ── DB 更新 ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _update_summary_in_db(
|
|
||||||
db: Session,
|
|
||||||
paper: Paper,
|
|
||||||
schema: SummarySchema,
|
|
||||||
quality: str,
|
|
||||||
raw_output: str,
|
|
||||||
) -> None:
|
|
||||||
"""将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。"""
|
|
||||||
from sqlalchemy import text
|
|
||||||
|
|
||||||
# 1. paper_summaries:upsert
|
|
||||||
existing = db.get(PaperSummary, paper.id)
|
|
||||||
flat = flatten_for_db(schema)
|
|
||||||
if existing:
|
|
||||||
for k, v in flat.items():
|
|
||||||
setattr(existing, k, v)
|
|
||||||
else:
|
|
||||||
db.add(PaperSummary(paper_id=paper.id, **flat))
|
|
||||||
|
|
||||||
# 2. papers 表
|
|
||||||
paper.title_zh = schema.title_zh
|
|
||||||
paper.summary_quality = quality
|
|
||||||
p_dir = paper_dir(paper.arxiv_id)
|
|
||||||
paper.summary_path = str(p_dir / "summary.json")
|
|
||||||
paper.raw_output_path = str(p_dir / "raw_output.txt")
|
|
||||||
|
|
||||||
# 3. AI 标签
|
|
||||||
existing_tag_names = {t.tag for t in paper.tags}
|
|
||||||
for tag_name in schema.tags:
|
|
||||||
if tag_name not in existing_tag_names:
|
|
||||||
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
|
|
||||||
existing_tag_names.add(tag_name)
|
|
||||||
|
|
||||||
# 4. FTS5 更新
|
|
||||||
summary_text = _build_fts_summary_text(schema)
|
|
||||||
db.execute(
|
|
||||||
text(
|
|
||||||
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
|
|
||||||
"WHERE rowid=:paper_id"
|
|
||||||
),
|
|
||||||
{
|
|
||||||
"title_zh": schema.title_zh,
|
|
||||||
"summary_text": summary_text,
|
|
||||||
"paper_id": paper.id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
|
|
||||||
|
|
||||||
|
|
||||||
# ── JSON 验证 ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
|
|
||||||
"""验证 JSON 数据是否符合要求,返回错误列表(空=通过)。"""
|
|
||||||
errors: list[str] = []
|
|
||||||
|
|
||||||
if not isinstance(json_data, dict):
|
|
||||||
return ["顶层必须是 JSON 对象"]
|
|
||||||
|
|
||||||
# 必填字段
|
|
||||||
for f in ["arxiv_id", "title_zh", "one_line", "tags"]:
|
|
||||||
if f not in json_data or not json_data[f]:
|
|
||||||
errors.append(f"缺少必填字段: {f}")
|
|
||||||
|
|
||||||
# tags 必须是非空数组
|
|
||||||
tags = json_data.get("tags")
|
|
||||||
if not isinstance(tags, list) or len(tags) == 0:
|
|
||||||
errors.append("tags 必须是非空数组")
|
|
||||||
|
|
||||||
# 字符串段落字段(必须是 str 且 ≥50 字)
|
|
||||||
string_fields = [
|
|
||||||
("motivation", "problem"), ("motivation", "goal"), ("motivation", "gap"),
|
|
||||||
("method", "overview"), ("method", "key_idea"), ("method", "steps"),
|
|
||||||
("method", "novelty"),
|
|
||||||
("results", "main_findings"), ("results", "limitations"),
|
|
||||||
("improvements", "weaknesses"), ("improvements", "future_work"),
|
|
||||||
("improvements", "reproducibility"),
|
|
||||||
]
|
|
||||||
for section, field in string_fields:
|
|
||||||
val = json_data.get(section, {}).get(field)
|
|
||||||
if isinstance(val, list):
|
|
||||||
errors.append(f"{section}.{field} 应该是字符串段落,不能是数组")
|
|
||||||
elif not isinstance(val, str) or len(val.strip()) < 50:
|
|
||||||
errors.append(
|
|
||||||
f"{section}.{field} 必须是详细段落(≥50字),"
|
|
||||||
f"当前: {type(val).__name__} ({len(str(val))}字)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# benchmarks 必须是数组
|
|
||||||
benchmarks = json_data.get("results", {}).get("benchmarks")
|
|
||||||
if benchmarks is not None and not isinstance(benchmarks, list):
|
|
||||||
errors.append("results.benchmarks 必须是数组")
|
|
||||||
|
|
||||||
# prerequisites.concepts 必须是对象数组,每个有 term
|
|
||||||
concepts = json_data.get("prerequisites", {}).get("concepts")
|
|
||||||
if concepts is not None:
|
|
||||||
if not isinstance(concepts, list):
|
|
||||||
errors.append("prerequisites.concepts 必须是数组")
|
|
||||||
elif len(concepts) == 0:
|
|
||||||
errors.append("prerequisites.concepts 不能为空")
|
|
||||||
else:
|
|
||||||
for i, c in enumerate(concepts):
|
|
||||||
if isinstance(c, str):
|
|
||||||
errors.append(f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串")
|
|
||||||
elif isinstance(c, dict) and not c.get("term"):
|
|
||||||
errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段")
|
|
||||||
|
|
||||||
# figures 必须是数组,每个元素应有 id
|
|
||||||
figures = json_data.get("figures")
|
|
||||||
if figures is not None:
|
|
||||||
if not isinstance(figures, list):
|
|
||||||
errors.append("figures 必须是数组")
|
|
||||||
else:
|
|
||||||
for i, fig in enumerate(figures):
|
|
||||||
if isinstance(fig, dict) and not fig.get("id"):
|
|
||||||
errors.append(f"figures[{i}] 缺少 id 字段")
|
|
||||||
|
|
||||||
return errors
|
|
||||||
|
|
||||||
|
|
||||||
# ── 文件操作 ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) -> None:
|
|
||||||
d = paper_dir(arxiv_id)
|
|
||||||
d.mkdir(parents=True, exist_ok=True)
|
|
||||||
if schema:
|
|
||||||
(d / "summary.json").write_text(
|
|
||||||
schema.model_dump_json(ensure_ascii=False, indent=2),
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
|
||||||
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
# ── 单篇总结 ────────────────────────────────────────────────────────────
|
# ── 单篇总结 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -264,277 +73,7 @@ async def summarize_one(
|
|||||||
return await _do_summarize_one(db, paper, pdf_mode=pdf_mode)
|
return await _do_summarize_one(db, paper, pdf_mode=pdf_mode)
|
||||||
|
|
||||||
|
|
||||||
async def _generate_with_retry(
|
async def _do_summarize_one(db: Session, paper: Paper, pdf_mode: str = "auto") -> dict:
|
||||||
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
|
|
||||||
) -> tuple[dict, str]:
|
|
||||||
"""调用 AI 后端生成总结,最多 4 轮验证循环。
|
|
||||||
|
|
||||||
根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(json_data, raw_output)
|
|
||||||
Raises:
|
|
||||||
ValueError: 4 轮验证仍未通过
|
|
||||||
"""
|
|
||||||
import time as _time
|
|
||||||
|
|
||||||
backend = settings.SUMMARY_BACKEND
|
|
||||||
validation_errors: list[str] = []
|
|
||||||
json_data: dict | None = None
|
|
||||||
raw_output = ""
|
|
||||||
session_id = None
|
|
||||||
|
|
||||||
summary_file = paper_dir(arxiv_id) / "summary.json"
|
|
||||||
|
|
||||||
# claude 后端需要预构建 prompt(pi 后端在 call_pi 内部构建)
|
|
||||||
claude_prompt: str | None = None
|
|
||||||
if backend == "claude":
|
|
||||||
_t0 = _time.monotonic()
|
|
||||||
txt_path = extract_pdf_text(pdf_path, max_chars=None)
|
|
||||||
body = txt_path.read_text(encoding="utf-8")
|
|
||||||
if len(body) > 80_000:
|
|
||||||
trimmed = body[:80_000].rstrip()
|
|
||||||
txt_path.write_text(trimmed, encoding="utf-8")
|
|
||||||
claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None)
|
|
||||||
logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0)
|
|
||||||
|
|
||||||
for attempt in range(1, 5):
|
|
||||||
# 清理上一轮写入的不完整文件
|
|
||||||
if summary_file.exists():
|
|
||||||
summary_file.unlink()
|
|
||||||
|
|
||||||
# 记录 AI 调用开始时间
|
|
||||||
_t_call_start = _time.monotonic()
|
|
||||||
|
|
||||||
if backend == "claude":
|
|
||||||
if attempt == 1:
|
|
||||||
raw_output, session_id = await claude_backend.call_claude(
|
|
||||||
claude_prompt, session_id=None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
retry_prompt = build_prompt(
|
|
||||||
arxiv_id, meta_path,
|
|
||||||
extract_pdf_text(pdf_path, max_chars=80000),
|
|
||||||
"inject", fix_errors=validation_errors,
|
|
||||||
)
|
|
||||||
raw_output, session_id = await claude_backend.call_claude(
|
|
||||||
retry_prompt, session_id=session_id, fix_errors=validation_errors,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if attempt == 1:
|
|
||||||
raw_output, session_id = await call_pi(meta_path, pdf_path, pdf_mode=pdf_mode)
|
|
||||||
else:
|
|
||||||
raw_output, session_id = await call_pi(
|
|
||||||
meta_path, pdf_path,
|
|
||||||
fix_errors=validation_errors,
|
|
||||||
session_id=session_id,
|
|
||||||
pdf_mode=pdf_mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
_t_call_end = _time.monotonic()
|
|
||||||
|
|
||||||
# 检查 summary.json 是否由 AI 子进程写入
|
|
||||||
file_written_by_ai = summary_file.exists()
|
|
||||||
file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None
|
|
||||||
file_size = summary_file.stat().st_size if file_written_by_ai else 0
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
" [%s] attempt %d AI调用: %.2fs summary.json=%s%s",
|
|
||||||
arxiv_id, attempt,
|
|
||||||
_t_call_end - _t_call_start,
|
|
||||||
f"已写入({file_size}B)" if file_written_by_ai else "未写入",
|
|
||||||
f" mtime={file_mtime:.2f}" if file_mtime else "",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 提取 JSON
|
|
||||||
_t_json_start = _time.monotonic()
|
|
||||||
try:
|
|
||||||
if file_written_by_ai:
|
|
||||||
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
|
|
||||||
logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id)
|
|
||||||
else:
|
|
||||||
json_data = extract_json(raw_output)
|
|
||||||
except (json.JSONDecodeError, JsonNotFoundError) as exc:
|
|
||||||
_t_json_end = _time.monotonic()
|
|
||||||
logger.warning(
|
|
||||||
" [%s] JSON提取失败: %.2fs %s",
|
|
||||||
arxiv_id, _t_json_end - _t_json_start, str(exc)[:200],
|
|
||||||
)
|
|
||||||
validation_errors = [f"无法提取有效 JSON: {str(exc)[:100]}"]
|
|
||||||
continue
|
|
||||||
_t_json_end = _time.monotonic()
|
|
||||||
|
|
||||||
# 验证
|
|
||||||
_t_val_start = _time.monotonic()
|
|
||||||
validation_errors = _validate_summary(json_data, arxiv_id)
|
|
||||||
_t_val_end = _time.monotonic()
|
|
||||||
|
|
||||||
if not validation_errors:
|
|
||||||
logger.info(
|
|
||||||
" [%s] JSON提取: %.2fs 验证: %.2fs ✅",
|
|
||||||
arxiv_id,
|
|
||||||
_t_json_end - _t_json_start,
|
|
||||||
_t_val_end - _t_val_start,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
logger.warning(
|
|
||||||
" [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s",
|
|
||||||
arxiv_id,
|
|
||||||
_t_json_end - _t_json_start,
|
|
||||||
_t_val_end - _t_val_start,
|
|
||||||
"; ".join(validation_errors)[:200],
|
|
||||||
)
|
|
||||||
|
|
||||||
if validation_errors:
|
|
||||||
exc = ValueError(
|
|
||||||
f"Summary validation failed after 4 attempts: {'; '.join(validation_errors)}"
|
|
||||||
)
|
|
||||||
exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
return json_data, raw_output
|
|
||||||
|
|
||||||
|
|
||||||
def _persist_summary(
|
|
||||||
db: Session, paper: Paper, json_data: dict, raw_output: str
|
|
||||||
) -> str:
|
|
||||||
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
|
|
||||||
import time as _time
|
|
||||||
arxiv_id = paper.arxiv_id
|
|
||||||
|
|
||||||
_t0 = _time.monotonic()
|
|
||||||
schema = SummarySchema.model_validate(json_data)
|
|
||||||
quality = assess_quality(schema)
|
|
||||||
_t1 = _time.monotonic()
|
|
||||||
|
|
||||||
_save_files(arxiv_id, schema, raw_output)
|
|
||||||
_t2 = _time.monotonic()
|
|
||||||
|
|
||||||
_update_summary_in_db(db, paper, schema, quality, raw_output)
|
|
||||||
_t3 = _time.monotonic()
|
|
||||||
|
|
||||||
# 状态 → done
|
|
||||||
paper.summary_status.status = SummaryState.DONE
|
|
||||||
paper.summary_status.quality = quality
|
|
||||||
paper.summary_status.completed_at = utc_now()
|
|
||||||
paper.summary_status.raw_output_saved = True
|
|
||||||
db.commit()
|
|
||||||
_t4 = _time.monotonic()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
" [%s] persist: pydantic=%.2fs 文件=%.2fs DB写入=%.2fs 状态commit=%.2fs",
|
|
||||||
arxiv_id,
|
|
||||||
_t1 - _t0,
|
|
||||||
_t2 - _t1,
|
|
||||||
_t3 - _t2,
|
|
||||||
_t4 - _t3,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 触发性增强(失败不影响总结)
|
|
||||||
_t5 = _time.monotonic()
|
|
||||||
_maybe_extract_images(arxiv_id, schema)
|
|
||||||
_t6 = _time.monotonic()
|
|
||||||
_maybe_index_chroma(arxiv_id, paper, schema)
|
|
||||||
_t7 = _time.monotonic()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
" [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs",
|
|
||||||
arxiv_id,
|
|
||||||
_t6 - _t5,
|
|
||||||
_t7 - _t6,
|
|
||||||
)
|
|
||||||
|
|
||||||
return quality
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_summary_failure(
|
|
||||||
db: Session, paper: Paper, exc: Exception, raw_output: str,
|
|
||||||
) -> dict:
|
|
||||||
"""记录失败:保存 raw_output、重试计数、错误分类。"""
|
|
||||||
error_type = _classify_error(exc)
|
|
||||||
logger.error(
|
|
||||||
"Summarize failed: %s error_type=%s %s",
|
|
||||||
paper.arxiv_id, error_type, str(exc)[:200],
|
|
||||||
)
|
|
||||||
|
|
||||||
status = paper.summary_status
|
|
||||||
if raw_output:
|
|
||||||
_save_files(paper.arxiv_id, None, raw_output)
|
|
||||||
status.raw_output_saved = True
|
|
||||||
|
|
||||||
status.retry_count = (status.retry_count or 0) + 1
|
|
||||||
status.error_type = error_type
|
|
||||||
status.error = str(exc)[:2000]
|
|
||||||
|
|
||||||
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
|
|
||||||
status.status = SummaryState.PERMANENT_FAILURE
|
|
||||||
else:
|
|
||||||
status.status = SummaryState.PENDING
|
|
||||||
|
|
||||||
status.completed_at = utc_now()
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"arxiv_id": paper.arxiv_id,
|
|
||||||
"status": "failed",
|
|
||||||
"error_type": error_type,
|
|
||||||
"error": str(exc)[:200],
|
|
||||||
"retry_count": status.retry_count,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_old_images(db: Session, paper: Paper) -> None:
|
|
||||||
"""清理旧的图片文件和 figures_json,避免重新总结时残留。"""
|
|
||||||
arxiv_id = paper.arxiv_id
|
|
||||||
images_dir = paper_dir(arxiv_id) / "images"
|
|
||||||
if images_dir.exists():
|
|
||||||
for old_file in images_dir.iterdir():
|
|
||||||
if old_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg") or old_file.name == "manifest.json":
|
|
||||||
old_file.unlink(missing_ok=True)
|
|
||||||
# 清除数据库中的 figures_json
|
|
||||||
if paper.summary and paper.summary.figures_json:
|
|
||||||
paper.summary.figures_json = None
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
|
|
||||||
"""从 PDF 提取图片和表格(失败不影响总结)。"""
|
|
||||||
try:
|
|
||||||
from app.services.pdf_image_extractor import (
|
|
||||||
extract_images_from_pdf,
|
|
||||||
filter_images_by_summary,
|
|
||||||
)
|
|
||||||
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
|
|
||||||
extract_images_from_pdf(arxiv_id, pdf_path)
|
|
||||||
if schema.figures:
|
|
||||||
filter_images_by_summary(arxiv_id, schema.figures)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None:
|
|
||||||
"""写入 ChromaDB 语义索引(失败不影响总结)。"""
|
|
||||||
try:
|
|
||||||
from app.services.embedder import index_paper
|
|
||||||
|
|
||||||
texts_dict = {
|
|
||||||
"arxiv_id": arxiv_id,
|
|
||||||
"title_zh": schema.title_zh or "",
|
|
||||||
"title_en": paper.title_en or "",
|
|
||||||
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
|
|
||||||
"one_line": schema.one_line or "",
|
|
||||||
"motivation_problem": schema.motivation.problem or "",
|
|
||||||
"method_key_idea": schema.method.key_idea or "",
|
|
||||||
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
|
|
||||||
}
|
|
||||||
index_paper(arxiv_id, texts_dict)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
async def _do_summarize_one(
|
|
||||||
db: Session, paper: Paper, pdf_mode: str = "auto"
|
|
||||||
) -> dict:
|
|
||||||
"""实际的单篇总结执行(在 semaphore 保护下)。"""
|
"""实际的单篇总结执行(在 semaphore 保护下)。"""
|
||||||
arxiv_id = paper.arxiv_id
|
arxiv_id = paper.arxiv_id
|
||||||
title_short = (paper.title_en or "")[:50]
|
title_short = (paper.title_en or "")[:50]
|
||||||
@@ -548,6 +87,7 @@ async def _do_summarize_one(
|
|||||||
|
|
||||||
# 清理旧的图片文件和 figures_json,避免重新总结时残留
|
# 清理旧的图片文件和 figures_json,避免重新总结时残留
|
||||||
import time as _time
|
import time as _time
|
||||||
|
|
||||||
_t_cleanup_start = _time.monotonic()
|
_t_cleanup_start = _time.monotonic()
|
||||||
_cleanup_old_images(db, paper)
|
_cleanup_old_images(db, paper)
|
||||||
_t_cleanup_end = _time.monotonic()
|
_t_cleanup_end = _time.monotonic()
|
||||||
@@ -567,7 +107,9 @@ async def _do_summarize_one(
|
|||||||
|
|
||||||
logger.info(" [%s] 调用 pi 生成总结...", arxiv_id)
|
logger.info(" [%s] 调用 pi 生成总结...", arxiv_id)
|
||||||
json_data, raw_output = await _generate_with_retry(
|
json_data, raw_output = await _generate_with_retry(
|
||||||
arxiv_id, meta_path, TMP_DIR / arxiv_id / "paper.pdf",
|
arxiv_id,
|
||||||
|
meta_path,
|
||||||
|
TMP_DIR / arxiv_id / "paper.pdf",
|
||||||
pdf_mode=pdf_mode,
|
pdf_mode=pdf_mode,
|
||||||
)
|
)
|
||||||
_t3 = _time.monotonic()
|
_t3 = _time.monotonic()
|
||||||
@@ -577,7 +119,9 @@ async def _do_summarize_one(
|
|||||||
_t4 = _time.monotonic()
|
_t4 = _time.monotonic()
|
||||||
logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3)
|
logger.info(" [%s] 持久化: %.2fs", arxiv_id, _t4 - _t3)
|
||||||
|
|
||||||
logger.info("✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0)
|
logger.info(
|
||||||
|
"✅ [%s] 完成: quality=%s 总耗时: %.2fs", arxiv_id, quality, _t4 - _t0
|
||||||
|
)
|
||||||
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
|
return {"arxiv_id": arxiv_id, "status": "done", "quality": quality}
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -586,7 +130,7 @@ async def _do_summarize_one(
|
|||||||
return _handle_summary_failure(db, paper, exc, fail_output)
|
return _handle_summary_failure(db, paper, exc, fail_output)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
cleanup_tmp(arxiv_id)
|
pass # cleanup_tmp(arxiv_id) # 暂时禁用,保留 PDF 用于调试图片提取
|
||||||
|
|
||||||
|
|
||||||
# ── 单篇入口 ────────────────────────────────────────────────────────────
|
# ── 单篇入口 ────────────────────────────────────────────────────────────
|
||||||
@@ -604,25 +148,19 @@ async def summarize_single(
|
|||||||
|
|
||||||
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
|
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
|
||||||
"""
|
"""
|
||||||
paper = db.execute(
|
paper = get_paper_by_arxiv_id(db, arxiv_id)
|
||||||
select(Paper)
|
|
||||||
.where(Paper.arxiv_id == arxiv_id)
|
|
||||||
.options(*PAPER_DEFAULT_LOAD)
|
|
||||||
).unique().scalar_one_or_none()
|
|
||||||
if not paper:
|
if not paper:
|
||||||
return {"status": "not_found", "arxiv_id": arxiv_id}
|
raise NotFoundError(f"Paper not found: {arxiv_id}")
|
||||||
|
|
||||||
make_session = _session_factory or SessionLocal
|
make_session = _session_factory or SessionLocal
|
||||||
|
|
||||||
# 每篇用独立 session 避免并发问题
|
# 每篇用独立 session 避免并发问题
|
||||||
paper_db = make_session()
|
paper_db = make_session()
|
||||||
try:
|
try:
|
||||||
paper_in_new_session = paper_db.execute(
|
paper_in_new_session = get_paper_by_arxiv_id(paper_db, arxiv_id)
|
||||||
select(Paper)
|
result = await summarize_one(
|
||||||
.where(Paper.arxiv_id == arxiv_id)
|
paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode
|
||||||
.options(*PAPER_DEFAULT_LOAD)
|
)
|
||||||
).unique().scalar_one_or_none()
|
|
||||||
result = await summarize_one(paper_db, paper_in_new_session, force=force, pdf_mode=pdf_mode)
|
|
||||||
finally:
|
finally:
|
||||||
paper_db.close()
|
paper_db.close()
|
||||||
|
|
||||||
@@ -656,10 +194,10 @@ async def summarize_batch(
|
|||||||
try:
|
try:
|
||||||
db.add(lock)
|
db.add(lock)
|
||||||
db.commit()
|
db.commit()
|
||||||
except Exception:
|
except IntegrityError:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
logger.warning("Summarize batch already running (lock conflict)")
|
logger.warning("Summarize batch already running (lock conflict)")
|
||||||
return {"status": "conflict", "error": "summarize batch already running"}
|
raise ConflictError("summarize batch already running")
|
||||||
|
|
||||||
# CrawlLog
|
# CrawlLog
|
||||||
log_entry = CrawlLog(
|
log_entry = CrawlLog(
|
||||||
@@ -717,19 +255,18 @@ async def summarize_batch(
|
|||||||
break
|
break
|
||||||
paper_db = make_session()
|
paper_db = make_session()
|
||||||
try:
|
try:
|
||||||
p = paper_db.execute(
|
p = get_paper_by_id(paper_db, paper.id)
|
||||||
select(Paper)
|
|
||||||
.where(Paper.id == paper.id)
|
|
||||||
.options(*PAPER_DEFAULT_LOAD)
|
|
||||||
).unique().scalar_one_or_none()
|
|
||||||
result = await summarize_one(paper_db, p, pdf_mode=pdf_mode)
|
result = await summarize_one(paper_db, p, pdf_mode=pdf_mode)
|
||||||
status = result.get("status", "failed")
|
status = result.get("status", "failed")
|
||||||
progress[status] = progress.get(status, 0) + 1
|
progress[status] = progress.get(status, 0) + 1
|
||||||
finished = sum(progress.values())
|
finished = sum(progress.values())
|
||||||
logger.info(
|
logger.info(
|
||||||
"📊 进度: %d/%d (✅%d ❌%d ⏭️%d) — %s",
|
"📊 进度: %d/%d (✅%d ❌%d ⏭️%d) — %s",
|
||||||
finished, total,
|
finished,
|
||||||
progress["done"], progress["failed"], progress["skipped"],
|
total,
|
||||||
|
progress["done"],
|
||||||
|
progress["failed"],
|
||||||
|
progress["skipped"],
|
||||||
paper.arxiv_id,
|
paper.arxiv_id,
|
||||||
)
|
)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
@@ -785,10 +322,10 @@ async def summarize_batch(
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Summarize batch failed")
|
logger.exception("Summarize batch failed")
|
||||||
log_entry.status = "failed"
|
log_entry.status = "failed"
|
||||||
log_entry.error = str(exc)[:2000]
|
log_entry.error = truncate_error(exc, limit=2000)
|
||||||
log_entry.completed_at = utc_now()
|
log_entry.completed_at = utc_now()
|
||||||
db.commit()
|
db.commit()
|
||||||
return {"status": "failed", "error": str(exc)}
|
return {"status": "failed", "error": truncate_error(exc)}
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
release_lock(db, lock)
|
release_lock(db, lock)
|
||||||
|
|||||||
@@ -0,0 +1,275 @@
|
|||||||
|
"""AI 总结生成器 — AI 后端调用、重试循环、JSON 验证、错误分类。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.services.pdf_downloader import (
|
||||||
|
PdfDownloadError,
|
||||||
|
paper_dir,
|
||||||
|
)
|
||||||
|
from app.services.summary_utils import (
|
||||||
|
JsonNotFoundError,
|
||||||
|
build_prompt,
|
||||||
|
extract_json,
|
||||||
|
extract_pdf_text,
|
||||||
|
)
|
||||||
|
from app.services.pi_client import (
|
||||||
|
PiProcessError,
|
||||||
|
PiTimeoutError,
|
||||||
|
call_pi,
|
||||||
|
)
|
||||||
|
from app.services import claude_backend
|
||||||
|
from app.services.schemas import classify_validation_error
|
||||||
|
from app.utils import truncate_error
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 错误分类 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_error(exc: Exception) -> str:
|
||||||
|
"""将异常映射到 error_type 枚举值。"""
|
||||||
|
if isinstance(exc, PdfDownloadError):
|
||||||
|
return "pdf_download_failed"
|
||||||
|
if isinstance(exc, PiTimeoutError):
|
||||||
|
return "timeout"
|
||||||
|
if isinstance(exc, PiProcessError):
|
||||||
|
return "process_error"
|
||||||
|
if isinstance(exc, JsonNotFoundError):
|
||||||
|
return "json_not_found"
|
||||||
|
if isinstance(exc, json.JSONDecodeError):
|
||||||
|
return "json_invalid"
|
||||||
|
if isinstance(exc, ValidationError):
|
||||||
|
return classify_validation_error(exc)
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
# ── JSON 验证 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_summary(json_data: dict, arxiv_id: str) -> list[str]:
|
||||||
|
"""验证 JSON 数据是否符合要求,返回错误列表(空=通过)。"""
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
if not isinstance(json_data, dict):
|
||||||
|
return ["顶层必须是 JSON 对象"]
|
||||||
|
|
||||||
|
# 必填字段
|
||||||
|
for f in ["arxiv_id", "title_zh", "one_line", "tags"]:
|
||||||
|
if f not in json_data or not json_data[f]:
|
||||||
|
errors.append(f"缺少必填字段: {f}")
|
||||||
|
|
||||||
|
# tags 必须是非空数组
|
||||||
|
tags = json_data.get("tags")
|
||||||
|
if not isinstance(tags, list) or len(tags) == 0:
|
||||||
|
errors.append("tags 必须是非空数组")
|
||||||
|
|
||||||
|
# 字符串段落字段(必须是 str 且 ≥50 字)
|
||||||
|
string_fields = [
|
||||||
|
("motivation", "problem"),
|
||||||
|
("motivation", "goal"),
|
||||||
|
("motivation", "gap"),
|
||||||
|
("method", "overview"),
|
||||||
|
("method", "key_idea"),
|
||||||
|
("method", "steps"),
|
||||||
|
("method", "novelty"),
|
||||||
|
("results", "main_findings"),
|
||||||
|
("results", "limitations"),
|
||||||
|
("improvements", "weaknesses"),
|
||||||
|
("improvements", "future_work"),
|
||||||
|
("improvements", "reproducibility"),
|
||||||
|
]
|
||||||
|
for section, field in string_fields:
|
||||||
|
val = json_data.get(section, {}).get(field)
|
||||||
|
if isinstance(val, list):
|
||||||
|
errors.append(f"{section}.{field} 应该是字符串段落,不能是数组")
|
||||||
|
elif not isinstance(val, str) or len(val.strip()) < 50:
|
||||||
|
errors.append(
|
||||||
|
f"{section}.{field} 必须是详细段落(≥50字),"
|
||||||
|
f"当前: {type(val).__name__} ({len(str(val))}字)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# benchmarks 必须是数组
|
||||||
|
benchmarks = json_data.get("results", {}).get("benchmarks")
|
||||||
|
if benchmarks is not None and not isinstance(benchmarks, list):
|
||||||
|
errors.append("results.benchmarks 必须是数组")
|
||||||
|
|
||||||
|
# prerequisites.concepts 必须是对象数组,每个有 term
|
||||||
|
concepts = json_data.get("prerequisites", {}).get("concepts")
|
||||||
|
if concepts is not None:
|
||||||
|
if not isinstance(concepts, list):
|
||||||
|
errors.append("prerequisites.concepts 必须是数组")
|
||||||
|
elif len(concepts) == 0:
|
||||||
|
errors.append("prerequisites.concepts 不能为空")
|
||||||
|
else:
|
||||||
|
for i, c in enumerate(concepts):
|
||||||
|
if isinstance(c, str):
|
||||||
|
errors.append(
|
||||||
|
f"prerequisites.concepts[{i}] 应该是对象 {{term,explanation,why_matters}},不能是字符串"
|
||||||
|
)
|
||||||
|
elif isinstance(c, dict) and not c.get("term"):
|
||||||
|
errors.append(f"prerequisites.concepts[{i}] 缺少 term 字段")
|
||||||
|
|
||||||
|
# figures 必须是数组,每个元素应有 id
|
||||||
|
figures = json_data.get("figures")
|
||||||
|
if figures is not None:
|
||||||
|
if not isinstance(figures, list):
|
||||||
|
errors.append("figures 必须是数组")
|
||||||
|
else:
|
||||||
|
for i, fig in enumerate(figures):
|
||||||
|
if isinstance(fig, dict) and not fig.get("id"):
|
||||||
|
errors.append(f"figures[{i}] 缺少 id 字段")
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
# ── AI 调用 + 重试 ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_with_retry(
|
||||||
|
arxiv_id: str, meta_path: Path, pdf_path: Path, pdf_mode: str = "auto"
|
||||||
|
) -> tuple[dict, str]:
|
||||||
|
"""调用 AI 后端生成总结,最多 4 轮验证循环。
|
||||||
|
|
||||||
|
根据 settings.SUMMARY_BACKEND 选择 pi 或 claude 后端。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(json_data, raw_output)
|
||||||
|
Raises:
|
||||||
|
ValueError: 4 轮验证仍未通过
|
||||||
|
"""
|
||||||
|
import time as _time
|
||||||
|
|
||||||
|
backend = settings.SUMMARY_BACKEND
|
||||||
|
validation_errors: list[str] = []
|
||||||
|
json_data: dict | None = None
|
||||||
|
raw_output = ""
|
||||||
|
session_id = None
|
||||||
|
|
||||||
|
summary_file = paper_dir(arxiv_id) / "summary.json"
|
||||||
|
|
||||||
|
# claude 后端需要预构建 prompt(pi 后端在 call_pi 内部构建)
|
||||||
|
claude_prompt: str | None = None
|
||||||
|
if backend == "claude":
|
||||||
|
_t0 = _time.monotonic()
|
||||||
|
txt_path = extract_pdf_text(pdf_path, max_chars=None)
|
||||||
|
body = txt_path.read_text(encoding="utf-8")
|
||||||
|
if len(body) > 80_000:
|
||||||
|
trimmed = body[:80_000].rstrip()
|
||||||
|
txt_path.write_text(trimmed, encoding="utf-8")
|
||||||
|
claude_prompt = build_prompt(arxiv_id, meta_path, txt_path, "inject", None)
|
||||||
|
logger.info(" [%s] 构建prompt: %.2fs", arxiv_id, _time.monotonic() - _t0)
|
||||||
|
|
||||||
|
for attempt in range(1, settings.SUMMARY_MAX_RETRIES + 1):
|
||||||
|
# 清理上一轮写入的不完整文件
|
||||||
|
if summary_file.exists():
|
||||||
|
summary_file.unlink()
|
||||||
|
|
||||||
|
# 记录 AI 调用开始时间
|
||||||
|
_t_call_start = _time.monotonic()
|
||||||
|
|
||||||
|
if backend == "claude":
|
||||||
|
if attempt == 1:
|
||||||
|
raw_output, session_id = await claude_backend.call_claude(
|
||||||
|
claude_prompt,
|
||||||
|
session_id=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
retry_prompt = build_prompt(
|
||||||
|
arxiv_id,
|
||||||
|
meta_path,
|
||||||
|
extract_pdf_text(pdf_path, max_chars=80000),
|
||||||
|
"inject",
|
||||||
|
fix_errors=validation_errors,
|
||||||
|
)
|
||||||
|
raw_output, session_id = await claude_backend.call_claude(
|
||||||
|
retry_prompt,
|
||||||
|
session_id=session_id,
|
||||||
|
fix_errors=validation_errors,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if attempt == 1:
|
||||||
|
raw_output, session_id = await call_pi(
|
||||||
|
meta_path, pdf_path, pdf_mode=pdf_mode
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw_output, session_id = await call_pi(
|
||||||
|
meta_path,
|
||||||
|
pdf_path,
|
||||||
|
fix_errors=validation_errors,
|
||||||
|
session_id=session_id,
|
||||||
|
pdf_mode=pdf_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
_t_call_end = _time.monotonic()
|
||||||
|
|
||||||
|
# 检查 summary.json 是否由 AI 子进程写入
|
||||||
|
file_written_by_ai = summary_file.exists()
|
||||||
|
file_mtime = summary_file.stat().st_mtime if file_written_by_ai else None
|
||||||
|
file_size = summary_file.stat().st_size if file_written_by_ai else 0
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
" [%s] attempt %d AI调用: %.2fs summary.json=%s%s",
|
||||||
|
arxiv_id,
|
||||||
|
attempt,
|
||||||
|
_t_call_end - _t_call_start,
|
||||||
|
f"已写入({file_size}B)" if file_written_by_ai else "未写入",
|
||||||
|
f" mtime={file_mtime:.2f}" if file_mtime else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取 JSON
|
||||||
|
_t_json_start = _time.monotonic()
|
||||||
|
try:
|
||||||
|
if file_written_by_ai:
|
||||||
|
json_data = json.loads(summary_file.read_text(encoding="utf-8"))
|
||||||
|
logger.info(" [%s] 从AI写入的summary.json读取", arxiv_id)
|
||||||
|
else:
|
||||||
|
json_data = extract_json(raw_output)
|
||||||
|
except (json.JSONDecodeError, JsonNotFoundError) as exc:
|
||||||
|
_t_json_end = _time.monotonic()
|
||||||
|
logger.warning(
|
||||||
|
" [%s] JSON提取失败: %.2fs %s",
|
||||||
|
arxiv_id,
|
||||||
|
_t_json_end - _t_json_start,
|
||||||
|
str(exc)[:200],
|
||||||
|
)
|
||||||
|
validation_errors = [f"无法提取有效 JSON: {truncate_error(exc)}"]
|
||||||
|
continue
|
||||||
|
_t_json_end = _time.monotonic()
|
||||||
|
|
||||||
|
# 验证
|
||||||
|
_t_val_start = _time.monotonic()
|
||||||
|
validation_errors = _validate_summary(json_data, arxiv_id)
|
||||||
|
_t_val_end = _time.monotonic()
|
||||||
|
|
||||||
|
if not validation_errors:
|
||||||
|
logger.info(
|
||||||
|
" [%s] JSON提取: %.2fs 验证: %.2fs ✅",
|
||||||
|
arxiv_id,
|
||||||
|
_t_json_end - _t_json_start,
|
||||||
|
_t_val_end - _t_val_start,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
logger.warning(
|
||||||
|
" [%s] JSON提取: %.2fs 验证: %.2fs ❌ %s",
|
||||||
|
arxiv_id,
|
||||||
|
_t_json_end - _t_json_start,
|
||||||
|
_t_val_end - _t_val_start,
|
||||||
|
"; ".join(validation_errors)[:200],
|
||||||
|
)
|
||||||
|
|
||||||
|
if validation_errors:
|
||||||
|
exc = ValueError(
|
||||||
|
f"Summary validation failed after {settings.SUMMARY_MAX_RETRIES} attempts: {'; '.join(validation_errors)}"
|
||||||
|
)
|
||||||
|
exc.raw_output = raw_output # 供上层 _handle_summary_failure 使用
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
return json_data, raw_output
|
||||||
@@ -0,0 +1,273 @@
|
|||||||
|
"""AI 总结持久化 — DB 写入、文件保存、FTS 索引、图片提取、ChromaDB 索引。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.models import (
|
||||||
|
Paper,
|
||||||
|
PaperSummary,
|
||||||
|
PaperTag,
|
||||||
|
SummaryState,
|
||||||
|
)
|
||||||
|
from app.services.pdf_downloader import paper_dir
|
||||||
|
from app.services.schemas import (
|
||||||
|
SummarySchema,
|
||||||
|
assess_quality,
|
||||||
|
flatten_for_db,
|
||||||
|
)
|
||||||
|
from app.services.summary_generator import _classify_error
|
||||||
|
from app.utils import TMP_DIR, truncate_error, utc_now
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── FTS5 文本构建 ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _build_fts_summary_text(schema: SummarySchema) -> str:
|
||||||
|
"""拼接用于 FTS5 索引的总结文本。"""
|
||||||
|
parts = [
|
||||||
|
schema.one_line or "",
|
||||||
|
schema.motivation.problem or "",
|
||||||
|
schema.motivation.goal or "",
|
||||||
|
schema.method.overview or "",
|
||||||
|
schema.method.key_idea or "",
|
||||||
|
schema.results.main_findings or "",
|
||||||
|
]
|
||||||
|
return " ".join(p for p in parts if p)
|
||||||
|
|
||||||
|
|
||||||
|
# ── DB 更新 ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _update_summary_in_db(
|
||||||
|
db: Session,
|
||||||
|
paper: Paper,
|
||||||
|
schema: SummarySchema,
|
||||||
|
quality: str,
|
||||||
|
raw_output: str,
|
||||||
|
) -> None:
|
||||||
|
"""将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。"""
|
||||||
|
# 1. paper_summaries:upsert
|
||||||
|
existing = db.get(PaperSummary, paper.id)
|
||||||
|
flat = flatten_for_db(schema)
|
||||||
|
if existing:
|
||||||
|
for k, v in flat.items():
|
||||||
|
setattr(existing, k, v)
|
||||||
|
else:
|
||||||
|
db.add(PaperSummary(paper_id=paper.id, **flat))
|
||||||
|
|
||||||
|
# 2. papers 表
|
||||||
|
paper.title_zh = schema.title_zh
|
||||||
|
paper.summary_quality = quality
|
||||||
|
p_dir = paper_dir(paper.arxiv_id)
|
||||||
|
paper.summary_path = str(p_dir / "summary.json")
|
||||||
|
paper.raw_output_path = str(p_dir / "raw_output.txt")
|
||||||
|
|
||||||
|
# 3. AI 标签
|
||||||
|
existing_tag_names = {t.tag for t in paper.tags}
|
||||||
|
for tag_name in schema.tags:
|
||||||
|
if tag_name not in existing_tag_names:
|
||||||
|
db.add(PaperTag(paper_id=paper.id, tag=tag_name, source="ai"))
|
||||||
|
existing_tag_names.add(tag_name)
|
||||||
|
|
||||||
|
# 4. FTS5 更新
|
||||||
|
summary_text = _build_fts_summary_text(schema)
|
||||||
|
db.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE papers_fts SET title_zh=:title_zh, summary_text=:summary_text "
|
||||||
|
"WHERE rowid=:paper_id"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"title_zh": schema.title_zh,
|
||||||
|
"summary_text": summary_text,
|
||||||
|
"paper_id": paper.id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
logger.info("DB updated: paper=%s quality=%s", paper.arxiv_id, quality)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 文件操作 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _save_files(arxiv_id: str, schema: SummarySchema | None, raw_output: str) -> None:
|
||||||
|
d = paper_dir(arxiv_id)
|
||||||
|
d.mkdir(parents=True, exist_ok=True)
|
||||||
|
if schema:
|
||||||
|
(d / "summary.json").write_text(
|
||||||
|
schema.model_dump_json(ensure_ascii=False, indent=2),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
# ── 失败处理 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_summary_failure(
|
||||||
|
db: Session,
|
||||||
|
paper: Paper,
|
||||||
|
exc: Exception,
|
||||||
|
raw_output: str,
|
||||||
|
) -> dict:
|
||||||
|
"""记录失败:保存 raw_output、重试计数、错误分类。"""
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
error_type = _classify_error(exc)
|
||||||
|
logger.error(
|
||||||
|
"Summarize failed: %s error_type=%s %s",
|
||||||
|
paper.arxiv_id,
|
||||||
|
error_type,
|
||||||
|
truncate_error(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
status = paper.summary_status
|
||||||
|
if raw_output:
|
||||||
|
_save_files(paper.arxiv_id, None, raw_output)
|
||||||
|
status.raw_output_saved = True
|
||||||
|
|
||||||
|
status.retry_count = (status.retry_count or 0) + 1
|
||||||
|
status.error_type = error_type
|
||||||
|
status.error = truncate_error(exc, limit=2000)
|
||||||
|
|
||||||
|
if status.retry_count >= settings.SUMMARY_MAX_RETRIES + 1:
|
||||||
|
status.status = SummaryState.PERMANENT_FAILURE
|
||||||
|
else:
|
||||||
|
status.status = SummaryState.PENDING
|
||||||
|
|
||||||
|
status.completed_at = utc_now()
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"arxiv_id": paper.arxiv_id,
|
||||||
|
"status": "failed",
|
||||||
|
"error_type": error_type,
|
||||||
|
"error": truncate_error(exc),
|
||||||
|
"retry_count": status.retry_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 持久化 ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _persist_summary(
|
||||||
|
db: Session, paper: Paper, json_data: dict, raw_output: str
|
||||||
|
) -> str:
|
||||||
|
"""Pydantic 校验 → 质量评估 → 保存文件 → 更新 DB → 返回 quality。"""
|
||||||
|
import time as _time
|
||||||
|
|
||||||
|
arxiv_id = paper.arxiv_id
|
||||||
|
|
||||||
|
_t0 = _time.monotonic()
|
||||||
|
schema = SummarySchema.model_validate(json_data)
|
||||||
|
quality = assess_quality(schema)
|
||||||
|
_t1 = _time.monotonic()
|
||||||
|
|
||||||
|
_save_files(arxiv_id, schema, raw_output)
|
||||||
|
_t2 = _time.monotonic()
|
||||||
|
|
||||||
|
_update_summary_in_db(db, paper, schema, quality, raw_output)
|
||||||
|
_t3 = _time.monotonic()
|
||||||
|
|
||||||
|
# 状态 → done
|
||||||
|
paper.summary_status.status = SummaryState.DONE
|
||||||
|
paper.summary_status.quality = quality
|
||||||
|
paper.summary_status.completed_at = utc_now()
|
||||||
|
paper.summary_status.raw_output_saved = True
|
||||||
|
db.commit()
|
||||||
|
_t4 = _time.monotonic()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
" [%s] persist: pydantic=%.2fs 文件=%.2fs DB写入=%.2fs 状态commit=%.2fs",
|
||||||
|
arxiv_id,
|
||||||
|
_t1 - _t0,
|
||||||
|
_t2 - _t1,
|
||||||
|
_t3 - _t2,
|
||||||
|
_t4 - _t3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 触发性增强(失败不影响总结)
|
||||||
|
_t5 = _time.monotonic()
|
||||||
|
_maybe_extract_images(arxiv_id, schema)
|
||||||
|
_t6 = _time.monotonic()
|
||||||
|
_maybe_index_chroma(arxiv_id, paper, schema)
|
||||||
|
_t7 = _time.monotonic()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
" [%s] 后处理: 图片提取=%.2fs ChromaDB=%.2fs",
|
||||||
|
arxiv_id,
|
||||||
|
_t6 - _t5,
|
||||||
|
_t7 - _t6,
|
||||||
|
)
|
||||||
|
|
||||||
|
return quality
|
||||||
|
|
||||||
|
|
||||||
|
# ── 清理 ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_old_images(db: Session, paper: Paper) -> None:
|
||||||
|
"""清理旧的图片文件和 figures_json,避免重新总结时残留。"""
|
||||||
|
arxiv_id = paper.arxiv_id
|
||||||
|
images_dir = paper_dir(arxiv_id) / "images"
|
||||||
|
if images_dir.exists():
|
||||||
|
for old_file in images_dir.iterdir():
|
||||||
|
if (
|
||||||
|
old_file.suffix.lower() in (".png", ".jpg", ".jpeg", ".gif", ".svg")
|
||||||
|
or old_file.name == "manifest.json"
|
||||||
|
):
|
||||||
|
old_file.unlink(missing_ok=True)
|
||||||
|
# 清除数据库中的 figures_json
|
||||||
|
if paper.summary and paper.summary.figures_json:
|
||||||
|
paper.summary.figures_json = None
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# ── 触发性增强 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_extract_images(arxiv_id: str, schema: SummarySchema) -> None:
|
||||||
|
"""从 PDF 提取图片和表格(失败不影响总结)。
|
||||||
|
|
||||||
|
两阶段流水线:
|
||||||
|
1. PicoDet 检测 + 渲染截图(通用标签)
|
||||||
|
2. 用 summary 的 figures ID 在 PDF 中搜索定位 → 重命名
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.services.pdf_image_extractor import (
|
||||||
|
extract_images_from_pdf,
|
||||||
|
label_images_by_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
|
||||||
|
extract_images_from_pdf(arxiv_id, pdf_path)
|
||||||
|
if schema.figures:
|
||||||
|
label_images_by_summary(arxiv_id, schema.figures, pdf_path)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_index_chroma(arxiv_id: str, paper: Paper, schema: SummarySchema) -> None:
|
||||||
|
"""写入 ChromaDB 语义索引(失败不影响总结)。"""
|
||||||
|
try:
|
||||||
|
from app.services.embedder import index_paper
|
||||||
|
|
||||||
|
texts_dict = {
|
||||||
|
"arxiv_id": arxiv_id,
|
||||||
|
"title_zh": schema.title_zh or "",
|
||||||
|
"title_en": paper.title_en or "",
|
||||||
|
"tags": " ".join(t.tag for t in paper.tags) if paper.tags else "",
|
||||||
|
"one_line": schema.one_line or "",
|
||||||
|
"motivation_problem": schema.motivation.problem or "",
|
||||||
|
"method_key_idea": schema.method.key_idea or "",
|
||||||
|
"paper_date": paper.paper_date.isoformat() if paper.paper_date else "",
|
||||||
|
}
|
||||||
|
index_paper(arxiv_id, texts_dict)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to index paper %s in ChromaDB", arxiv_id, exc_info=True)
|
||||||
@@ -80,11 +80,16 @@ def _trim_body(text: str, max_chars: int | None = None) -> str:
|
|||||||
ack_match = re.search(r"(?m)^(?:Acknowledgments?\s*|致谢\s*)$", text)
|
ack_match = re.search(r"(?m)^(?:Acknowledgments?\s*|致谢\s*)$", text)
|
||||||
if ack_match:
|
if ack_match:
|
||||||
# 只删 Acknowledgments 本身,不删后面的内容
|
# 只删 Acknowledgments 本身,不删后面的内容
|
||||||
next_section = re.search(r"(?m)^(?:A\s|Appendix|Supplementary|附录)\s*$", text[ack_match.start():])
|
next_section = re.search(
|
||||||
|
r"(?m)^(?:A\s|Appendix|Supplementary|附录)\s*$", text[ack_match.start() :]
|
||||||
|
)
|
||||||
if next_section:
|
if next_section:
|
||||||
text = text[:ack_match.start()] + text[ack_match.start() + next_section.start():]
|
text = (
|
||||||
|
text[: ack_match.start()]
|
||||||
|
+ text[ack_match.start() + next_section.start() :]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
text = text[:ack_match.start()].rstrip()
|
text = text[: ack_match.start()].rstrip()
|
||||||
|
|
||||||
# 最后:如果指定了上限且超长,从末尾截断(附录在后面,正文在前面,优先保留正文)
|
# 最后:如果指定了上限且超长,从末尾截断(附录在后面,正文在前面,优先保留正文)
|
||||||
if max_chars is not None and len(text) > max_chars:
|
if max_chars is not None and len(text) > max_chars:
|
||||||
@@ -105,10 +110,9 @@ def extract_pdf_text(pdf_path: Path, max_chars: int | None = None) -> Path:
|
|||||||
# 缓存优先;如果需重新提取(不同 max_chars),先删旧文件
|
# 缓存优先;如果需重新提取(不同 max_chars),先删旧文件
|
||||||
return txt_path
|
return txt_path
|
||||||
|
|
||||||
doc = pymupdf.open(str(pdf_path))
|
with pymupdf.open(str(pdf_path)) as doc:
|
||||||
# sort=True 启用阅读顺序检测,避免双栏论文中跨栏错位
|
# sort=True 启用阅读顺序检测,避免双栏论文中跨栏错位
|
||||||
raw_text = "\n\n".join(page.get_text(sort=True) for page in doc)
|
raw_text = "\n\n".join(page.get_text(sort=True) for page in doc)
|
||||||
doc.close()
|
|
||||||
|
|
||||||
body = _trim_body(raw_text, max_chars=max_chars)
|
body = _trim_body(raw_text, max_chars=max_chars)
|
||||||
txt_path.write_text(body, encoding="utf-8")
|
txt_path.write_text(body, encoding="utf-8")
|
||||||
@@ -160,7 +164,8 @@ def build_prompt(
|
|||||||
'"reproducibility": "详细段落:复现评估(开源情况、数据、算力、难度")}, '
|
'"reproducibility": "详细段落:复现评估(开源情况、数据、算力、难度")}, '
|
||||||
'"figures": [{"id":"Figure 1","caption":"原图标题","description":"文字描述图展示了什么","reason":"为什么这张图对理解论文重要","section":"method"},'
|
'"figures": [{"id":"Figure 1","caption":"原图标题","description":"文字描述图展示了什么","reason":"为什么这张图对理解论文重要","section":"method"},'
|
||||||
'{"id":"Table 1","caption":"表格标题","description":"文字描述表格包含的数据和结论","reason":"为什么这个表格对理解论文重要","section":"results"}]'
|
'{"id":"Table 1","caption":"表格标题","description":"文字描述表格包含的数据和结论","reason":"为什么这个表格对理解论文重要","section":"results"}]'
|
||||||
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Table,id 严格使用 \"Figure N\" 或 \"Table N\" 格式。"
|
"\n注意:figures 必须包含论文中的所有重要图表,包括 Figure 和 Table。"
|
||||||
|
'id 必须严格复用论文原文的写法(原文用 "Fig. 1" 就写 "Fig. 1",用 "Figure A1" 就写 "Figure A1",用 "Table 1" 就写 "Table 1")。'
|
||||||
"section 必须是 motivation/method/results/limitations 之一,表示该图最适合展示在哪个章节。"
|
"section 必须是 motivation/method/results/limitations 之一,表示该图最适合展示在哪个章节。"
|
||||||
"}"
|
"}"
|
||||||
)
|
)
|
||||||
|
|||||||
+28
-11
@@ -5,7 +5,15 @@ from __future__ import annotations
|
|||||||
from sqlalchemy import or_, select
|
from sqlalchemy import or_, select
|
||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
from app.models import PAPER_FULL_LOAD, Paper, PaperTag, UserBookmark, UserNote, UserReadingStatus
|
from app.exceptions import NotFoundError, ValidationError
|
||||||
|
from app.models import (
|
||||||
|
PAPER_FULL_LOAD,
|
||||||
|
Paper,
|
||||||
|
PaperTag,
|
||||||
|
UserBookmark,
|
||||||
|
UserNote,
|
||||||
|
UserReadingStatus,
|
||||||
|
)
|
||||||
from app.utils import utc_now
|
from app.utils import utc_now
|
||||||
|
|
||||||
# ── 收藏 ──────────────────────────────────────────────────────────────
|
# ── 收藏 ──────────────────────────────────────────────────────────────
|
||||||
@@ -13,9 +21,11 @@ from app.utils import utc_now
|
|||||||
|
|
||||||
def toggle_bookmark(db: Session, arxiv_id: str) -> dict:
|
def toggle_bookmark(db: Session, arxiv_id: str) -> dict:
|
||||||
"""切换收藏状态。返回 {"bookmarked": bool, "arxiv_id": str}。"""
|
"""切换收藏状态。返回 {"bookmarked": bool, "arxiv_id": str}。"""
|
||||||
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
|
paper = db.execute(
|
||||||
|
select(Paper).where(Paper.arxiv_id == arxiv_id)
|
||||||
|
).scalar_one_or_none()
|
||||||
if not paper:
|
if not paper:
|
||||||
return {"error": "not_found"}
|
raise NotFoundError(f"Paper not found: {arxiv_id}")
|
||||||
|
|
||||||
existing = db.execute(
|
existing = db.execute(
|
||||||
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
|
select(UserBookmark).where(UserBookmark.paper_id == paper.id)
|
||||||
@@ -42,11 +52,15 @@ VALID_STATUSES = {"unread", "skimmed", "read_summary", "read_full"}
|
|||||||
def set_reading_status(db: Session, arxiv_id: str, status: str) -> dict:
|
def set_reading_status(db: Session, arxiv_id: str, status: str) -> dict:
|
||||||
"""设置阅读状态。status 必须是 unread/skimmed/read_summary/read_full。"""
|
"""设置阅读状态。status 必须是 unread/skimmed/read_summary/read_full。"""
|
||||||
if status not in VALID_STATUSES:
|
if status not in VALID_STATUSES:
|
||||||
return {"error": "invalid_status", "valid": sorted(VALID_STATUSES)}
|
raise ValidationError(
|
||||||
|
f"Invalid reading status: {status}. Valid: {', '.join(sorted(VALID_STATUSES))}"
|
||||||
|
)
|
||||||
|
|
||||||
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
|
paper = db.execute(
|
||||||
|
select(Paper).where(Paper.arxiv_id == arxiv_id)
|
||||||
|
).scalar_one_or_none()
|
||||||
if not paper:
|
if not paper:
|
||||||
return {"error": "not_found"}
|
raise NotFoundError(f"Paper not found: {arxiv_id}")
|
||||||
|
|
||||||
now = utc_now()
|
now = utc_now()
|
||||||
existing = db.execute(
|
existing = db.execute(
|
||||||
@@ -72,7 +86,9 @@ def set_reading_status(db: Session, arxiv_id: str, status: str) -> dict:
|
|||||||
|
|
||||||
def get_note(db: Session, arxiv_id: str) -> dict | None:
|
def get_note(db: Session, arxiv_id: str) -> dict | None:
|
||||||
"""获取笔记。返回 {"arxiv_id", "content", "updated_at"} 或 None(论文不存在时)。"""
|
"""获取笔记。返回 {"arxiv_id", "content", "updated_at"} 或 None(论文不存在时)。"""
|
||||||
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
|
paper = db.execute(
|
||||||
|
select(Paper).where(Paper.arxiv_id == arxiv_id)
|
||||||
|
).scalar_one_or_none()
|
||||||
if not paper:
|
if not paper:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -91,9 +107,11 @@ def get_note(db: Session, arxiv_id: str) -> dict | None:
|
|||||||
|
|
||||||
def save_note(db: Session, arxiv_id: str, content: str) -> dict:
|
def save_note(db: Session, arxiv_id: str, content: str) -> dict:
|
||||||
"""创建或更新笔记。返回 {"arxiv_id", "content", "updated_at"}。"""
|
"""创建或更新笔记。返回 {"arxiv_id", "content", "updated_at"}。"""
|
||||||
paper = db.execute(select(Paper).where(Paper.arxiv_id == arxiv_id)).scalar_one_or_none()
|
paper = db.execute(
|
||||||
|
select(Paper).where(Paper.arxiv_id == arxiv_id)
|
||||||
|
).scalar_one_or_none()
|
||||||
if not paper:
|
if not paper:
|
||||||
return {"error": "not_found"}
|
raise NotFoundError(f"Paper not found: {arxiv_id}")
|
||||||
|
|
||||||
now = utc_now()
|
now = utc_now()
|
||||||
existing = db.execute(
|
existing = db.execute(
|
||||||
@@ -154,8 +172,7 @@ def query_reading_list(
|
|||||||
stmt.options(
|
stmt.options(
|
||||||
joinedload(Paper.note),
|
joinedload(Paper.note),
|
||||||
*PAPER_FULL_LOAD,
|
*PAPER_FULL_LOAD,
|
||||||
)
|
).order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
|
||||||
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
|
|
||||||
)
|
)
|
||||||
.unique()
|
.unique()
|
||||||
.scalars()
|
.scalars()
|
||||||
|
|||||||
+42
-6
@@ -137,12 +137,35 @@ def safe_json_loads(text: str | None, default: Any = None) -> Any:
|
|||||||
|
|
||||||
# AI 生成内容中允许的 HTML 标签和属性
|
# AI 生成内容中允许的 HTML 标签和属性
|
||||||
_ALLOWED_TAGS = {
|
_ALLOWED_TAGS = {
|
||||||
"p", "br", "strong", "b", "em", "i", "u", "s", "del",
|
"p",
|
||||||
"h3", "h4", "h5", "h6",
|
"br",
|
||||||
"ul", "ol", "li",
|
"strong",
|
||||||
"a", "code", "pre", "blockquote",
|
"b",
|
||||||
"table", "thead", "tbody", "tr", "th", "td",
|
"em",
|
||||||
"sup", "sub", "span",
|
"i",
|
||||||
|
"u",
|
||||||
|
"s",
|
||||||
|
"del",
|
||||||
|
"h3",
|
||||||
|
"h4",
|
||||||
|
"h5",
|
||||||
|
"h6",
|
||||||
|
"ul",
|
||||||
|
"ol",
|
||||||
|
"li",
|
||||||
|
"a",
|
||||||
|
"code",
|
||||||
|
"pre",
|
||||||
|
"blockquote",
|
||||||
|
"table",
|
||||||
|
"thead",
|
||||||
|
"tbody",
|
||||||
|
"tr",
|
||||||
|
"th",
|
||||||
|
"td",
|
||||||
|
"sup",
|
||||||
|
"sub",
|
||||||
|
"span",
|
||||||
}
|
}
|
||||||
_ALLOWED_ATTRS = {
|
_ALLOWED_ATTRS = {
|
||||||
"a": {"href", "title"},
|
"a": {"href", "title"},
|
||||||
@@ -167,3 +190,16 @@ def sanitize_html(text: str | None) -> str:
|
|||||||
strip=True,
|
strip=True,
|
||||||
)
|
)
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
# ── 错误消息截断 ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_ERROR_TRUNCATE_LIMIT = 500
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_error(exc: Exception | str, limit: int = _ERROR_TRUNCATE_LIMIT) -> str:
|
||||||
|
"""将异常或字符串截断到指定长度,保持统一的错误消息格式。"""
|
||||||
|
text = str(exc)
|
||||||
|
if len(text) <= limit:
|
||||||
|
return text
|
||||||
|
return text[:limit] + f"... ({len(text)} chars total)"
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"arxiv_id": "2602.21760",
|
||||||
|
"title_zh": "基于条件引导调度的混合数据-流水线并行加速扩散模型",
|
||||||
|
"one_line": "提出混合并行框架,通过条件划分与自适应流水线切换加速扩散推理,实现2.31倍提速。",
|
||||||
|
"tags": ["Diffusion Models", "Distributed Inference", "Parallel Computing", "Image Generation"],
|
||||||
|
"difficulty": "进阶",
|
||||||
|
"prerequisites": {
|
||||||
|
"concepts": [
|
||||||
|
{
|
||||||
|
"term": "Diffusion Models",
|
||||||
|
"explanation": "扩散模型是一类基于去噪过程的生成模型。在正向过程中,它逐渐向数据添加高斯噪声直到变成纯噪声;在反向过程中,模型学习逐步去噪以恢复原始数据。这种迭代特性虽然能生成高质量的样本,但也导致了高昂的推理计算成本。",
|
||||||
|
"why_matters": "理解扩散模型的迭代去噪机制是理解本文如何通过并行化减少推理延迟的基础。"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"term": "Classifier-Free Guidance (CFG)",
|
||||||
|
"explanation": "无分类器引导是一种在推理时提升生成样本与文本条件一致性的技术。模型同时预测有条件噪声(给定文本提示)和无条件噪声(不给定提示),最终通过加权组合两者来获得最终预测。公式为 $\epsilon_{cfg} = \epsilon_\theta(x_t, c, t) + w (\epsilon_\theta(x_t, c, t) - \epsilon_\theta(x_t, t))$,其中 $w$ 是引导强度。",
|
||||||
|
"why_matters": "本文的核心创新点在于利用CFG中存在的有条件和无条件双路径作为数据划分的基础。"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"term": "Distributed Inference",
|
||||||
|
"explanation": "分布式推理利用多个GPU并行处理计算任务以减少延迟。主要分为数据并行(如将图像切片处理)和流水线并行(如将模型层切分)。然而,现有的分布式方法在扩散模型中往往面临通信开销大或生成图像出现拼接伪影的问题。",
|
||||||
|
"why_matters": "本文提出的混合并行框架正是为了解决现有分布式推理方法中的这些痛点。"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"motivation": {
|
||||||
|
"problem": "现有的扩散模型加速方法,无论是单卡优化(如减少采样步数、模型剪枝)还是多卡分布式并行(如DistriFusion和AsyncDiff),都存在明显的局限性。单卡优化受限于硬件算力上限,而现有多卡并行方法通常只能实现次线性的加速比。例如,DistriFusion将图像切片并行处理,容易在拼接处产生明显的伪影;AsyncDiff采用异步流水线,虽然加速了但会引入估计误差,且通信开销巨大(在SDXL上高达9.83GB)。",
|
||||||
|
"goal": "本文旨在提出一种新颖的混合并行框架,在仅使用两张GPU的情况下,不仅能实现超过线性的加速比(即 $>2\times$),还要严格保持甚至提升生成图像的质量,同时将通信开销降到最低。",
|
||||||
|
"gap": "与以往将图像空间切片(Patch-based)的思路不同,本文独辟蹊径,利用无分类器引导(CFG)中天然存在的“有条件”和“无条件”两条路径作为新的数据划分维度(Condition-based Partitioning)。同时,作者发现这两条路径的预测误差差异在整个去噪过程中呈现出先大后小再变大的U型曲线,因此引入了自适应的并行切换策略,只在误差差异最小时才进行并行流水线处理。"
|
||||||
|
},
|
||||||
|
"method": {
|
||||||
|
"overview": "该框架的核心思想是将扩散推理过程划分为三个阶段:预热阶段(Warm-Up)、并行阶段(Parallelism)和完全连接阶段(Fully-Connecting)。在预热和完全连接阶段,使用“基于条件的划分”策略,即一张GPU处理有条件的预测,另一张处理无条件的预测。而在中间的并行阶段,由于两个预测结果非常接近,框架切换到“自适应流水线并行”,利用两张GPU交替执行推理步骤,从而大幅压缩时间。",
|
||||||
|
"key_idea": "核心创新在于不再将图片在空间上切片,而是沿“条件”维度切分数据。这保证了每个GPU都能看到整张图片的全局信息,从而避免了拼接伪影。此外,引入了“去噪差异度”(Denoising Discrepancy,即 rel-MAE)这一指标来动态评估两条路径的相似性,并以此自动决定何时开启和关闭流水线并行,实现了最优的加速-质量平衡。",
|
||||||
|
"steps": "1. 数据划分:输入潜变量同时送入GPU 1(有条件预测 $\epsilon_\theta(x_t, c, t)$)和GPU 2(无条件预测 $\epsilon_\theta(x_t, t)$)。2. 阶段判断:根据实时计算的“去噪差异度” $G_t$ 与阈值 $g_{slope}$ 的关系,确定切换点 $\tau_1$ 和 $\tau_2$。3. 混合执行:在 $[T, \tau_1]$ 阶段同步运行;在 $[\tau_1, \tau_2]$ 阶段启用流水线并行(如GPU 1处理 $t-1$ 步时GPU 2处理 $t$ 步);在 $[\tau_2, 0]$ 阶段重新恢复同步以精细调整细节。",
|
||||||
|
"novelty": "该方法的另一大新颖之处在于其“安全性”设计:通过设置 $\tau_{cap}$ 作为安全上限,确保即使自动算法失效,也不会在错误的时间点引入并行,从而保证了算法的鲁棒性。此外,该框架对U-Net(如SDXL)和DiT(如SD3)架构均具有良好的泛化性。"
|
||||||
|
},
|
||||||
|
"results": {
|
||||||
|
"main_findings": "实验在SDXL和SD3模型上进行,使用MS-COCO 2014验证集。结果显示,在SDXL上,该方法实现了2.31倍加速,延迟从16.49秒降至7.12秒,且FID指标与原始单卡模型持平(甚至略优)。相比此前最强的DistriFusion(1.22倍)和AsyncDiff(1.31倍),提速效果显著。在通信开销方面,本方法仅为0.516GB,比AsyncDiff的9.83GB降低了19.6倍。在SD3模型上,同样实现了2.07倍的加速。",
|
||||||
|
"benchmarks": [
|
||||||
|
{
|
||||||
|
"task": "Text-to-Image (SDXL)",
|
||||||
|
"metric": "Speed-Up",
|
||||||
|
"this_work": "2.31x",
|
||||||
|
"baseline": "1.31x (AsyncDiff)",
|
||||||
|
"improvement": "1.0x (Extra speed)"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"task": "Text-to-Image (SDXL)",
|
||||||
|
"metric": "Comm. (GB)",
|
||||||
|
"this_work": "0.516",
|
||||||
|
"baseline": "9.830 (AsyncDiff)",
|
||||||
|
"improvement": "Reduced by 19.6x"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"task": "Text-to-Image (SD3)",
|
||||||
|
"metric": "Speed-Up",
|
||||||
|
"this_work": "2.07x",
|
||||||
|
"baseline": "1.97x (AsyncDiff)",
|
||||||
|
"improvement": "0.1x (Extra speed)"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"limitations": "尽管该方法在通用性上表现出色,但在处理极高分辨率(如4K以上)时,加速比会随分辨率提升而有所下降(从2.72x降至1.62x)。此外,目前的实现仅针对两张GPU进行了深度优化,虽然文中提出了多卡扩展策略,但在单个样本推理场景下,如何高效地扩展到四卡或更多卡仍是一个挑战。最后,参数 $k$ 的选取目前仍需人工根据经验设定。"
|
||||||
|
},
|
||||||
|
"improvements": {
|
||||||
|
"weaknesses": "主要弱点在于自适应切换参数(如 $k$ 和 $\tau_{cap}$)的确定目前仍偏向经验性,缺乏完全自动化的端到端学习机制。此外,虽然避免了图像切片,但条件分支的“信息量”并不总是完全对等的,特别是在极早期的噪声阶段,可能导致其中一张GPU负载不均衡。改进方向可以是结合动态负载均衡算法,根据当前步骤的预测难度动态分配计算资源。",
|
||||||
|
"future_work": "未来的研究方向包括:1. 将该混合并行策略扩展到视频生成模型(Video Diffusion)中,利用时间轴上的相关性进行更细粒度的流水线调度。2. 结合模型量化(Quantization)和蒸馏技术,在多卡并行的基础上进一步压缩单步推理时间。3. 探索在“去噪差异度”指标指导下自动学习最优的 $k$ 值和切换点。",
|
||||||
|
"reproducibility": "代码已在GitHub开源(https://github.com/kaist-dmlab/Hybridiff)。实验环境基于PyTorch,使用的GPU为NVIDIA GeForce 3090,硬件门槛相对较低。文中详细列出了关键超参数(如SDXL上的 $L=12, k=5, \tau_{cap}=15$),使得复现结果的难度较低。"
|
||||||
|
},
|
||||||
|
"figures": [
|
||||||
|
{
|
||||||
|
"id": "Figure 1",
|
||||||
|
"caption": "Summary of the proposed hybrid data-pipeline parallelism",
|
||||||
|
"description": "五维雷达图展示了该方法在速度、图像质量、通用性、高分辨率能力和通信开销五个方面均优于现有分布式框架。",
|
||||||
|
"reason": "直观概括了本文的核心优势,即全方位的性能提升。",
|
||||||
|
"section": "results"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Figure 2",
|
||||||
|
"caption": "Comparison of parallel strategies",
|
||||||
|
"description": "对比了三种并行策略:(a)基于切片的数据并行容易产生伪影,(b)流水线并行通信开销大,(c)本文提出的混合并行既保留全局一致性又实现了高效并行。",
|
||||||
|
"reason": "通过对比展示了本文方法设计的合理性和必要性。",
|
||||||
|
"section": "method"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Figure 3",
|
||||||
|
"caption": "Overview of the hybrid parallel framework",
|
||||||
|
"description": "详细展示了三个阶段(Warm-Up, Parallelism, Fully-Connecting)的数据流和通信模式,清晰地说明了自适应切换的动态过程。",
|
||||||
|
"reason": "这是理解整个算法执行流程的关键示意图。",
|
||||||
|
"section": "method"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Table 1",
|
||||||
|
"caption": "Quantitative comparison on SDXL and SD3",
|
||||||
|
"description": "表格列出了该方法与基线方法在延迟、加速比、通信开销及生成质量指标(FID, LPIPS, PSNR)上的详细对比数据。",
|
||||||
|
"reason": "提供了最核心的定量证据,证明了该方法的有效性。",
|
||||||
|
"section": "results"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with open("data/papers/2602.21760/summary.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
print("summary.json created successfully.")
|
||||||
+1
-1
@@ -19,7 +19,7 @@ dependencies = [
|
|||||||
"pymupdf>=1.25",
|
"pymupdf>=1.25",
|
||||||
"itsdangerous>=2.2.0",
|
"itsdangerous>=2.2.0",
|
||||||
"bleach>=6.4.0",
|
"bleach>=6.4.0",
|
||||||
"pymupdf4llm>=1.27.2.3",
|
"onnxruntime>=1.17",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
@@ -0,0 +1,172 @@
|
|||||||
|
"""导出 PicoDet-S_layout_3cls 为 ONNX 格式.
|
||||||
|
|
||||||
|
一次性脚本,在独立 venv 中运行:
|
||||||
|
python -m venv .venv-export && source .venv-export/bin/activate
|
||||||
|
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple paddlepaddle paddleocr paddle2onnx onnxruntime opencv-python-headless
|
||||||
|
HF_ENDPOINT=https://hf-mirror.com python scripts/export_picodet_onnx.py
|
||||||
|
|
||||||
|
输出:
|
||||||
|
data/models/picodet_layout_3cls.onnx
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# hf-mirror
|
||||||
|
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
||||||
|
|
||||||
|
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||||
|
MODEL_DIR = PROJECT_ROOT / "data" / "models"
|
||||||
|
OUTPUT_PATH = MODEL_DIR / "picodet_layout_3cls.onnx"
|
||||||
|
MODEL_NAME = "PicoDet-S_layout_3cls"
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ── Step 1: 用 PaddleOCR paddle_static 引擎加载模型,触发下载 ──
|
||||||
|
print(f"[1/4] Loading model '{MODEL_NAME}' (paddle_static engine, triggers download) ...")
|
||||||
|
from paddleocr import LayoutDetection
|
||||||
|
|
||||||
|
model = LayoutDetection(
|
||||||
|
model_name=MODEL_NAME,
|
||||||
|
engine="paddle_static",
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
print(" ✓ Model loaded and cached")
|
||||||
|
|
||||||
|
# ── Step 2: 找到 PaddleX 缓存的 Paddle 模型文件 ────────────────
|
||||||
|
paddlex_cache = Path.home() / ".paddlex"
|
||||||
|
print(f"\n[2/4] Searching Paddle model cache in {paddlex_cache} ...")
|
||||||
|
|
||||||
|
# 搜索 layout 相关的缓存目录
|
||||||
|
candidates = []
|
||||||
|
for d in paddlex_cache.rglob("*"):
|
||||||
|
if d.is_dir() and (d / "inference.pdiparams").exists():
|
||||||
|
# 检查是否是 layout 模型
|
||||||
|
marker = d.name
|
||||||
|
parent_name = d.parent.name
|
||||||
|
if "layout" in marker.lower() or "layout" in parent_name.lower() or "picodet" in marker.lower():
|
||||||
|
candidates.append(d)
|
||||||
|
elif "PicoDet" in str(d):
|
||||||
|
candidates.append(d)
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
# 如果没找到明确的 layout 目录,列出所有含 inference.pdiparams 的目录
|
||||||
|
all_model_dirs = [d for d in paddlex_cache.rglob("*") if d.is_dir() and (d / "inference.pdiparams").exists()]
|
||||||
|
print(" No layout-specific dir found. All model dirs with inference.pdiparams:")
|
||||||
|
for d in all_model_dirs:
|
||||||
|
files = [f.name for f in d.iterdir()]
|
||||||
|
print(f" {d} ({', '.join(files)})")
|
||||||
|
if all_model_dirs:
|
||||||
|
# 取最新的(刚下载的)
|
||||||
|
candidates = sorted(all_model_dirs, key=lambda d: (d / "inference.pdiparams").stat().st_mtime, reverse=True)[:1]
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
print(" ✗ No cached model found")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
model_cache_dir = candidates[0]
|
||||||
|
files_in_dir = list(model_cache_dir.iterdir())
|
||||||
|
print(f" Using: {model_cache_dir}")
|
||||||
|
for f in files_in_dir:
|
||||||
|
print(f" {f.name} ({f.stat().st_size / 1024:.1f} KB)")
|
||||||
|
|
||||||
|
# ── Step 3: 用 paddle2onnx 转换 ─────────────────────────────────
|
||||||
|
print("\n[3/4] Converting to ONNX with paddle2onnx ...")
|
||||||
|
tmp_onnx = OUTPUT_PATH.with_suffix(".tmp.onnx")
|
||||||
|
|
||||||
|
# 确定 model_filename
|
||||||
|
pdmodel = model_cache_dir / "inference.pdmodel"
|
||||||
|
has_pdmodel = pdmodel.exists()
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
sys.executable, "-m", "paddle2onnx",
|
||||||
|
"--model_dir", str(model_cache_dir),
|
||||||
|
"--save_file", str(tmp_onnx),
|
||||||
|
"--opset_version", "11",
|
||||||
|
"--enable_onnx_checker", "True",
|
||||||
|
]
|
||||||
|
if has_pdmodel:
|
||||||
|
cmd.extend(["--model_filename", "inference.pdmodel"])
|
||||||
|
cmd.extend(["--params_filename", "inference.pdiparams"])
|
||||||
|
|
||||||
|
print(f" Running: {' '.join(cmd)}")
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
if result.stdout:
|
||||||
|
print(f" stdout: {result.stdout[:500]}")
|
||||||
|
if result.returncode != 0:
|
||||||
|
print(f" ✗ paddle2onnx failed (exit {result.returncode})")
|
||||||
|
print(f" stderr: {result.stderr[:500]}")
|
||||||
|
|
||||||
|
# 尝试不带 model_filename(combined format)
|
||||||
|
if has_pdmodel:
|
||||||
|
print(" Retrying without explicit model_filename ...")
|
||||||
|
cmd2 = [
|
||||||
|
sys.executable, "-m", "paddle2onnx",
|
||||||
|
"--model_dir", str(model_cache_dir),
|
||||||
|
"--params_filename", "inference.pdiparams",
|
||||||
|
"--save_file", str(tmp_onnx),
|
||||||
|
"--opset_version", "11",
|
||||||
|
]
|
||||||
|
result2 = subprocess.run(cmd2, capture_output=True, text=True)
|
||||||
|
if result2.returncode != 0:
|
||||||
|
print(f" ✗ Retry also failed: {result2.stderr[:500]}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not tmp_onnx.exists() or tmp_onnx.stat().st_size < 1000:
|
||||||
|
print(" ✗ ONNX file not created or too small")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
shutil.move(str(tmp_onnx), str(OUTPUT_PATH))
|
||||||
|
print(f" ✓ ONNX saved ({OUTPUT_PATH.stat().st_size / 1024 / 1024:.2f} MB)")
|
||||||
|
|
||||||
|
# ── Step 4: 用 onnxruntime 验证 ─────────────────────────────────
|
||||||
|
print("\n[4/4] Verifying with onnxruntime ...")
|
||||||
|
_inspect_onnx(OUTPUT_PATH)
|
||||||
|
|
||||||
|
print(f"\n✓ Done! ONNX model saved to {OUTPUT_PATH}")
|
||||||
|
|
||||||
|
|
||||||
|
def _inspect_onnx(onnx_path: Path) -> None:
|
||||||
|
"""用 onnxruntime 加载模型,打印输入输出信息."""
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
|
||||||
|
|
||||||
|
print(" Inputs:")
|
||||||
|
for inp in session.get_inputs():
|
||||||
|
print(f" {inp.name}: shape={inp.shape}, dtype={inp.type}")
|
||||||
|
|
||||||
|
print(" Outputs:")
|
||||||
|
for out in session.get_outputs():
|
||||||
|
print(f" {out.name}: shape={out.shape}, dtype={out.type}")
|
||||||
|
|
||||||
|
# 试推理
|
||||||
|
input_info = session.get_inputs()[0]
|
||||||
|
input_name = input_info.name
|
||||||
|
batch_size = input_info.shape[0] if isinstance(input_info.shape[0], int) else 1
|
||||||
|
channels = input_info.shape[1] if isinstance(input_info.shape[1], int) else 3
|
||||||
|
height = input_info.shape[2] if isinstance(input_info.shape[2], int) else 480
|
||||||
|
width = input_info.shape[3] if isinstance(input_info.shape[3], int) else 480
|
||||||
|
|
||||||
|
dummy_input = np.random.rand(batch_size, channels, height, width).astype(np.float32)
|
||||||
|
outputs = session.run(None, {input_name: dummy_input})
|
||||||
|
|
||||||
|
print(" Inference test outputs:")
|
||||||
|
for i, (out_info, out_val) in enumerate(zip(session.get_outputs(), outputs)):
|
||||||
|
print(f" output[{i}] '{out_info.name}': shape={out_val.shape}, dtype={out_val.dtype}")
|
||||||
|
if out_val.size <= 20:
|
||||||
|
print(f" values: {out_val}")
|
||||||
|
|
||||||
|
print(" ✓ Inference OK")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,212 @@
|
|||||||
|
"""批量重新提取所有论文的图片 — 下载 PDF + PicoDet 检测 + caption 匹配.
|
||||||
|
|
||||||
|
用法:
|
||||||
|
PROXY_SERVER=http://... uv run python scripts/reextract_images.py
|
||||||
|
uv run python scripts/reextract_images.py --limit 10 # 只处理前 10 篇
|
||||||
|
uv run python scripts/reextract_images.py --id 2512.24880 # 只处理指定论文
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# 让脚本可以从项目根目录直接运行
|
||||||
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
from app.database import SessionLocal, init_db, engine # noqa: E402
|
||||||
|
from app.models import Paper # noqa: E402
|
||||||
|
from app.services.pdf_image_extractor import extract_images_from_pdf # noqa: E402
|
||||||
|
from app.utils import TMP_DIR # noqa: E402
|
||||||
|
from sqlalchemy import select # noqa: E402
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)-5s %(message)s",
|
||||||
|
datefmt="%H:%M:%S",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 下载并发数
|
||||||
|
MAX_WORKERS = 3
|
||||||
|
# 下载超时(秒)
|
||||||
|
DOWNLOAD_TIMEOUT = 120
|
||||||
|
|
||||||
|
|
||||||
|
def _get_session() -> requests.Session:
|
||||||
|
"""创建带代理的 HTTP session。"""
|
||||||
|
sess = requests.Session()
|
||||||
|
sess.headers.update({"User-Agent": "hf-daily-papers/1.0"})
|
||||||
|
proxy = os.environ.get("PROXY_SERVER") or os.environ.get("HTTPS_PROXY")
|
||||||
|
if proxy:
|
||||||
|
sess.proxies = {"http": proxy, "https": proxy}
|
||||||
|
logger.info("使用代理: %s", proxy)
|
||||||
|
else:
|
||||||
|
logger.warning("未设置代理 (PROXY_SERVER / HTTPS_PROXY),直连 arxiv.org")
|
||||||
|
return sess
|
||||||
|
|
||||||
|
|
||||||
|
def download_pdf(session: requests.Session, arxiv_id: str, pdf_url: str) -> Path | None:
|
||||||
|
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf,返回路径或 None。"""
|
||||||
|
dest_dir = TMP_DIR / arxiv_id
|
||||||
|
dest = dest_dir / "paper.pdf"
|
||||||
|
if dest.exists() and dest.stat().st_size > 1000:
|
||||||
|
return dest
|
||||||
|
|
||||||
|
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
resp = session.get(pdf_url, timeout=DOWNLOAD_TIMEOUT, allow_redirects=True)
|
||||||
|
resp.raise_for_status()
|
||||||
|
dest.write_bytes(resp.content)
|
||||||
|
return dest
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("下载失败 %s: %s", arxiv_id, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def process_one(session: requests.Session, arxiv_id: str, pdf_url: str) -> dict:
|
||||||
|
"""处理单篇论文:下载 → 提取图片 → 返回统计。"""
|
||||||
|
result = {"arxiv_id": arxiv_id, "downloaded": False, "extracted": 0, "error": None}
|
||||||
|
|
||||||
|
# 下载 PDF
|
||||||
|
pdf_path = download_pdf(session, arxiv_id, pdf_url)
|
||||||
|
if pdf_path is None:
|
||||||
|
result["error"] = "download_failed"
|
||||||
|
return result
|
||||||
|
result["downloaded"] = True
|
||||||
|
|
||||||
|
# 提取图片
|
||||||
|
try:
|
||||||
|
n = extract_images_from_pdf(arxiv_id, pdf_path)
|
||||||
|
result["extracted"] = n
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("提取失败 %s: %s", arxiv_id, exc, exc_info=True)
|
||||||
|
result["error"] = f"extract_failed: {exc}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 统计 matched / orphan
|
||||||
|
mf = Path(f"data/papers/{arxiv_id}/images/manifest.json")
|
||||||
|
if mf.exists():
|
||||||
|
m = json.loads(mf.read_text(encoding="utf-8"))
|
||||||
|
result["matched"] = sum(1 for v in m.values() if "(p" not in v.get("label", ""))
|
||||||
|
result["orphan"] = sum(1 for v in m.values() if "(p" in v.get("label", ""))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="批量重新提取论文图片")
|
||||||
|
parser.add_argument("--limit", type=int, default=0, help="只处理前 N 篇")
|
||||||
|
parser.add_argument("--id", dest="arxiv_id", help="只处理指定 arxiv_id")
|
||||||
|
parser.add_argument("--workers", type=int, default=MAX_WORKERS, help="并发数")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 初始化数据库
|
||||||
|
os.makedirs("data/db", exist_ok=True)
|
||||||
|
init_db(engine)
|
||||||
|
|
||||||
|
# 读取论文列表
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
if args.arxiv_id:
|
||||||
|
papers = (
|
||||||
|
db.execute(select(Paper).where(Paper.arxiv_id == args.arxiv_id))
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
papers = db.execute(select(Paper)).scalars().all()
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
if args.limit > 0:
|
||||||
|
papers = papers[: args.limit]
|
||||||
|
|
||||||
|
total = len(papers)
|
||||||
|
logger.info("待处理论文: %d 篇", total)
|
||||||
|
if total == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
session = _get_session()
|
||||||
|
|
||||||
|
# 统计
|
||||||
|
done = 0
|
||||||
|
failed = 0
|
||||||
|
total_extracted = 0
|
||||||
|
total_matched = 0
|
||||||
|
total_orphan = 0
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=args.workers) as pool:
|
||||||
|
futures = {}
|
||||||
|
for p in papers:
|
||||||
|
f = pool.submit(process_one, session, p.arxiv_id, p.pdf_url)
|
||||||
|
futures[f] = p.arxiv_id
|
||||||
|
|
||||||
|
for f in as_completed(futures):
|
||||||
|
arxiv_id = futures[f]
|
||||||
|
try:
|
||||||
|
r = f.result()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("异常 %s: %s", arxiv_id, exc)
|
||||||
|
failed += 1
|
||||||
|
done += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
done += 1
|
||||||
|
if r["error"]:
|
||||||
|
failed += 1
|
||||||
|
logger.info("[%d/%d] ✗ %s — %s", done, total, arxiv_id, r["error"])
|
||||||
|
else:
|
||||||
|
total_extracted += r["extracted"]
|
||||||
|
total_matched += r.get("matched", 0)
|
||||||
|
total_orphan += r.get("orphan", 0)
|
||||||
|
matched = r.get("matched", 0)
|
||||||
|
orphan = r.get("orphan", 0)
|
||||||
|
elapsed = time.time() - t0
|
||||||
|
rate = done / elapsed if elapsed > 0 else 0
|
||||||
|
eta = (total - done) / rate if rate > 0 else 0
|
||||||
|
logger.info(
|
||||||
|
"[%d/%d] ✓ %s — %d 张 (matched=%d, orphan=%d) ETA %.0fs",
|
||||||
|
done,
|
||||||
|
total,
|
||||||
|
arxiv_id,
|
||||||
|
r["extracted"],
|
||||||
|
matched,
|
||||||
|
orphan,
|
||||||
|
eta,
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed = time.time() - t0
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(
|
||||||
|
"完成: %d/%d 成功, %d 失败, 耗时 %.1fs",
|
||||||
|
done - failed,
|
||||||
|
total,
|
||||||
|
failed,
|
||||||
|
elapsed,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"图片: %d 总计, %d matched, %d orphan (%.1f%%)",
|
||||||
|
total_extracted,
|
||||||
|
total_matched,
|
||||||
|
total_orphan,
|
||||||
|
total_orphan / total_extracted * 100 if total_extracted else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
+7
-1
@@ -161,7 +161,13 @@ def sample_summary_dict() -> dict:
|
|||||||
"results": {
|
"results": {
|
||||||
"main_findings": "在长文本基准 LongBench 上取得了 SOTA 结果,平均得分提升 3.2 个百分点。推理速度相比全注意力提升了 2 倍,显存占用降低 60%。在 32k 序列长度下仍保持与全注意力相当的生成质量。",
|
"main_findings": "在长文本基准 LongBench 上取得了 SOTA 结果,平均得分提升 3.2 个百分点。推理速度相比全注意力提升了 2 倍,显存占用降低 60%。在 32k 序列长度下仍保持与全注意力相当的生成质量。",
|
||||||
"benchmarks": [
|
"benchmarks": [
|
||||||
{"task": "长文本摘要", "metric": "ROUGE-L", "this_work": "42.1", "baseline": "38.9", "improvement": "+3.2"},
|
{
|
||||||
|
"task": "长文本摘要",
|
||||||
|
"metric": "ROUGE-L",
|
||||||
|
"this_work": "42.1",
|
||||||
|
"baseline": "38.9",
|
||||||
|
"improvement": "+3.2",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
"limitations": "在超长文本(>100k tokens)上效果有所下降,主要原因是全局采样点数量不足以覆盖所有关键信息。此外,在小规模数据集上的优势不如大规模数据集明显。",
|
"limitations": "在超长文本(>100k tokens)上效果有所下降,主要原因是全局采样点数量不足以覆盖所有关键信息。此外,在小规模数据集上的优势不如大规模数据集明显。",
|
||||||
},
|
},
|
||||||
|
|||||||
+9
-13
@@ -67,7 +67,7 @@ class TestAdminAuth:
|
|||||||
def test_correct_session_accepted(self, auth_client):
|
def test_correct_session_accepted(self, auth_client):
|
||||||
"""已登录 session 应被接受(crawl 可能会失败但不是 303)。"""
|
"""已登录 session 应被接受(crawl 可能会失败但不是 303)。"""
|
||||||
with patch(
|
with patch(
|
||||||
"app.routes.admin.crawl_daily", new_callable=AsyncMock
|
"app.routes.admin.run_crawl", new_callable=AsyncMock
|
||||||
) as mock_crawl:
|
) as mock_crawl:
|
||||||
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
|
mock_crawl.return_value = {"found": 0, "new": 0, "status": "success"}
|
||||||
resp = auth_client.post("/admin/crawl")
|
resp = auth_client.post("/admin/crawl")
|
||||||
@@ -83,9 +83,7 @@ class TestAdminAuth:
|
|||||||
|
|
||||||
def test_correct_session_batch_summarize(self, auth_client):
|
def test_correct_session_batch_summarize(self, auth_client):
|
||||||
"""已登录调用 batch summarize,mock 掉服务层。"""
|
"""已登录调用 batch summarize,mock 掉服务层。"""
|
||||||
with patch(
|
with patch("app.routes.admin.summarize_batch", new_callable=AsyncMock) as mock:
|
||||||
"app.routes.admin.summarize_batch", new_callable=AsyncMock
|
|
||||||
) as mock:
|
|
||||||
mock.return_value = {
|
mock.return_value = {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"done": 0,
|
"done": 0,
|
||||||
@@ -98,10 +96,12 @@ class TestAdminAuth:
|
|||||||
|
|
||||||
def test_single_paper_not_found(self, auth_client):
|
def test_single_paper_not_found(self, auth_client):
|
||||||
"""单篇总结不存在的论文返回 404。"""
|
"""单篇总结不存在的论文返回 404。"""
|
||||||
|
from app.exceptions import NotFoundError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routes.admin.summarize_single",
|
"app.routes.admin.summarize_single",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value={"status": "not_found", "arxiv_id": "nonexistent.99999"},
|
side_effect=NotFoundError("Paper not found: nonexistent.99999"),
|
||||||
):
|
):
|
||||||
resp = auth_client.post("/admin/summarize/nonexistent.99999")
|
resp = auth_client.post("/admin/summarize/nonexistent.99999")
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
@@ -118,7 +118,7 @@ class TestAdminCrawl:
|
|||||||
def test_crawl_default_today(self, auth_client):
|
def test_crawl_default_today(self, auth_client):
|
||||||
"""不指定日期时默认抓取今天。"""
|
"""不指定日期时默认抓取今天。"""
|
||||||
with patch(
|
with patch(
|
||||||
"app.routes.admin.crawl_daily", new_callable=AsyncMock
|
"app.routes.admin.run_crawl", new_callable=AsyncMock
|
||||||
) as mock_crawl:
|
) as mock_crawl:
|
||||||
mock_crawl.return_value = {"found": 5, "new": 3, "status": "success"}
|
mock_crawl.return_value = {"found": 5, "new": 3, "status": "success"}
|
||||||
resp = auth_client.post("/admin/crawl")
|
resp = auth_client.post("/admin/crawl")
|
||||||
@@ -130,7 +130,7 @@ class TestAdminCrawl:
|
|||||||
def test_crawl_specific_date(self, auth_client):
|
def test_crawl_specific_date(self, auth_client):
|
||||||
"""指定日期抓取。"""
|
"""指定日期抓取。"""
|
||||||
with patch(
|
with patch(
|
||||||
"app.routes.admin.crawl_daily", new_callable=AsyncMock
|
"app.routes.admin.run_crawl", new_callable=AsyncMock
|
||||||
) as mock_crawl:
|
) as mock_crawl:
|
||||||
mock_crawl.return_value = {"found": 2, "new": 1, "status": "success"}
|
mock_crawl.return_value = {"found": 2, "new": 1, "status": "success"}
|
||||||
resp = auth_client.post("/admin/crawl?date=2024-01-15")
|
resp = auth_client.post("/admin/crawl?date=2024-01-15")
|
||||||
@@ -194,9 +194,7 @@ class TestAdminDelete:
|
|||||||
)
|
)
|
||||||
assert resp.status_code == 422
|
assert resp.status_code == 422
|
||||||
|
|
||||||
def test_delete_with_confirm(
|
def test_delete_with_confirm(self, auth_client, db_session, sample_papers_range):
|
||||||
self, auth_client, db_session, sample_papers_range
|
|
||||||
):
|
|
||||||
"""confirm='DELETE' 时应执行删除。"""
|
"""confirm='DELETE' 时应执行删除。"""
|
||||||
resp = auth_client.post(
|
resp = auth_client.post(
|
||||||
"/admin/delete",
|
"/admin/delete",
|
||||||
@@ -255,9 +253,7 @@ class TestAdminLogs:
|
|||||||
resp = client.get("/admin/logs", follow_redirects=False)
|
resp = client.get("/admin/logs", follow_redirects=False)
|
||||||
assert resp.status_code == 303
|
assert resp.status_code == 303
|
||||||
|
|
||||||
def test_logs_contains_data(
|
def test_logs_contains_data(self, auth_client, db_session, sample_papers_range):
|
||||||
self, auth_client, db_session, sample_papers_range
|
|
||||||
):
|
|
||||||
"""日志页面应包含日志数据。"""
|
"""日志页面应包含日志数据。"""
|
||||||
# 先创建一条日志
|
# 先创建一条日志
|
||||||
now = utc_now()
|
now = utc_now()
|
||||||
|
|||||||
@@ -0,0 +1,189 @@
|
|||||||
|
"""爬虫服务测试 — _parse_paper、fetch_daily、upsert_papers、crawl_daily。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.crawler import (
|
||||||
|
_parse_paper,
|
||||||
|
crawl_daily,
|
||||||
|
fetch_daily,
|
||||||
|
upsert_papers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# _parse_paper
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestParsePaper:
|
||||||
|
def test_normal_item(self):
|
||||||
|
item = {
|
||||||
|
"paper": {
|
||||||
|
"id": "2401.12345",
|
||||||
|
"title": "Test Paper",
|
||||||
|
"abstract": "Abstract text",
|
||||||
|
"publishedAt": "2024-01-15T00:00:00",
|
||||||
|
"authors": [{"name": "Alice"}, {"name": "Bob"}],
|
||||||
|
"tags": [{"name": "NLP"}, {"name": "LLM"}],
|
||||||
|
"upvotes": 42,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = _parse_paper(item)
|
||||||
|
assert result["arxiv_id"] == "2401.12345"
|
||||||
|
assert result["title_en"] == "Test Paper"
|
||||||
|
assert len(result["authors"]) == 2
|
||||||
|
assert result["authors"] == ["Alice", "Bob"]
|
||||||
|
assert result["tags"] == ["NLP", "LLM"]
|
||||||
|
assert result["upvotes"] == 42
|
||||||
|
assert "huggingface.co" in result["hf_url"]
|
||||||
|
|
||||||
|
def test_empty_id(self):
|
||||||
|
item = {"paper": {"id": "", "authors": [], "tags": []}}
|
||||||
|
result = _parse_paper(item)
|
||||||
|
assert result["arxiv_id"] == ""
|
||||||
|
assert result["hf_url"] == ""
|
||||||
|
|
||||||
|
def test_missing_published_at(self):
|
||||||
|
item = {"paper": {"id": "2401.00001", "title": "T", "authors": [], "tags": []}}
|
||||||
|
result = _parse_paper(item)
|
||||||
|
assert result["published_at"] is None
|
||||||
|
|
||||||
|
def test_flat_structure_fallback(self):
|
||||||
|
"""无 paper 包装时直接从顶层取字段。"""
|
||||||
|
item = {"id": "2401.99999", "title": "Flat", "authors": [], "tags": []}
|
||||||
|
result = _parse_paper(item)
|
||||||
|
assert result["arxiv_id"] == "2401.99999"
|
||||||
|
assert result["title_en"] == "Flat"
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# fetch_daily
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestFetchDaily:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_papers(self, monkeypatch):
|
||||||
|
fake_data = [{"paper": {"id": "2401.00001"}}]
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = fake_data
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get.return_value = mock_resp
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("app.services.crawler.make_http_client", return_value=mock_client):
|
||||||
|
result = await fetch_daily("2024-01-15")
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["paper"]["id"] == "2401.00001"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_respects_top_n(self, monkeypatch):
|
||||||
|
fake_data = [{"paper": {"id": f"2401.{i:05d}"}} for i in range(10)]
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = fake_data
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get.return_value = mock_resp
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("app.services.crawler.make_http_client", return_value=mock_client):
|
||||||
|
result = await fetch_daily("2024-01-15", top_n=3)
|
||||||
|
assert len(result) == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# upsert_papers
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpsertPapers:
|
||||||
|
def test_inserts_new_paper(self, db_session):
|
||||||
|
papers_raw = [
|
||||||
|
{
|
||||||
|
"paper": {
|
||||||
|
"id": "2401.00001",
|
||||||
|
"title": "New Paper",
|
||||||
|
"abstract": "Abstract",
|
||||||
|
"authors": [{"name": "Alice"}],
|
||||||
|
"tags": [{"name": "CV"}],
|
||||||
|
"upvotes": 5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
new = upsert_papers(db_session, papers_raw, "2024-01-15")
|
||||||
|
assert len(new) == 1
|
||||||
|
assert new[0].arxiv_id == "2401.00001"
|
||||||
|
assert new[0].title_en == "New Paper"
|
||||||
|
|
||||||
|
def test_updates_existing_upvotes(self, db_session, sample_paper):
|
||||||
|
papers_raw = [
|
||||||
|
{
|
||||||
|
"paper": {
|
||||||
|
"id": sample_paper.arxiv_id,
|
||||||
|
"title": sample_paper.title_en,
|
||||||
|
"upvotes": 999,
|
||||||
|
"authors": [],
|
||||||
|
"tags": [],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
new = upsert_papers(db_session, papers_raw, "2024-01-15")
|
||||||
|
assert len(new) == 0 # 不新增
|
||||||
|
db_session.refresh(sample_paper)
|
||||||
|
assert sample_paper.upvotes == 999
|
||||||
|
|
||||||
|
def test_skips_empty_id(self, db_session):
|
||||||
|
papers_raw = [{"paper": {"id": "", "title": "Nope", "authors": [], "tags": []}}]
|
||||||
|
new = upsert_papers(db_session, papers_raw, "2024-01-15")
|
||||||
|
assert len(new) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# crawl_daily
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrawlDaily:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_success_flow(self, db_session):
|
||||||
|
with patch(
|
||||||
|
"app.services.crawler.fetch_daily",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_fetch:
|
||||||
|
mock_fetch.return_value = [
|
||||||
|
{
|
||||||
|
"paper": {
|
||||||
|
"id": "2401.00001",
|
||||||
|
"title": "T",
|
||||||
|
"authors": [],
|
||||||
|
"tags": [],
|
||||||
|
"upvotes": 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
result = await crawl_daily(db_session, "2024-01-15")
|
||||||
|
|
||||||
|
assert result["status"] == "success"
|
||||||
|
assert result["new"] == 1
|
||||||
|
assert result["found"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_failure_returns_failed(self, db_session):
|
||||||
|
with patch(
|
||||||
|
"app.services.crawler.fetch_daily",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=ConnectionError("network error"),
|
||||||
|
):
|
||||||
|
result = await crawl_daily(db_session, "2024-01-15")
|
||||||
|
|
||||||
|
assert result["status"] == "failed"
|
||||||
|
assert "network error" in result["error"]
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
"""PDF 下载测试 — download_pdf、路径工具、错误处理。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.pdf_downloader import (
|
||||||
|
PdfDownloadError,
|
||||||
|
download_pdf,
|
||||||
|
paper_dir,
|
||||||
|
tmp_dir,
|
||||||
|
)
|
||||||
|
from app.utils import PAPERS_DIR, TMP_DIR
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# 路径工具
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestPathHelpers:
|
||||||
|
def test_paper_dir(self):
|
||||||
|
assert paper_dir("2401.12345") == PAPERS_DIR / "2401.12345"
|
||||||
|
|
||||||
|
def test_tmp_dir(self):
|
||||||
|
assert tmp_dir("2401.12345") == TMP_DIR / "2401.12345"
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# download_pdf
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestDownloadPdf:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_success_download(self, tmp_path):
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.content = b"%PDF-1.4 fake"
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.return_value = mock_resp
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.pdf_downloader.TMP_DIR", tmp_path),
|
||||||
|
patch(
|
||||||
|
"app.services.pdf_downloader._get_session", return_value=mock_session
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await download_pdf(
|
||||||
|
"2401.12345", "https://arxiv.org/pdf/2401.12345.pdf"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exists()
|
||||||
|
assert result.name == "paper.pdf"
|
||||||
|
assert result.read_bytes() == b"%PDF-1.4 fake"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_pdf_url_raises(self):
|
||||||
|
with pytest.raises(PdfDownloadError, match="no pdf_url"):
|
||||||
|
await download_pdf("2401.12345", "")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_failure_raises(self, tmp_path):
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.side_effect = ConnectionError("refused")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.pdf_downloader.TMP_DIR", tmp_path),
|
||||||
|
patch(
|
||||||
|
"app.services.pdf_downloader._get_session", return_value=mock_session
|
||||||
|
),
|
||||||
|
):
|
||||||
|
with pytest.raises(PdfDownloadError, match="failed to download"):
|
||||||
|
await download_pdf("2401.12345", "https://bad.url/pdf.pdf")
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
"""流水线编排测试 — run_pipeline (crawl → summarize → cleanup)。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models import TaskLock
|
||||||
|
from app.services.pipeline import run_pipeline
|
||||||
|
from app.utils import utc_now
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunPipeline:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_pipeline_success(self, db_session):
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.services.pipeline.crawl_daily", new_callable=AsyncMock
|
||||||
|
) as mock_crawl,
|
||||||
|
patch(
|
||||||
|
"app.services.pipeline.summarize_batch", new_callable=AsyncMock
|
||||||
|
) as mock_summ,
|
||||||
|
patch("app.services.pipeline.cleanup_tmp") as mock_clean,
|
||||||
|
):
|
||||||
|
mock_crawl.return_value = {"status": "success", "found": 5, "new": 2}
|
||||||
|
mock_summ.return_value = {"status": "success", "done": 2, "failed": 0}
|
||||||
|
mock_clean.return_value = {"removed": 0}
|
||||||
|
|
||||||
|
result = await run_pipeline(db_session, "2024-01-15", "test")
|
||||||
|
|
||||||
|
assert result["status"] == "success"
|
||||||
|
mock_crawl.assert_called_once()
|
||||||
|
mock_summ.assert_called_once()
|
||||||
|
mock_clean.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_lock_prevents_reentry(self, db_session):
|
||||||
|
"""已有 running 锁时抛出 RuntimeError。"""
|
||||||
|
now = utc_now()
|
||||||
|
db_session.add(
|
||||||
|
TaskLock(
|
||||||
|
task="scheduler",
|
||||||
|
lock_key="pipeline-2024-01-15",
|
||||||
|
status="running",
|
||||||
|
owner="other",
|
||||||
|
acquired_at=now,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="already running"):
|
||||||
|
await run_pipeline(db_session, "2024-01-15", "test")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_crawl_failure_still_runs_summarize_and_cleanup(self, db_session):
|
||||||
|
"""crawl 失败时 pipeline 继续执行 summarize 和 cleanup。"""
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.services.pipeline.crawl_daily", new_callable=AsyncMock
|
||||||
|
) as mock_crawl,
|
||||||
|
patch(
|
||||||
|
"app.services.pipeline.summarize_batch", new_callable=AsyncMock
|
||||||
|
) as mock_summ,
|
||||||
|
patch("app.services.pipeline.cleanup_tmp") as mock_clean,
|
||||||
|
):
|
||||||
|
mock_crawl.side_effect = ConnectionError("timeout")
|
||||||
|
mock_summ.return_value = {"status": "success", "done": 0}
|
||||||
|
mock_clean.return_value = {"removed": 0}
|
||||||
|
|
||||||
|
result = await run_pipeline(db_session, "2024-01-15", "test")
|
||||||
|
|
||||||
|
# pipeline 捕获异常,返回 failed
|
||||||
|
assert result["status"] == "failed"
|
||||||
|
assert "timeout" in result["error"]
|
||||||
|
# summarize 和 cleanup 不会被调用(exception 跳出 try 块)
|
||||||
|
mock_summ.assert_not_called()
|
||||||
@@ -20,7 +20,7 @@ from app.services.schemas import (
|
|||||||
classify_validation_error,
|
classify_validation_error,
|
||||||
flatten_for_db,
|
flatten_for_db,
|
||||||
)
|
)
|
||||||
from app.services.summarizer import _classify_error
|
from app.services.summary_generator import _classify_error
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|||||||
+35
-22
@@ -23,12 +23,8 @@ from app.services.pdf_downloader import (
|
|||||||
)
|
)
|
||||||
from app.services.pi_client import PiTimeoutError
|
from app.services.pi_client import PiTimeoutError
|
||||||
from app.services.schemas import SummarySchema
|
from app.services.schemas import SummarySchema
|
||||||
from app.services.summarizer import (
|
from app.services.summarizer import summarize_batch, summarize_one
|
||||||
_save_files,
|
from app.services.summary_persister import _save_files, _update_summary_in_db
|
||||||
_update_summary_in_db,
|
|
||||||
summarize_batch,
|
|
||||||
summarize_one,
|
|
||||||
)
|
|
||||||
from app.utils import utc_now
|
from app.utils import utc_now
|
||||||
|
|
||||||
|
|
||||||
@@ -39,7 +35,14 @@ from app.utils import utc_now
|
|||||||
def _summarize_tmp_paths(tmp_path):
|
def _summarize_tmp_paths(tmp_path):
|
||||||
"""将 data 目录重定向到 tmp_path(供 summarizer 测试使用)。"""
|
"""将 data 目录重定向到 tmp_path(供 summarizer 测试使用)。"""
|
||||||
with (
|
with (
|
||||||
patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
|
patch(
|
||||||
|
"app.services.summary_persister.paper_dir",
|
||||||
|
lambda aid: tmp_path / "papers" / aid,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"app.services.summary_generator.paper_dir",
|
||||||
|
lambda aid: tmp_path / "papers" / aid,
|
||||||
|
),
|
||||||
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
|
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
|
||||||
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
|
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
|
||||||
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
|
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
|
||||||
@@ -134,7 +137,9 @@ class TestFileOperations:
|
|||||||
|
|
||||||
def test_save_files(self, tmp_path, sample_summary_dict):
|
def test_save_files(self, tmp_path, sample_summary_dict):
|
||||||
schema = SummarySchema.model_validate(sample_summary_dict)
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
|
with patch(
|
||||||
|
"app.services.summary_persister.paper_dir", lambda aid: tmp_path / aid
|
||||||
|
):
|
||||||
_save_files("2401.12345", schema, "raw output text")
|
_save_files("2401.12345", schema, "raw output text")
|
||||||
|
|
||||||
paper_dir = tmp_path / "2401.12345"
|
paper_dir = tmp_path / "2401.12345"
|
||||||
@@ -144,7 +149,9 @@ class TestFileOperations:
|
|||||||
assert saved["title_zh"] == "测试论文中文标题"
|
assert saved["title_zh"] == "测试论文中文标题"
|
||||||
|
|
||||||
def test_save_raw_output_only(self, tmp_path):
|
def test_save_raw_output_only(self, tmp_path):
|
||||||
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
|
with patch(
|
||||||
|
"app.services.summary_persister.paper_dir", lambda aid: tmp_path / aid
|
||||||
|
):
|
||||||
_save_files("2401.12345", None, "raw output")
|
_save_files("2401.12345", None, "raw output")
|
||||||
paper_dir = tmp_path / "2401.12345"
|
paper_dir = tmp_path / "2401.12345"
|
||||||
assert (paper_dir / "raw_output.txt").exists()
|
assert (paper_dir / "raw_output.txt").exists()
|
||||||
@@ -180,7 +187,7 @@ class TestSummarizeOneFlow:
|
|||||||
with (
|
with (
|
||||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch(
|
patch(
|
||||||
"app.services.summarizer.call_pi",
|
"app.services.summary_generator.call_pi",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=(mock_pi_output, "test-session-id"),
|
return_value=(mock_pi_output, "test-session-id"),
|
||||||
),
|
),
|
||||||
@@ -209,7 +216,9 @@ class TestSummarizeOneFlow:
|
|||||||
assert fts_row[0] == "测试论文中文标题"
|
assert fts_row[0] == "测试论文中文标题"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_pdf_download_failure(self, db_session, sample_paper, _summarize_tmp_paths):
|
async def test_pdf_download_failure(
|
||||||
|
self, db_session, sample_paper, _summarize_tmp_paths
|
||||||
|
):
|
||||||
"""PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。"""
|
"""PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。"""
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
@@ -233,7 +242,7 @@ class TestSummarizeOneFlow:
|
|||||||
with (
|
with (
|
||||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch(
|
patch(
|
||||||
"app.services.summarizer.call_pi",
|
"app.services.summary_generator.call_pi",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
side_effect=PiTimeoutError("timeout after 300s"),
|
side_effect=PiTimeoutError("timeout after 300s"),
|
||||||
),
|
),
|
||||||
@@ -250,7 +259,7 @@ class TestSummarizeOneFlow:
|
|||||||
with (
|
with (
|
||||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch(
|
patch(
|
||||||
"app.services.summarizer.call_pi",
|
"app.services.summary_generator.call_pi",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=("No JSON in this output at all.", "test-session-id"),
|
return_value=("No JSON in this output at all.", "test-session-id"),
|
||||||
),
|
),
|
||||||
@@ -281,7 +290,7 @@ class TestSummarizeOneFlow:
|
|||||||
with (
|
with (
|
||||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch(
|
patch(
|
||||||
"app.services.summarizer.call_pi",
|
"app.services.summary_generator.call_pi",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=(bad_output, "test-session-id"),
|
return_value=(bad_output, "test-session-id"),
|
||||||
),
|
),
|
||||||
@@ -300,7 +309,7 @@ class TestSummarizeOneFlow:
|
|||||||
with (
|
with (
|
||||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch(
|
patch(
|
||||||
"app.services.summarizer.call_pi",
|
"app.services.summary_generator.call_pi",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=("Some output without JSON", "test-session-id"),
|
return_value=("Some output without JSON", "test-session-id"),
|
||||||
),
|
),
|
||||||
@@ -319,7 +328,7 @@ class TestSummarizeOneFlow:
|
|||||||
with (
|
with (
|
||||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch(
|
patch(
|
||||||
"app.services.summarizer.call_pi",
|
"app.services.summary_generator.call_pi",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=(mock_pi_output, "test-session-id"),
|
return_value=(mock_pi_output, "test-session-id"),
|
||||||
),
|
),
|
||||||
@@ -347,7 +356,9 @@ class TestSummarizeOneFlow:
|
|||||||
assert not tmp_paper.exists()
|
assert not tmp_paper.exists()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_skips_done_paper(self, db_session, sample_paper, _summarize_tmp_paths):
|
async def test_skips_done_paper(
|
||||||
|
self, db_session, sample_paper, _summarize_tmp_paths
|
||||||
|
):
|
||||||
"""已完成的论文跳过。"""
|
"""已完成的论文跳过。"""
|
||||||
sample_paper.summary_status.status = "done"
|
sample_paper.summary_status.status = "done"
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
@@ -393,7 +404,7 @@ class TestBatchSummarize:
|
|||||||
with (
|
with (
|
||||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch(
|
patch(
|
||||||
"app.services.summarizer.call_pi",
|
"app.services.summary_generator.call_pi",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=(mock_pi_output, "test-session-id"),
|
return_value=(mock_pi_output, "test-session-id"),
|
||||||
),
|
),
|
||||||
@@ -446,7 +457,7 @@ class TestBatchSummarize:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch("app.services.summarizer.call_pi", side_effect=_mock_call_pi),
|
patch("app.services.summary_generator.call_pi", side_effect=_mock_call_pi),
|
||||||
):
|
):
|
||||||
result = await summarize_batch(db_session, _session_factory=_TestSession)
|
result = await summarize_batch(db_session, _session_factory=_TestSession)
|
||||||
|
|
||||||
@@ -456,6 +467,8 @@ class TestBatchSummarize:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_task_lock_conflict(self, db_session, _summarize_tmp_paths):
|
async def test_task_lock_conflict(self, db_session, _summarize_tmp_paths):
|
||||||
"""TaskLock 防止并发 batch。"""
|
"""TaskLock 防止并发 batch。"""
|
||||||
|
from app.exceptions import ConflictError
|
||||||
|
|
||||||
# 先插入一个 running 锁
|
# 先插入一个 running 锁
|
||||||
db_session.add(
|
db_session.add(
|
||||||
TaskLock(
|
TaskLock(
|
||||||
@@ -467,8 +480,8 @@ class TestBatchSummarize:
|
|||||||
)
|
)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
result = await summarize_batch(db_session)
|
with pytest.raises(ConflictError):
|
||||||
assert result["status"] == "conflict"
|
await summarize_batch(db_session)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_task_lock_released(
|
async def test_task_lock_released(
|
||||||
@@ -482,7 +495,7 @@ class TestBatchSummarize:
|
|||||||
with (
|
with (
|
||||||
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch(
|
patch(
|
||||||
"app.services.summarizer.call_pi",
|
"app.services.summary_generator.call_pi",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=(mock_pi_output, "test-session-id"),
|
return_value=(mock_pi_output, "test-session-id"),
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -0,0 +1,174 @@
|
|||||||
|
"""summary_utils 测试 — PDF 文本提取、正文裁剪、JSON 提取、meta.json 写入、prompt 构建。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.summary_utils import (
|
||||||
|
JsonNotFoundError,
|
||||||
|
_trim_body,
|
||||||
|
build_prompt,
|
||||||
|
extract_json,
|
||||||
|
extract_pdf_text,
|
||||||
|
write_meta_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# _trim_body 正文裁剪
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrimBody:
|
||||||
|
def test_removes_references_section(self):
|
||||||
|
text = "Intro\n\nSome content here.\n\nReferences\n[1] Smith et al."
|
||||||
|
result = _trim_body(text)
|
||||||
|
assert "References" not in result
|
||||||
|
assert "Intro" in result
|
||||||
|
|
||||||
|
def test_removes_bibliography(self):
|
||||||
|
text = "Body\n\nBibliography\n[1] Smith"
|
||||||
|
result = _trim_body(text)
|
||||||
|
assert "Bibliography" not in result
|
||||||
|
|
||||||
|
def test_keeps_appendix_after_references(self):
|
||||||
|
text = "Body\n\nReferences\n[1] X\n\nAppendix\nExtra content"
|
||||||
|
result = _trim_body(text)
|
||||||
|
assert "Appendix" in result
|
||||||
|
assert "Extra content" in result
|
||||||
|
assert "References" not in result
|
||||||
|
|
||||||
|
def test_removes_acknowledgments(self):
|
||||||
|
text = "Body\n\nAcknowledgments\nThanks to everyone."
|
||||||
|
result = _trim_body(text)
|
||||||
|
assert "Acknowledgments" not in result
|
||||||
|
|
||||||
|
def test_max_chars_truncation(self):
|
||||||
|
text = "A" * 1000
|
||||||
|
result = _trim_body(text, max_chars=100)
|
||||||
|
assert len(result) <= 100
|
||||||
|
|
||||||
|
def test_no_truncation_when_none(self):
|
||||||
|
text = "A" * 500
|
||||||
|
result = _trim_body(text, max_chars=None)
|
||||||
|
assert len(result) == 500
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# extract_pdf_text
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractPdfText:
|
||||||
|
def test_extracts_text_and_saves(self, tmp_path):
|
||||||
|
pdf_path = tmp_path / "test.pdf"
|
||||||
|
pdf_path.write_bytes(b"%PDF-fake")
|
||||||
|
|
||||||
|
mock_page = MagicMock()
|
||||||
|
mock_page.get_text.return_value = "Page 1 text"
|
||||||
|
mock_doc = MagicMock()
|
||||||
|
mock_doc.__iter__ = MagicMock(return_value=iter([mock_page]))
|
||||||
|
mock_doc.__enter__ = MagicMock(return_value=mock_doc)
|
||||||
|
mock_doc.__exit__ = MagicMock(return_value=False)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("pymupdf.open", return_value=mock_doc),
|
||||||
|
patch(
|
||||||
|
"app.services.summary_utils._trim_body", side_effect=lambda t, **kw: t
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result_path = extract_pdf_text(pdf_path)
|
||||||
|
|
||||||
|
assert result_path.suffix == ".txt"
|
||||||
|
assert result_path.exists()
|
||||||
|
assert "Page 1 text" in result_path.read_text()
|
||||||
|
|
||||||
|
def test_uses_cached_txt(self, tmp_path):
|
||||||
|
pdf_path = tmp_path / "test.pdf"
|
||||||
|
pdf_path.write_bytes(b"%PDF-fake")
|
||||||
|
txt_path = tmp_path / "test.txt"
|
||||||
|
txt_path.write_text("cached", encoding="utf-8")
|
||||||
|
|
||||||
|
with patch("pymupdf.open") as mock_open:
|
||||||
|
result = extract_pdf_text(pdf_path)
|
||||||
|
|
||||||
|
mock_open.assert_not_called()
|
||||||
|
assert result == txt_path
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# write_meta_json
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestWriteMetaJson:
|
||||||
|
def test_writes_meta_json(self, tmp_path, sample_paper):
|
||||||
|
with patch("app.services.pdf_downloader.paper_dir", lambda aid: tmp_path / aid):
|
||||||
|
result = write_meta_json(sample_paper)
|
||||||
|
|
||||||
|
assert result.exists()
|
||||||
|
assert result.name == "meta.json"
|
||||||
|
data = json.loads(result.read_text(encoding="utf-8"))
|
||||||
|
assert data["arxiv_id"] == "2401.12345"
|
||||||
|
assert data["title_en"] == "Test Paper Title"
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# build_prompt
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildPrompt:
|
||||||
|
def test_inject_mode_contains_schema(self, tmp_path):
|
||||||
|
prompt = build_prompt(
|
||||||
|
"2401.12345", tmp_path / "meta", tmp_path / "txt", "inject"
|
||||||
|
)
|
||||||
|
assert "title_zh" in prompt
|
||||||
|
assert "必须包含以下字段" in prompt
|
||||||
|
|
||||||
|
def test_search_mode_contains_read_instruction(self, tmp_path):
|
||||||
|
prompt = build_prompt(
|
||||||
|
"2401.12345", tmp_path / "meta", tmp_path / "txt", "search"
|
||||||
|
)
|
||||||
|
assert "read" in prompt.lower()
|
||||||
|
assert "title_zh" in prompt
|
||||||
|
|
||||||
|
def test_fix_errors_mode(self, tmp_path):
|
||||||
|
prompt = build_prompt(
|
||||||
|
"2401.12345",
|
||||||
|
tmp_path / "meta",
|
||||||
|
tmp_path / "txt",
|
||||||
|
"inject",
|
||||||
|
fix_errors=["字段缺失"],
|
||||||
|
)
|
||||||
|
assert "字段缺失" in prompt
|
||||||
|
assert "修正" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# extract_json
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractJson:
|
||||||
|
def test_direct_json(self, sample_summary_json):
|
||||||
|
result = extract_json(sample_summary_json)
|
||||||
|
assert result["title_zh"] == "测试论文中文标题"
|
||||||
|
|
||||||
|
def test_fenced_code_block(self, sample_summary_json):
|
||||||
|
raw = f"some text\n```json\n{sample_summary_json}\n```\nmore text"
|
||||||
|
result = extract_json(raw)
|
||||||
|
assert result["title_zh"] == "测试论文中文标题"
|
||||||
|
|
||||||
|
def test_brace_matching_fallback(self, sample_summary_dict):
|
||||||
|
json_str = json.dumps(sample_summary_dict, ensure_ascii=False)
|
||||||
|
raw = f"Here is the result: {json_str} end."
|
||||||
|
result = extract_json(raw)
|
||||||
|
assert result["title_zh"] == "测试论文中文标题"
|
||||||
|
|
||||||
|
def test_no_json_raises(self):
|
||||||
|
with pytest.raises(JsonNotFoundError):
|
||||||
|
extract_json("plain text no json here at all")
|
||||||
+13
-14
@@ -2,6 +2,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.exceptions import NotFoundError, ValidationError
|
||||||
from app.services.user_data import (
|
from app.services.user_data import (
|
||||||
get_note,
|
get_note,
|
||||||
save_note,
|
save_note,
|
||||||
@@ -27,9 +30,8 @@ class TestBookmarkService:
|
|||||||
assert result["bookmarked"] is False
|
assert result["bookmarked"] is False
|
||||||
|
|
||||||
def test_toggle_bookmark_not_found(self, db_session):
|
def test_toggle_bookmark_not_found(self, db_session):
|
||||||
result = toggle_bookmark(db_session, "nonexistent")
|
with pytest.raises(NotFoundError):
|
||||||
assert "error" in result
|
toggle_bookmark(db_session, "nonexistent")
|
||||||
assert result["error"] == "not_found"
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
@@ -44,9 +46,8 @@ class TestReadingStatusService:
|
|||||||
assert result["arxiv_id"] == "2401.12345"
|
assert result["arxiv_id"] == "2401.12345"
|
||||||
|
|
||||||
def test_set_reading_status_invalid(self, db_session, sample_paper):
|
def test_set_reading_status_invalid(self, db_session, sample_paper):
|
||||||
result = set_reading_status(db_session, "2401.12345", "invalid_status")
|
with pytest.raises(ValidationError):
|
||||||
assert "error" in result
|
set_reading_status(db_session, "2401.12345", "invalid_status")
|
||||||
assert result["error"] == "invalid_status"
|
|
||||||
|
|
||||||
def test_update_existing_status(self, db_session, sample_paper):
|
def test_update_existing_status(self, db_session, sample_paper):
|
||||||
set_reading_status(db_session, "2401.12345", "skimmed")
|
set_reading_status(db_session, "2401.12345", "skimmed")
|
||||||
@@ -54,9 +55,8 @@ class TestReadingStatusService:
|
|||||||
assert result["status"] == "read_full"
|
assert result["status"] == "read_full"
|
||||||
|
|
||||||
def test_set_reading_status_not_found(self, db_session):
|
def test_set_reading_status_not_found(self, db_session):
|
||||||
result = set_reading_status(db_session, "nonexistent", "unread")
|
with pytest.raises(NotFoundError):
|
||||||
assert "error" in result
|
set_reading_status(db_session, "nonexistent", "unread")
|
||||||
assert result["error"] == "not_found"
|
|
||||||
|
|
||||||
def test_all_valid_statuses(self, db_session, sample_paper):
|
def test_all_valid_statuses(self, db_session, sample_paper):
|
||||||
for status in ("unread", "skimmed", "read_summary", "read_full"):
|
for status in ("unread", "skimmed", "read_summary", "read_full"):
|
||||||
@@ -93,9 +93,8 @@ class TestNoteService:
|
|||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
def test_save_note_paper_not_found(self, db_session):
|
def test_save_note_paper_not_found(self, db_session):
|
||||||
result = save_note(db_session, "nonexistent", "内容")
|
with pytest.raises(NotFoundError):
|
||||||
assert "error" in result
|
save_note(db_session, "nonexistent", "内容")
|
||||||
assert result["error"] == "not_found"
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
@@ -143,12 +142,12 @@ class TestUserDataRoutes:
|
|||||||
assert data["status"] == "read_summary"
|
assert data["status"] == "read_summary"
|
||||||
|
|
||||||
def test_reading_status_invalid(self, client, sample_paper):
|
def test_reading_status_invalid(self, client, sample_paper):
|
||||||
"""无效状态返回 422。"""
|
"""无效状态返回 400 (ValidationError)。"""
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
"/api/reading-status/2401.12345",
|
"/api/reading-status/2401.12345",
|
||||||
json={"status": "invalid"},
|
json={"status": "invalid"},
|
||||||
)
|
)
|
||||||
assert resp.status_code == 422
|
assert resp.status_code == 400
|
||||||
|
|
||||||
def test_reading_status_not_found(self, client):
|
def test_reading_status_not_found(self, client):
|
||||||
"""不存在的论文返回 404。"""
|
"""不存在的论文返回 404。"""
|
||||||
|
|||||||
@@ -709,10 +709,10 @@ dependencies = [
|
|||||||
{ name = "httpx", extra = ["http2"] },
|
{ name = "httpx", extra = ["http2"] },
|
||||||
{ name = "itsdangerous" },
|
{ name = "itsdangerous" },
|
||||||
{ name = "jinja2" },
|
{ name = "jinja2" },
|
||||||
|
{ name = "onnxruntime" },
|
||||||
{ name = "pydantic" },
|
{ name = "pydantic" },
|
||||||
{ name = "pydantic-settings" },
|
{ name = "pydantic-settings" },
|
||||||
{ name = "pymupdf" },
|
{ name = "pymupdf" },
|
||||||
{ name = "pymupdf4llm" },
|
|
||||||
{ name = "python-dotenv" },
|
{ name = "python-dotenv" },
|
||||||
{ name = "python-multipart" },
|
{ name = "python-multipart" },
|
||||||
{ name = "sqlalchemy" },
|
{ name = "sqlalchemy" },
|
||||||
@@ -741,10 +741,10 @@ requires-dist = [
|
|||||||
{ name = "httpx", extras = ["http2"], specifier = ">=0.28" },
|
{ name = "httpx", extras = ["http2"], specifier = ">=0.28" },
|
||||||
{ name = "itsdangerous", specifier = ">=2.2.0" },
|
{ name = "itsdangerous", specifier = ">=2.2.0" },
|
||||||
{ name = "jinja2", specifier = ">=3.1" },
|
{ name = "jinja2", specifier = ">=3.1" },
|
||||||
|
{ name = "onnxruntime", specifier = ">=1.17" },
|
||||||
{ name = "pydantic", specifier = ">=2.0" },
|
{ name = "pydantic", specifier = ">=2.0" },
|
||||||
{ name = "pydantic-settings", specifier = ">=2.0" },
|
{ name = "pydantic-settings", specifier = ">=2.0" },
|
||||||
{ name = "pymupdf", specifier = ">=1.25" },
|
{ name = "pymupdf", specifier = ">=1.25" },
|
||||||
{ name = "pymupdf4llm", specifier = ">=1.27.2.3" },
|
|
||||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0" },
|
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0" },
|
||||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24" },
|
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24" },
|
||||||
{ name = "python-dotenv", specifier = ">=1.0" },
|
{ name = "python-dotenv", specifier = ">=1.0" },
|
||||||
@@ -1261,15 +1261,6 @@ wheels = [
|
|||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/81/08/7036c080d7117f28a4af526d794aab6a84463126db031b007717c1a6676e/multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56", size = 12319, upload-time = "2026-01-26T02:46:44.004Z" },
|
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/81/08/7036c080d7117f28a4af526d794aab6a84463126db031b007717c1a6676e/multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56", size = 12319, upload-time = "2026-01-26T02:46:44.004Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "networkx"
|
|
||||||
version = "3.6.1"
|
|
||||||
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
|
|
||||||
sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "numpy"
|
name = "numpy"
|
||||||
version = "2.4.6"
|
version = "2.4.6"
|
||||||
@@ -1889,39 +1880,6 @@ wheels = [
|
|||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/53/a4/b9e91aac82293f9c954654c85581ee8212b5b05efadc534b581141241e6f/pymupdf-1.27.2.3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:77691604c5d1d0233827139bbcdea61fd57879c84712b8e49b1f45520f7ab9c2", size = 25000393, upload-time = "2026-04-24T14:11:01.669Z" },
|
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/53/a4/b9e91aac82293f9c954654c85581ee8212b5b05efadc534b581141241e6f/pymupdf-1.27.2.3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:77691604c5d1d0233827139bbcdea61fd57879c84712b8e49b1f45520f7ab9c2", size = 25000393, upload-time = "2026-04-24T14:11:01.669Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pymupdf-layout"
|
|
||||||
version = "1.27.2.3"
|
|
||||||
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
|
|
||||||
dependencies = [
|
|
||||||
{ name = "networkx" },
|
|
||||||
{ name = "numpy" },
|
|
||||||
{ name = "onnxruntime" },
|
|
||||||
{ name = "pymupdf" },
|
|
||||||
{ name = "pyyaml" },
|
|
||||||
]
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/bc/ee/067726c3ee5574ad5c605d00d7419e264ef509d626a726f99388111f8216/pymupdf_layout-1.27.2.3-cp310-abi3-macosx_10_9_x86_64.whl", hash = "sha256:75c2ab3c0e8830ac2bc50cfd32d375a30768a2610dac72a02f08265336e0834f", size = 15799844, upload-time = "2026-04-24T14:11:13.177Z" },
|
|
||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/0a/ba/46a7a36474722f9280d885f6eec878561a257d9378e52590b43d32ffb96c/pymupdf_layout-1.27.2.3-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:5656b09669dcd7c51f539afb6fdaf853602bab4cbc20479ee5ee1a85a4e32b60", size = 15795220, upload-time = "2026-04-24T14:11:23.17Z" },
|
|
||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/84/87/bfdcca67346052943a4549814f2009b38f4d15ec025798cdf7dfa5f57c84/pymupdf_layout-1.27.2.3-cp310-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:fcf03aa815cbceebdb3263dd6a190de4547c46b1d168928836ec38738afe127d", size = 15805240, upload-time = "2026-04-24T14:11:33.465Z" },
|
|
||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/32/e9/7ce6eaf97cebd46c3808593282e9eb99a60cddd6183e25a636980d5c7986/pymupdf_layout-1.27.2.3-cp310-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:303b9414216dfaf711ec7d807b6f1e4c3e0a92bbb4569340fcedd9d5593d16ca", size = 15806269, upload-time = "2026-04-24T14:11:43.481Z" },
|
|
||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/bf/61/3b2417d8f2cdfaa0f4749cd9dafa3379cb5cdaddf4233165f1ff81953c30/pymupdf_layout-1.27.2.3-cp310-abi3-win_amd64.whl", hash = "sha256:503b64d9b6b31ea3af79ef85cf7d36950c5048af468cb297684d2953553c62ad", size = 15809163, upload-time = "2026-04-24T14:11:53.956Z" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pymupdf4llm"
|
|
||||||
version = "1.27.2.3"
|
|
||||||
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
|
|
||||||
dependencies = [
|
|
||||||
{ name = "pymupdf" },
|
|
||||||
{ name = "pymupdf-layout" },
|
|
||||||
{ name = "tabulate" },
|
|
||||||
]
|
|
||||||
sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/87/c0/e3830452d82032c3d82a9879616c05bf0c51e0dea03c1d80d57b3a6ec0d1/pymupdf4llm-1.27.2.3.tar.gz", hash = "sha256:42ec1a47ddc62be3f4f40c116d27618611c6f9fa366719016d9ddc3f3a3dc22b", size = 1406297, upload-time = "2026-04-24T14:13:18.843Z" }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e6/38/84bf29f4dd72e6c450546df6ca8f53021f764fd945ba67dcc235d39bc20e/pymupdf4llm-1.27.2.3-py3-none-any.whl", hash = "sha256:bd724b79fa3f06a5b28d7a65f7acfa8de56e04bdb603ac2d6dff315e0d151aaa", size = 77348, upload-time = "2026-04-24T14:11:04.305Z" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pypika"
|
name = "pypika"
|
||||||
version = "0.51.1"
|
version = "0.51.1"
|
||||||
@@ -2282,15 +2240,6 @@ wheels = [
|
|||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1c/54/196d0c1db10af76baa4f64894448505d60d3cdf70ef92cbb35f46a4e4c71/starlette-1.2.1-py3-none-any.whl", hash = "sha256:4de0082d08c8f6764a85a54cf1120d6939507a19905c7768acad2a9f875d2b89", size = 73350, upload-time = "2026-05-31T01:07:50.09Z" },
|
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1c/54/196d0c1db10af76baa4f64894448505d60d3cdf70ef92cbb35f46a4e4c71/starlette-1.2.1-py3-none-any.whl", hash = "sha256:4de0082d08c8f6764a85a54cf1120d6939507a19905c7768acad2a9f875d2b89", size = 73350, upload-time = "2026-05-31T01:07:50.09Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tabulate"
|
|
||||||
version = "0.10.0"
|
|
||||||
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
|
|
||||||
sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/46/58/8c37dea7bbf769b20d58e7ace7e5edfe65b849442b00ffcdd56be88697c6/tabulate-0.10.0.tar.gz", hash = "sha256:e2cfde8f79420f6deeffdeda9aaec3b6bc5abce947655d17ac662b126e48a60d", size = 91754, upload-time = "2026-03-04T18:55:34.402Z" }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/99/55/db07de81b5c630da5cbf5c7df646580ca26dfaefa593667fc6f2fe016d2e/tabulate-0.10.0-py3-none-any.whl", hash = "sha256:f0b0622e567335c8fabaaa659f1b33bcb6ddfe2e496071b743aa113f8774f2d3", size = 39814, upload-time = "2026-03-04T18:55:31.284Z" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tenacity"
|
name = "tenacity"
|
||||||
version = "9.1.4"
|
version = "9.1.4"
|
||||||
|
|||||||
Reference in New Issue
Block a user