1ccac1f29a
- Extract captions from PDF text dict instead of DocLayout caption boxes - Use _CaptionBlock dataclass to carry authoritative ID, kind, text, bbox - Pair captions to content boxes with directional preference (figure below, table above) - Filter out uncaptioned boxes (Algorithm pseudo-code, unnumbered appendix tables, false positives) - Remove label_images_by_summary and Phase 2 rename pipeline entirely - Update tests to cover text-based caption pairing and filtering
486 lines
17 KiB
Python
486 lines
17 KiB
Python
"""PDF 图片与表格提取。
|
||
|
||
DocLayout-YOLO 检测 figure/table 内容区域 → PDF 文本流定位 caption → 只渲染配到
|
||
Figure/Table 标题的,用 caption 自带权威 ID 命名。没配到标题的(Algorithm 伪代码、
|
||
无编号附录表、DocLayout 误检碎片)一律过滤,不输出。
|
||
|
||
caption 定位用 PDF 文本而非 DocLayout 的 caption box —— 后者检测不稳(多行标题只
|
||
框一行→截断、漏检→无标题、配对错误→串台)。page.get_text("dict") 找以
|
||
"Figure N"/"Table N" 开头的文本块:文本块天然含完整多行标题,且其 ID 即论文实际
|
||
编号,直接命名规避串台。figure 标题优先在下方、table 标题优先在上方配对。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
import pymupdf
|
||
|
||
from app.services.layout_detector import LayoutBox, detect_page_layout
|
||
from app.services.pdf_downloader import paper_dir
|
||
from app.utils import PAPERS_DIR, TMP_DIR
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 截图区域的外边距(单位: pt)
|
||
_REGION_PADDING = 5
|
||
# 渲染倍率(3x 保证清晰度)
|
||
_RENDER_ZOOM = 3
|
||
# 相邻 box 聚类间距(单位: pt)— 同一 figure/table 的碎片间距通常 < 15pt
|
||
_CLUSTER_GAP = 15
|
||
# 最小 bbox 面积(单位: pt²)— 过滤 icon/logo 等微小误检
|
||
_MIN_BOX_AREA = 2000
|
||
# caption 文本块与 figure/table 内容块的最大垂直距离(单位: pt)
|
||
_CAPTION_MATCH_DISTANCE = 120
|
||
# 方向不符(figure 标题在上 / table 标题在下)的配对惩罚分(仍允许,兜底异常排版)
|
||
_CAPTION_WRONG_SIDE_PENALTY = 300
|
||
# caption 开头标记:Figure 3 / Fig. 3 / Table C1 / Figure 3.5 等(大小写均可)
|
||
# 编号 = 数字开头 或 字母+数字(附录 C1);行首匹配,规避正文 "see Table 3" 引用
|
||
_CAPTION_HEAD_RE = re.compile(
|
||
r"^\s*(Figure|Fig\.?|Table)\b\.?\s+([0-9][0-9A-Za-z.]*|[A-Z]\d[0-9A-Za-z.]*)",
|
||
re.IGNORECASE,
|
||
)
|
||
|
||
|
||
# ── Box 聚类 ─────────────────────────────────────────────────────────
|
||
|
||
|
||
class _BoxCluster:
|
||
"""合并后的布局区域(由一个或多个相邻 LayoutBox 组成)。"""
|
||
|
||
__slots__ = ("x0", "y0", "x1", "y1", "boxclass")
|
||
|
||
def __init__(self, boxes: list):
|
||
self.x0 = min(b.x0 for b in boxes)
|
||
self.y0 = min(b.y0 for b in boxes)
|
||
self.x1 = max(b.x1 for b in boxes)
|
||
self.y1 = max(b.y1 for b in boxes)
|
||
self.boxclass = boxes[0].boxclass
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class _CaptionBlock:
|
||
"""从 PDF 文本流提取的标题块:自带权威 ID、完整多行文本、精确 bbox。"""
|
||
|
||
id: str # "Figure 3" / "Table C1"
|
||
kind: str # "figure" | "table"
|
||
text: str # 完整多行标题文本
|
||
bbox: list[float] # [x0, y0, x1, y1]
|
||
|
||
|
||
def _cluster_to_box(cluster: _BoxCluster) -> list[float]:
|
||
return [
|
||
round(float(cluster.x0), 1),
|
||
round(float(cluster.y0), 1),
|
||
round(float(cluster.x1), 1),
|
||
round(float(cluster.y1), 1),
|
||
]
|
||
|
||
|
||
def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
|
||
"""将相邻的同类型 box 合并为聚类。"""
|
||
if not boxes:
|
||
return []
|
||
|
||
n = len(boxes)
|
||
parent = list(range(n))
|
||
|
||
def find(x: int) -> int:
|
||
while parent[x] != x:
|
||
parent[x] = parent[parent[x]]
|
||
x = parent[x]
|
||
return x
|
||
|
||
def union(a: int, b: int) -> None:
|
||
ra, rb = find(a), find(b)
|
||
if ra != rb:
|
||
parent[ra] = rb
|
||
|
||
for i in range(n):
|
||
bi = boxes[i]
|
||
for j in range(i + 1, n):
|
||
bj = boxes[j]
|
||
if bi.boxclass != bj.boxclass:
|
||
continue
|
||
h_gap = max(0.0, max(bi.x0, bj.x0) - min(bi.x1, bj.x1))
|
||
v_gap = max(0.0, max(bi.y0, bj.y0) - min(bi.y1, bj.y1))
|
||
h_overlap = bi.x1 > bj.x0 - gap and bj.x1 > bi.x0 - gap
|
||
v_overlap = bi.y1 > bj.y0 - gap and bj.y1 > bi.y0 - gap
|
||
if (h_gap <= gap and v_overlap) or (v_gap <= gap and h_overlap):
|
||
union(i, j)
|
||
|
||
groups: dict[int, list] = {}
|
||
for i in range(n):
|
||
groups.setdefault(find(i), []).append(boxes[i])
|
||
|
||
return [_BoxCluster(members) for members in groups.values()]
|
||
|
||
|
||
def _find_caption_blocks(page) -> list[_CaptionBlock]:
|
||
"""从页面文本流提取以 "Figure N"/"Table N"/"Fig. N" 开头的标题块。
|
||
|
||
用 PDF 文本而非 DocLayout caption box:文本块天然含完整多行标题,
|
||
且其 ID 即论文实际编号(如 "Table C1"),权威且不依赖模型检测。
|
||
"""
|
||
try:
|
||
d = page.get_text("dict")
|
||
except Exception:
|
||
return []
|
||
|
||
results: list[_CaptionBlock] = []
|
||
for block in d.get("blocks", []):
|
||
if block.get("type") != 0: # 仅文本块
|
||
continue
|
||
lines = block.get("lines", [])
|
||
if not lines:
|
||
continue
|
||
line_texts = [
|
||
"".join(span.get("text", "") for span in line.get("spans", []))
|
||
for line in lines
|
||
]
|
||
first_line = next((t for t in line_texts if t.strip()), "")
|
||
m = _CAPTION_HEAD_RE.match(first_line)
|
||
if not m:
|
||
continue
|
||
kind_word, num = m.group(1), m.group(2)
|
||
is_table = kind_word.lower().startswith("table")
|
||
bbox = block.get("bbox")
|
||
if not bbox or len(bbox) != 4:
|
||
continue
|
||
full_text = " ".join(t.strip() for t in line_texts if t.strip())
|
||
results.append(
|
||
_CaptionBlock(
|
||
id=f"{'Table' if is_table else 'Figure'} {num}",
|
||
kind="table" if is_table else "figure",
|
||
text=full_text,
|
||
bbox=[float(v) for v in bbox],
|
||
)
|
||
)
|
||
return results
|
||
|
||
|
||
def _pair_caption_blocks(
|
||
content_clusters: list[_BoxCluster],
|
||
caption_blocks: list[_CaptionBlock],
|
||
) -> dict[int, _CaptionBlock]:
|
||
"""每个内容块配方向上最近的同类型标题块。
|
||
|
||
figure 标题惯例在下方、table 标题在上方;方向相符优先,不符加惩罚兜底
|
||
(跨页 / 异常排版)。按 (距离+惩罚) 升序贪心匹配,每个内容块与标题块唯一配对。
|
||
"""
|
||
candidates: list[tuple[float, int, int]] = []
|
||
for c_idx, content in enumerate(content_clusters):
|
||
want_below = content.boxclass == "picture" # figure 标题在下
|
||
want_kind = "figure" if want_below else "table"
|
||
for b_idx, cap in enumerate(caption_blocks):
|
||
if cap.kind != want_kind:
|
||
continue
|
||
cx0, cy0, cx1, cy1 = cap.bbox
|
||
h_overlap = min(content.x1, cx1) - max(content.x0, cx0)
|
||
min_width = min(content.x1 - content.x0, cx1 - cx0)
|
||
if min_width <= 0 or h_overlap < min_width * 0.25:
|
||
continue
|
||
if cy1 <= content.y0: # 标题在内容上方
|
||
side_below, v_gap = False, content.y0 - cy1
|
||
elif cy0 >= content.y1: # 标题在内容下方
|
||
side_below, v_gap = True, cy0 - content.y1
|
||
else:
|
||
continue # 重叠,跳过
|
||
if v_gap > _CAPTION_MATCH_DISTANCE:
|
||
continue
|
||
penalty = 0.0 if side_below == want_below else _CAPTION_WRONG_SIDE_PENALTY
|
||
candidates.append((v_gap + penalty, c_idx, b_idx))
|
||
|
||
matches: dict[int, _CaptionBlock] = {}
|
||
used: set[int] = set()
|
||
for _score, c_idx, b_idx in sorted(candidates):
|
||
if c_idx in matches or b_idx in used:
|
||
continue
|
||
matches[c_idx] = caption_blocks[b_idx]
|
||
used.add(b_idx)
|
||
return matches
|
||
|
||
|
||
# ── Phase 1: 检测 + 渲染 ──────────────────────────────────────────────
|
||
|
||
|
||
def _render_box(
|
||
page,
|
||
box: _BoxCluster,
|
||
images_dest: Path,
|
||
filename: str,
|
||
cap_type: str,
|
||
page_num: int,
|
||
caption_bbox: list[float] | None = None,
|
||
) -> bool:
|
||
"""渲染单个 box 区域并保存 JPEG,成功返回 True。
|
||
|
||
若提供 caption_bbox,则将内容与标题区域合并后一起截取,
|
||
使同一张截图同时包含图/表及其完整标题。
|
||
"""
|
||
page_width = page.rect.width
|
||
page_height = page.rect.height
|
||
x0, y0, x1, y1 = box.x0, box.y0, box.x1, box.y1
|
||
if caption_bbox is not None:
|
||
cx0, cy0, cx1, cy1 = caption_bbox
|
||
x0 = min(x0, cx0)
|
||
y0 = min(y0, cy0)
|
||
x1 = max(x1, cx1)
|
||
y1 = max(y1, cy1)
|
||
clip = pymupdf.Rect(
|
||
max(0, x0 - _REGION_PADDING),
|
||
max(0, y0 - _REGION_PADDING),
|
||
min(page_width, x1 + _REGION_PADDING),
|
||
min(page_height, y1 + _REGION_PADDING),
|
||
)
|
||
mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM)
|
||
try:
|
||
pix = page.get_pixmap(matrix=mat, clip=clip)
|
||
except Exception:
|
||
return False
|
||
|
||
(images_dest / filename).write_bytes(pix.tobytes("jpeg", jpg_quality=92))
|
||
return True
|
||
|
||
|
||
def _process_page(
|
||
doc,
|
||
page_idx: int,
|
||
page_boxes: list[LayoutBox],
|
||
images_dest: Path,
|
||
manifest: dict,
|
||
seen_labels: set,
|
||
arxiv_id: str,
|
||
) -> int:
|
||
"""处理单页:检测内容 box → 文本定位 caption → 只渲染配到标题的。
|
||
|
||
配到 Figure/Table caption 的 box 用 caption 自带 ID 命名(figure_3.jpg);
|
||
没配到标题的(Algorithm 伪代码、无编号附录表、误检碎片)一律过滤,不输出。
|
||
"""
|
||
page = doc[page_idx]
|
||
page_num = page_idx + 1
|
||
|
||
# 收集本页 figure/table 内容 box(跳过极小区域;caption 改由文本定位,不收 box)
|
||
raw_boxes = []
|
||
for box in page_boxes:
|
||
if box.boxclass in ("table", "picture"):
|
||
w = box.x1 - box.x0
|
||
h = box.y1 - box.y0
|
||
if w < 20 or h < 20 or w * h < _MIN_BOX_AREA:
|
||
continue
|
||
raw_boxes.append(box)
|
||
|
||
if not raw_boxes:
|
||
return 0
|
||
|
||
# 聚类:将同一 figure/table 的碎片 box 合并;用 PDF 文本定位 caption
|
||
clusters = _cluster_boxes(raw_boxes)
|
||
caption_blocks = _find_caption_blocks(page)
|
||
caption_matches = _pair_caption_blocks(clusters, caption_blocks)
|
||
|
||
extracted = 0
|
||
for cluster_idx, cluster in enumerate(clusters):
|
||
cap_match = caption_matches.get(cluster_idx)
|
||
if cap_match is None:
|
||
continue # 无 Figure/Table 标题 → 过滤(Algorithm、无编号表、误检碎片)
|
||
if cap_match.id in seen_labels:
|
||
continue # 同一图表被 DocLayout 切成多块重复检测,跳过后续
|
||
seen_labels.add(cap_match.id)
|
||
|
||
filename = f"{cap_match.id.replace(' ', '_').lower()}.jpg"
|
||
if not _render_box(
|
||
page,
|
||
cluster,
|
||
images_dest,
|
||
filename,
|
||
cap_match.kind,
|
||
page_num,
|
||
caption_bbox=cap_match.bbox,
|
||
):
|
||
continue
|
||
|
||
manifest[filename] = {
|
||
"page": page_num,
|
||
"type": cap_match.kind,
|
||
"label": cap_match.id,
|
||
"box": _cluster_to_box(cluster),
|
||
"caption_text": cap_match.text[:500],
|
||
"caption_box": cap_match.bbox,
|
||
"caption_source": "text",
|
||
}
|
||
extracted += 1
|
||
|
||
return extracted
|
||
|
||
|
||
# ── Phase 1 核心入口 ───────────────────────────────────────────────────
|
||
|
||
|
||
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
|
||
"""Phase 1: 从 PDF 提取 Figure/Table 截图,生成通用标签的 manifest。
|
||
|
||
Args:
|
||
arxiv_id: 论文 ID
|
||
pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf
|
||
|
||
Returns:
|
||
提取的图片数量
|
||
"""
|
||
if pdf_path is None:
|
||
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
|
||
|
||
if not pdf_path.exists():
|
||
logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path)
|
||
return 0
|
||
|
||
images_dest = paper_dir(arxiv_id) / "images"
|
||
images_dest.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 清理上次提取的旧图片
|
||
for old_file in images_dest.iterdir():
|
||
if old_file.suffix.lower() in (".png", ".jpg", ".jpeg"):
|
||
old_file.unlink()
|
||
if (images_dest / "manifest.json").exists():
|
||
(images_dest / "manifest.json").unlink()
|
||
|
||
with pymupdf.open(str(pdf_path)) as doc:
|
||
extracted = 0
|
||
manifest: dict[str, dict] = {}
|
||
seen_labels: set[str] = set()
|
||
|
||
for page_idx in range(doc.page_count):
|
||
try:
|
||
page_boxes = detect_page_layout(doc[page_idx])
|
||
extracted += _process_page(
|
||
doc,
|
||
page_idx,
|
||
page_boxes,
|
||
images_dest=images_dest,
|
||
manifest=manifest,
|
||
seen_labels=seen_labels,
|
||
arxiv_id=arxiv_id,
|
||
)
|
||
except Exception:
|
||
logger.warning(
|
||
"Failed to process page %d for %s",
|
||
page_idx + 1,
|
||
arxiv_id,
|
||
exc_info=True,
|
||
)
|
||
continue
|
||
|
||
# 保存 manifest
|
||
manifest_path = images_dest / "manifest.json"
|
||
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))
|
||
|
||
if extracted > 0:
|
||
logger.info(
|
||
"Extracted %d figure/table screenshots from PDF for %s",
|
||
extracted,
|
||
arxiv_id,
|
||
)
|
||
|
||
return extracted
|
||
|
||
|
||
# ── Figure ↔ Image 关联 ────────────────────────────────────────────────
|
||
|
||
|
||
def _normalize_figure_id(raw_id: str) -> str:
|
||
"""归一化 Figure/Table ID:'Figure 1'/'Fig.1' → 'Figure 1'。"""
|
||
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
|
||
if m:
|
||
return f"Figure {m.group(1)}"
|
||
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
|
||
if m2:
|
||
return f"Table {m2.group(1)}"
|
||
return raw_id
|
||
|
||
|
||
def _is_figure_type(fig_id: str) -> bool:
|
||
"""判断是否为 Figure 类型(非 Table)。"""
|
||
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
|
||
|
||
|
||
def _image_sort_key(name: str) -> tuple[int, int]:
|
||
"""按文件名中的编号排序提取的图片。"""
|
||
# 新格式:figure_1.jpg, table_1.jpg
|
||
m = re.search(r"(?:figure|table)_(\d+)", name)
|
||
if m:
|
||
return (0, int(m.group(1)))
|
||
return (0, 0)
|
||
|
||
|
||
def link_figures_with_images(
|
||
figures: list[dict], images: list[dict], arxiv_id: str
|
||
) -> list[dict]:
|
||
"""将 summary figures 元数据与提取的图片文件关联。
|
||
|
||
策略:
|
||
1. 优先用 manifest.json 的 label 做 ID 精确匹配
|
||
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
|
||
"""
|
||
if not figures or not images:
|
||
return figures
|
||
|
||
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
|
||
|
||
# ── 策略 1:manifest ID 精确匹配 ──
|
||
id_to_url: dict[str, str] = {}
|
||
if manifest_path.exists():
|
||
try:
|
||
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||
except (ValueError, TypeError):
|
||
manifest = {}
|
||
for filename, info in manifest.items():
|
||
url = f"/papers/{arxiv_id}/images/{filename}"
|
||
# 优先用 label 字段(新格式)
|
||
label = info.get("label", "")
|
||
if label:
|
||
id_to_url[label] = url
|
||
# 也兼容 figures/tables 列表(旧格式)
|
||
for fig_id in info.get("figures", []) + info.get("tables", []):
|
||
if fig_id not in id_to_url:
|
||
id_to_url[fig_id] = url
|
||
|
||
for fig in figures:
|
||
raw_id = fig.get("id", "")
|
||
normalized = _normalize_figure_id(raw_id)
|
||
if normalized in id_to_url:
|
||
fig["image_url"] = id_to_url[normalized]
|
||
|
||
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
|
||
unmatched = [f for f in figures if not f.get("image_url")]
|
||
if not unmatched:
|
||
return figures
|
||
|
||
# 按类型分流:Figure vs Table
|
||
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
|
||
table_type_unmatched = [
|
||
f for f in unmatched if not _is_figure_type(f.get("id", ""))
|
||
]
|
||
|
||
# 提取的图片按类型分流,按文件名中的编号排序
|
||
fig_images = sorted(
|
||
[img for img in images if "table" not in img["name"].lower()],
|
||
key=lambda img: _image_sort_key(img["name"]),
|
||
)
|
||
table_images = sorted(
|
||
[img for img in images if "table" in img["name"].lower()],
|
||
key=lambda img: _image_sort_key(img["name"]),
|
||
)
|
||
|
||
for i, fig in enumerate(fig_type_unmatched):
|
||
if i < len(fig_images):
|
||
fig["image_url"] = fig_images[i]["url"]
|
||
|
||
for i, fig in enumerate(table_type_unmatched):
|
||
if i < len(table_images):
|
||
fig["image_url"] = table_images[i]["url"]
|
||
|
||
return figures
|