Files
daily-paper/app/services/pdf_image_extractor.py
T

536 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""PDF 图片与表格提取 — 基于 caption 定位的页面区域截图。
核心思路:学术论文排版极其规整,Figure caption 在图下方,Table caption 在表格上方。
因此反过来:先找 caption 文字 → 向上/向下截取页面区域 → 渲染为 PNG。
优势(相比提取嵌入位图):
- 复合图表不会被拆成碎片(整块截取)
- 矢量图也能截取(页面渲染包含一切)
- 不依赖 find_tables()(纯文本匹配 caption
"""
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
from app.services.pdf_downloader import paper_dir
from app.utils import TMP_DIR
logger = logging.getLogger(__name__)
# ── 截取区域参数 ───────────────────────────────────────────────────────
# Figure: caption 上方搜索图的范围(点)
_FIGURE_MAX_HEIGHT = 450 # 最大向上搜索范围
_FIGURE_MIN_HEIGHT = 50 # 最小有效截图高度
_FIGURE_DEFAULT_HEIGHT = 280 # 上方未找到内容块时的默认图高度
# Table: caption 下方搜索表格的范围
_TABLE_MAX_HEIGHT = 500 # 最大向下搜索范围
_TABLE_MIN_HEIGHT = 30
# caption 左右扩展(双栏论文中 caption 可能比表格窄)
_REGION_SIDE_PADDING = 10
# 表格通常比 caption 文字宽,使用更大的水平扩展
_TABLE_SIDE_PADDING = 60
# 正文行距的 ~1.5 倍 ≈ 空白间隙阈值(学术论文紧密排版,30pt 太宽松)
_CONTENT_GAP_THRESHOLD = 20
# ── Caption 正则 ───────────────────────────────────────────────────────
# 要求以 Figure/Table 开头(避免匹配正文中的 "see Figure 3" 等)
# 支持三种 caption 格式:
# "Figure 1: Title" / "Figure 1. Title" / "Figure 1 Title"(无标点,空格分隔)
# 第三种需要后续紧跟大写字母(排除 "Figure 1 shows..." 等正文引用)
_CAPTION_RE = re.compile(
r"^(?:Fig\.?|Figure)\s+(\d+)\s*(?:[:\.]\s*|\s+(?=[A-Z]))",
re.IGNORECASE,
)
_TABLE_CAPTION_RE = re.compile(
r"^Table\s+(\d+)\s*(?:[:\.]\s*|\s+(?=[A-Z]))",
re.IGNORECASE,
)
# ── 停止信号:表格边界检测遇到以下内容时立即停止 ──
# 下一个 Figure/Table caption(如 "Table 2:" "Figure 3:" "Figure 4 Title"
_CAPTION_STOP_RE = re.compile(
r"^(?:Table|Fig\.?|Figure)\s+\d+\s*(?:[:\.]\s*|\s+[A-Z])",
re.IGNORECASE,
)
# Section header(如 "6.2 Evolution" "D.1 Dependency" "7 Conclusion"
_SECTION_STOP_RE = re.compile(
r"^(\d{1,2}(?:\.\d+)?\s+[A-Z][a-z]|[A-Z]\.\d+\s+[A-Z][a-z])"
)
def _estimate_column_x(caption: dict) -> tuple[float, float]:
"""估计 caption 所在列的水平边界(col_x0, col_x1)。
双栏论文中 caption 宽度远小于页面宽度,据此判断左右列。
单栏或跨栏 caption(宽度 >65% 页宽)返回整页宽度。
caption 居中对齐(中心接近页面中线)时按跨栏处理,使用宽范围。
"""
pw = caption["page_width"]
caption_w = caption["caption_x1"] - caption["caption_x0"]
# caption 宽度 >65% 页宽 → 单栏或跨栏
if caption_w > pw * 0.65:
return 0, pw
cx = (caption["caption_x0"] + caption["caption_x1"]) / 2
# caption 居中(中心距页面中线 <8%)→ 可能是跨栏表格,使用宽范围
if abs(cx - pw / 2) / pw < 0.08:
return (
max(0, caption["caption_x0"] - _TABLE_SIDE_PADDING * 2),
min(pw, caption["caption_x1"] + _TABLE_SIDE_PADDING * 2),
)
if cx < pw / 2:
return 0, pw / 2
else:
return pw / 2, pw
def _find_captions(doc) -> list[dict]:
"""扫描整个文档,找到所有 Figure/Table caption 的位置和信息。"""
captions = []
for page_num in range(len(doc)):
page = doc[page_num]
page_width = page.rect.width
page_height = page.rect.height
blocks = page.get_text("blocks")
for block in blocks:
if len(block) < 5:
continue
text = str(block[4]).strip()
if not text:
continue
bx0, by0, bx1, by1 = block[0], block[1], block[2], block[3]
# 只取 block 第一行做匹配(避免 block 包含多段文字干扰)
first_line = text.split("\n")[0].strip()
m = _CAPTION_RE.match(first_line)
if m:
captions.append(
{
"type": "figure",
"num": int(m.group(1)),
"label": f"Figure {m.group(1)}",
"page_num": page_num,
"caption_y0": by0,
"caption_y1": by1,
"caption_x0": bx0,
"caption_x1": bx1,
"caption_text": text,
"page_width": page_width,
"page_height": page_height,
}
)
continue
m = _TABLE_CAPTION_RE.match(first_line)
if m:
captions.append(
{
"type": "table",
"num": int(m.group(1)),
"label": f"Table {m.group(1)}",
"page_num": page_num,
"caption_y0": by0,
"caption_y1": by1,
"caption_x0": bx0,
"caption_x1": bx1,
"caption_text": text,
"page_width": page_width,
"page_height": page_height,
}
)
return captions
def _find_figure_top(page, caption: dict) -> float:
"""向上扫描页面,找到 Figure 的上边界。
策略:
1. 优先用嵌入图片定位(绝大多数 figure 包含嵌入图片,图片边界即 figure 边界)
2. 无图片时回退到文本块间隙检测(处理纯矢量图如 TikZ/matplotlib PDF
"""
caption_y = caption["caption_y0"]
col_x0, col_x1 = _estimate_column_x(caption)
cx0 = max(col_x0, caption["caption_x0"] - _REGION_SIDE_PADDING)
cx1 = min(col_x1, caption["caption_x1"] + _REGION_SIDE_PADDING)
# 同页上方最近的 Figure/Table caption(多 figure 同页时截断)
_caption_cutoff: float | None = None
for b in page.get_text("blocks"):
if len(b) < 5:
continue
by0, by1 = b[1], b[3]
if by1 >= caption_y or by1 <= caption_y - _FIGURE_MAX_HEIGHT:
continue
first_line = str(b[4]).strip().split("\n")[0].strip()
if _CAPTION_STOP_RE.match(first_line):
_caption_cutoff = by0
break
# ── 策略 1:嵌入图片定位(覆盖绝大多数 figure) ──
topmost_image_y: float | None = None
for img_info in page.get_image_info():
bbox = img_info.get("bbox")
if bbox is None:
continue
if hasattr(bbox, "x0"):
ix0, iy0, ix1, iy1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
else:
ix0, iy0, ix1, iy1 = bbox[0], bbox[1], bbox[2], bbox[3]
if iy1 <= caption_y and iy1 > caption_y - _FIGURE_MAX_HEIGHT:
if ix1 > cx0 and ix0 < cx1:
if _caption_cutoff is not None and iy0 < _caption_cutoff:
continue # 属于上方另一个 figure
if topmost_image_y is None or iy0 < topmost_image_y:
topmost_image_y = iy0
if topmost_image_y is not None:
figure_top = topmost_image_y
else:
# ── 策略 2:文本块间隙检测(纯矢量图) ──
above_blocks: list[tuple[float, float, float, float]] = []
for b in page.get_text("blocks"):
if len(b) < 5:
continue
bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3]
if by1 <= caption_y and by1 > caption_y - _FIGURE_MAX_HEIGHT:
if bx1 > cx0 and bx0 < cx1:
if col_x0 > 0 and bx0 < col_x0 - _REGION_SIDE_PADDING * 2:
continue
above_blocks.append((bx0, by0, bx1, by1))
if not above_blocks:
return max(0, caption_y - _FIGURE_DEFAULT_HEIGHT)
above_blocks.sort(key=lambda b: b[1], reverse=True)
prev_bottom = caption_y
for b in above_blocks:
if prev_bottom - b[3] > _CONTENT_GAP_THRESHOLD:
figure_top = prev_bottom - 5
break
prev_bottom = b[1]
else:
figure_top = above_blocks[-1][1]
# 同页 caption 截断
if _caption_cutoff is not None:
figure_top = max(figure_top, _caption_cutoff + 5)
# 限制最大高度
if caption_y - figure_top > _FIGURE_MAX_HEIGHT:
figure_top = caption_y - _FIGURE_MAX_HEIGHT
return max(0, figure_top)
def _find_table_region(page, caption: dict) -> tuple[float, float, float, float]:
"""向下扫描页面,找到 Table 的下边界和水平范围。
返回: (x0, bottom, x1) — 裁剪区域的左、下、右边界。
上边界由调用方根据 caption 位置确定。
策略:
1. 收集 caption 下方的文本块(表格内容是文本)
2. 找到连续内容区域的底部(遇到大间隙时停止)
3. 同时检测表格内容的水平范围(表格通常比 caption 宽)
"""
blocks = page.get_text("blocks")
caption_y = caption["caption_y1"] # caption 底部作为扫描起点
caption_x0 = caption["caption_x0"]
caption_x1 = caption["caption_x1"]
page_height = caption["page_height"]
page_width = caption["page_width"]
# 估计 caption 所在列的水平边界,避免双栏论文跨列抓取
col_x0, col_x1 = _estimate_column_x(caption)
search_x0 = max(col_x0, caption_x0 - _TABLE_SIDE_PADDING)
search_x1 = min(col_x1, caption_x1 + _TABLE_SIDE_PADDING)
below_blocks: list[tuple[float, float, float, float]] = []
for b in blocks:
if len(b) < 5:
continue
bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3]
if by0 > caption_y and by0 < caption_y + _TABLE_MAX_HEIGHT:
if bx1 > search_x0 and bx0 < search_x1:
# 双栏论文:排除跨列正文段落(宽度 >> 列宽,起点在另一列)
# 表格行起点在列内或列边界附近;正文段落起点在另一列(bx0 远小于 col_x0)
if col_x0 > 0 and bx0 < col_x0 - _TABLE_SIDE_PADDING:
continue
# 停止信号:遇到下一个 caption 或 section header 立即停止
text = str(b[4]).strip()
first_line = text.split("\n")[0].strip()
if _CAPTION_STOP_RE.match(first_line) or _SECTION_STOP_RE.match(
first_line
):
break
below_blocks.append((bx0, by0, bx1, by1))
if not below_blocks:
# 没有内容 → 使用默认高度和 caption 宽度
return (
max(0, caption_x0 - _REGION_SIDE_PADDING),
min(page_height, caption_y + _TABLE_MIN_HEIGHT),
min(page_width, caption_x1 + _REGION_SIDE_PADDING),
)
# ── 找到连续内容区域的底部 ──
below_blocks.sort(key=lambda b: b[1]) # 按 y 升序
prev_y = caption_y
bottom = below_blocks[-1][3] + 5 # 最后一块的底部 + margin
for b in below_blocks:
gap = b[1] - prev_y # b[1] = by0
if gap > _CONTENT_GAP_THRESHOLD:
bottom = prev_y + 5
break
prev_y = b[3] # b[3] = by1
# 限制最大高度
if bottom - caption_y > _TABLE_MAX_HEIGHT:
bottom = caption_y + _TABLE_MAX_HEIGHT
# ── 检测表格内容的水平范围 ──
# 只用 gap 之前的 block 计算水平范围(gap 之后的 block 属于正文,可能更宽)
table_blocks = [b for b in below_blocks if b[1] < bottom]
if not table_blocks:
table_blocks = below_blocks[:1] # 至少用第一个 block
content_x0 = min(caption_x0, min(b[0] for b in table_blocks))
content_x1 = max(caption_x1, max(b[2] for b in table_blocks))
# 添加边距,不超出页面
# 使用较小 padding,避免将相邻列内容(如同页另一列的 Figure)带入截图;
# 同时不限制列边界 — 双栏论文中 caption 可能跨列起始
x0 = max(0, content_x0 - _REGION_SIDE_PADDING)
x1 = min(page_width, content_x1 + _REGION_SIDE_PADDING)
return (x0, bottom, x1)
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
"""从 PDF 提取 Figure/Table 截图,生成 manifest。
策略:找 caption → 定位区域 → 渲染页面截图。
Args:
arxiv_id: 论文 ID
pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf
Returns:
提取的图片数量
"""
import pymupdf
if pdf_path is None:
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
if not pdf_path.exists():
logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path)
return 0
images_dest = paper_dir(arxiv_id) / "images"
images_dest.mkdir(parents=True, exist_ok=True)
# 清理上次提取的旧图片,避免残留
for old_file in images_dest.glob("*.png"):
old_file.unlink()
if (images_dest / "manifest.json").exists():
(images_dest / "manifest.json").unlink()
doc = pymupdf.open(str(pdf_path))
captions = _find_captions(doc)
if not captions:
logger.info("No Figure/Table captions found in PDF for %s", arxiv_id)
doc.close()
return 0
# 去重:同一页同一 label 可能匹配到多个 block(如正文引用 "Figure 7"
# 保留每个 (type, num) 的第一个匹配(即真正的 caption)
seen_labels: dict[str, dict] = {}
for cap in captions:
key = cap["label"]
if key not in seen_labels:
seen_labels[key] = cap
unique_captions = list(seen_labels.values())
extracted = 0
manifest: dict[str, dict] = {}
zoom = 3 # 3x 渲染,保证清晰度
for cap in unique_captions:
page = doc[cap["page_num"]]
pw = cap["page_width"]
if cap["type"] == "figure":
# Figure: caption 上方是图 → 向上找图的上边界
top = _find_figure_top(page, cap)
# 上方多留 5pt 边距,确保图框边框、装饰线等不被截断
top = max(0, top - 5)
bottom = cap["caption_y1"] + 5 # 包含 caption
# 水平范围:caption 宽度 + 边距(图和 caption 通常等宽)
# 但也要考虑图内容的实际宽度
x0 = max(0, cap["caption_x0"] - _REGION_SIDE_PADDING)
x1 = min(pw, cap["caption_x1"] + _REGION_SIDE_PADDING)
height = bottom - top
if height < _FIGURE_MIN_HEIGHT:
logger.debug(
"Figure %s too small (%.0fpt), skipping", cap["label"], height
)
continue
else:
# Table: caption 下方是表格 → 向下找表格的下边界和水平范围
x0, bottom, x1 = _find_table_region(page, cap)
top = max(0, cap["caption_y0"] - 3) # 包含 caption,上边留少许 margin
height = bottom - top
if height < _TABLE_MIN_HEIGHT:
logger.debug(
"Table %s too small (%.0fpt), skipping", cap["label"], height
)
continue
# 渲染截取
clip = pymupdf.Rect(x0, top, x1, bottom)
mat = pymupdf.Matrix(zoom, zoom)
try:
pix = page.get_pixmap(matrix=mat, clip=clip)
except Exception:
logger.debug("Failed to render %s region for %s", cap["label"], arxiv_id)
continue
filename = f"{cap['label'].replace(' ', '_').lower()}.png"
pix.save(str(images_dest / filename))
extracted += 1
cap_preview = cap["caption_text"][:200] if cap["caption_text"] else ""
manifest[filename] = {
"page": cap["page_num"] + 1,
"type": cap["type"],
"label": cap["label"],
"caption_text": cap_preview,
"figures" if cap["type"] == "figure" else "tables": [cap["label"]],
}
logger.debug(
"Rendered %s: page %d, region (%.0f,%.0f)-(%.0f,%.0f) h=%.0fpt → %s",
cap["label"],
cap["page_num"] + 1,
x0,
top,
x1,
bottom,
height,
filename,
)
doc.close()
# 保存 manifest
manifest_path = images_dest / "manifest.json"
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))
if extracted > 0:
logger.info(
"Extracted %d figure/table screenshots from PDF for %s "
"(from %d captions found, %d unique)",
extracted,
arxiv_id,
len(captions),
len(unique_captions),
)
return extracted
def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
"""根据 summary 中的 figures 字段过滤提取的图片/表格。
用 manifest.json 中的 label 匹配,保留被 AI 总结引用的图片。
"""
if not figures:
return 0
images_dir = paper_dir(arxiv_id) / "images"
manifest_path = images_dir / "manifest.json"
if not images_dir.exists() or not manifest_path.exists():
return 0
all_files = [f for f in images_dir.iterdir() if f.suffix == ".png"]
if not all_files:
return 0
manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8"))
# 收集 summary 中引用的所有 Figure/Table ID(归一化)
referenced_ids: set[str] = set()
for fig in figures:
fig_id = fig.get("id", "")
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", fig_id, re.IGNORECASE)
if m:
referenced_ids.add(f"Figure {m.group(1)}")
m2 = re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
if m2:
referenced_ids.add(f"Table {m2.group(1)}")
if not referenced_ids:
logger.warning("No valid figure/table IDs in summary for %s", arxiv_id)
return len(all_files)
# 根据 manifest 的 label 字段匹配
keep_filenames: set[str] = set()
for filename, info in manifest.items():
label = info.get("label", "")
if label in referenced_ids:
keep_filenames.add(filename)
continue
for ref in info.get("figures", []) + info.get("tables", []):
if ref in referenced_ids:
keep_filenames.add(filename)
break
if not keep_filenames:
logger.warning(
"No manifest matches for %s (refs=%s), keeping all",
arxiv_id,
referenced_ids,
)
return len(all_files)
removed = 0
for f in all_files:
if f.name not in keep_filenames:
f.unlink()
removed += 1
kept = len(all_files) - removed
logger.info(
"Filtered images for %s: kept %d, removed %d (refs=%s)",
arxiv_id,
kept,
removed,
referenced_ids,
)
return kept