Files
daily-paper/app/services/pdf_image_extractor.py
T
Rain-Bus 18f44ac244 feat: improve PDF image extraction with caption-based labeling and fallback matching
- Enhance pdf_image_extractor with caption text extraction near images/tables
- Add figure/table type correction based on caption content
- Implement sequential numbering fallback for unmatched items
- Improve figure linking in pages with manifest ID matching and fallback strategies
- Remove docling dependency, add dev dependency group
2026-06-09 14:07:21 +08:00

437 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""PDF 图片与表格提取 — 从 PDF 中提取嵌入图片和表格截图。
策略:
1. 提取 PDF 中嵌入的图片(图表、插图等),按页面位置排序
2. 检测表格区域,渲染为截图
3. 为每张图/表格提取附近的说明文字(caption),从中识别 Figure N / Table N
4. 根据 caption 内容矫正类型:标注为 "Figure" 的表格区域 → 归为图片
5. 序号匹配兜底:第 N 张图 → Figure N(学术论文图表严格按顺序出现)
6. 保存 manifest.json 供后续与 AI 总结的 figures 字段匹配
"""
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
from app.services.pdf_downloader import paper_dir
from app.utils import TMP_DIR
logger = logging.getLogger(__name__)
# 最小面积阈值(像素),小于此值的图片视为图标/装饰
_MIN_AREA = 10_000 # ~100x100
_MIN_DIM = 80
# Caption 搜索区域 — Figure caption 在图下方,Table caption 在图上方
_CAPTION_MARGIN = 10 # 贴边距离
_CAPTION_MAX_DISTANCE = 250 # 最远搜索距离
_CAPTION_SIDE_PADDING = 40 # 左右扩展
# Figure/Table 标注正则
_FIGURE_CAPTION_RE = re.compile(
r'\b(?:Fig\.?|Figure)\s*(\d+)\b', re.IGNORECASE
)
_TABLE_CAPTION_RE = re.compile(
r'\bTable\s*(\d+)\b', re.IGNORECASE
)
def _extract_caption_text(page, bbox, page_height: float, *,
search_above: bool = False,
search_both: bool = False) -> str | None:
"""从图片/表格附近区域提取 caption 文字。
search_above=True:搜索上方(Table caption 通常在上)
默认搜索下方(Figure caption 通常在下)
search_both=True:上下都搜,返回包含 Figure/Table 标注的那边
"""
import pymupdf
x0 = max(0, bbox.x0 - _CAPTION_SIDE_PADDING)
x1 = bbox.x1 + _CAPTION_SIDE_PADDING
def _search(y0: float, y1: float) -> str | None:
rect = pymupdf.Rect(x0, y0, x1, y1)
blocks = page.get_text("blocks")
parts: list[str] = []
for block in blocks:
if len(block) < 5:
continue
block_rect = pymupdf.Rect(block[:4])
if block_rect.intersects(rect):
text = str(block[4]).strip()
if text:
parts.append(text)
if parts:
return " ".join(parts)
text = page.get_textbox(rect)
if text and len(text.strip()) >= 5:
return text.strip()
return None
if search_both:
# 上方
above_y1 = max(0, bbox.y0 - _CAPTION_MARGIN)
above_y0 = max(0, bbox.y0 - _CAPTION_MAX_DISTANCE)
above = _search(above_y0, above_y1)
# 下方
below_y0 = bbox.y1 + _CAPTION_MARGIN
below_y1 = min(page_height, bbox.y1 + _CAPTION_MAX_DISTANCE)
below = _search(below_y0, below_y1)
# 优先返回包含 Figure/Table 标注的那边
if above and (_FIGURE_CAPTION_RE.search(above) or _TABLE_CAPTION_RE.search(above)):
return above
if below and (_FIGURE_CAPTION_RE.search(below) or _TABLE_CAPTION_RE.search(below)):
return below
# 否则返回更长的
if above and below:
return above if len(above) >= len(below) else below
return above or below
if search_above:
y1 = max(0, bbox.y0 - _CAPTION_MARGIN)
y0 = max(0, bbox.y0 - _CAPTION_MAX_DISTANCE)
else:
y0 = bbox.y1 + _CAPTION_MARGIN
y1 = min(page_height, bbox.y1 + _CAPTION_MAX_DISTANCE)
return _search(y0, y1)
def _identify_label(caption_text: str | None) -> str | None:
"""从 caption 文字中识别 Figure N / Table N 编号。"""
if not caption_text:
return None
m = _FIGURE_CAPTION_RE.search(caption_text)
if m:
return f"Figure {m.group(1)}"
m = _TABLE_CAPTION_RE.search(caption_text)
if m:
return f"Table {m.group(1)}"
return None
def _is_figure_caption(caption_text: str | None) -> bool:
"""判断 caption 是否标注为 Figure(用于矫正 find_tables 的误判)。"""
if not caption_text:
return False
return bool(_FIGURE_CAPTION_RE.search(caption_text))
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
"""从 PDF 提取嵌入图片和表格截图,生成 manifest。
匹配策略:
1. 提取图片→提取 caption 文字→从中识别 Figure/Table 编号
2. 表格区域若 caption 标注为 "Figure",则重分类为图片
3. 未能从 caption 识别编号的,按(页码, 纵向位置)排序后用序号匹配兜底
Args:
arxiv_id: 论文 ID
pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf
Returns:
提取的图片+表格数量
"""
import pymupdf
if pdf_path is None:
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
if not pdf_path.exists():
logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path)
return 0
images_dest = paper_dir(arxiv_id) / "images"
images_dest.mkdir(parents=True, exist_ok=True)
doc = pymupdf.open(str(pdf_path))
extracted = 0
seen_hashes: set[int] = set()
# ── 第一遍:收集所有图片和表格 ──
image_items: list[dict] = []
table_items: list[dict] = []
for page_num in range(len(doc)):
page = doc[page_num]
page_height = page.rect.height
# 1. 提取嵌入图片
image_list = page.get_images(full=True)
for img_index, img_info in enumerate(image_list):
xref = img_info[0]
try:
pix = pymupdf.Pixmap(doc, xref)
except Exception:
continue
if pix.width < _MIN_DIM or pix.height < _MIN_DIM:
continue
if pix.width * pix.height < _MIN_AREA:
continue
img_hash = hash(pix.tobytes()[:1024])
if img_hash in seen_hashes:
continue
seen_hashes.add(img_hash)
img_rects = page.get_image_rects(xref)
if not img_rects:
continue
bbox = img_rects[0]
if pix.n >= 5:
try:
pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
except Exception:
continue
filename = f"page{page_num + 1}_img{img_index + 1}.png"
pix.save(str(images_dest / filename))
extracted += 1
caption_text = _extract_caption_text(page, bbox, page_height)
label = _identify_label(caption_text)
image_items.append({
"filename": filename,
"page": page_num + 1,
"y0": bbox.y0,
"caption_text": caption_text,
"label": label,
})
# 2. 提取表格截图(同时搜索上方 caption,Table 标题通常在表格上方)
try:
tables = page.find_tables()
except Exception:
tables = None
if tables and tables.tables:
for table_index, table in enumerate(tables.tables):
bbox = table.bbox
if not bbox:
continue
margin = 5
if hasattr(bbox, 'x0'):
x0, y0, x1, y1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
table_rect = bbox
else:
x0, y0, x1, y1 = bbox
table_rect = pymupdf.Rect(x0, y0, x1, y1)
clip_rect = pymupdf.Rect(
x0 - margin, y0 - margin, x1 + margin, y1 + margin
)
zoom = 2
mat = pymupdf.Matrix(zoom, zoom)
try:
pix = page.get_pixmap(matrix=mat, clip=clip_rect)
except Exception:
continue
if pix.width < _MIN_DIM * 2 or pix.height < 30 * 2:
continue
filename = f"page{page_num + 1}_table{table_index + 1}.png"
pix.save(str(images_dest / filename))
extracted += 1
# Table caption 上下都搜(学术论文惯例:Table 标题在上方,但实际排版各异)
caption_text = _extract_caption_text(
page, table_rect, page_height, search_both=True,
)
label = _identify_label(caption_text)
item = {
"filename": filename,
"page": page_num + 1,
"y0": y0,
"caption_text": caption_text,
"label": label,
}
# 关键:caption 标注为 Figure → 重分类为图片
if _is_figure_caption(caption_text):
image_items.append(item)
else:
table_items.append(item)
doc.close()
# ── 第二遍:矫正 find_tables 的误判 ──
# 如果表格与同页的图片高度重叠(复合图表的子区域),且 caption 不含 "Table"
# 则重分类为图片,归入邻近图片的 label
for t_item in table_items[:]:
t_page = t_item["page"]
t_y0 = t_item["y0"]
same_page_images = [i for i in image_items if i["page"] == t_page]
if not same_page_images:
continue
# 检查是否有重叠的图片
nearby = [
i for i in same_page_images
if abs(i["y0"] - t_y0) < 50
]
if nearby and not (t_item["caption_text"] and _TABLE_CAPTION_RE.search(t_item["caption_text"])):
# 重分类为图片,继承邻近图片的 label
neighbor_label = nearby[0].get("label")
t_item["label"] = neighbor_label
image_items.append(t_item)
table_items.remove(t_item)
# ── 第三遍:按 (page, y0) 排序 → 序号匹配兜底 ──
image_items.sort(key=lambda it: (it["page"], it["y0"]))
table_items.sort(key=lambda it: (it["page"], it["y0"]))
# 统计已通过 caption 确认的 Figure/Table 编号,避免序号重复分配
used_figure_nums: set[int] = set()
used_table_nums: set[int] = set()
for item in image_items:
if item["label"]:
m = _FIGURE_CAPTION_RE.search(item["label"])
if m:
used_figure_nums.add(int(m.group(1)))
for item in table_items:
if item["label"]:
m = _TABLE_CAPTION_RE.search(item["label"])
if m:
used_table_nums.add(int(m.group(1)))
# 为未识别编号的图片分配序号(跳过已占用的编号)
next_fig = 1
for item in image_items:
if item["label"] is None:
while next_fig in used_figure_nums:
next_fig += 1
item["label"] = f"Figure {next_fig}"
used_figure_nums.add(next_fig)
next_tbl = 1
for item in table_items:
if item["label"] is None:
while next_tbl in used_table_nums:
next_tbl += 1
item["label"] = f"Table {next_tbl}"
used_table_nums.add(next_tbl)
# ── 第三遍:构建 manifest ──
manifest: dict[str, dict] = {}
for item in image_items:
manifest[item["filename"]] = {
"page": item["page"],
"type": "image",
"label": item["label"],
"caption_text": item.get("caption_text"),
"figures": [item["label"]],
}
for item in table_items:
manifest[item["filename"]] = {
"page": item["page"],
"type": "table",
"label": item["label"],
"caption_text": item.get("caption_text"),
"tables": [item["label"]],
}
# 保存 manifest
manifest_path = images_dest / "manifest.json"
manifest_path.write_text(
json.dumps(manifest, ensure_ascii=False, indent=2)
)
captioned = sum(
1 for it in image_items + table_items if it["caption_text"]
)
label_matched = sum(
1 for it in image_items + table_items
if it["caption_text"] and _identify_label(it["caption_text"])
)
if extracted > 0:
logger.info(
"Extracted %d items from PDF for %s "
"(%d images, %d tables, %d with captions, %d label-matched)",
extracted, arxiv_id,
len(image_items), len(table_items), captioned, label_matched,
)
return extracted
def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
"""根据 summary 中的 figures 字段过滤提取的图片/表格。
用 manifest.json 中的 label 匹配,保留被 AI 总结引用的图片。
"""
if not figures:
return 0
images_dir = paper_dir(arxiv_id) / "images"
manifest_path = images_dir / "manifest.json"
if not images_dir.exists() or not manifest_path.exists():
return 0
all_files = [f for f in images_dir.iterdir() if f.suffix == ".png"]
if not all_files:
return 0
manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8"))
# 收集 summary 中引用的所有 Figure/Table ID(归一化)
referenced_ids: set[str] = set()
for fig in figures:
fig_id = fig.get("id", "")
m = re.match(r'(?:Fig\.?|Figure)\s*(\d+)', fig_id, re.IGNORECASE)
if m:
referenced_ids.add(f"Figure {m.group(1)}")
m2 = re.match(r'Table\s*(\d+)', fig_id, re.IGNORECASE)
if m2:
referenced_ids.add(f"Table {m2.group(1)}")
if not referenced_ids:
logger.warning("No valid figure/table IDs in summary for %s", arxiv_id)
return len(all_files)
# 根据 manifest 的 label 字段匹配
keep_filenames: set[str] = set()
for filename, info in manifest.items():
label = info.get("label", "")
if label in referenced_ids:
keep_filenames.add(filename)
continue
for ref in info.get("figures", []) + info.get("tables", []):
if ref in referenced_ids:
keep_filenames.add(filename)
break
if not keep_filenames:
logger.warning(
"No manifest matches for %s (refs=%s), keeping all",
arxiv_id, referenced_ids,
)
return len(all_files)
removed = 0
for f in all_files:
if f.name not in keep_filenames:
f.unlink()
removed += 1
kept = len(all_files) - removed
logger.info(
"Filtered images for %s: kept %d, removed %d (refs=%s)",
arxiv_id, kept, removed, referenced_ids,
)
return kept