feat: improve PDF image extraction with caption-based labeling and fallback matching

- Enhance pdf_image_extractor with caption text extraction near images/tables
- Add figure/table type correction based on caption content
- Implement sequential numbering fallback for unmatched items
- Improve figure linking in pages with manifest ID matching and fallback strategies
- Remove docling dependency, add dev dependency group
This commit is contained in:
2026-06-09 14:07:21 +08:00
parent 32978b3fc5
commit 18f44ac244
4 changed files with 343 additions and 1593 deletions
+69 -21
View File
@@ -273,38 +273,86 @@ def _link_figures_with_images(
) -> list[dict]:
"""将 summary figures 元数据与提取的图片文件关联。
通过 manifest.json 中的 figure ID 匹配,给每个 figure 加上 image_url。
策略:
1. 优先用 manifest.json 的 label 做 ID 精确匹配
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
"""
if not figures or not images:
return figures
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
if not manifest_path.exists():
return figures
try:
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
except (ValueError, TypeError):
return figures
# 构建 figure_id -> image_url 的映射
# ── 策略 1manifest ID 精确匹配 ──
id_to_url: dict[str, str] = {}
for filename, info in manifest.items():
url = f"/papers/{arxiv_id}/images/{filename}"
for fig_id in info.get("figures", []) + info.get("tables", []):
id_to_url[fig_id] = url
if manifest_path.exists():
try:
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
except (ValueError, TypeError):
manifest = {}
for filename, info in manifest.items():
url = f"/papers/{arxiv_id}/images/{filename}"
# 优先用 label 字段(新格式)
label = info.get("label", "")
if label:
id_to_url[label] = url
# 也兼容 figures/tables 列表(旧格式)
for fig_id in info.get("figures", []) + info.get("tables", []):
if fig_id not in id_to_url:
id_to_url[fig_id] = url
# 归一化 summary figures 的 ID
for fig in figures:
raw_id = fig.get("id", "")
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
if m:
normalized = f"Figure {m.group(1)}"
else:
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
normalized = f"Table {m2.group(1)}" if m2 else raw_id
normalized = _normalize_figure_id(raw_id)
if normalized in id_to_url:
fig["image_url"] = id_to_url[normalized]
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
unmatched = [f for f in figures if not f.get("image_url")]
if not unmatched:
return figures
# 按类型分流:Figure vs Table
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
table_type_unmatched = [f for f in unmatched if not _is_figure_type(f.get("id", ""))]
# 提取的图片也按类型分流,按文件名排序
def _sort_key(name: str) -> tuple[int, int]:
m = re.search(r'page(\d+)_(?:img|table)(\d+)', name)
if m:
return (int(m.group(1)), int(m.group(2)))
return (0, 0)
fig_images = sorted(
[img for img in images if "table" not in img["name"].lower()],
key=lambda img: _sort_key(img["name"]),
)
table_images = sorted(
[img for img in images if "table" in img["name"].lower()],
key=lambda img: _sort_key(img["name"]),
)
for i, fig in enumerate(fig_type_unmatched):
if i < len(fig_images):
fig["image_url"] = fig_images[i]["url"]
for i, fig in enumerate(table_type_unmatched):
if i < len(table_images):
fig["image_url"] = table_images[i]["url"]
return figures
def _normalize_figure_id(raw_id: str) -> str:
"""归一化 Figure/Table ID'Figure 1'/'Fig.1''Figure 1'"""
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
if m:
return f"Figure {m.group(1)}"
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
if m2:
return f"Table {m2.group(1)}"
return raw_id
def _is_figure_type(fig_id: str) -> bool:
"""判断是否为 Figure 类型(非 Table)。"""
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
+256 -76
View File
@@ -1,11 +1,12 @@
"""PDF 图片与表格提取 — 从 PDF 中提取嵌入图片和表格截图。
策略:
1. 提取 PDF 中嵌入的图片(图表、插图等)
1. 提取 PDF 中嵌入的图片(图表、插图等),按页面位置排序
2. 检测表格区域,渲染为截图
3. 同时搜索页面中的 Figure/Table 标注,记录到 manifest
4. 过滤掉过小的图片
5. 保存到 data/papers/{arxiv_id}/images/
3. 为每张图/表格提取附近的说明文字(caption),从中识别 Figure N / Table N
4. 根据 caption 内容矫正类型:标注为 "Figure" 的表格区域 → 归为图片
5. 序号匹配兜底:第 N 张图 → Figure N(学术论文图表严格按顺序出现)
6. 保存 manifest.json 供后续与 AI 总结的 figures 字段匹配
"""
from __future__ import annotations
@@ -24,38 +25,113 @@ logger = logging.getLogger(__name__)
_MIN_AREA = 10_000 # ~100x100
_MIN_DIM = 80
# Figure/Table 标注与图片/表格的最大垂直距离(点)
_MAX_LABEL_DISTANCE = 120
# Caption 搜索区域 — Figure caption 在图下方,Table caption 在图上方
_CAPTION_MARGIN = 10 # 贴边距离
_CAPTION_MAX_DISTANCE = 250 # 最远搜索距离
_CAPTION_SIDE_PADDING = 40 # 左右扩展
# Figure/Table 标注正则
_FIGURE_RE = re.compile(r'\b(?:Fig\.?|Figure)\s*(\d+)\b', re.IGNORECASE)
_TABLE_RE = re.compile(r'\bTable\s*(\d+)\b', re.IGNORECASE)
# Figure/Table 标注正则
_FIGURE_CAPTION_RE = re.compile(
r'\b(?:Fig\.?|Figure)\s*(\d+)\b', re.IGNORECASE
)
_TABLE_CAPTION_RE = re.compile(
r'\bTable\s*(\d+)\b', re.IGNORECASE
)
def _find_nearby_labels(
rects: list, labels: dict[str, list[tuple[int, float]]], page_num: int
) -> list[str]:
"""查找与给定矩形区域在位置上接近的 Figure/Table 标注
def _extract_caption_text(page, bbox, page_height: float, *,
search_above: bool = False,
search_both: bool = False) -> str | None:
"""从图片/表格附近区域提取 caption 文字
匹配逻辑:标注的垂直位置 (y) 需在图片/表格的上下 _MAX_LABEL_DISTANCE 点范围内。
search_above=True:搜索上方(Table caption 通常在上)
默认搜索下方(Figure caption 通常在下)
search_both=True:上下都搜,返回包含 Figure/Table 标注的那边
"""
matched: list[str] = []
for rect in rects:
y_min, y_max = rect.y0, rect.y1
import pymupdf
for label_key, positions in labels.items():
for label_page, label_y in positions:
if label_page == page_num:
# 标注在图片/表格上方或下方的距离
distance = min(abs(label_y - y_min), abs(label_y - y_max))
if distance <= _MAX_LABEL_DISTANCE:
if label_key not in matched:
matched.append(label_key)
return matched
x0 = max(0, bbox.x0 - _CAPTION_SIDE_PADDING)
x1 = bbox.x1 + _CAPTION_SIDE_PADDING
def _search(y0: float, y1: float) -> str | None:
rect = pymupdf.Rect(x0, y0, x1, y1)
blocks = page.get_text("blocks")
parts: list[str] = []
for block in blocks:
if len(block) < 5:
continue
block_rect = pymupdf.Rect(block[:4])
if block_rect.intersects(rect):
text = str(block[4]).strip()
if text:
parts.append(text)
if parts:
return " ".join(parts)
text = page.get_textbox(rect)
if text and len(text.strip()) >= 5:
return text.strip()
return None
if search_both:
# 上方
above_y1 = max(0, bbox.y0 - _CAPTION_MARGIN)
above_y0 = max(0, bbox.y0 - _CAPTION_MAX_DISTANCE)
above = _search(above_y0, above_y1)
# 下方
below_y0 = bbox.y1 + _CAPTION_MARGIN
below_y1 = min(page_height, bbox.y1 + _CAPTION_MAX_DISTANCE)
below = _search(below_y0, below_y1)
# 优先返回包含 Figure/Table 标注的那边
if above and (_FIGURE_CAPTION_RE.search(above) or _TABLE_CAPTION_RE.search(above)):
return above
if below and (_FIGURE_CAPTION_RE.search(below) or _TABLE_CAPTION_RE.search(below)):
return below
# 否则返回更长的
if above and below:
return above if len(above) >= len(below) else below
return above or below
if search_above:
y1 = max(0, bbox.y0 - _CAPTION_MARGIN)
y0 = max(0, bbox.y0 - _CAPTION_MAX_DISTANCE)
else:
y0 = bbox.y1 + _CAPTION_MARGIN
y1 = min(page_height, bbox.y1 + _CAPTION_MAX_DISTANCE)
return _search(y0, y1)
def _identify_label(caption_text: str | None) -> str | None:
"""从 caption 文字中识别 Figure N / Table N 编号。"""
if not caption_text:
return None
m = _FIGURE_CAPTION_RE.search(caption_text)
if m:
return f"Figure {m.group(1)}"
m = _TABLE_CAPTION_RE.search(caption_text)
if m:
return f"Table {m.group(1)}"
return None
def _is_figure_caption(caption_text: str | None) -> bool:
"""判断 caption 是否标注为 Figure(用于矫正 find_tables 的误判)。"""
if not caption_text:
return False
return bool(_FIGURE_CAPTION_RE.search(caption_text))
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
"""从 PDF 提取嵌入图片和表格截图,同时生成 manifest。
"""从 PDF 提取嵌入图片和表格截图,生成 manifest。
匹配策略:
1. 提取图片→提取 caption 文字→从中识别 Figure/Table 编号
2. 表格区域若 caption 标注为 "Figure",则重分类为图片
3. 未能从 caption 识别编号的,按(页码, 纵向位置)排序后用序号匹配兜底
Args:
arxiv_id: 论文 ID
@@ -80,39 +156,15 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
extracted = 0
seen_hashes: set[int] = set()
# 扫描每页的 Figure/Table 标注位置
# figure_labels: {key: [(page_num, y_center)]} — 记录标注在页面中的垂直位置
figure_labels: dict[str, list[tuple[int, float]]] = {}
table_labels: dict[str, list[tuple[int, float]]] = {}
# ── 第一遍:收集所有图片和表格 ──
image_items: list[dict] = []
table_items: list[dict] = []
for page_num in range(len(doc)):
page = doc[page_num]
text_dict = page.get_text("dict")
for block in text_dict.get("blocks", []):
if block.get("type") != 0: # 只看文本块
continue
block_text = ""
for line in block.get("lines", []):
for span in line.get("spans", []):
block_text += span.get("text", "")
for m in _FIGURE_RE.finditer(block_text):
key = f"Figure {m.group(1)}"
bbox = block.get("bbox", [0, 0, 0, 0])
y_center = (bbox[1] + bbox[3]) / 2
figure_labels.setdefault(key, []).append((page_num, y_center))
for m in _TABLE_RE.finditer(block_text):
key = f"Table {m.group(1)}"
bbox = block.get("bbox", [0, 0, 0, 0])
y_center = (bbox[1] + bbox[3]) / 2
table_labels.setdefault(key, []).append((page_num, y_center))
page_height = page.rect.height
# 记录每个提取文件的元信息
manifest: dict[str, dict] = {}
for page_num in range(len(doc)):
page = doc[page_num]
# ── 1. 提取嵌入图片 ──
# 1. 提取嵌入图片
image_list = page.get_images(full=True)
for img_index, img_info in enumerate(image_list):
xref = img_info[0]
@@ -131,6 +183,11 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
continue
seen_hashes.add(img_hash)
img_rects = page.get_image_rects(xref)
if not img_rects:
continue
bbox = img_rects[0]
if pix.n >= 5:
try:
pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
@@ -140,14 +197,19 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
filename = f"page{page_num + 1}_img{img_index + 1}.png"
pix.save(str(images_dest / filename))
extracted += 1
logger.debug("Image: %s (%dx%d)", filename, pix.width, pix.height)
# 查找该图片位置附近的 Figure 标注
img_rects = page.get_image_rects(xref)
matched = _find_nearby_labels(img_rects, figure_labels, page_num)
manifest[filename] = {"page": page_num + 1, "type": "image", "figures": matched}
caption_text = _extract_caption_text(page, bbox, page_height)
label = _identify_label(caption_text)
# ── 2. 提取表格截图 ──
image_items.append({
"filename": filename,
"page": page_num + 1,
"y0": bbox.y0,
"caption_text": caption_text,
"label": label,
})
# 2. 提取表格截图(同时搜索上方 caption,Table 标题通常在表格上方)
try:
tables = page.find_tables()
except Exception:
@@ -160,8 +222,15 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
continue
margin = 5
x0, y0, x1, y1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
clip_rect = pymupdf.Rect(x0 - margin, y0 - margin, x1 + margin, y1 + margin)
if hasattr(bbox, 'x0'):
x0, y0, x1, y1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
table_rect = bbox
else:
x0, y0, x1, y1 = bbox
table_rect = pymupdf.Rect(x0, y0, x1, y1)
clip_rect = pymupdf.Rect(
x0 - margin, y0 - margin, x1 + margin, y1 + margin
)
zoom = 2
mat = pymupdf.Matrix(zoom, zoom)
@@ -176,28 +245,133 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
filename = f"page{page_num + 1}_table{table_index + 1}.png"
pix.save(str(images_dest / filename))
extracted += 1
logger.debug("Table: %s (%dx%d)", filename, pix.width, pix.height)
# 查找该表格位置附近的 Table 标注
table_rect = pymupdf.Rect(x0, y0, x1, y1)
matched = _find_nearby_labels([table_rect], table_labels, page_num)
manifest[filename] = {"page": page_num + 1, "type": "table", "tables": matched}
# Table caption 上下都搜(学术论文惯例:Table 标题在上方,但实际排版各异)
caption_text = _extract_caption_text(
page, table_rect, page_height, search_both=True,
)
label = _identify_label(caption_text)
item = {
"filename": filename,
"page": page_num + 1,
"y0": y0,
"caption_text": caption_text,
"label": label,
}
# 关键:caption 标注为 Figure → 重分类为图片
if _is_figure_caption(caption_text):
image_items.append(item)
else:
table_items.append(item)
doc.close()
# ── 第二遍:矫正 find_tables 的误判 ──
# 如果表格与同页的图片高度重叠(复合图表的子区域),且 caption 不含 "Table"
# 则重分类为图片,归入邻近图片的 label
for t_item in table_items[:]:
t_page = t_item["page"]
t_y0 = t_item["y0"]
same_page_images = [i for i in image_items if i["page"] == t_page]
if not same_page_images:
continue
# 检查是否有重叠的图片
nearby = [
i for i in same_page_images
if abs(i["y0"] - t_y0) < 50
]
if nearby and not (t_item["caption_text"] and _TABLE_CAPTION_RE.search(t_item["caption_text"])):
# 重分类为图片,继承邻近图片的 label
neighbor_label = nearby[0].get("label")
t_item["label"] = neighbor_label
image_items.append(t_item)
table_items.remove(t_item)
# ── 第三遍:按 (page, y0) 排序 → 序号匹配兜底 ──
image_items.sort(key=lambda it: (it["page"], it["y0"]))
table_items.sort(key=lambda it: (it["page"], it["y0"]))
# 统计已通过 caption 确认的 Figure/Table 编号,避免序号重复分配
used_figure_nums: set[int] = set()
used_table_nums: set[int] = set()
for item in image_items:
if item["label"]:
m = _FIGURE_CAPTION_RE.search(item["label"])
if m:
used_figure_nums.add(int(m.group(1)))
for item in table_items:
if item["label"]:
m = _TABLE_CAPTION_RE.search(item["label"])
if m:
used_table_nums.add(int(m.group(1)))
# 为未识别编号的图片分配序号(跳过已占用的编号)
next_fig = 1
for item in image_items:
if item["label"] is None:
while next_fig in used_figure_nums:
next_fig += 1
item["label"] = f"Figure {next_fig}"
used_figure_nums.add(next_fig)
next_tbl = 1
for item in table_items:
if item["label"] is None:
while next_tbl in used_table_nums:
next_tbl += 1
item["label"] = f"Table {next_tbl}"
used_table_nums.add(next_tbl)
# ── 第三遍:构建 manifest ──
manifest: dict[str, dict] = {}
for item in image_items:
manifest[item["filename"]] = {
"page": item["page"],
"type": "image",
"label": item["label"],
"caption_text": item.get("caption_text"),
"figures": [item["label"]],
}
for item in table_items:
manifest[item["filename"]] = {
"page": item["page"],
"type": "table",
"label": item["label"],
"caption_text": item.get("caption_text"),
"tables": [item["label"]],
}
# 保存 manifest
manifest_path = images_dest / "manifest.json"
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))
manifest_path.write_text(
json.dumps(manifest, ensure_ascii=False, indent=2)
)
captioned = sum(
1 for it in image_items + table_items if it["caption_text"]
)
label_matched = sum(
1 for it in image_items + table_items
if it["caption_text"] and _identify_label(it["caption_text"])
)
if extracted > 0:
logger.info("Extracted %d images+tables from PDF for %s", extracted, arxiv_id)
logger.info(
"Extracted %d items from PDF for %s "
"(%d images, %d tables, %d with captions, %d label-matched)",
extracted, arxiv_id,
len(image_items), len(table_items), captioned, label_matched,
)
return extracted
def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
"""根据 summary 中的 figures 字段过滤提取的图片/表格。
用 manifest.json 匹配,不需要 PDF 文件
用 manifest.json 中的 label 匹配,保留被 AI 总结引用的图片
"""
if not figures:
return 0
@@ -229,11 +403,14 @@ def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
logger.warning("No valid figure/table IDs in summary for %s", arxiv_id)
return len(all_files)
# 根据 manifest 判断每个文件是否被引用
# 根据 manifest 的 label 字段匹配
keep_filenames: set[str] = set()
for filename, info in manifest.items():
file_refs = info.get("figures", []) + info.get("tables", [])
for ref in file_refs:
label = info.get("label", "")
if label in referenced_ids:
keep_filenames.add(filename)
continue
for ref in info.get("figures", []) + info.get("tables", []):
if ref in referenced_ids:
keep_filenames.add(filename)
break
@@ -252,5 +429,8 @@ def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
removed += 1
kept = len(all_files) - removed
logger.info("Filtered images for %s: kept %d, removed %d (refs=%s)", arxiv_id, kept, removed, referenced_ids)
logger.info(
"Filtered images for %s: kept %d, removed %d (refs=%s)",
arxiv_id, kept, removed, referenced_ids,
)
return kept
+6 -1
View File
@@ -19,7 +19,6 @@ dependencies = [
"pymupdf>=1.25",
"itsdangerous>=2.2.0",
"bleach>=6.4.0",
"docling>=2.99.0",
]
[project.optional-dependencies]
@@ -34,3 +33,9 @@ build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["app"]
[dependency-groups]
dev = [
"pytest>=9.0.3",
"pytest-asyncio>=1.4.0",
]
Generated
+12 -1495
View File
File diff suppressed because it is too large Load Diff