90fe705e8f
- 核心变更: - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024) - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备 - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式 - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls] - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table) - 新增文件: - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行) - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例) - 配置更新: - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID) - app/config.py: Settings 类对应字段 - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等) - 删除旧文件: - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本 - 文档更新: - README.md: 更新环境变量说明 - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py) 此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
213 lines
6.4 KiB
Python
213 lines
6.4 KiB
Python
"""批量重新提取所有论文的图片 — 下载 PDF + DocLayout-YOLO 检测 + caption 匹配.
|
|
|
|
用法:
|
|
PROXY_SERVER=http://... uv run python scripts/reextract_images.py
|
|
uv run python scripts/reextract_images.py --limit 10 # 只处理前 10 篇
|
|
uv run python scripts/reextract_images.py --id 2512.24880 # 只处理指定论文
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from pathlib import Path
|
|
|
|
import requests
|
|
|
|
# 让脚本可以从项目根目录直接运行
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
from app.database import SessionLocal, init_db, engine # noqa: E402
|
|
from app.models import Paper # noqa: E402
|
|
from app.services.pdf_image_extractor import extract_images_from_pdf # noqa: E402
|
|
from app.utils import TMP_DIR # noqa: E402
|
|
from sqlalchemy import select # noqa: E402
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)-5s %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 下载并发数
|
|
MAX_WORKERS = 3
|
|
# 下载超时(秒)
|
|
DOWNLOAD_TIMEOUT = 120
|
|
|
|
|
|
def _get_session() -> requests.Session:
|
|
"""创建带代理的 HTTP session。"""
|
|
sess = requests.Session()
|
|
sess.headers.update({"User-Agent": "hf-daily-papers/1.0"})
|
|
proxy = os.environ.get("PROXY_SERVER") or os.environ.get("HTTPS_PROXY")
|
|
if proxy:
|
|
sess.proxies = {"http": proxy, "https": proxy}
|
|
logger.info("使用代理: %s", proxy)
|
|
else:
|
|
logger.warning("未设置代理 (PROXY_SERVER / HTTPS_PROXY),直连 arxiv.org")
|
|
return sess
|
|
|
|
|
|
def download_pdf(session: requests.Session, arxiv_id: str, pdf_url: str) -> Path | None:
|
|
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf,返回路径或 None。"""
|
|
dest_dir = TMP_DIR / arxiv_id
|
|
dest = dest_dir / "paper.pdf"
|
|
if dest.exists() and dest.stat().st_size > 1000:
|
|
return dest
|
|
|
|
dest_dir.mkdir(parents=True, exist_ok=True)
|
|
try:
|
|
resp = session.get(pdf_url, timeout=DOWNLOAD_TIMEOUT, allow_redirects=True)
|
|
resp.raise_for_status()
|
|
dest.write_bytes(resp.content)
|
|
return dest
|
|
except Exception as exc:
|
|
logger.warning("下载失败 %s: %s", arxiv_id, exc)
|
|
return None
|
|
|
|
|
|
def process_one(session: requests.Session, arxiv_id: str, pdf_url: str) -> dict:
|
|
"""处理单篇论文:下载 → 提取图片 → 返回统计。"""
|
|
result = {"arxiv_id": arxiv_id, "downloaded": False, "extracted": 0, "error": None}
|
|
|
|
# 下载 PDF
|
|
pdf_path = download_pdf(session, arxiv_id, pdf_url)
|
|
if pdf_path is None:
|
|
result["error"] = "download_failed"
|
|
return result
|
|
result["downloaded"] = True
|
|
|
|
# 提取图片
|
|
try:
|
|
n = extract_images_from_pdf(arxiv_id, pdf_path)
|
|
result["extracted"] = n
|
|
except Exception as exc:
|
|
logger.warning("提取失败 %s: %s", arxiv_id, exc, exc_info=True)
|
|
result["error"] = f"extract_failed: {exc}"
|
|
return result
|
|
|
|
# 统计 matched / orphan
|
|
mf = Path(f"data/papers/{arxiv_id}/images/manifest.json")
|
|
if mf.exists():
|
|
m = json.loads(mf.read_text(encoding="utf-8"))
|
|
result["matched"] = sum(1 for v in m.values() if "(p" not in v.get("label", ""))
|
|
result["orphan"] = sum(1 for v in m.values() if "(p" in v.get("label", ""))
|
|
|
|
return result
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="批量重新提取论文图片")
|
|
parser.add_argument("--limit", type=int, default=0, help="只处理前 N 篇")
|
|
parser.add_argument("--id", dest="arxiv_id", help="只处理指定 arxiv_id")
|
|
parser.add_argument("--workers", type=int, default=MAX_WORKERS, help="并发数")
|
|
args = parser.parse_args()
|
|
|
|
# 初始化数据库
|
|
os.makedirs("data/db", exist_ok=True)
|
|
init_db(engine)
|
|
|
|
# 读取论文列表
|
|
db = SessionLocal()
|
|
try:
|
|
if args.arxiv_id:
|
|
papers = (
|
|
db.execute(select(Paper).where(Paper.arxiv_id == args.arxiv_id))
|
|
.scalars()
|
|
.all()
|
|
)
|
|
else:
|
|
papers = db.execute(select(Paper)).scalars().all()
|
|
finally:
|
|
db.close()
|
|
|
|
if args.limit > 0:
|
|
papers = papers[: args.limit]
|
|
|
|
total = len(papers)
|
|
logger.info("待处理论文: %d 篇", total)
|
|
if total == 0:
|
|
return
|
|
|
|
session = _get_session()
|
|
|
|
# 统计
|
|
done = 0
|
|
failed = 0
|
|
total_extracted = 0
|
|
total_matched = 0
|
|
total_orphan = 0
|
|
t0 = time.time()
|
|
|
|
with ThreadPoolExecutor(max_workers=args.workers) as pool:
|
|
futures = {}
|
|
for p in papers:
|
|
f = pool.submit(process_one, session, p.arxiv_id, p.pdf_url)
|
|
futures[f] = p.arxiv_id
|
|
|
|
for f in as_completed(futures):
|
|
arxiv_id = futures[f]
|
|
try:
|
|
r = f.result()
|
|
except Exception as exc:
|
|
logger.error("异常 %s: %s", arxiv_id, exc)
|
|
failed += 1
|
|
done += 1
|
|
continue
|
|
|
|
done += 1
|
|
if r["error"]:
|
|
failed += 1
|
|
logger.info("[%d/%d] ✗ %s — %s", done, total, arxiv_id, r["error"])
|
|
else:
|
|
total_extracted += r["extracted"]
|
|
total_matched += r.get("matched", 0)
|
|
total_orphan += r.get("orphan", 0)
|
|
matched = r.get("matched", 0)
|
|
orphan = r.get("orphan", 0)
|
|
elapsed = time.time() - t0
|
|
rate = done / elapsed if elapsed > 0 else 0
|
|
eta = (total - done) / rate if rate > 0 else 0
|
|
logger.info(
|
|
"[%d/%d] ✓ %s — %d 张 (matched=%d, orphan=%d) ETA %.0fs",
|
|
done,
|
|
total,
|
|
arxiv_id,
|
|
r["extracted"],
|
|
matched,
|
|
orphan,
|
|
eta,
|
|
)
|
|
|
|
elapsed = time.time() - t0
|
|
logger.info("=" * 60)
|
|
logger.info(
|
|
"完成: %d/%d 成功, %d 失败, 耗时 %.1fs",
|
|
done - failed,
|
|
total,
|
|
failed,
|
|
elapsed,
|
|
)
|
|
logger.info(
|
|
"图片: %d 总计, %d matched, %d orphan (%.1f%%)",
|
|
total_extracted,
|
|
total_matched,
|
|
total_orphan,
|
|
total_orphan / total_extracted * 100 if total_extracted else 0,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|