Files
daily-paper/app/services/pdf_image_extractor.py
T
Rain-Bus 21f16e6756 feat: refactor summarizer and PDF extraction pipeline
- Split summarizer into summary_generator and summary_persister modules
- Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection
- Add layout_detector service for PicoDet-S_layout_3cls integration
- Add exceptions module with ConflictError and NotFoundError
- Improve admin dashboard with better statistics and task management
- Add design review document with system optimization suggestions
- Add new tests for crawler, pdf_downloader, pipeline, and summary_utils
- Update dependencies and configuration
- Clean up dead code and improve error handling
2026-06-13 13:16:47 +08:00

579 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""PDF 图片与表格提取 — 两阶段流水线。
Phase 1: PicoDet-S_layout_3cls 检测 figure/table 区域 → 渲染为 JPEG(通用标签)
Phase 2: 用 LLM summary 的 figures[].id 在 PDF 中搜索定位 → 匹配到 box → 重命名
相比旧方案(正则匹配 caption):
- 不再依赖正则,用 LLM 输出的 ID 直接搜索 PDF 文本
- page.search_for() 精确搜索 + 空间距离过滤,避免正文引用误匹配
- 通用标签兜底,LLM 没提到的图表不会被丢弃
"""
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
import pymupdf
from app.services.layout_detector import LayoutBox, detect_page_layout
from app.services.pdf_downloader import paper_dir
from app.utils import PAPERS_DIR, TMP_DIR
logger = logging.getLogger(__name__)
# 截图区域的外边距(单位: pt
_REGION_PADDING = 5
# 渲染倍率(3x 保证清晰度)
_RENDER_ZOOM = 3
# 相邻 box 聚类间距(单位: pt)— 同一 figure/table 的碎片间距通常 < 15pt
_CLUSTER_GAP = 15
# 最小 bbox 面积(单位: pt²)— 过滤 icon/logo 等微小误检
_MIN_BOX_AREA = 2000
# Phase 2: 搜索文本到 box 的最大匹配距离(单位: pt)
_LABEL_MATCH_DISTANCE = 100
# ── Box 聚类 ─────────────────────────────────────────────────────────
class _BoxCluster:
"""合并后的布局区域(由一个或多个相邻 LayoutBox 组成)。"""
__slots__ = ("x0", "y0", "x1", "y1", "boxclass")
def __init__(self, boxes: list):
self.x0 = min(b.x0 for b in boxes)
self.y0 = min(b.y0 for b in boxes)
self.x1 = max(b.x1 for b in boxes)
self.y1 = max(b.y1 for b in boxes)
raw = boxes[0].boxclass
self.boxclass = "table" if raw == "table-fallback" else raw
def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
"""将相邻的同类型 box 合并为聚类。"""
if not boxes:
return []
n = len(boxes)
parent = list(range(n))
def find(x: int) -> int:
while parent[x] != x:
parent[x] = parent[parent[x]]
x = parent[x]
return x
def union(a: int, b: int) -> None:
ra, rb = find(a), find(b)
if ra != rb:
parent[ra] = rb
for i in range(n):
bi = boxes[i]
for j in range(i + 1, n):
bj = boxes[j]
if bi.boxclass != bj.boxclass:
continue
h_gap = max(0.0, max(bi.x0, bj.x0) - min(bi.x1, bj.x1))
v_gap = max(0.0, max(bi.y0, bj.y0) - min(bi.y1, bj.y1))
h_overlap = bi.x1 > bj.x0 - gap and bj.x1 > bi.x0 - gap
v_overlap = bi.y1 > bj.y0 - gap and bj.y1 > bi.y0 - gap
if (h_gap <= gap and v_overlap) or (v_gap <= gap and h_overlap):
union(i, j)
groups: dict[int, list] = {}
for i in range(n):
groups.setdefault(find(i), []).append(boxes[i])
return [_BoxCluster(members) for members in groups.values()]
# ── Phase 1: 检测 + 渲染 ──────────────────────────────────────────────
def _render_box(
page,
box: _BoxCluster,
images_dest: Path,
filename: str,
cap_type: str,
page_num: int,
) -> bool:
"""渲染单个 box 区域并保存 JPEG,成功返回 True。"""
page_width = page.rect.width
clip = pymupdf.Rect(
max(0, box.x0 - _REGION_PADDING),
max(0, box.y0 - _REGION_PADDING),
min(page_width, box.x1 + _REGION_PADDING),
box.y1 + _REGION_PADDING,
)
mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM)
try:
pix = page.get_pixmap(matrix=mat, clip=clip)
except Exception:
return False
(images_dest / filename).write_bytes(pix.tobytes("jpeg"))
return True
def _process_page(
doc,
page_idx: int,
page_boxes: list[LayoutBox],
images_dest: Path,
manifest: dict,
seen_labels: set,
arxiv_id: str,
) -> int:
"""处理单页:检测 → 聚类 → 渲染,全部用通用标签。"""
page = doc[page_idx]
page_num = page_idx + 1
fig_counter = 0
tbl_counter = 0
# 收集本页的 table/picture box(跳过极小区域)
raw_boxes = []
for box in page_boxes:
if box.boxclass not in ("table", "table-fallback", "picture"):
continue
w = box.x1 - box.x0
h = box.y1 - box.y0
if w < 20 or h < 20 or w * h < _MIN_BOX_AREA:
continue
raw_boxes.append(box)
if not raw_boxes:
return 0
# 聚类:将同一 figure/table 的碎片 box 合并
clusters = _cluster_boxes(raw_boxes)
extracted = 0
for cluster in clusters:
cap_type = "figure" if cluster.boxclass == "picture" else "table"
if cap_type == "figure":
fig_counter += 1
label = f"Figure (p{page_num}-{fig_counter})"
else:
tbl_counter += 1
label = f"Table (p{page_num}-{tbl_counter})"
if label in seen_labels:
continue
seen_labels.add(label)
filename = f"{label.replace(' ', '_').lower()}.jpg"
if not _render_box(page, cluster, images_dest, filename, cap_type, page_num):
continue
manifest[filename] = {
"page": page_num,
"type": cap_type,
"label": label,
"box": [
round(float(cluster.x0), 1),
round(float(cluster.y0), 1),
round(float(cluster.x1), 1),
round(float(cluster.y1), 1),
],
}
extracted += 1
return extracted
# ── Phase 1 核心入口 ───────────────────────────────────────────────────
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
"""Phase 1: 从 PDF 提取 Figure/Table 截图,生成通用标签的 manifest。
Args:
arxiv_id: 论文 ID
pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf
Returns:
提取的图片数量
"""
if pdf_path is None:
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
if not pdf_path.exists():
logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path)
return 0
images_dest = paper_dir(arxiv_id) / "images"
images_dest.mkdir(parents=True, exist_ok=True)
# 清理上次提取的旧图片
for old_file in images_dest.iterdir():
if old_file.suffix.lower() in (".png", ".jpg", ".jpeg"):
old_file.unlink()
if (images_dest / "manifest.json").exists():
(images_dest / "manifest.json").unlink()
with pymupdf.open(str(pdf_path)) as doc:
extracted = 0
manifest: dict[str, dict] = {}
seen_labels: set[str] = set()
for page_idx in range(doc.page_count):
try:
page_boxes = detect_page_layout(doc[page_idx])
extracted += _process_page(
doc,
page_idx,
page_boxes,
images_dest=images_dest,
manifest=manifest,
seen_labels=seen_labels,
arxiv_id=arxiv_id,
)
except Exception:
logger.warning(
"Failed to process page %d for %s",
page_idx + 1,
arxiv_id,
exc_info=True,
)
continue
# 保存 manifest
manifest_path = images_dest / "manifest.json"
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))
if extracted > 0:
logger.info(
"Extracted %d figure/table screenshots from PDF for %s",
extracted,
arxiv_id,
)
return extracted
# ── Phase 2: 用 summary 的 figures ID 定位并重命名 ─────────────────────
def _distance_text_to_box(rect: pymupdf.Rect, box: list[float]) -> float | None:
"""计算搜索到的文本 rect 到 box 的距离。超出阈值返回 None。
判断逻辑:rect 中心与 box 的垂直距离 + 水平重叠检查。
"""
rect_cx = (rect.x0 + rect.x1) / 2
rect_cy = (rect.y0 + rect.y1) / 2
bx0, by0, bx1, by1 = box
# 水平重叠:rect 中心在 box 水平范围内(或接近)
if not (bx0 - 20 <= rect_cx <= bx1 + 20):
return None
# 垂直距离
if rect_cy < by0:
dist = by0 - rect_cy
elif rect_cy > by1:
dist = rect_cy - by1
else:
dist = 0
return dist if dist <= _LABEL_MATCH_DISTANCE else None
def _search_variants(fig_id: str) -> list[str]:
"""为 figure/table ID 生成搜索变体。
"Figure 1" → ["Figure 1", "Fig. 1", "Fig 1"]
"Fig. 1" → ["Fig. 1", "Figure 1", "Fig 1"]
"Table A1" → ["Table A1"]
"""
variants = [fig_id]
m = re.match(r"(Fig\.?|Figure)\s+(\d+.*)", fig_id, re.IGNORECASE)
if m:
num_part = m.group(2)
variants.extend(
[
f"Figure {num_part}",
f"Fig. {num_part}",
f"Fig {num_part}",
]
)
# 去重保序
seen = set()
result = []
for v in variants:
if v not in seen:
seen.add(v)
result.append(v)
return result
def label_images_by_summary(
arxiv_id: str,
figures: list[dict],
pdf_path: Path | None = None,
) -> int:
"""Phase 2: 用 summary 的 figures ID 在 PDF 中搜索定位,重命名图片。
对 summary 中的每个 figure/table ID
1. page.search_for(id) 在所有页面搜索文本位置
2. 计算搜索位置与 manifest 中 box 坐标的距离
3. 最近匹配 → 重命名文件、更新 manifest
Args:
arxiv_id: 论文 ID
figures: summary 的 figures 列表,每项含 id/caption/description 等
pdf_path: PDF 路径
Returns:
成功重命名的图片数量
"""
if not figures:
return 0
if pdf_path is None:
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
if not pdf_path.exists():
return 0
images_dest = paper_dir(arxiv_id) / "images"
manifest_path = images_dest / "manifest.json"
if not manifest_path.exists():
return 0
manifest: dict[str, dict] = json.loads(manifest_path.read_text(encoding="utf-8"))
if not manifest:
return 0
# 构建候选列表:只对通用标签的条目做匹配
candidates: dict[str, dict] = {} # filename → {page, box, ...}
for fname, info in manifest.items():
if "(p" in info.get("label", ""):
candidates[fname] = info
if not candidates:
return 0
with pymupdf.open(str(pdf_path)) as doc:
# 收集所有匹配候选:(fig_id, fig_index, filename, distance)
matches: list[tuple[str, int, str, float]] = []
for fig_idx, fig in enumerate(figures):
fig_id = fig.get("id", "")
if not fig_id:
continue
# 生成搜索变体:Figure 1 / Fig. 1 / Fig 1 等
search_terms = _search_variants(fig_id)
# 在所有页面搜索该文本(含变体)
search_hits: list[tuple[int, pymupdf.Rect]] = [] # (page_num_1based, Rect)
for page_idx in range(doc.page_count):
page = doc[page_idx]
seen_rects: set[tuple[float, float]] = set()
for term in search_terms:
for r in page.search_for(term):
key = (round(r.x0, 1), round(r.y0, 1))
if key not in seen_rects:
seen_rects.add(key)
search_hits.append((page_idx + 1, r))
if not search_hits:
continue
# 对每个候选 manifest 条目,找最近的搜索命中
for fname, info in candidates.items():
box = info.get("box")
if not box:
continue
manifest_page = info.get("page", 0)
best_dist: float | None = None
for hit_page, rect in search_hits:
# 只匹配同页面
if hit_page != manifest_page:
continue
dist = _distance_text_to_box(rect, box)
if dist is not None and (best_dist is None or dist < best_dist):
best_dist = dist
if best_dist is not None:
matches.append((fig_id, fig_idx, fname, best_dist))
if not matches:
logger.info("No label matches for %s", arxiv_id)
return 0
# 去冲突:按距离排序,每个 fig_id 和每个 filename 只匹配一次
matches.sort(key=lambda x: x[3])
used_fig_ids: set[int] = set()
used_filenames: set[str] = set()
renames: list[tuple[str, str, str]] = [] # (old_fname, new_fname, fig_id)
for fig_id, fig_idx, fname, dist in matches:
if fig_idx in used_fig_ids or fname in used_filenames:
continue
used_fig_ids.add(fig_idx)
used_filenames.add(fname)
new_fname = f"{fig_id.replace(' ', '_').lower()}.jpg"
renames.append((fname, new_fname, fig_id))
# 执行重命名
labeled = 0
new_manifest: dict[str, dict] = {}
for fname, info in manifest.items():
if fname in used_filenames:
continue
# 未匹配的保持原样
new_manifest[fname] = info
for old_fname, new_fname, fig_id in renames:
old_path = images_dest / old_fname
new_path = images_dest / new_fname
if not old_path.exists():
continue
# 搬运 manifest 信息
info = manifest[old_fname].copy()
cap_type = info.get("type", "figure")
# 读取 caption 文本(从 figures 列表)
caption_text = ""
for fig in figures:
if fig.get("id") == fig_id:
caption_text = fig.get("caption", "")
break
info["label"] = fig_id
info["caption_text"] = caption_text[:200] if caption_text else ""
info.setdefault("figures" if cap_type == "figure" else "tables", []).append(
fig_id
)
# 重命名文件
if new_fname != old_fname:
old_path.rename(new_path)
new_manifest[new_fname] = info
labeled += 1
# 写回 manifest
manifest_path.write_text(json.dumps(new_manifest, ensure_ascii=False, indent=2))
logger.info(
"Labeled %d/%d images for %s using summary figures",
labeled,
len(manifest),
arxiv_id,
)
return labeled
# ── Figure ↔ Image 关联 ────────────────────────────────────────────────
def _normalize_figure_id(raw_id: str) -> str:
"""归一化 Figure/Table ID'Figure 1'/'Fig.1''Figure 1'"""
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
if m:
return f"Figure {m.group(1)}"
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
if m2:
return f"Table {m2.group(1)}"
return raw_id
def _is_figure_type(fig_id: str) -> bool:
"""判断是否为 Figure 类型(非 Table)。"""
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
def _image_sort_key(name: str) -> tuple[int, int]:
"""按文件名中的编号排序提取的图片。"""
# 新格式:figure_1.jpg, table_1.jpg
m = re.search(r"(?:figure|table)_(\d+)", name)
if m:
return (0, int(m.group(1)))
# 旧格式:page2_img1.png, page5_table1.png, figure_1.png
m2 = re.search(r"page(\d+)_(?:img|table)(\d+)", name)
if m2:
return (int(m2.group(1)), int(m2.group(2)))
return (0, 0)
def link_figures_with_images(
figures: list[dict], images: list[dict], arxiv_id: str
) -> list[dict]:
"""将 summary figures 元数据与提取的图片文件关联。
策略:
1. 优先用 manifest.json 的 label 做 ID 精确匹配
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
"""
if not figures or not images:
return figures
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
# ── 策略 1manifest ID 精确匹配 ──
id_to_url: dict[str, str] = {}
if manifest_path.exists():
try:
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
except (ValueError, TypeError):
manifest = {}
for filename, info in manifest.items():
url = f"/papers/{arxiv_id}/images/{filename}"
# 优先用 label 字段(新格式)
label = info.get("label", "")
if label:
id_to_url[label] = url
# 也兼容 figures/tables 列表(旧格式)
for fig_id in info.get("figures", []) + info.get("tables", []):
if fig_id not in id_to_url:
id_to_url[fig_id] = url
for fig in figures:
raw_id = fig.get("id", "")
normalized = _normalize_figure_id(raw_id)
if normalized in id_to_url:
fig["image_url"] = id_to_url[normalized]
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
unmatched = [f for f in figures if not f.get("image_url")]
if not unmatched:
return figures
# 按类型分流:Figure vs Table
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
table_type_unmatched = [
f for f in unmatched if not _is_figure_type(f.get("id", ""))
]
# 提取的图片按类型分流,按文件名中的编号排序
fig_images = sorted(
[img for img in images if "table" not in img["name"].lower()],
key=lambda img: _image_sort_key(img["name"]),
)
table_images = sorted(
[img for img in images if "table" in img["name"].lower()],
key=lambda img: _image_sort_key(img["name"]),
)
for i, fig in enumerate(fig_type_unmatched):
if i < len(fig_images):
fig["image_url"] = fig_images[i]["url"]
for i, fig in enumerate(table_type_unmatched):
if i < len(table_images):
fig["image_url"] = table_images[i]["url"]
return figures