Files
daily-paper/app/services/pdf_image_extractor.py
T

658 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""PDF 图片与表格提取 — 基于 pymupdf4llm layout analysis。
用 pymupdf4llm 的 layout analysis 检测 table / picture 区域,
再通过 caption 文字匹配确定 Figure/Table 编号,渲染为 JPEG。
相比旧方案(caption 正则 + pdfplumber/find_tables/文本块扫描三套策略):
- layout analysis 直接给出区域 bbox,不存在相邻表格互相侵入的问题
- 无需手动调参(最大高度、间隙阈值等)
- 页面级 caption 匹配:每个 caption 只分配给最近的 box,避免上下相邻表格抢夺同一个 caption
"""
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
import pymupdf
import pymupdf4llm.helpers.document_layout as dl
from app.services.pdf_downloader import paper_dir
from app.utils import TMP_DIR
logger = logging.getLogger(__name__)
# ── Caption 正则 ───────────────────────────────────────────────────────
# 用于从 caption 文字中提取 Figure/Table 编号
_FIGURE_CAPTION_RE = re.compile(
r"^(?:Fig\.?|Figure)\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))",
re.IGNORECASE,
)
_TABLE_CAPTION_RE = re.compile(
r"^Table\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))",
re.IGNORECASE,
)
# caption 与 table/picture 的最大匹配距离(点)
_CAPTION_MATCH_DISTANCE = 100
# 截图区域的外边距
_REGION_PADDING = 5
# 3x 渲染,保证清晰度
_RENDER_ZOOM = 3
# 相邻 box 聚类间距(点)— 同一 figure/table 的碎片间距通常 < 15pt
_CLUSTER_GAP = 15
# ── Box 聚类 ─────────────────────────────────────────────────────────
class _BoxCluster:
"""合并后的布局区域(由一个或多个相邻 LayoutBox 组成)。
pymupdf4llm 有时将一个大图拆成多个小 picture box(如视频帧网格),
聚类后用整体 bbox 作为渲染区域。
"""
__slots__ = ("x0", "y0", "x1", "y1", "boxclass")
def __init__(self, boxes: list):
self.x0 = min(b.x0 for b in boxes)
self.y0 = min(b.y0 for b in boxes)
self.x1 = max(b.x1 for b in boxes)
self.y1 = max(b.y1 for b in boxes)
# table-fallback 归一化为 tablelayout model 检测到表格但无法提取结构)
raw = boxes[0].boxclass
self.boxclass = "table" if raw == "table-fallback" else raw
def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
"""将相邻的同类型 box 合并为聚类。
用 union-find 将间距 ≤ gap 的同类型 box 归为一组,
每组生成一个 _BoxCluster(整体 bbox)。
"""
if not boxes:
return []
n = len(boxes)
parent = list(range(n))
def find(x: int) -> int:
while parent[x] != x:
parent[x] = parent[parent[x]]
x = parent[x]
return x
def union(a: int, b: int) -> None:
ra, rb = find(a), find(b)
if ra != rb:
parent[ra] = rb
for i in range(n):
bi = boxes[i]
for j in range(i + 1, n):
bj = boxes[j]
if bi.boxclass != bj.boxclass:
continue
h_gap = max(0.0, max(bi.x0, bj.x0) - min(bi.x1, bj.x1))
v_gap = max(0.0, max(bi.y0, bj.y0) - min(bi.y1, bj.y1))
h_overlap = bi.x1 > bj.x0 - gap and bj.x1 > bi.x0 - gap
v_overlap = bi.y1 > bj.y0 - gap and bj.y1 > bi.y0 - gap
if (h_gap <= gap and v_overlap) or (v_gap <= gap and h_overlap):
union(i, j)
groups: dict[int, list] = {}
for i in range(n):
groups.setdefault(find(i), []).append(boxes[i])
return [_BoxCluster(members) for members in groups.values()]
# ── 页面级 Caption 查找与匹配 ──────────────────────────────────────────
def _find_page_captions(page) -> list[dict]:
"""查找页面上所有 Figure/Table caption 文字块。"""
blocks = page.get_text("blocks")
captions = []
for b in blocks:
if len(b) < 5:
continue
bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3]
text = str(b[4]).strip()
first_line = text.split("\n")[0].strip()
cap_type = None
m = _TABLE_CAPTION_RE.match(first_line)
if m:
cap_type = "table"
else:
m = _FIGURE_CAPTION_RE.match(first_line)
if m:
cap_type = "figure"
if m is None:
continue
captions.append(
{
"label": f"{'Table' if cap_type == 'table' else 'Figure'} {m.group(1)}",
"type": cap_type,
"caption_text": text,
"caption_y0": by0,
"caption_y1": by1,
"caption_x0": bx0,
"caption_x1": bx1,
}
)
return captions
def _vertical_distance(cap_y0, cap_y1, box_y0, box_y1) -> float | None:
"""计算 caption 到 box 的垂直距离。不邻接时返回 None。
三种情况:caption 完全在 box 上方、完全在下方、与 box 有垂直重叠。
重叠(含部分溢出)视为 distance=0,确保 caption 延伸到 box 边界外时不会丢失。
"""
# Caption 完全在 box 上方
if cap_y1 <= box_y0:
dist = box_y0 - cap_y1
return dist if dist <= _CAPTION_MATCH_DISTANCE else None
# Caption 完全在 box 下方
if cap_y0 >= box_y1:
dist = cap_y0 - box_y1
return dist if dist <= _CAPTION_MATCH_DISTANCE else None
# Caption 与 box 有垂直重叠(内部、部分溢出都算)→ 距离 0
return 0
def _same_column(cap: dict, box, page_width: float) -> bool:
"""判断 caption 和 box 是否在同一列。
双栏论文中左右栏间距有限,简单的水平重叠检查会跨列匹配。
策略:用中心 X 坐标判断各自在哪半边,只有同半边才算同列。
跨栏图表(caption 或 box 宽度 >65% 页宽)不受此限制。
"""
cap_w = cap["caption_x1"] - cap["caption_x0"]
box_w = box.x1 - box.x0
# 跨栏元素:宽度超过页面的 65%
if cap_w > page_width * 0.65 or box_w > page_width * 0.65:
return True
cap_cx = (cap["caption_x0"] + cap["caption_x1"]) / 2
box_cx = (box.x0 + box.x1) / 2
mid = page_width / 2
# 同在左半边或同在右半边
return (cap_cx < mid) == (box_cx < mid)
def _match_captions_to_boxes(
page_boxes: list, captions: list[dict], page_width: float
) -> list[tuple[list[int], list[dict]]]:
"""将 caption 分配给 box,允许一个 caption 匹配多个同类型 box。
典型场景:
- Figure 由左右两个 picture box 组成,caption 同时靠近两者
- Table 的视觉内容被 layout analysis 误分类为 picture,需要跨类型匹配
Returns:
[(box_indices, captions), ...] 每组是一个独立的渲染任务
"""
# 每个 caption 找到所有距离在阈值内的 box
# 优先匹配同类型;如果找不到,再匹配任意 table/picture box
cap_to_boxes: dict[int, list[tuple[int, float]]] = {}
for ci, cap in enumerate(captions):
same_type: list[tuple[int, float]] = []
any_type: list[tuple[int, float]] = []
expected = "table" if cap["type"] == "table" else "picture"
for bi, box in enumerate(page_boxes):
# 列感知:双栏论文中只匹配同栏的 box
if not _same_column(cap, box, page_width):
continue
# 水平重叠检查(同列内仍需有重叠)
if not (
cap["caption_x1"] > box.x0 - 5 and cap["caption_x0"] < box.x1 + 5
):
continue
dist = _vertical_distance(
cap["caption_y0"], cap["caption_y1"], box.y0, box.y1
)
if dist is None:
continue
entry = (bi, dist)
any_type.append(entry)
if box.boxclass == expected:
same_type.append(entry)
# 优先用同类型匹配;没有时回退到任意类型;都没有则跳过
if same_type:
cap_to_boxes[ci] = same_type
elif any_type:
cap_to_boxes[ci] = any_type
# else: 该 caption 无匹配 box,不加入 cap_to_boxes
# 每个 caption → 最近的 box(用于分组),但记录所有匹配的 box
cap_primary: dict[int, int] = {} # caption → primary box index
cap_all_boxes: dict[int, list[int]] = {} # caption → all matched box indices
for ci, matches in cap_to_boxes.items():
matches.sort(key=lambda x: x[1])
cap_primary[ci] = matches[0][0]
# 所有距离最近的同组 box(距离差 < 20pt 视为同一组)
best_dist = matches[0][1]
cap_all_boxes[ci] = [bi for bi, d in matches if d <= best_dist + 20]
# 按 primary box 分组
box_to_caps: dict[int, list[int]] = {}
for ci, bi in cap_primary.items():
box_to_caps.setdefault(bi, []).append(ci)
# 构建渲染组:每个 caption 独立成组(共享 box 但各自渲染)
# 同类型同 label 的 caption 会合并;不同类型则分开
used_captions: set[int] = set()
groups: list[tuple[list[int], list[dict]]] = []
for bi in sorted(box_to_caps.keys()):
cis = box_to_caps[bi]
for ci in cis:
if ci in used_captions:
continue
used_captions.add(ci)
all_box_indices = set(cap_all_boxes.get(ci, [bi]))
# 只合并同 label 的 caption(同 figure/table 的重复 caption
merged_captions = [captions[ci]]
for other_bi in all_box_indices:
if other_bi in box_to_caps:
for other_ci in box_to_caps[other_bi]:
if other_ci not in used_captions:
other_cap = captions[other_ci]
if other_cap["label"] == captions[ci]["label"]:
used_captions.add(other_ci)
merged_captions.append(other_cap)
groups.append((sorted(all_box_indices), merged_captions))
return groups
# ── 单页处理 ─────────────────────────────────────────────────────────
def _render_and_save(
page,
clip: pymupdf.Rect,
images_dest: Path,
manifest: dict,
label: str,
cap_type: str,
caption_text: str,
page_num_1based: int,
arxiv_id: str,
) -> bool:
"""渲染页面区域并保存 JPEG,写入 manifest。成功返回 True。"""
mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM)
try:
pix = page.get_pixmap(matrix=mat, clip=clip)
except Exception:
logger.debug("Failed to render %s for %s", label, arxiv_id)
return False
filename = f"{label.replace(' ', '_').lower()}.jpg"
(images_dest / filename).write_bytes(pix.tobytes("jpeg"))
manifest[filename] = {
"page": page_num_1based,
"type": cap_type,
"label": label,
"caption_text": caption_text[:200] if caption_text else "",
"figures" if cap_type == "figure" else "tables": [label],
}
logger.debug(
"Rendered %s: page %d, region (%.0f,%.0f)-(%.0f,%.0f) → %s",
label,
page_num_1based,
clip.x0,
clip.y0,
clip.x1,
clip.y1,
filename,
)
return True
def _process_page(
doc,
page_idx: int,
page_layout,
images_dest: Path,
manifest: dict,
seen_labels: set,
arxiv_id: str,
) -> int:
"""处理单页:caption 匹配 + orphan 兜底,返回本页提取数量。"""
page = doc[page_idx]
page_width = page.rect.width
page_num = page_idx + 1
orphan_fig_counter = 0
orphan_tbl_counter = 0
# 收集本页的 table/picture box(跳过极小区域)
raw_boxes = []
for box in page_layout.boxes:
if box.boxclass not in ("table", "table-fallback", "picture"):
continue
if (box.x1 - box.x0) < 20 or (box.y1 - box.y0) < 20:
continue
raw_boxes.append(box)
if not raw_boxes:
return 0
# 聚类:将同一 figure/table 的碎片 box 合并
page_boxes = _cluster_boxes(raw_boxes)
# 页面级匹配:查找所有 caption,分配给 box
captions = _find_page_captions(page)
groups = _match_captions_to_boxes(page_boxes, captions, page_width)
# 只合并同 label 的 group(同一个 figure/table 的重复 caption
# 不同 label 的 group 即使共享 box 也不合并(如 Figure 7 和 Figure 8),
# 渲染时用 caption 位置切割区域
_merged_groups: set[int] = set()
merged_groups: list[tuple[list[int], list[dict]]] = []
for gi, (box_indices, caps) in enumerate(groups):
if gi in _merged_groups:
continue
this_labels = {c["label"] for c in caps}
all_box_set = set(box_indices)
merge_targets = {gi}
for other_gi, (other_bi, other_caps) in enumerate(groups):
if other_gi <= gi or other_gi in _merged_groups:
continue
other_labels = {c["label"] for c in other_caps}
# 只在 label 有交集时合并(同一个 figure/table
if this_labels & other_labels and all_box_set & set(other_bi):
merge_targets.add(other_gi)
all_box_set |= set(other_bi)
all_caps = []
for mgi in sorted(merge_targets):
_merged_groups.add(mgi)
all_caps.extend(groups[mgi][1])
merged_groups.append((sorted(all_box_set), all_caps))
groups = merged_groups
# ── 阶段 1:渲染有 caption 匹配的图/表 ──
matched_box_indices: set[int] = set()
extracted = 0
for box_indices, caps in groups:
matched_box_indices.update(box_indices)
# 去重同一 label,跳过已处理的
unique_caps = []
for cap in caps:
if cap["label"] not in seen_labels:
seen_labels.add(cap["label"])
unique_caps.append(cap)
if not unique_caps:
continue
# 合并所有关联 box 的 bbox
bx0 = min(page_boxes[i].x0 for i in box_indices)
by0 = min(page_boxes[i].y0 for i in box_indices)
bx1 = max(page_boxes[i].x1 for i in box_indices)
by1 = max(page_boxes[i].y1 for i in box_indices)
# 渲染区域:box + caption
all_cap_y0 = min(c["caption_y0"] for c in unique_caps)
all_cap_y1 = max(c["caption_y1"] for c in unique_caps)
all_cap_x0 = min(c["caption_x0"] for c in unique_caps)
all_cap_x1 = max(c["caption_x1"] for c in unique_caps)
top = max(0, min(by0, all_cap_y0) - _REGION_PADDING)
bottom = max(by1, all_cap_y1) + _REGION_PADDING
rx0 = max(0, min(bx0, all_cap_x0) - _REGION_PADDING)
rx1 = min(page_width, max(bx1, all_cap_x1) + _REGION_PADDING)
clip = pymupdf.Rect(rx0, top, rx1, bottom)
# 多个 caption 可能共享同一区域(如 subfigure),只需渲染一次
jpeg_bytes = None
for cap in unique_caps:
if jpeg_bytes is None:
if not _render_and_save(
page,
clip,
images_dest,
manifest,
cap["label"],
cap["type"],
cap["caption_text"],
page_num,
arxiv_id,
):
break
# 读取刚写入的 bytes 供后续同名 caption 复用
filename = f"{cap['label'].replace(' ', '_').lower()}.jpg"
jpeg_bytes = (images_dest / filename).read_bytes()
extracted += 1
else:
# 同区域的不同 caption(如 subfigure),复用图片
filename = f"{cap['label'].replace(' ', '_').lower()}.jpg"
(images_dest / filename).write_bytes(jpeg_bytes)
cap_preview = cap["caption_text"][:200]
manifest[filename] = {
"page": page_num,
"type": cap["type"],
"label": cap["label"],
"caption_text": cap_preview,
"figures" if cap["type"] == "figure" else "tables": [cap["label"]],
}
extracted += 1
# ── 阶段 2:渲染无 caption 匹配的图/表(orphan boxes ──
orphan_indices = set(range(len(page_boxes))) - matched_box_indices
for bi in sorted(orphan_indices):
box = page_boxes[bi]
cap_type = "figure" if box.boxclass == "picture" else "table"
if cap_type == "figure":
orphan_fig_counter += 1
label = f"Figure (p{page_num}-{orphan_fig_counter})"
else:
orphan_tbl_counter += 1
label = f"Table (p{page_num}-{orphan_tbl_counter})"
if label in seen_labels:
continue
seen_labels.add(label)
clip = pymupdf.Rect(
max(0, box.x0 - _REGION_PADDING),
max(0, box.y0 - _REGION_PADDING),
min(page_width, box.x1 + _REGION_PADDING),
box.y1 + _REGION_PADDING,
)
if _render_and_save(
page,
clip,
images_dest,
manifest,
label,
cap_type,
"",
page_num,
arxiv_id,
):
extracted += 1
return extracted
# ── 核心提取 ───────────────────────────────────────────────────────────
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
"""从 PDF 提取 Figure/Table 截图,生成 manifest。
用 pymupdf4llm layout analysis 检测 table/picture 区域,
再通过 caption 文字确定编号,渲染为 JPEG。
Args:
arxiv_id: 论文 ID
pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf
Returns:
提取的图片数量
"""
if pdf_path is None:
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
if not pdf_path.exists():
logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path)
return 0
images_dest = paper_dir(arxiv_id) / "images"
images_dest.mkdir(parents=True, exist_ok=True)
# 清理上次提取的旧图片
for old_file in images_dest.iterdir():
if old_file.suffix.lower() in (".png", ".jpg", ".jpeg"):
old_file.unlink()
if (images_dest / "manifest.json").exists():
(images_dest / "manifest.json").unlink()
doc = pymupdf.open(str(pdf_path))
# layout analysis
try:
parsed = dl.parse_document(
doc, filename=str(pdf_path), use_ocr=dl.OCRMode.NEVER
)
except Exception:
logger.warning(
"pymupdf4llm layout analysis failed for %s", arxiv_id, exc_info=True
)
doc.close()
return 0
extracted = 0
manifest: dict[str, dict] = {}
seen_labels: set[str] = set()
for page_idx, page_layout in enumerate(parsed.pages):
try:
extracted += _process_page(
doc,
page_idx,
page_layout,
images_dest=images_dest,
manifest=manifest,
seen_labels=seen_labels,
arxiv_id=arxiv_id,
)
except Exception:
logger.warning(
"Failed to process page %d for %s",
page_idx + 1,
arxiv_id,
exc_info=True,
)
continue
doc.close()
# 保存 manifest
manifest_path = images_dest / "manifest.json"
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))
if extracted > 0:
logger.info(
"Extracted %d figure/table screenshots from PDF for %s",
extracted,
arxiv_id,
)
return extracted
# ── 按 summary 过滤 ────────────────────────────────────────────────────
def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
"""根据 summary 中的 figures 字段过滤提取的图片/表格。
用 manifest.json 中的 label 匹配,保留被 AI 总结引用的图片。
"""
if not figures:
return 0
images_dir = paper_dir(arxiv_id) / "images"
manifest_path = images_dir / "manifest.json"
if not images_dir.exists() or not manifest_path.exists():
return 0
all_files = [
f for f in images_dir.iterdir() if f.suffix.lower() in (".png", ".jpg", ".jpeg")
]
if not all_files:
return 0
manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8"))
# 收集 summary 中引用的所有 Figure/Table ID(归一化)
referenced_ids: set[str] = set()
for fig in figures:
fig_id = fig.get("id", "")
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", fig_id, re.IGNORECASE)
if m:
referenced_ids.add(f"Figure {m.group(1)}")
m2 = re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
if m2:
referenced_ids.add(f"Table {m2.group(1)}")
if not referenced_ids:
logger.warning("No valid figure/table IDs in summary for %s", arxiv_id)
return len(all_files)
# 根据 manifest 的 label 字段匹配
keep_filenames: set[str] = set()
for filename, info in manifest.items():
label = info.get("label", "")
if label in referenced_ids:
keep_filenames.add(filename)
continue
for ref in info.get("figures", []) + info.get("tables", []):
if ref in referenced_ids:
keep_filenames.add(filename)
break
if not keep_filenames:
logger.warning(
"No manifest matches for %s (refs=%s), keeping all",
arxiv_id,
referenced_ids,
)
return len(all_files)
removed = 0
for f in all_files:
if f.name not in keep_filenames:
f.unlink()
removed += 1
kept = len(all_files) - removed
logger.info(
"Filtered images for %s: kept %d, removed %d (refs=%s)",
arxiv_id,
kept,
removed,
referenced_ids,
)
return kept