b42e9149e5
- Add subfigure clustering in _find_figure_top(): collect all images near caption, cluster by Y proximity, use largest cluster's min y - Add _find_figure_horizontal(): determine crop range from caption + embedded image union - Refactor _find_table_region() to use page.find_tables() as primary method with segment merging, fallback to block-based detection - Extract _scan_blocks_direction() for bidirectional block scanning with table data density awareness - Add _TABLE_DATA_GAP_THRESHOLD for denser gap tolerance after table data blocks - Fix caption regex to use (?-i:[A-Z]) for correct case-insensitive matching - Switch image output from PNG to JPEG (5-10x smaller for web delivery) - Update cleanup and filter to handle both .png and .jpg formats - Reformat imports and conditional expressions in pages.py
751 lines
26 KiB
Python
751 lines
26 KiB
Python
"""PDF 图片与表格提取 — 基于 caption 定位的页面区域截图。
|
||
|
||
核心思路:学术论文排版极其规整,Figure caption 在图下方,Table caption 在表格上方。
|
||
因此反过来:先找 caption 文字 → 向上/向下截取页面区域 → 渲染为 PNG。
|
||
|
||
优势(相比提取嵌入位图):
|
||
- 复合图表不会被拆成碎片(整块截取)
|
||
- 矢量图也能截取(页面渲染包含一切)
|
||
- 不依赖 find_tables()(纯文本匹配 caption)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
from pathlib import Path
|
||
|
||
from app.services.pdf_downloader import paper_dir
|
||
from app.utils import TMP_DIR
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ── 截取区域参数 ───────────────────────────────────────────────────────
|
||
|
||
# Figure: caption 上方搜索图的范围(点)
|
||
_FIGURE_MAX_HEIGHT = 450 # 最大向上搜索范围
|
||
_FIGURE_MIN_HEIGHT = 50 # 最小有效截图高度
|
||
_FIGURE_DEFAULT_HEIGHT = 280 # 上方未找到内容块时的默认图高度
|
||
|
||
# Table: caption 下方搜索表格的范围
|
||
_TABLE_MAX_HEIGHT = 500 # 最大向下搜索范围
|
||
_TABLE_MIN_HEIGHT = 30
|
||
|
||
# caption 左右扩展(双栏论文中 caption 可能比表格窄)
|
||
_REGION_SIDE_PADDING = 10
|
||
# 表格通常比 caption 文字宽,使用更大的水平扩展
|
||
_TABLE_SIDE_PADDING = 60
|
||
|
||
# 正文行距的 ~1.5 倍 ≈ 空白间隙阈值(学术论文紧密排版,30pt 太宽松)
|
||
_CONTENT_GAP_THRESHOLD = 20
|
||
# 密集表格数据块后的过渡阈值:表格块之后的段落间距常只有 12-18pt
|
||
_TABLE_DATA_GAP_THRESHOLD = 12
|
||
|
||
|
||
# ── Caption 正则 ───────────────────────────────────────────────────────
|
||
|
||
# 要求以 Figure/Table 开头(避免匹配正文中的 "see Figure 3" 等)
|
||
# 支持三种 caption 格式:
|
||
# "Figure 1: Title" / "Figure 1. Title" / "Figure 1 Title"(无标点,空格分隔)
|
||
# 第三种需要后续紧跟大写字母(排除 "Figure 1 shows..." 等正文引用)
|
||
_CAPTION_RE = re.compile(
|
||
r"^(?:Fig\.?|Figure)\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))",
|
||
re.IGNORECASE,
|
||
)
|
||
_TABLE_CAPTION_RE = re.compile(
|
||
r"^Table\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))",
|
||
re.IGNORECASE,
|
||
)
|
||
|
||
# ── 停止信号:表格边界检测遇到以下内容时立即停止 ──
|
||
|
||
# 下一个 Figure/Table caption(如 "Table 2:" "Figure 3:" "Figure 4 Title")
|
||
_CAPTION_STOP_RE = re.compile(
|
||
r"^(?:Table|Fig\.?|Figure)\s+\d+\s*(?:[:\.]\s*|\s+[A-Z])",
|
||
re.IGNORECASE,
|
||
)
|
||
# Section header(如 "6.2 Evolution" "D.1 Dependency" "7 Conclusion")
|
||
_SECTION_STOP_RE = re.compile(
|
||
r"^(\d{1,2}(?:\.\d+)?\s+[A-Z][a-z]|[A-Z]\.\d+\s+[A-Z][a-z])"
|
||
)
|
||
|
||
|
||
def _estimate_column_x(caption: dict) -> tuple[float, float]:
|
||
"""估计 caption 所在列的水平边界(col_x0, col_x1)。
|
||
|
||
双栏论文中 caption 宽度远小于页面宽度,据此判断左右列。
|
||
单栏或跨栏 caption(宽度 >65% 页宽)返回整页宽度。
|
||
caption 居中对齐(中心接近页面中线)时按跨栏处理,使用宽范围。
|
||
"""
|
||
pw = caption["page_width"]
|
||
caption_w = caption["caption_x1"] - caption["caption_x0"]
|
||
|
||
# caption 宽度 >65% 页宽 → 单栏或跨栏
|
||
if caption_w > pw * 0.65:
|
||
return 0, pw
|
||
|
||
cx = (caption["caption_x0"] + caption["caption_x1"]) / 2
|
||
|
||
# caption 居中(中心距页面中线 <8%)→ 可能是跨栏表格,使用宽范围
|
||
if abs(cx - pw / 2) / pw < 0.08:
|
||
return (
|
||
max(0, caption["caption_x0"] - _TABLE_SIDE_PADDING * 2),
|
||
min(pw, caption["caption_x1"] + _TABLE_SIDE_PADDING * 2),
|
||
)
|
||
|
||
if cx < pw / 2:
|
||
return 0, pw / 2
|
||
else:
|
||
return pw / 2, pw
|
||
|
||
|
||
def _find_captions(doc) -> list[dict]:
|
||
"""扫描整个文档,找到所有 Figure/Table caption 的位置和信息。"""
|
||
captions = []
|
||
|
||
for page_num in range(len(doc)):
|
||
page = doc[page_num]
|
||
page_width = page.rect.width
|
||
page_height = page.rect.height
|
||
blocks = page.get_text("blocks")
|
||
|
||
for block in blocks:
|
||
if len(block) < 5:
|
||
continue
|
||
text = str(block[4]).strip()
|
||
if not text:
|
||
continue
|
||
|
||
bx0, by0, bx1, by1 = block[0], block[1], block[2], block[3]
|
||
# 只取 block 第一行做匹配(避免 block 包含多段文字干扰)
|
||
first_line = text.split("\n")[0].strip()
|
||
|
||
m = _CAPTION_RE.match(first_line)
|
||
if m:
|
||
captions.append(
|
||
{
|
||
"type": "figure",
|
||
"num": int(m.group(1)),
|
||
"label": f"Figure {m.group(1)}",
|
||
"page_num": page_num,
|
||
"caption_y0": by0,
|
||
"caption_y1": by1,
|
||
"caption_x0": bx0,
|
||
"caption_x1": bx1,
|
||
"caption_text": text,
|
||
"page_width": page_width,
|
||
"page_height": page_height,
|
||
}
|
||
)
|
||
continue
|
||
|
||
m = _TABLE_CAPTION_RE.match(first_line)
|
||
if m:
|
||
captions.append(
|
||
{
|
||
"type": "table",
|
||
"num": int(m.group(1)),
|
||
"label": f"Table {m.group(1)}",
|
||
"page_num": page_num,
|
||
"caption_y0": by0,
|
||
"caption_y1": by1,
|
||
"caption_x0": bx0,
|
||
"caption_x1": bx1,
|
||
"caption_text": text,
|
||
"page_width": page_width,
|
||
"page_height": page_height,
|
||
}
|
||
)
|
||
|
||
return captions
|
||
|
||
|
||
def _find_figure_top(page, caption: dict) -> float:
|
||
"""向上扫描页面,找到 Figure 的上边界。
|
||
|
||
策略:
|
||
1. 优先用嵌入图片定位 — 收集 caption 上方所有相关图片 bbox,
|
||
按 Y 轴聚类后取最大簇的最小 y 作为上界(处理 subfigure 组合图)
|
||
2. 无图片时回退到文本块间隙检测(处理纯矢量图如 TikZ/matplotlib PDF)
|
||
"""
|
||
caption_y = caption["caption_y0"]
|
||
col_x0, col_x1 = _estimate_column_x(caption)
|
||
cx0 = max(col_x0, caption["caption_x0"] - _REGION_SIDE_PADDING)
|
||
cx1 = min(col_x1, caption["caption_x1"] + _REGION_SIDE_PADDING)
|
||
|
||
# 同页上方最近的 Figure/Table caption(多 figure 同页时截断)
|
||
_caption_cutoff: float | None = None
|
||
for b in page.get_text("blocks"):
|
||
if len(b) < 5:
|
||
continue
|
||
by0, by1 = b[1], b[3]
|
||
if by1 >= caption_y or by1 <= caption_y - _FIGURE_MAX_HEIGHT:
|
||
continue
|
||
first_line = str(b[4]).strip().split("\n")[0].strip()
|
||
if _CAPTION_STOP_RE.match(first_line):
|
||
_caption_cutoff = by0
|
||
break
|
||
|
||
# ── 策略 1:嵌入图片聚类定位 ──
|
||
# 收集 caption 上方搜索范围内所有与 caption 水平区域重叠的图片
|
||
image_tops: list[float] = []
|
||
for img_info in page.get_image_info():
|
||
bbox = img_info.get("bbox")
|
||
if bbox is None:
|
||
continue
|
||
if hasattr(bbox, "x0"):
|
||
ix0, iy0, ix1, iy1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
|
||
else:
|
||
ix0, iy0, ix1, iy1 = bbox[0], bbox[1], bbox[2], bbox[3]
|
||
|
||
# 图片底部必须在 caption 上方、且在搜索范围内
|
||
if not (iy1 <= caption_y and iy1 > caption_y - _FIGURE_MAX_HEIGHT):
|
||
continue
|
||
# 图片水平范围与 caption 所在列有重叠
|
||
if not (ix1 > cx0 and ix0 < cx1):
|
||
continue
|
||
# 跳过属于上方另一个 figure 的图片
|
||
if _caption_cutoff is not None and iy0 < _caption_cutoff:
|
||
continue
|
||
# 跳过极小图标(宽度或高度 <15pt,通常是 logo/符号)
|
||
if (ix1 - ix0) < 15 or (iy1 - iy0) < 15:
|
||
continue
|
||
|
||
image_tops.append(iy0)
|
||
|
||
if image_tops:
|
||
# 聚类:将 Y 轴接近的图片视为同一组(subfigure),最大簇的最小 y 即图上界
|
||
image_tops.sort()
|
||
# 用简单单遍聚类:相邻图片 top 差 < 最大高度的 40% 视为同簇
|
||
cluster_gap = _FIGURE_MAX_HEIGHT * 0.4
|
||
clusters: list[list[float]] = [[image_tops[0]]]
|
||
for yt in image_tops[1:]:
|
||
if yt - clusters[-1][-1] < cluster_gap:
|
||
clusters[-1].append(yt)
|
||
else:
|
||
clusters.append([yt])
|
||
# 取最大簇(图片数最多的)的最小 y
|
||
biggest = max(clusters, key=len)
|
||
figure_top = min(biggest)
|
||
else:
|
||
# ── 策略 2:文本块间隙检测(纯矢量图) ──
|
||
above_blocks: list[tuple[float, float, float, float]] = []
|
||
for b in page.get_text("blocks"):
|
||
if len(b) < 5:
|
||
continue
|
||
bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3]
|
||
if by1 <= caption_y and by1 > caption_y - _FIGURE_MAX_HEIGHT:
|
||
if bx1 > cx0 and bx0 < cx1:
|
||
if col_x0 > 0 and bx0 < col_x0 - _REGION_SIDE_PADDING * 2:
|
||
continue
|
||
above_blocks.append((bx0, by0, bx1, by1))
|
||
|
||
if not above_blocks:
|
||
return max(0, caption_y - _FIGURE_DEFAULT_HEIGHT)
|
||
|
||
above_blocks.sort(key=lambda b: b[1], reverse=True)
|
||
prev_bottom = caption_y
|
||
for b in above_blocks:
|
||
if prev_bottom - b[3] > _CONTENT_GAP_THRESHOLD:
|
||
figure_top = prev_bottom - 5
|
||
break
|
||
prev_bottom = b[1]
|
||
else:
|
||
figure_top = above_blocks[-1][1]
|
||
|
||
# 同页 caption 截断
|
||
if _caption_cutoff is not None:
|
||
figure_top = max(figure_top, _caption_cutoff + 5)
|
||
|
||
# 限制最大高度
|
||
if caption_y - figure_top > _FIGURE_MAX_HEIGHT:
|
||
figure_top = caption_y - _FIGURE_MAX_HEIGHT
|
||
|
||
return max(0, figure_top)
|
||
|
||
|
||
def _find_figure_horizontal(
|
||
page, caption: dict, top: float, bottom: float
|
||
) -> tuple[float, float]:
|
||
"""确定 Figure 的水平裁剪范围。
|
||
|
||
取 caption 宽度和图片实际宽度的并集,避免截断比 caption 更宽的图。
|
||
"""
|
||
pw = caption["page_width"]
|
||
x0 = caption["caption_x0"]
|
||
x1 = caption["caption_x1"]
|
||
|
||
# 收集裁剪区域内所有嵌入图片的水平范围
|
||
col_x0, col_x1 = _estimate_column_x(caption)
|
||
for img_info in page.get_image_info():
|
||
bbox = img_info.get("bbox")
|
||
if bbox is None:
|
||
continue
|
||
if hasattr(bbox, "x0"):
|
||
ix0, iy0, ix1, iy1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
|
||
else:
|
||
ix0, iy0, ix1, iy1 = bbox[0], bbox[1], bbox[2], bbox[3]
|
||
# 图片在裁剪区域内且在 caption 所在列
|
||
if iy0 < bottom and iy1 > top and ix1 > col_x0 and ix0 < col_x1:
|
||
if (ix1 - ix0) < 15:
|
||
continue # 跳过小图标
|
||
x0 = min(x0, ix0)
|
||
x1 = max(x1, ix1)
|
||
|
||
return max(0, x0 - _REGION_SIDE_PADDING), min(pw, x1 + _REGION_SIDE_PADDING)
|
||
|
||
|
||
def _find_table_region(page, caption: dict) -> tuple[float, float, float, float]:
|
||
"""向下扫描页面,找到 Table 的下边界和水平范围。
|
||
|
||
返回: (x0, bottom, x1) — 裁剪区域的左、下、右边界。
|
||
上边界由调用方根据 caption 位置确定。
|
||
|
||
策略:
|
||
1. 用 page.find_tables() 收集 caption 下方所有相邻的表格段,合并为一个完整区域
|
||
(学术论文表格常被拆成表头行 + 数据行等多个 find_tables 段)
|
||
2. 未命中时回退到文本块间隙检测
|
||
"""
|
||
caption_y = caption["caption_y1"] # caption 底部作为扫描起点
|
||
caption_x0 = caption["caption_x0"]
|
||
caption_x1 = caption["caption_x1"]
|
||
page_width = caption["page_width"]
|
||
|
||
# ── 策略 1: find_tables() 结构化检测 + 合并相邻段 ──
|
||
try:
|
||
tables = page.find_tables()
|
||
except Exception:
|
||
tables = None
|
||
|
||
if tables and tables.tables:
|
||
# 确定 caption 所在栏的范围(防止双栏论文中跨栏收集)
|
||
col_x0, col_x1 = _estimate_column_x(caption)
|
||
|
||
# 收集 caption 下方附近且在同一栏内的表格段 bbox
|
||
segments: list[tuple[float, float, float, float]] = []
|
||
for t in tables.tables:
|
||
tb = t.bbox
|
||
if isinstance(tb, (list, tuple)):
|
||
tx0, ty0, tx1, ty1 = (
|
||
float(tb[0]),
|
||
float(tb[1]),
|
||
float(tb[2]),
|
||
float(tb[3]),
|
||
)
|
||
else:
|
||
tx0, ty0, tx1, ty1 = (
|
||
float(tb.x0),
|
||
float(tb.y0),
|
||
float(tb.x1),
|
||
float(tb.y1),
|
||
)
|
||
|
||
# 表格段上边在 caption 底部附近,且与 caption 同栏
|
||
if (
|
||
ty0 >= caption_y - 5
|
||
and ty0 < caption_y + 200
|
||
and tx1 > col_x0
|
||
and tx0 < col_x1
|
||
):
|
||
segments.append((tx0, ty0, tx1, ty1))
|
||
|
||
if segments:
|
||
# 按 y 排序,合并相邻段(gap < 30pt 视为同一表格的连续部分)
|
||
segments.sort(key=lambda s: s[1])
|
||
merged: list[tuple[float, float, float, float]] = [segments[0]]
|
||
for seg in segments[1:]:
|
||
prev = merged[-1]
|
||
gap = seg[1] - prev[3] # 当前段 top - 上一段 bottom
|
||
if gap < 30:
|
||
# 合并:取并集范围
|
||
merged[-1] = (
|
||
min(prev[0], seg[0]),
|
||
min(prev[1], seg[1]),
|
||
max(prev[2], seg[2]),
|
||
max(prev[3], seg[3]),
|
||
)
|
||
else:
|
||
merged.append(seg)
|
||
|
||
# 取第一个合并段(最靠近 caption 的完整表格)
|
||
final = merged[0]
|
||
tx0, ty0, tx1, ty1 = final
|
||
|
||
# 限制最大高度
|
||
if ty1 - caption_y > _TABLE_MAX_HEIGHT:
|
||
ty1 = caption_y + _TABLE_MAX_HEIGHT
|
||
x0 = max(0, min(caption_x0, tx0) - _REGION_SIDE_PADDING)
|
||
x1 = min(page_width, max(caption_x1, tx1) + _REGION_SIDE_PADDING)
|
||
logger.debug(
|
||
"Table detected by find_tables() (%d segments merged): "
|
||
"(%.0f,%.0f)-(%.0f,%.0f)",
|
||
len(segments),
|
||
x0,
|
||
caption_y,
|
||
x1,
|
||
ty1,
|
||
)
|
||
return (x0, caption["caption_y0"], ty1, x1)
|
||
|
||
# ── 策略 2: 回退到文本块间隙检测 ──
|
||
x0, t_top, t_bottom, x1 = _find_table_region_by_blocks(page, caption)
|
||
return (x0, t_top, t_bottom, x1)
|
||
|
||
|
||
def _scan_blocks_direction(
|
||
blocks: list,
|
||
start_y: float,
|
||
col_x0: float,
|
||
col_x1: float,
|
||
direction: int,
|
||
max_range: float,
|
||
) -> list[tuple[float, float, float, float]]:
|
||
"""从 start_y 向上(direction=-1)或向下(direction=1)扫描文本块。
|
||
|
||
收集间隙连续的块,遇到 stop 信号(caption / section header)或大间隙即停。
|
||
用 current_top/current_bottom 追踪连通区域边界,正确处理 y 重叠块。
|
||
|
||
Returns:
|
||
收集到的块列表 [(x0, y0, x1, y1), ...]
|
||
"""
|
||
# 过滤在扫描范围内的块
|
||
if direction > 0: # 向下
|
||
candidates = [
|
||
b
|
||
for b in blocks
|
||
if len(b) >= 5
|
||
and b[1] > start_y
|
||
and b[1] < start_y + max_range
|
||
and b[2] > col_x0
|
||
and b[0] < col_x1
|
||
]
|
||
candidates.sort(key=lambda b: b[1]) # 按 y0 升序
|
||
else: # 向上
|
||
candidates = [
|
||
b
|
||
for b in blocks
|
||
if len(b) >= 5
|
||
and b[3] <= start_y
|
||
and b[1] > start_y - max_range
|
||
and b[2] > col_x0
|
||
and b[0] < col_x1
|
||
]
|
||
candidates.sort(key=lambda b: b[3], reverse=True) # 按 y1 降序(底部离 start_y 最近的在前)
|
||
|
||
if not candidates:
|
||
return []
|
||
|
||
# 从 start_y 开始,追踪连通区域边界
|
||
connected: list[tuple[float, float, float, float]] = []
|
||
boundary = start_y # 当前连通区域离 start_y 最近端的 y 坐标
|
||
prev_was_dense_table = False
|
||
|
||
for b in candidates:
|
||
bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3]
|
||
text = str(b[4]).strip()
|
||
first_line = text.split("\n")[0].strip()
|
||
|
||
# stop 信号
|
||
if _CAPTION_STOP_RE.match(first_line) or _SECTION_STOP_RE.match(first_line):
|
||
break
|
||
|
||
# 检查当前块是否与连通区域相连(间隙 < 阈值)
|
||
if direction > 0:
|
||
gap = by0 - boundary
|
||
else:
|
||
gap = boundary - by1
|
||
|
||
# 密集表格数据块后使用更低的间隙阈值
|
||
threshold = (
|
||
_TABLE_DATA_GAP_THRESHOLD
|
||
if prev_was_dense_table
|
||
else _CONTENT_GAP_THRESHOLD
|
||
)
|
||
if gap > threshold:
|
||
break
|
||
|
||
connected.append((bx0, by0, bx1, by1))
|
||
|
||
# 更新连通区域边界
|
||
if direction > 0:
|
||
boundary = by1 # 向下扩展
|
||
else:
|
||
boundary = min(boundary, by0) # 向上扩展
|
||
|
||
# 判断当前块是否为密集表格数据(行密度高)
|
||
lines = [l for l in text.split("\n") if l.strip()]
|
||
block_height = by1 - by0
|
||
prev_was_dense_table = (
|
||
len(lines) >= 4
|
||
and block_height > 0
|
||
and len(lines) / block_height >= 0.08
|
||
)
|
||
|
||
return connected
|
||
|
||
|
||
def _find_table_region_by_blocks(
|
||
page, caption: dict
|
||
) -> tuple[float, float, float]:
|
||
"""文本块间隙检测 — 作为 find_tables() 的 fallback。
|
||
|
||
向下扫描找表格下边界,向上扫描找表格上边界(处理 caption 在数据下方)。
|
||
使用 _scan_blocks_direction 统一双向扫描逻辑。
|
||
"""
|
||
blocks = page.get_text("blocks")
|
||
caption_y0 = caption["caption_y0"]
|
||
caption_y1 = caption["caption_y1"]
|
||
caption_x0 = caption["caption_x0"]
|
||
caption_x1 = caption["caption_x1"]
|
||
page_width = caption["page_width"]
|
||
page_height = caption["page_height"]
|
||
|
||
col_x0, col_x1 = _estimate_column_x(caption)
|
||
|
||
# 向下扫描
|
||
below = _scan_blocks_direction(
|
||
blocks, caption_y1, col_x0, col_x1, direction=1, max_range=_TABLE_MAX_HEIGHT
|
||
)
|
||
# 向上扫描
|
||
above = _scan_blocks_direction(
|
||
blocks, caption_y0, col_x0, col_x1, direction=-1, max_range=_TABLE_MAX_HEIGHT
|
||
)
|
||
|
||
# 确定上下边界
|
||
scan_top = min(b[1] for b in above) if above else caption_y0
|
||
scan_bottom = max(b[3] for b in below) if below else caption_y1
|
||
|
||
top = scan_top
|
||
bottom = scan_bottom + 5 # 底部 padding
|
||
|
||
if bottom - top > _TABLE_MAX_HEIGHT:
|
||
bottom = top + _TABLE_MAX_HEIGHT
|
||
|
||
# 水平范围:caption + 所有纳入块
|
||
all_blocks = above + below
|
||
if all_blocks:
|
||
content_x0 = min(caption_x0, min(b[0] for b in all_blocks))
|
||
content_x1 = max(caption_x1, max(b[2] for b in all_blocks))
|
||
else:
|
||
content_x0 = caption_x0
|
||
content_x1 = caption_x1
|
||
|
||
x0 = max(0, content_x0 - _REGION_SIDE_PADDING)
|
||
x1 = min(page_width, content_x1 + _REGION_SIDE_PADDING)
|
||
|
||
return (x0, top, bottom, x1)
|
||
|
||
|
||
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
|
||
"""从 PDF 提取 Figure/Table 截图,生成 manifest。
|
||
|
||
策略:找 caption → 定位区域 → 渲染页面截图。
|
||
|
||
Args:
|
||
arxiv_id: 论文 ID
|
||
pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf
|
||
|
||
Returns:
|
||
提取的图片数量
|
||
"""
|
||
import pymupdf
|
||
|
||
if pdf_path is None:
|
||
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
|
||
|
||
if not pdf_path.exists():
|
||
logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path)
|
||
return 0
|
||
|
||
images_dest = paper_dir(arxiv_id) / "images"
|
||
images_dest.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 清理上次提取的旧图片,避免残留(同时清理 .png 和 .jpg)
|
||
for old_file in images_dest.iterdir():
|
||
if old_file.suffix.lower() in (".png", ".jpg", ".jpeg"):
|
||
old_file.unlink()
|
||
if (images_dest / "manifest.json").exists():
|
||
(images_dest / "manifest.json").unlink()
|
||
|
||
doc = pymupdf.open(str(pdf_path))
|
||
captions = _find_captions(doc)
|
||
|
||
if not captions:
|
||
logger.info("No Figure/Table captions found in PDF for %s", arxiv_id)
|
||
doc.close()
|
||
return 0
|
||
|
||
# 去重:同一页同一 label 可能匹配到多个 block(如正文引用 "Figure 7")
|
||
# 保留每个 (type, num) 的第一个匹配(即真正的 caption)
|
||
seen_labels: dict[str, dict] = {}
|
||
for cap in captions:
|
||
key = cap["label"]
|
||
if key not in seen_labels:
|
||
seen_labels[key] = cap
|
||
|
||
unique_captions = list(seen_labels.values())
|
||
extracted = 0
|
||
manifest: dict[str, dict] = {}
|
||
|
||
zoom = 3 # 3x 渲染,保证清晰度
|
||
|
||
for cap in unique_captions:
|
||
page = doc[cap["page_num"]]
|
||
|
||
if cap["type"] == "figure":
|
||
# Figure: caption 上方是图 → 向上找图的上边界
|
||
top = _find_figure_top(page, cap)
|
||
# 上方多留 5pt 边距,确保图框边框、装饰线等不被截断
|
||
top = max(0, top - 5)
|
||
bottom = cap["caption_y1"] + 5 # 包含 caption
|
||
# 水平范围:取 caption 宽度和图片实际宽度的并集
|
||
x0, x1 = _find_figure_horizontal(page, cap, top, bottom)
|
||
|
||
height = bottom - top
|
||
if height < _FIGURE_MIN_HEIGHT:
|
||
logger.debug(
|
||
"Figure %s too small (%.0fpt), skipping", cap["label"], height
|
||
)
|
||
continue
|
||
|
||
else:
|
||
# Table: 找表格区域(find_tables() → 块级 fallback,双向扫描)
|
||
x0, tbl_top, bottom, x1 = _find_table_region(page, cap)
|
||
top = max(0, tbl_top - 5) # 包含 caption 及上方数据,留 5pt margin
|
||
|
||
height = bottom - top
|
||
if height < _TABLE_MIN_HEIGHT:
|
||
logger.debug(
|
||
"Table %s too small (%.0fpt), skipping", cap["label"], height
|
||
)
|
||
continue
|
||
|
||
# 渲染截取
|
||
clip = pymupdf.Rect(x0, top, x1, bottom)
|
||
mat = pymupdf.Matrix(zoom, zoom)
|
||
try:
|
||
pix = page.get_pixmap(matrix=mat, clip=clip)
|
||
except Exception:
|
||
logger.debug("Failed to render %s region for %s", cap["label"], arxiv_id)
|
||
continue
|
||
|
||
# 保存为 JPEG(比 PNG 小 5-10 倍,适合网络传输)
|
||
filename = f"{cap['label'].replace(' ', '_').lower()}.jpg"
|
||
jpeg_path = images_dest / filename
|
||
jpeg_bytes = pix.tobytes("jpeg")
|
||
jpeg_path.write_bytes(jpeg_bytes)
|
||
extracted += 1
|
||
|
||
cap_preview = cap["caption_text"][:200] if cap["caption_text"] else ""
|
||
manifest[filename] = {
|
||
"page": cap["page_num"] + 1,
|
||
"type": cap["type"],
|
||
"label": cap["label"],
|
||
"caption_text": cap_preview,
|
||
"figures" if cap["type"] == "figure" else "tables": [cap["label"]],
|
||
}
|
||
logger.debug(
|
||
"Rendered %s: page %d, region (%.0f,%.0f)-(%.0f,%.0f) h=%.0fpt → %s",
|
||
cap["label"],
|
||
cap["page_num"] + 1,
|
||
x0,
|
||
top,
|
||
x1,
|
||
bottom,
|
||
height,
|
||
filename,
|
||
)
|
||
|
||
doc.close()
|
||
|
||
# 保存 manifest
|
||
manifest_path = images_dest / "manifest.json"
|
||
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))
|
||
|
||
if extracted > 0:
|
||
logger.info(
|
||
"Extracted %d figure/table screenshots from PDF for %s "
|
||
"(from %d captions found, %d unique)",
|
||
extracted,
|
||
arxiv_id,
|
||
len(captions),
|
||
len(unique_captions),
|
||
)
|
||
|
||
return extracted
|
||
|
||
|
||
def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
|
||
"""根据 summary 中的 figures 字段过滤提取的图片/表格。
|
||
|
||
用 manifest.json 中的 label 匹配,保留被 AI 总结引用的图片。
|
||
"""
|
||
if not figures:
|
||
return 0
|
||
|
||
images_dir = paper_dir(arxiv_id) / "images"
|
||
manifest_path = images_dir / "manifest.json"
|
||
|
||
if not images_dir.exists() or not manifest_path.exists():
|
||
return 0
|
||
|
||
all_files = [
|
||
f for f in images_dir.iterdir() if f.suffix.lower() in (".png", ".jpg", ".jpeg")
|
||
]
|
||
if not all_files:
|
||
return 0
|
||
|
||
manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||
|
||
# 收集 summary 中引用的所有 Figure/Table ID(归一化)
|
||
referenced_ids: set[str] = set()
|
||
for fig in figures:
|
||
fig_id = fig.get("id", "")
|
||
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", fig_id, re.IGNORECASE)
|
||
if m:
|
||
referenced_ids.add(f"Figure {m.group(1)}")
|
||
m2 = re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
|
||
if m2:
|
||
referenced_ids.add(f"Table {m2.group(1)}")
|
||
|
||
if not referenced_ids:
|
||
logger.warning("No valid figure/table IDs in summary for %s", arxiv_id)
|
||
return len(all_files)
|
||
|
||
# 根据 manifest 的 label 字段匹配
|
||
keep_filenames: set[str] = set()
|
||
for filename, info in manifest.items():
|
||
label = info.get("label", "")
|
||
if label in referenced_ids:
|
||
keep_filenames.add(filename)
|
||
continue
|
||
for ref in info.get("figures", []) + info.get("tables", []):
|
||
if ref in referenced_ids:
|
||
keep_filenames.add(filename)
|
||
break
|
||
|
||
if not keep_filenames:
|
||
logger.warning(
|
||
"No manifest matches for %s (refs=%s), keeping all",
|
||
arxiv_id,
|
||
referenced_ids,
|
||
)
|
||
return len(all_files)
|
||
|
||
removed = 0
|
||
for f in all_files:
|
||
if f.name not in keep_filenames:
|
||
f.unlink()
|
||
removed += 1
|
||
|
||
kept = len(all_files) - removed
|
||
logger.info(
|
||
"Filtered images for %s: kept %d, removed %d (refs=%s)",
|
||
arxiv_id,
|
||
kept,
|
||
removed,
|
||
referenced_ids,
|
||
)
|
||
return kept
|