feat: refactor summarizer and PDF extraction pipeline

- Split summarizer into summary_generator and summary_persister modules
- Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection
- Add layout_detector service for PicoDet-S_layout_3cls integration
- Add exceptions module with ConflictError and NotFoundError
- Improve admin dashboard with better statistics and task management
- Add design review document with system optimization suggestions
- Add new tests for crawler, pdf_downloader, pipeline, and summary_utils
- Update dependencies and configuration
- Clean up dead code and improve error handling
This commit is contained in:
2026-06-13 13:16:47 +08:00
parent e2f0e1a8be
commit 21f16e6756
43 changed files with 3304 additions and 1494 deletions
+212
View File
@@ -0,0 +1,212 @@
"""批量重新提取所有论文的图片 — 下载 PDF + PicoDet 检测 + caption 匹配.
用法:
PROXY_SERVER=http://... uv run python scripts/reextract_images.py
uv run python scripts/reextract_images.py --limit 10 # 只处理前 10 篇
uv run python scripts/reextract_images.py --id 2512.24880 # 只处理指定论文
"""
from __future__ import annotations
import json
import logging
import os
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import requests
# 让脚本可以从项目根目录直接运行
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from dotenv import load_dotenv
load_dotenv()
from app.database import SessionLocal, init_db, engine # noqa: E402
from app.models import Paper # noqa: E402
from app.services.pdf_image_extractor import extract_images_from_pdf # noqa: E402
from app.utils import TMP_DIR # noqa: E402
from sqlalchemy import select # noqa: E402
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-5s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
# 下载并发数
MAX_WORKERS = 3
# 下载超时(秒)
DOWNLOAD_TIMEOUT = 120
def _get_session() -> requests.Session:
"""创建带代理的 HTTP session。"""
sess = requests.Session()
sess.headers.update({"User-Agent": "hf-daily-papers/1.0"})
proxy = os.environ.get("PROXY_SERVER") or os.environ.get("HTTPS_PROXY")
if proxy:
sess.proxies = {"http": proxy, "https": proxy}
logger.info("使用代理: %s", proxy)
else:
logger.warning("未设置代理 (PROXY_SERVER / HTTPS_PROXY),直连 arxiv.org")
return sess
def download_pdf(session: requests.Session, arxiv_id: str, pdf_url: str) -> Path | None:
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf,返回路径或 None。"""
dest_dir = TMP_DIR / arxiv_id
dest = dest_dir / "paper.pdf"
if dest.exists() and dest.stat().st_size > 1000:
return dest
dest_dir.mkdir(parents=True, exist_ok=True)
try:
resp = session.get(pdf_url, timeout=DOWNLOAD_TIMEOUT, allow_redirects=True)
resp.raise_for_status()
dest.write_bytes(resp.content)
return dest
except Exception as exc:
logger.warning("下载失败 %s: %s", arxiv_id, exc)
return None
def process_one(session: requests.Session, arxiv_id: str, pdf_url: str) -> dict:
"""处理单篇论文:下载 → 提取图片 → 返回统计。"""
result = {"arxiv_id": arxiv_id, "downloaded": False, "extracted": 0, "error": None}
# 下载 PDF
pdf_path = download_pdf(session, arxiv_id, pdf_url)
if pdf_path is None:
result["error"] = "download_failed"
return result
result["downloaded"] = True
# 提取图片
try:
n = extract_images_from_pdf(arxiv_id, pdf_path)
result["extracted"] = n
except Exception as exc:
logger.warning("提取失败 %s: %s", arxiv_id, exc, exc_info=True)
result["error"] = f"extract_failed: {exc}"
return result
# 统计 matched / orphan
mf = Path(f"data/papers/{arxiv_id}/images/manifest.json")
if mf.exists():
m = json.loads(mf.read_text(encoding="utf-8"))
result["matched"] = sum(1 for v in m.values() if "(p" not in v.get("label", ""))
result["orphan"] = sum(1 for v in m.values() if "(p" in v.get("label", ""))
return result
def main():
import argparse
parser = argparse.ArgumentParser(description="批量重新提取论文图片")
parser.add_argument("--limit", type=int, default=0, help="只处理前 N 篇")
parser.add_argument("--id", dest="arxiv_id", help="只处理指定 arxiv_id")
parser.add_argument("--workers", type=int, default=MAX_WORKERS, help="并发数")
args = parser.parse_args()
# 初始化数据库
os.makedirs("data/db", exist_ok=True)
init_db(engine)
# 读取论文列表
db = SessionLocal()
try:
if args.arxiv_id:
papers = (
db.execute(select(Paper).where(Paper.arxiv_id == args.arxiv_id))
.scalars()
.all()
)
else:
papers = db.execute(select(Paper)).scalars().all()
finally:
db.close()
if args.limit > 0:
papers = papers[: args.limit]
total = len(papers)
logger.info("待处理论文: %d", total)
if total == 0:
return
session = _get_session()
# 统计
done = 0
failed = 0
total_extracted = 0
total_matched = 0
total_orphan = 0
t0 = time.time()
with ThreadPoolExecutor(max_workers=args.workers) as pool:
futures = {}
for p in papers:
f = pool.submit(process_one, session, p.arxiv_id, p.pdf_url)
futures[f] = p.arxiv_id
for f in as_completed(futures):
arxiv_id = futures[f]
try:
r = f.result()
except Exception as exc:
logger.error("异常 %s: %s", arxiv_id, exc)
failed += 1
done += 1
continue
done += 1
if r["error"]:
failed += 1
logger.info("[%d/%d] ✗ %s%s", done, total, arxiv_id, r["error"])
else:
total_extracted += r["extracted"]
total_matched += r.get("matched", 0)
total_orphan += r.get("orphan", 0)
matched = r.get("matched", 0)
orphan = r.get("orphan", 0)
elapsed = time.time() - t0
rate = done / elapsed if elapsed > 0 else 0
eta = (total - done) / rate if rate > 0 else 0
logger.info(
"[%d/%d] ✓ %s%d 张 (matched=%d, orphan=%d) ETA %.0fs",
done,
total,
arxiv_id,
r["extracted"],
matched,
orphan,
eta,
)
elapsed = time.time() - t0
logger.info("=" * 60)
logger.info(
"完成: %d/%d 成功, %d 失败, 耗时 %.1fs",
done - failed,
total,
failed,
elapsed,
)
logger.info(
"图片: %d 总计, %d matched, %d orphan (%.1f%%)",
total_extracted,
total_matched,
total_orphan,
total_orphan / total_extracted * 100 if total_extracted else 0,
)
if __name__ == "__main__":
main()