Files
daily-paper/app/services/pdf_image_extractor.py
T
Rain-Bus b42e9149e5 feat: improve PDF extraction with image clustering, find_tables() integration, and JPEG output
- Add subfigure clustering in _find_figure_top(): collect all images near caption, cluster by Y proximity, use largest cluster's min y
- Add _find_figure_horizontal(): determine crop range from caption + embedded image union
- Refactor _find_table_region() to use page.find_tables() as primary method with segment merging, fallback to block-based detection
- Extract _scan_blocks_direction() for bidirectional block scanning with table data density awareness
- Add _TABLE_DATA_GAP_THRESHOLD for denser gap tolerance after table data blocks
- Fix caption regex to use (?-i:[A-Z]) for correct case-insensitive matching
- Switch image output from PNG to JPEG (5-10x smaller for web delivery)
- Update cleanup and filter to handle both .png and .jpg formats
- Reformat imports and conditional expressions in pages.py
2026-06-10 23:17:03 +08:00

751 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""PDF 图片与表格提取 — 基于 caption 定位的页面区域截图。
核心思路:学术论文排版极其规整,Figure caption 在图下方,Table caption 在表格上方。
因此反过来:先找 caption 文字 → 向上/向下截取页面区域 → 渲染为 PNG。
优势(相比提取嵌入位图):
- 复合图表不会被拆成碎片(整块截取)
- 矢量图也能截取(页面渲染包含一切)
- 不依赖 find_tables()(纯文本匹配 caption
"""
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
from app.services.pdf_downloader import paper_dir
from app.utils import TMP_DIR
logger = logging.getLogger(__name__)
# ── 截取区域参数 ───────────────────────────────────────────────────────
# Figure: caption 上方搜索图的范围(点)
_FIGURE_MAX_HEIGHT = 450 # 最大向上搜索范围
_FIGURE_MIN_HEIGHT = 50 # 最小有效截图高度
_FIGURE_DEFAULT_HEIGHT = 280 # 上方未找到内容块时的默认图高度
# Table: caption 下方搜索表格的范围
_TABLE_MAX_HEIGHT = 500 # 最大向下搜索范围
_TABLE_MIN_HEIGHT = 30
# caption 左右扩展(双栏论文中 caption 可能比表格窄)
_REGION_SIDE_PADDING = 10
# 表格通常比 caption 文字宽,使用更大的水平扩展
_TABLE_SIDE_PADDING = 60
# 正文行距的 ~1.5 倍 ≈ 空白间隙阈值(学术论文紧密排版,30pt 太宽松)
_CONTENT_GAP_THRESHOLD = 20
# 密集表格数据块后的过渡阈值:表格块之后的段落间距常只有 12-18pt
_TABLE_DATA_GAP_THRESHOLD = 12
# ── Caption 正则 ───────────────────────────────────────────────────────
# 要求以 Figure/Table 开头(避免匹配正文中的 "see Figure 3" 等)
# 支持三种 caption 格式:
# "Figure 1: Title" / "Figure 1. Title" / "Figure 1 Title"(无标点,空格分隔)
# 第三种需要后续紧跟大写字母(排除 "Figure 1 shows..." 等正文引用)
_CAPTION_RE = re.compile(
r"^(?:Fig\.?|Figure)\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))",
re.IGNORECASE,
)
_TABLE_CAPTION_RE = re.compile(
r"^Table\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))",
re.IGNORECASE,
)
# ── 停止信号:表格边界检测遇到以下内容时立即停止 ──
# 下一个 Figure/Table caption(如 "Table 2:" "Figure 3:" "Figure 4 Title"
_CAPTION_STOP_RE = re.compile(
r"^(?:Table|Fig\.?|Figure)\s+\d+\s*(?:[:\.]\s*|\s+[A-Z])",
re.IGNORECASE,
)
# Section header(如 "6.2 Evolution" "D.1 Dependency" "7 Conclusion"
_SECTION_STOP_RE = re.compile(
r"^(\d{1,2}(?:\.\d+)?\s+[A-Z][a-z]|[A-Z]\.\d+\s+[A-Z][a-z])"
)
def _estimate_column_x(caption: dict) -> tuple[float, float]:
"""估计 caption 所在列的水平边界(col_x0, col_x1)。
双栏论文中 caption 宽度远小于页面宽度,据此判断左右列。
单栏或跨栏 caption(宽度 >65% 页宽)返回整页宽度。
caption 居中对齐(中心接近页面中线)时按跨栏处理,使用宽范围。
"""
pw = caption["page_width"]
caption_w = caption["caption_x1"] - caption["caption_x0"]
# caption 宽度 >65% 页宽 → 单栏或跨栏
if caption_w > pw * 0.65:
return 0, pw
cx = (caption["caption_x0"] + caption["caption_x1"]) / 2
# caption 居中(中心距页面中线 <8%)→ 可能是跨栏表格,使用宽范围
if abs(cx - pw / 2) / pw < 0.08:
return (
max(0, caption["caption_x0"] - _TABLE_SIDE_PADDING * 2),
min(pw, caption["caption_x1"] + _TABLE_SIDE_PADDING * 2),
)
if cx < pw / 2:
return 0, pw / 2
else:
return pw / 2, pw
def _find_captions(doc) -> list[dict]:
"""扫描整个文档,找到所有 Figure/Table caption 的位置和信息。"""
captions = []
for page_num in range(len(doc)):
page = doc[page_num]
page_width = page.rect.width
page_height = page.rect.height
blocks = page.get_text("blocks")
for block in blocks:
if len(block) < 5:
continue
text = str(block[4]).strip()
if not text:
continue
bx0, by0, bx1, by1 = block[0], block[1], block[2], block[3]
# 只取 block 第一行做匹配(避免 block 包含多段文字干扰)
first_line = text.split("\n")[0].strip()
m = _CAPTION_RE.match(first_line)
if m:
captions.append(
{
"type": "figure",
"num": int(m.group(1)),
"label": f"Figure {m.group(1)}",
"page_num": page_num,
"caption_y0": by0,
"caption_y1": by1,
"caption_x0": bx0,
"caption_x1": bx1,
"caption_text": text,
"page_width": page_width,
"page_height": page_height,
}
)
continue
m = _TABLE_CAPTION_RE.match(first_line)
if m:
captions.append(
{
"type": "table",
"num": int(m.group(1)),
"label": f"Table {m.group(1)}",
"page_num": page_num,
"caption_y0": by0,
"caption_y1": by1,
"caption_x0": bx0,
"caption_x1": bx1,
"caption_text": text,
"page_width": page_width,
"page_height": page_height,
}
)
return captions
def _find_figure_top(page, caption: dict) -> float:
"""向上扫描页面,找到 Figure 的上边界。
策略:
1. 优先用嵌入图片定位 — 收集 caption 上方所有相关图片 bbox
按 Y 轴聚类后取最大簇的最小 y 作为上界(处理 subfigure 组合图)
2. 无图片时回退到文本块间隙检测(处理纯矢量图如 TikZ/matplotlib PDF
"""
caption_y = caption["caption_y0"]
col_x0, col_x1 = _estimate_column_x(caption)
cx0 = max(col_x0, caption["caption_x0"] - _REGION_SIDE_PADDING)
cx1 = min(col_x1, caption["caption_x1"] + _REGION_SIDE_PADDING)
# 同页上方最近的 Figure/Table caption(多 figure 同页时截断)
_caption_cutoff: float | None = None
for b in page.get_text("blocks"):
if len(b) < 5:
continue
by0, by1 = b[1], b[3]
if by1 >= caption_y or by1 <= caption_y - _FIGURE_MAX_HEIGHT:
continue
first_line = str(b[4]).strip().split("\n")[0].strip()
if _CAPTION_STOP_RE.match(first_line):
_caption_cutoff = by0
break
# ── 策略 1:嵌入图片聚类定位 ──
# 收集 caption 上方搜索范围内所有与 caption 水平区域重叠的图片
image_tops: list[float] = []
for img_info in page.get_image_info():
bbox = img_info.get("bbox")
if bbox is None:
continue
if hasattr(bbox, "x0"):
ix0, iy0, ix1, iy1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
else:
ix0, iy0, ix1, iy1 = bbox[0], bbox[1], bbox[2], bbox[3]
# 图片底部必须在 caption 上方、且在搜索范围内
if not (iy1 <= caption_y and iy1 > caption_y - _FIGURE_MAX_HEIGHT):
continue
# 图片水平范围与 caption 所在列有重叠
if not (ix1 > cx0 and ix0 < cx1):
continue
# 跳过属于上方另一个 figure 的图片
if _caption_cutoff is not None and iy0 < _caption_cutoff:
continue
# 跳过极小图标(宽度或高度 <15pt,通常是 logo/符号)
if (ix1 - ix0) < 15 or (iy1 - iy0) < 15:
continue
image_tops.append(iy0)
if image_tops:
# 聚类:将 Y 轴接近的图片视为同一组(subfigure),最大簇的最小 y 即图上界
image_tops.sort()
# 用简单单遍聚类:相邻图片 top 差 < 最大高度的 40% 视为同簇
cluster_gap = _FIGURE_MAX_HEIGHT * 0.4
clusters: list[list[float]] = [[image_tops[0]]]
for yt in image_tops[1:]:
if yt - clusters[-1][-1] < cluster_gap:
clusters[-1].append(yt)
else:
clusters.append([yt])
# 取最大簇(图片数最多的)的最小 y
biggest = max(clusters, key=len)
figure_top = min(biggest)
else:
# ── 策略 2:文本块间隙检测(纯矢量图) ──
above_blocks: list[tuple[float, float, float, float]] = []
for b in page.get_text("blocks"):
if len(b) < 5:
continue
bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3]
if by1 <= caption_y and by1 > caption_y - _FIGURE_MAX_HEIGHT:
if bx1 > cx0 and bx0 < cx1:
if col_x0 > 0 and bx0 < col_x0 - _REGION_SIDE_PADDING * 2:
continue
above_blocks.append((bx0, by0, bx1, by1))
if not above_blocks:
return max(0, caption_y - _FIGURE_DEFAULT_HEIGHT)
above_blocks.sort(key=lambda b: b[1], reverse=True)
prev_bottom = caption_y
for b in above_blocks:
if prev_bottom - b[3] > _CONTENT_GAP_THRESHOLD:
figure_top = prev_bottom - 5
break
prev_bottom = b[1]
else:
figure_top = above_blocks[-1][1]
# 同页 caption 截断
if _caption_cutoff is not None:
figure_top = max(figure_top, _caption_cutoff + 5)
# 限制最大高度
if caption_y - figure_top > _FIGURE_MAX_HEIGHT:
figure_top = caption_y - _FIGURE_MAX_HEIGHT
return max(0, figure_top)
def _find_figure_horizontal(
page, caption: dict, top: float, bottom: float
) -> tuple[float, float]:
"""确定 Figure 的水平裁剪范围。
取 caption 宽度和图片实际宽度的并集,避免截断比 caption 更宽的图。
"""
pw = caption["page_width"]
x0 = caption["caption_x0"]
x1 = caption["caption_x1"]
# 收集裁剪区域内所有嵌入图片的水平范围
col_x0, col_x1 = _estimate_column_x(caption)
for img_info in page.get_image_info():
bbox = img_info.get("bbox")
if bbox is None:
continue
if hasattr(bbox, "x0"):
ix0, iy0, ix1, iy1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
else:
ix0, iy0, ix1, iy1 = bbox[0], bbox[1], bbox[2], bbox[3]
# 图片在裁剪区域内且在 caption 所在列
if iy0 < bottom and iy1 > top and ix1 > col_x0 and ix0 < col_x1:
if (ix1 - ix0) < 15:
continue # 跳过小图标
x0 = min(x0, ix0)
x1 = max(x1, ix1)
return max(0, x0 - _REGION_SIDE_PADDING), min(pw, x1 + _REGION_SIDE_PADDING)
def _find_table_region(page, caption: dict) -> tuple[float, float, float, float]:
"""向下扫描页面,找到 Table 的下边界和水平范围。
返回: (x0, bottom, x1) — 裁剪区域的左、下、右边界。
上边界由调用方根据 caption 位置确定。
策略:
1. 用 page.find_tables() 收集 caption 下方所有相邻的表格段,合并为一个完整区域
(学术论文表格常被拆成表头行 + 数据行等多个 find_tables 段)
2. 未命中时回退到文本块间隙检测
"""
caption_y = caption["caption_y1"] # caption 底部作为扫描起点
caption_x0 = caption["caption_x0"]
caption_x1 = caption["caption_x1"]
page_width = caption["page_width"]
# ── 策略 1: find_tables() 结构化检测 + 合并相邻段 ──
try:
tables = page.find_tables()
except Exception:
tables = None
if tables and tables.tables:
# 确定 caption 所在栏的范围(防止双栏论文中跨栏收集)
col_x0, col_x1 = _estimate_column_x(caption)
# 收集 caption 下方附近且在同一栏内的表格段 bbox
segments: list[tuple[float, float, float, float]] = []
for t in tables.tables:
tb = t.bbox
if isinstance(tb, (list, tuple)):
tx0, ty0, tx1, ty1 = (
float(tb[0]),
float(tb[1]),
float(tb[2]),
float(tb[3]),
)
else:
tx0, ty0, tx1, ty1 = (
float(tb.x0),
float(tb.y0),
float(tb.x1),
float(tb.y1),
)
# 表格段上边在 caption 底部附近,且与 caption 同栏
if (
ty0 >= caption_y - 5
and ty0 < caption_y + 200
and tx1 > col_x0
and tx0 < col_x1
):
segments.append((tx0, ty0, tx1, ty1))
if segments:
# 按 y 排序,合并相邻段(gap < 30pt 视为同一表格的连续部分)
segments.sort(key=lambda s: s[1])
merged: list[tuple[float, float, float, float]] = [segments[0]]
for seg in segments[1:]:
prev = merged[-1]
gap = seg[1] - prev[3] # 当前段 top - 上一段 bottom
if gap < 30:
# 合并:取并集范围
merged[-1] = (
min(prev[0], seg[0]),
min(prev[1], seg[1]),
max(prev[2], seg[2]),
max(prev[3], seg[3]),
)
else:
merged.append(seg)
# 取第一个合并段(最靠近 caption 的完整表格)
final = merged[0]
tx0, ty0, tx1, ty1 = final
# 限制最大高度
if ty1 - caption_y > _TABLE_MAX_HEIGHT:
ty1 = caption_y + _TABLE_MAX_HEIGHT
x0 = max(0, min(caption_x0, tx0) - _REGION_SIDE_PADDING)
x1 = min(page_width, max(caption_x1, tx1) + _REGION_SIDE_PADDING)
logger.debug(
"Table detected by find_tables() (%d segments merged): "
"(%.0f,%.0f)-(%.0f,%.0f)",
len(segments),
x0,
caption_y,
x1,
ty1,
)
return (x0, caption["caption_y0"], ty1, x1)
# ── 策略 2: 回退到文本块间隙检测 ──
x0, t_top, t_bottom, x1 = _find_table_region_by_blocks(page, caption)
return (x0, t_top, t_bottom, x1)
def _scan_blocks_direction(
blocks: list,
start_y: float,
col_x0: float,
col_x1: float,
direction: int,
max_range: float,
) -> list[tuple[float, float, float, float]]:
"""从 start_y 向上(direction=-1)或向下(direction=1)扫描文本块。
收集间隙连续的块,遇到 stop 信号(caption / section header)或大间隙即停。
用 current_top/current_bottom 追踪连通区域边界,正确处理 y 重叠块。
Returns:
收集到的块列表 [(x0, y0, x1, y1), ...]
"""
# 过滤在扫描范围内的块
if direction > 0: # 向下
candidates = [
b
for b in blocks
if len(b) >= 5
and b[1] > start_y
and b[1] < start_y + max_range
and b[2] > col_x0
and b[0] < col_x1
]
candidates.sort(key=lambda b: b[1]) # 按 y0 升序
else: # 向上
candidates = [
b
for b in blocks
if len(b) >= 5
and b[3] <= start_y
and b[1] > start_y - max_range
and b[2] > col_x0
and b[0] < col_x1
]
candidates.sort(key=lambda b: b[3], reverse=True) # 按 y1 降序(底部离 start_y 最近的在前)
if not candidates:
return []
# 从 start_y 开始,追踪连通区域边界
connected: list[tuple[float, float, float, float]] = []
boundary = start_y # 当前连通区域离 start_y 最近端的 y 坐标
prev_was_dense_table = False
for b in candidates:
bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3]
text = str(b[4]).strip()
first_line = text.split("\n")[0].strip()
# stop 信号
if _CAPTION_STOP_RE.match(first_line) or _SECTION_STOP_RE.match(first_line):
break
# 检查当前块是否与连通区域相连(间隙 < 阈值)
if direction > 0:
gap = by0 - boundary
else:
gap = boundary - by1
# 密集表格数据块后使用更低的间隙阈值
threshold = (
_TABLE_DATA_GAP_THRESHOLD
if prev_was_dense_table
else _CONTENT_GAP_THRESHOLD
)
if gap > threshold:
break
connected.append((bx0, by0, bx1, by1))
# 更新连通区域边界
if direction > 0:
boundary = by1 # 向下扩展
else:
boundary = min(boundary, by0) # 向上扩展
# 判断当前块是否为密集表格数据(行密度高)
lines = [l for l in text.split("\n") if l.strip()]
block_height = by1 - by0
prev_was_dense_table = (
len(lines) >= 4
and block_height > 0
and len(lines) / block_height >= 0.08
)
return connected
def _find_table_region_by_blocks(
page, caption: dict
) -> tuple[float, float, float]:
"""文本块间隙检测 — 作为 find_tables() 的 fallback。
向下扫描找表格下边界,向上扫描找表格上边界(处理 caption 在数据下方)。
使用 _scan_blocks_direction 统一双向扫描逻辑。
"""
blocks = page.get_text("blocks")
caption_y0 = caption["caption_y0"]
caption_y1 = caption["caption_y1"]
caption_x0 = caption["caption_x0"]
caption_x1 = caption["caption_x1"]
page_width = caption["page_width"]
page_height = caption["page_height"]
col_x0, col_x1 = _estimate_column_x(caption)
# 向下扫描
below = _scan_blocks_direction(
blocks, caption_y1, col_x0, col_x1, direction=1, max_range=_TABLE_MAX_HEIGHT
)
# 向上扫描
above = _scan_blocks_direction(
blocks, caption_y0, col_x0, col_x1, direction=-1, max_range=_TABLE_MAX_HEIGHT
)
# 确定上下边界
scan_top = min(b[1] for b in above) if above else caption_y0
scan_bottom = max(b[3] for b in below) if below else caption_y1
top = scan_top
bottom = scan_bottom + 5 # 底部 padding
if bottom - top > _TABLE_MAX_HEIGHT:
bottom = top + _TABLE_MAX_HEIGHT
# 水平范围:caption + 所有纳入块
all_blocks = above + below
if all_blocks:
content_x0 = min(caption_x0, min(b[0] for b in all_blocks))
content_x1 = max(caption_x1, max(b[2] for b in all_blocks))
else:
content_x0 = caption_x0
content_x1 = caption_x1
x0 = max(0, content_x0 - _REGION_SIDE_PADDING)
x1 = min(page_width, content_x1 + _REGION_SIDE_PADDING)
return (x0, top, bottom, x1)
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
"""从 PDF 提取 Figure/Table 截图,生成 manifest。
策略:找 caption → 定位区域 → 渲染页面截图。
Args:
arxiv_id: 论文 ID
pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf
Returns:
提取的图片数量
"""
import pymupdf
if pdf_path is None:
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
if not pdf_path.exists():
logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path)
return 0
images_dest = paper_dir(arxiv_id) / "images"
images_dest.mkdir(parents=True, exist_ok=True)
# 清理上次提取的旧图片,避免残留(同时清理 .png 和 .jpg)
for old_file in images_dest.iterdir():
if old_file.suffix.lower() in (".png", ".jpg", ".jpeg"):
old_file.unlink()
if (images_dest / "manifest.json").exists():
(images_dest / "manifest.json").unlink()
doc = pymupdf.open(str(pdf_path))
captions = _find_captions(doc)
if not captions:
logger.info("No Figure/Table captions found in PDF for %s", arxiv_id)
doc.close()
return 0
# 去重:同一页同一 label 可能匹配到多个 block(如正文引用 "Figure 7"
# 保留每个 (type, num) 的第一个匹配(即真正的 caption)
seen_labels: dict[str, dict] = {}
for cap in captions:
key = cap["label"]
if key not in seen_labels:
seen_labels[key] = cap
unique_captions = list(seen_labels.values())
extracted = 0
manifest: dict[str, dict] = {}
zoom = 3 # 3x 渲染,保证清晰度
for cap in unique_captions:
page = doc[cap["page_num"]]
if cap["type"] == "figure":
# Figure: caption 上方是图 → 向上找图的上边界
top = _find_figure_top(page, cap)
# 上方多留 5pt 边距,确保图框边框、装饰线等不被截断
top = max(0, top - 5)
bottom = cap["caption_y1"] + 5 # 包含 caption
# 水平范围:取 caption 宽度和图片实际宽度的并集
x0, x1 = _find_figure_horizontal(page, cap, top, bottom)
height = bottom - top
if height < _FIGURE_MIN_HEIGHT:
logger.debug(
"Figure %s too small (%.0fpt), skipping", cap["label"], height
)
continue
else:
# Table: 找表格区域(find_tables() → 块级 fallback,双向扫描)
x0, tbl_top, bottom, x1 = _find_table_region(page, cap)
top = max(0, tbl_top - 5) # 包含 caption 及上方数据,留 5pt margin
height = bottom - top
if height < _TABLE_MIN_HEIGHT:
logger.debug(
"Table %s too small (%.0fpt), skipping", cap["label"], height
)
continue
# 渲染截取
clip = pymupdf.Rect(x0, top, x1, bottom)
mat = pymupdf.Matrix(zoom, zoom)
try:
pix = page.get_pixmap(matrix=mat, clip=clip)
except Exception:
logger.debug("Failed to render %s region for %s", cap["label"], arxiv_id)
continue
# 保存为 JPEG(比 PNG 小 5-10 倍,适合网络传输)
filename = f"{cap['label'].replace(' ', '_').lower()}.jpg"
jpeg_path = images_dest / filename
jpeg_bytes = pix.tobytes("jpeg")
jpeg_path.write_bytes(jpeg_bytes)
extracted += 1
cap_preview = cap["caption_text"][:200] if cap["caption_text"] else ""
manifest[filename] = {
"page": cap["page_num"] + 1,
"type": cap["type"],
"label": cap["label"],
"caption_text": cap_preview,
"figures" if cap["type"] == "figure" else "tables": [cap["label"]],
}
logger.debug(
"Rendered %s: page %d, region (%.0f,%.0f)-(%.0f,%.0f) h=%.0fpt → %s",
cap["label"],
cap["page_num"] + 1,
x0,
top,
x1,
bottom,
height,
filename,
)
doc.close()
# 保存 manifest
manifest_path = images_dest / "manifest.json"
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))
if extracted > 0:
logger.info(
"Extracted %d figure/table screenshots from PDF for %s "
"(from %d captions found, %d unique)",
extracted,
arxiv_id,
len(captions),
len(unique_captions),
)
return extracted
def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
"""根据 summary 中的 figures 字段过滤提取的图片/表格。
用 manifest.json 中的 label 匹配,保留被 AI 总结引用的图片。
"""
if not figures:
return 0
images_dir = paper_dir(arxiv_id) / "images"
manifest_path = images_dir / "manifest.json"
if not images_dir.exists() or not manifest_path.exists():
return 0
all_files = [
f for f in images_dir.iterdir() if f.suffix.lower() in (".png", ".jpg", ".jpeg")
]
if not all_files:
return 0
manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8"))
# 收集 summary 中引用的所有 Figure/Table ID(归一化)
referenced_ids: set[str] = set()
for fig in figures:
fig_id = fig.get("id", "")
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", fig_id, re.IGNORECASE)
if m:
referenced_ids.add(f"Figure {m.group(1)}")
m2 = re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
if m2:
referenced_ids.add(f"Table {m2.group(1)}")
if not referenced_ids:
logger.warning("No valid figure/table IDs in summary for %s", arxiv_id)
return len(all_files)
# 根据 manifest 的 label 字段匹配
keep_filenames: set[str] = set()
for filename, info in manifest.items():
label = info.get("label", "")
if label in referenced_ids:
keep_filenames.add(filename)
continue
for ref in info.get("figures", []) + info.get("tables", []):
if ref in referenced_ids:
keep_filenames.add(filename)
break
if not keep_filenames:
logger.warning(
"No manifest matches for %s (refs=%s), keeping all",
arxiv_id,
referenced_ids,
)
return len(all_files)
removed = 0
for f in all_files:
if f.name not in keep_filenames:
f.unlink()
removed += 1
kept = len(all_files) - removed
logger.info(
"Filtered images for %s: kept %d, removed %d (refs=%s)",
arxiv_id,
kept,
removed,
referenced_ids,
)
return kept