feat: refactor summarizer and PDF extraction pipeline
- Split summarizer into summary_generator and summary_persister modules - Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection - Add layout_detector service for PicoDet-S_layout_3cls integration - Add exceptions module with ConflictError and NotFoundError - Improve admin dashboard with better statistics and task management - Add design review document with system optimization suggestions - Add new tests for crawler, pdf_downloader, pipeline, and summary_utils - Update dependencies and configuration - Clean up dead code and improve error handling
This commit is contained in:
@@ -0,0 +1,172 @@
|
||||
"""导出 PicoDet-S_layout_3cls 为 ONNX 格式.
|
||||
|
||||
一次性脚本,在独立 venv 中运行:
|
||||
python -m venv .venv-export && source .venv-export/bin/activate
|
||||
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple paddlepaddle paddleocr paddle2onnx onnxruntime opencv-python-headless
|
||||
HF_ENDPOINT=https://hf-mirror.com python scripts/export_picodet_onnx.py
|
||||
|
||||
输出:
|
||||
data/models/picodet_layout_3cls.onnx
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# hf-mirror
|
||||
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
MODEL_DIR = PROJECT_ROOT / "data" / "models"
|
||||
OUTPUT_PATH = MODEL_DIR / "picodet_layout_3cls.onnx"
|
||||
MODEL_NAME = "PicoDet-S_layout_3cls"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ── Step 1: 用 PaddleOCR paddle_static 引擎加载模型,触发下载 ──
|
||||
print(f"[1/4] Loading model '{MODEL_NAME}' (paddle_static engine, triggers download) ...")
|
||||
from paddleocr import LayoutDetection
|
||||
|
||||
model = LayoutDetection(
|
||||
model_name=MODEL_NAME,
|
||||
engine="paddle_static",
|
||||
device="cpu",
|
||||
)
|
||||
print(" ✓ Model loaded and cached")
|
||||
|
||||
# ── Step 2: 找到 PaddleX 缓存的 Paddle 模型文件 ────────────────
|
||||
paddlex_cache = Path.home() / ".paddlex"
|
||||
print(f"\n[2/4] Searching Paddle model cache in {paddlex_cache} ...")
|
||||
|
||||
# 搜索 layout 相关的缓存目录
|
||||
candidates = []
|
||||
for d in paddlex_cache.rglob("*"):
|
||||
if d.is_dir() and (d / "inference.pdiparams").exists():
|
||||
# 检查是否是 layout 模型
|
||||
marker = d.name
|
||||
parent_name = d.parent.name
|
||||
if "layout" in marker.lower() or "layout" in parent_name.lower() or "picodet" in marker.lower():
|
||||
candidates.append(d)
|
||||
elif "PicoDet" in str(d):
|
||||
candidates.append(d)
|
||||
|
||||
if not candidates:
|
||||
# 如果没找到明确的 layout 目录,列出所有含 inference.pdiparams 的目录
|
||||
all_model_dirs = [d for d in paddlex_cache.rglob("*") if d.is_dir() and (d / "inference.pdiparams").exists()]
|
||||
print(" No layout-specific dir found. All model dirs with inference.pdiparams:")
|
||||
for d in all_model_dirs:
|
||||
files = [f.name for f in d.iterdir()]
|
||||
print(f" {d} ({', '.join(files)})")
|
||||
if all_model_dirs:
|
||||
# 取最新的(刚下载的)
|
||||
candidates = sorted(all_model_dirs, key=lambda d: (d / "inference.pdiparams").stat().st_mtime, reverse=True)[:1]
|
||||
|
||||
if not candidates:
|
||||
print(" ✗ No cached model found")
|
||||
sys.exit(1)
|
||||
|
||||
model_cache_dir = candidates[0]
|
||||
files_in_dir = list(model_cache_dir.iterdir())
|
||||
print(f" Using: {model_cache_dir}")
|
||||
for f in files_in_dir:
|
||||
print(f" {f.name} ({f.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
# ── Step 3: 用 paddle2onnx 转换 ─────────────────────────────────
|
||||
print("\n[3/4] Converting to ONNX with paddle2onnx ...")
|
||||
tmp_onnx = OUTPUT_PATH.with_suffix(".tmp.onnx")
|
||||
|
||||
# 确定 model_filename
|
||||
pdmodel = model_cache_dir / "inference.pdmodel"
|
||||
has_pdmodel = pdmodel.exists()
|
||||
|
||||
cmd = [
|
||||
sys.executable, "-m", "paddle2onnx",
|
||||
"--model_dir", str(model_cache_dir),
|
||||
"--save_file", str(tmp_onnx),
|
||||
"--opset_version", "11",
|
||||
"--enable_onnx_checker", "True",
|
||||
]
|
||||
if has_pdmodel:
|
||||
cmd.extend(["--model_filename", "inference.pdmodel"])
|
||||
cmd.extend(["--params_filename", "inference.pdiparams"])
|
||||
|
||||
print(f" Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.stdout:
|
||||
print(f" stdout: {result.stdout[:500]}")
|
||||
if result.returncode != 0:
|
||||
print(f" ✗ paddle2onnx failed (exit {result.returncode})")
|
||||
print(f" stderr: {result.stderr[:500]}")
|
||||
|
||||
# 尝试不带 model_filename(combined format)
|
||||
if has_pdmodel:
|
||||
print(" Retrying without explicit model_filename ...")
|
||||
cmd2 = [
|
||||
sys.executable, "-m", "paddle2onnx",
|
||||
"--model_dir", str(model_cache_dir),
|
||||
"--params_filename", "inference.pdiparams",
|
||||
"--save_file", str(tmp_onnx),
|
||||
"--opset_version", "11",
|
||||
]
|
||||
result2 = subprocess.run(cmd2, capture_output=True, text=True)
|
||||
if result2.returncode != 0:
|
||||
print(f" ✗ Retry also failed: {result2.stderr[:500]}")
|
||||
sys.exit(1)
|
||||
|
||||
if not tmp_onnx.exists() or tmp_onnx.stat().st_size < 1000:
|
||||
print(" ✗ ONNX file not created or too small")
|
||||
sys.exit(1)
|
||||
|
||||
shutil.move(str(tmp_onnx), str(OUTPUT_PATH))
|
||||
print(f" ✓ ONNX saved ({OUTPUT_PATH.stat().st_size / 1024 / 1024:.2f} MB)")
|
||||
|
||||
# ── Step 4: 用 onnxruntime 验证 ─────────────────────────────────
|
||||
print("\n[4/4] Verifying with onnxruntime ...")
|
||||
_inspect_onnx(OUTPUT_PATH)
|
||||
|
||||
print(f"\n✓ Done! ONNX model saved to {OUTPUT_PATH}")
|
||||
|
||||
|
||||
def _inspect_onnx(onnx_path: Path) -> None:
|
||||
"""用 onnxruntime 加载模型,打印输入输出信息."""
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
|
||||
|
||||
print(" Inputs:")
|
||||
for inp in session.get_inputs():
|
||||
print(f" {inp.name}: shape={inp.shape}, dtype={inp.type}")
|
||||
|
||||
print(" Outputs:")
|
||||
for out in session.get_outputs():
|
||||
print(f" {out.name}: shape={out.shape}, dtype={out.type}")
|
||||
|
||||
# 试推理
|
||||
input_info = session.get_inputs()[0]
|
||||
input_name = input_info.name
|
||||
batch_size = input_info.shape[0] if isinstance(input_info.shape[0], int) else 1
|
||||
channels = input_info.shape[1] if isinstance(input_info.shape[1], int) else 3
|
||||
height = input_info.shape[2] if isinstance(input_info.shape[2], int) else 480
|
||||
width = input_info.shape[3] if isinstance(input_info.shape[3], int) else 480
|
||||
|
||||
dummy_input = np.random.rand(batch_size, channels, height, width).astype(np.float32)
|
||||
outputs = session.run(None, {input_name: dummy_input})
|
||||
|
||||
print(" Inference test outputs:")
|
||||
for i, (out_info, out_val) in enumerate(zip(session.get_outputs(), outputs)):
|
||||
print(f" output[{i}] '{out_info.name}': shape={out_val.shape}, dtype={out_val.dtype}")
|
||||
if out_val.size <= 20:
|
||||
print(f" values: {out_val}")
|
||||
|
||||
print(" ✓ Inference OK")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,212 @@
|
||||
"""批量重新提取所有论文的图片 — 下载 PDF + PicoDet 检测 + caption 匹配.
|
||||
|
||||
用法:
|
||||
PROXY_SERVER=http://... uv run python scripts/reextract_images.py
|
||||
uv run python scripts/reextract_images.py --limit 10 # 只处理前 10 篇
|
||||
uv run python scripts/reextract_images.py --id 2512.24880 # 只处理指定论文
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
# 让脚本可以从项目根目录直接运行
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
from app.database import SessionLocal, init_db, engine # noqa: E402
|
||||
from app.models import Paper # noqa: E402
|
||||
from app.services.pdf_image_extractor import extract_images_from_pdf # noqa: E402
|
||||
from app.utils import TMP_DIR # noqa: E402
|
||||
from sqlalchemy import select # noqa: E402
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)-5s %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 下载并发数
|
||||
MAX_WORKERS = 3
|
||||
# 下载超时(秒)
|
||||
DOWNLOAD_TIMEOUT = 120
|
||||
|
||||
|
||||
def _get_session() -> requests.Session:
|
||||
"""创建带代理的 HTTP session。"""
|
||||
sess = requests.Session()
|
||||
sess.headers.update({"User-Agent": "hf-daily-papers/1.0"})
|
||||
proxy = os.environ.get("PROXY_SERVER") or os.environ.get("HTTPS_PROXY")
|
||||
if proxy:
|
||||
sess.proxies = {"http": proxy, "https": proxy}
|
||||
logger.info("使用代理: %s", proxy)
|
||||
else:
|
||||
logger.warning("未设置代理 (PROXY_SERVER / HTTPS_PROXY),直连 arxiv.org")
|
||||
return sess
|
||||
|
||||
|
||||
def download_pdf(session: requests.Session, arxiv_id: str, pdf_url: str) -> Path | None:
|
||||
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf,返回路径或 None。"""
|
||||
dest_dir = TMP_DIR / arxiv_id
|
||||
dest = dest_dir / "paper.pdf"
|
||||
if dest.exists() and dest.stat().st_size > 1000:
|
||||
return dest
|
||||
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
resp = session.get(pdf_url, timeout=DOWNLOAD_TIMEOUT, allow_redirects=True)
|
||||
resp.raise_for_status()
|
||||
dest.write_bytes(resp.content)
|
||||
return dest
|
||||
except Exception as exc:
|
||||
logger.warning("下载失败 %s: %s", arxiv_id, exc)
|
||||
return None
|
||||
|
||||
|
||||
def process_one(session: requests.Session, arxiv_id: str, pdf_url: str) -> dict:
|
||||
"""处理单篇论文:下载 → 提取图片 → 返回统计。"""
|
||||
result = {"arxiv_id": arxiv_id, "downloaded": False, "extracted": 0, "error": None}
|
||||
|
||||
# 下载 PDF
|
||||
pdf_path = download_pdf(session, arxiv_id, pdf_url)
|
||||
if pdf_path is None:
|
||||
result["error"] = "download_failed"
|
||||
return result
|
||||
result["downloaded"] = True
|
||||
|
||||
# 提取图片
|
||||
try:
|
||||
n = extract_images_from_pdf(arxiv_id, pdf_path)
|
||||
result["extracted"] = n
|
||||
except Exception as exc:
|
||||
logger.warning("提取失败 %s: %s", arxiv_id, exc, exc_info=True)
|
||||
result["error"] = f"extract_failed: {exc}"
|
||||
return result
|
||||
|
||||
# 统计 matched / orphan
|
||||
mf = Path(f"data/papers/{arxiv_id}/images/manifest.json")
|
||||
if mf.exists():
|
||||
m = json.loads(mf.read_text(encoding="utf-8"))
|
||||
result["matched"] = sum(1 for v in m.values() if "(p" not in v.get("label", ""))
|
||||
result["orphan"] = sum(1 for v in m.values() if "(p" in v.get("label", ""))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="批量重新提取论文图片")
|
||||
parser.add_argument("--limit", type=int, default=0, help="只处理前 N 篇")
|
||||
parser.add_argument("--id", dest="arxiv_id", help="只处理指定 arxiv_id")
|
||||
parser.add_argument("--workers", type=int, default=MAX_WORKERS, help="并发数")
|
||||
args = parser.parse_args()
|
||||
|
||||
# 初始化数据库
|
||||
os.makedirs("data/db", exist_ok=True)
|
||||
init_db(engine)
|
||||
|
||||
# 读取论文列表
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if args.arxiv_id:
|
||||
papers = (
|
||||
db.execute(select(Paper).where(Paper.arxiv_id == args.arxiv_id))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
papers = db.execute(select(Paper)).scalars().all()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if args.limit > 0:
|
||||
papers = papers[: args.limit]
|
||||
|
||||
total = len(papers)
|
||||
logger.info("待处理论文: %d 篇", total)
|
||||
if total == 0:
|
||||
return
|
||||
|
||||
session = _get_session()
|
||||
|
||||
# 统计
|
||||
done = 0
|
||||
failed = 0
|
||||
total_extracted = 0
|
||||
total_matched = 0
|
||||
total_orphan = 0
|
||||
t0 = time.time()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=args.workers) as pool:
|
||||
futures = {}
|
||||
for p in papers:
|
||||
f = pool.submit(process_one, session, p.arxiv_id, p.pdf_url)
|
||||
futures[f] = p.arxiv_id
|
||||
|
||||
for f in as_completed(futures):
|
||||
arxiv_id = futures[f]
|
||||
try:
|
||||
r = f.result()
|
||||
except Exception as exc:
|
||||
logger.error("异常 %s: %s", arxiv_id, exc)
|
||||
failed += 1
|
||||
done += 1
|
||||
continue
|
||||
|
||||
done += 1
|
||||
if r["error"]:
|
||||
failed += 1
|
||||
logger.info("[%d/%d] ✗ %s — %s", done, total, arxiv_id, r["error"])
|
||||
else:
|
||||
total_extracted += r["extracted"]
|
||||
total_matched += r.get("matched", 0)
|
||||
total_orphan += r.get("orphan", 0)
|
||||
matched = r.get("matched", 0)
|
||||
orphan = r.get("orphan", 0)
|
||||
elapsed = time.time() - t0
|
||||
rate = done / elapsed if elapsed > 0 else 0
|
||||
eta = (total - done) / rate if rate > 0 else 0
|
||||
logger.info(
|
||||
"[%d/%d] ✓ %s — %d 张 (matched=%d, orphan=%d) ETA %.0fs",
|
||||
done,
|
||||
total,
|
||||
arxiv_id,
|
||||
r["extracted"],
|
||||
matched,
|
||||
orphan,
|
||||
eta,
|
||||
)
|
||||
|
||||
elapsed = time.time() - t0
|
||||
logger.info("=" * 60)
|
||||
logger.info(
|
||||
"完成: %d/%d 成功, %d 失败, 耗时 %.1fs",
|
||||
done - failed,
|
||||
total,
|
||||
failed,
|
||||
elapsed,
|
||||
)
|
||||
logger.info(
|
||||
"图片: %d 总计, %d matched, %d orphan (%.1f%%)",
|
||||
total_extracted,
|
||||
total_matched,
|
||||
total_orphan,
|
||||
total_orphan / total_extracted * 100 if total_extracted else 0,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user