"""批量重新提取所有论文的图片 — 下载 PDF + PicoDet 检测 + caption 匹配. 用法: PROXY_SERVER=http://... uv run python scripts/reextract_images.py uv run python scripts/reextract_images.py --limit 10 # 只处理前 10 篇 uv run python scripts/reextract_images.py --id 2512.24880 # 只处理指定论文 """ from __future__ import annotations import json import logging import os import sys import time from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import requests # 让脚本可以从项目根目录直接运行 sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from dotenv import load_dotenv load_dotenv() from app.database import SessionLocal, init_db, engine # noqa: E402 from app.models import Paper # noqa: E402 from app.services.pdf_image_extractor import extract_images_from_pdf # noqa: E402 from app.utils import TMP_DIR # noqa: E402 from sqlalchemy import select # noqa: E402 logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-5s %(message)s", datefmt="%H:%M:%S", ) logger = logging.getLogger(__name__) # 下载并发数 MAX_WORKERS = 3 # 下载超时(秒) DOWNLOAD_TIMEOUT = 120 def _get_session() -> requests.Session: """创建带代理的 HTTP session。""" sess = requests.Session() sess.headers.update({"User-Agent": "hf-daily-papers/1.0"}) proxy = os.environ.get("PROXY_SERVER") or os.environ.get("HTTPS_PROXY") if proxy: sess.proxies = {"http": proxy, "https": proxy} logger.info("使用代理: %s", proxy) else: logger.warning("未设置代理 (PROXY_SERVER / HTTPS_PROXY),直连 arxiv.org") return sess def download_pdf(session: requests.Session, arxiv_id: str, pdf_url: str) -> Path | None: """下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf,返回路径或 None。""" dest_dir = TMP_DIR / arxiv_id dest = dest_dir / "paper.pdf" if dest.exists() and dest.stat().st_size > 1000: return dest dest_dir.mkdir(parents=True, exist_ok=True) try: resp = session.get(pdf_url, timeout=DOWNLOAD_TIMEOUT, allow_redirects=True) resp.raise_for_status() dest.write_bytes(resp.content) return dest except Exception as exc: logger.warning("下载失败 %s: %s", arxiv_id, exc) return None def process_one(session: requests.Session, arxiv_id: str, pdf_url: str) -> dict: """处理单篇论文:下载 → 提取图片 → 返回统计。""" result = {"arxiv_id": arxiv_id, "downloaded": False, "extracted": 0, "error": None} # 下载 PDF pdf_path = download_pdf(session, arxiv_id, pdf_url) if pdf_path is None: result["error"] = "download_failed" return result result["downloaded"] = True # 提取图片 try: n = extract_images_from_pdf(arxiv_id, pdf_path) result["extracted"] = n except Exception as exc: logger.warning("提取失败 %s: %s", arxiv_id, exc, exc_info=True) result["error"] = f"extract_failed: {exc}" return result # 统计 matched / orphan mf = Path(f"data/papers/{arxiv_id}/images/manifest.json") if mf.exists(): m = json.loads(mf.read_text(encoding="utf-8")) result["matched"] = sum(1 for v in m.values() if "(p" not in v.get("label", "")) result["orphan"] = sum(1 for v in m.values() if "(p" in v.get("label", "")) return result def main(): import argparse parser = argparse.ArgumentParser(description="批量重新提取论文图片") parser.add_argument("--limit", type=int, default=0, help="只处理前 N 篇") parser.add_argument("--id", dest="arxiv_id", help="只处理指定 arxiv_id") parser.add_argument("--workers", type=int, default=MAX_WORKERS, help="并发数") args = parser.parse_args() # 初始化数据库 os.makedirs("data/db", exist_ok=True) init_db(engine) # 读取论文列表 db = SessionLocal() try: if args.arxiv_id: papers = ( db.execute(select(Paper).where(Paper.arxiv_id == args.arxiv_id)) .scalars() .all() ) else: papers = db.execute(select(Paper)).scalars().all() finally: db.close() if args.limit > 0: papers = papers[: args.limit] total = len(papers) logger.info("待处理论文: %d 篇", total) if total == 0: return session = _get_session() # 统计 done = 0 failed = 0 total_extracted = 0 total_matched = 0 total_orphan = 0 t0 = time.time() with ThreadPoolExecutor(max_workers=args.workers) as pool: futures = {} for p in papers: f = pool.submit(process_one, session, p.arxiv_id, p.pdf_url) futures[f] = p.arxiv_id for f in as_completed(futures): arxiv_id = futures[f] try: r = f.result() except Exception as exc: logger.error("异常 %s: %s", arxiv_id, exc) failed += 1 done += 1 continue done += 1 if r["error"]: failed += 1 logger.info("[%d/%d] ✗ %s — %s", done, total, arxiv_id, r["error"]) else: total_extracted += r["extracted"] total_matched += r.get("matched", 0) total_orphan += r.get("orphan", 0) matched = r.get("matched", 0) orphan = r.get("orphan", 0) elapsed = time.time() - t0 rate = done / elapsed if elapsed > 0 else 0 eta = (total - done) / rate if rate > 0 else 0 logger.info( "[%d/%d] ✓ %s — %d 张 (matched=%d, orphan=%d) ETA %.0fs", done, total, arxiv_id, r["extracted"], matched, orphan, eta, ) elapsed = time.time() - t0 logger.info("=" * 60) logger.info( "完成: %d/%d 成功, %d 失败, 耗时 %.1fs", done - failed, total, failed, elapsed, ) logger.info( "图片: %d 总计, %d matched, %d orphan (%.1f%%)", total_extracted, total_matched, total_orphan, total_orphan / total_extracted * 100 if total_extracted else 0, ) if __name__ == "__main__": main()