18f44ac244
- Enhance pdf_image_extractor with caption text extraction near images/tables - Add figure/table type correction based on caption content - Implement sequential numbering fallback for unmatched items - Improve figure linking in pages with manifest ID matching and fallback strategies - Remove docling dependency, add dev dependency group
437 lines
14 KiB
Python
437 lines
14 KiB
Python
"""PDF 图片与表格提取 — 从 PDF 中提取嵌入图片和表格截图。
|
||
|
||
策略:
|
||
1. 提取 PDF 中嵌入的图片(图表、插图等),按页面位置排序
|
||
2. 检测表格区域,渲染为截图
|
||
3. 为每张图/表格提取附近的说明文字(caption),从中识别 Figure N / Table N
|
||
4. 根据 caption 内容矫正类型:标注为 "Figure" 的表格区域 → 归为图片
|
||
5. 序号匹配兜底:第 N 张图 → Figure N(学术论文图表严格按顺序出现)
|
||
6. 保存 manifest.json 供后续与 AI 总结的 figures 字段匹配
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
from pathlib import Path
|
||
|
||
from app.services.pdf_downloader import paper_dir
|
||
from app.utils import TMP_DIR
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 最小面积阈值(像素),小于此值的图片视为图标/装饰
|
||
_MIN_AREA = 10_000 # ~100x100
|
||
_MIN_DIM = 80
|
||
|
||
# Caption 搜索区域 — Figure caption 在图下方,Table caption 在图上方
|
||
_CAPTION_MARGIN = 10 # 贴边距离
|
||
_CAPTION_MAX_DISTANCE = 250 # 最远搜索距离
|
||
_CAPTION_SIDE_PADDING = 40 # 左右扩展
|
||
|
||
# Figure/Table 标注正则
|
||
_FIGURE_CAPTION_RE = re.compile(
|
||
r'\b(?:Fig\.?|Figure)\s*(\d+)\b', re.IGNORECASE
|
||
)
|
||
_TABLE_CAPTION_RE = re.compile(
|
||
r'\bTable\s*(\d+)\b', re.IGNORECASE
|
||
)
|
||
|
||
|
||
def _extract_caption_text(page, bbox, page_height: float, *,
|
||
search_above: bool = False,
|
||
search_both: bool = False) -> str | None:
|
||
"""从图片/表格附近区域提取 caption 文字。
|
||
|
||
search_above=True:搜索上方(Table caption 通常在上)
|
||
默认搜索下方(Figure caption 通常在下)
|
||
search_both=True:上下都搜,返回包含 Figure/Table 标注的那边
|
||
"""
|
||
import pymupdf
|
||
|
||
x0 = max(0, bbox.x0 - _CAPTION_SIDE_PADDING)
|
||
x1 = bbox.x1 + _CAPTION_SIDE_PADDING
|
||
|
||
def _search(y0: float, y1: float) -> str | None:
|
||
rect = pymupdf.Rect(x0, y0, x1, y1)
|
||
blocks = page.get_text("blocks")
|
||
parts: list[str] = []
|
||
for block in blocks:
|
||
if len(block) < 5:
|
||
continue
|
||
block_rect = pymupdf.Rect(block[:4])
|
||
if block_rect.intersects(rect):
|
||
text = str(block[4]).strip()
|
||
if text:
|
||
parts.append(text)
|
||
if parts:
|
||
return " ".join(parts)
|
||
text = page.get_textbox(rect)
|
||
if text and len(text.strip()) >= 5:
|
||
return text.strip()
|
||
return None
|
||
|
||
if search_both:
|
||
# 上方
|
||
above_y1 = max(0, bbox.y0 - _CAPTION_MARGIN)
|
||
above_y0 = max(0, bbox.y0 - _CAPTION_MAX_DISTANCE)
|
||
above = _search(above_y0, above_y1)
|
||
# 下方
|
||
below_y0 = bbox.y1 + _CAPTION_MARGIN
|
||
below_y1 = min(page_height, bbox.y1 + _CAPTION_MAX_DISTANCE)
|
||
below = _search(below_y0, below_y1)
|
||
|
||
# 优先返回包含 Figure/Table 标注的那边
|
||
if above and (_FIGURE_CAPTION_RE.search(above) or _TABLE_CAPTION_RE.search(above)):
|
||
return above
|
||
if below and (_FIGURE_CAPTION_RE.search(below) or _TABLE_CAPTION_RE.search(below)):
|
||
return below
|
||
# 否则返回更长的
|
||
if above and below:
|
||
return above if len(above) >= len(below) else below
|
||
return above or below
|
||
|
||
if search_above:
|
||
y1 = max(0, bbox.y0 - _CAPTION_MARGIN)
|
||
y0 = max(0, bbox.y0 - _CAPTION_MAX_DISTANCE)
|
||
else:
|
||
y0 = bbox.y1 + _CAPTION_MARGIN
|
||
y1 = min(page_height, bbox.y1 + _CAPTION_MAX_DISTANCE)
|
||
|
||
return _search(y0, y1)
|
||
|
||
|
||
def _identify_label(caption_text: str | None) -> str | None:
|
||
"""从 caption 文字中识别 Figure N / Table N 编号。"""
|
||
if not caption_text:
|
||
return None
|
||
|
||
m = _FIGURE_CAPTION_RE.search(caption_text)
|
||
if m:
|
||
return f"Figure {m.group(1)}"
|
||
|
||
m = _TABLE_CAPTION_RE.search(caption_text)
|
||
if m:
|
||
return f"Table {m.group(1)}"
|
||
|
||
return None
|
||
|
||
|
||
def _is_figure_caption(caption_text: str | None) -> bool:
|
||
"""判断 caption 是否标注为 Figure(用于矫正 find_tables 的误判)。"""
|
||
if not caption_text:
|
||
return False
|
||
return bool(_FIGURE_CAPTION_RE.search(caption_text))
|
||
|
||
|
||
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
|
||
"""从 PDF 提取嵌入图片和表格截图,生成 manifest。
|
||
|
||
匹配策略:
|
||
1. 提取图片→提取 caption 文字→从中识别 Figure/Table 编号
|
||
2. 表格区域若 caption 标注为 "Figure",则重分类为图片
|
||
3. 未能从 caption 识别编号的,按(页码, 纵向位置)排序后用序号匹配兜底
|
||
|
||
Args:
|
||
arxiv_id: 论文 ID
|
||
pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf
|
||
|
||
Returns:
|
||
提取的图片+表格数量
|
||
"""
|
||
import pymupdf
|
||
|
||
if pdf_path is None:
|
||
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
|
||
|
||
if not pdf_path.exists():
|
||
logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path)
|
||
return 0
|
||
|
||
images_dest = paper_dir(arxiv_id) / "images"
|
||
images_dest.mkdir(parents=True, exist_ok=True)
|
||
|
||
doc = pymupdf.open(str(pdf_path))
|
||
extracted = 0
|
||
seen_hashes: set[int] = set()
|
||
|
||
# ── 第一遍:收集所有图片和表格 ──
|
||
image_items: list[dict] = []
|
||
table_items: list[dict] = []
|
||
|
||
for page_num in range(len(doc)):
|
||
page = doc[page_num]
|
||
page_height = page.rect.height
|
||
|
||
# 1. 提取嵌入图片
|
||
image_list = page.get_images(full=True)
|
||
for img_index, img_info in enumerate(image_list):
|
||
xref = img_info[0]
|
||
try:
|
||
pix = pymupdf.Pixmap(doc, xref)
|
||
except Exception:
|
||
continue
|
||
|
||
if pix.width < _MIN_DIM or pix.height < _MIN_DIM:
|
||
continue
|
||
if pix.width * pix.height < _MIN_AREA:
|
||
continue
|
||
|
||
img_hash = hash(pix.tobytes()[:1024])
|
||
if img_hash in seen_hashes:
|
||
continue
|
||
seen_hashes.add(img_hash)
|
||
|
||
img_rects = page.get_image_rects(xref)
|
||
if not img_rects:
|
||
continue
|
||
bbox = img_rects[0]
|
||
|
||
if pix.n >= 5:
|
||
try:
|
||
pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
|
||
except Exception:
|
||
continue
|
||
|
||
filename = f"page{page_num + 1}_img{img_index + 1}.png"
|
||
pix.save(str(images_dest / filename))
|
||
extracted += 1
|
||
|
||
caption_text = _extract_caption_text(page, bbox, page_height)
|
||
label = _identify_label(caption_text)
|
||
|
||
image_items.append({
|
||
"filename": filename,
|
||
"page": page_num + 1,
|
||
"y0": bbox.y0,
|
||
"caption_text": caption_text,
|
||
"label": label,
|
||
})
|
||
|
||
# 2. 提取表格截图(同时搜索上方 caption,Table 标题通常在表格上方)
|
||
try:
|
||
tables = page.find_tables()
|
||
except Exception:
|
||
tables = None
|
||
|
||
if tables and tables.tables:
|
||
for table_index, table in enumerate(tables.tables):
|
||
bbox = table.bbox
|
||
if not bbox:
|
||
continue
|
||
|
||
margin = 5
|
||
if hasattr(bbox, 'x0'):
|
||
x0, y0, x1, y1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
|
||
table_rect = bbox
|
||
else:
|
||
x0, y0, x1, y1 = bbox
|
||
table_rect = pymupdf.Rect(x0, y0, x1, y1)
|
||
clip_rect = pymupdf.Rect(
|
||
x0 - margin, y0 - margin, x1 + margin, y1 + margin
|
||
)
|
||
|
||
zoom = 2
|
||
mat = pymupdf.Matrix(zoom, zoom)
|
||
try:
|
||
pix = page.get_pixmap(matrix=mat, clip=clip_rect)
|
||
except Exception:
|
||
continue
|
||
|
||
if pix.width < _MIN_DIM * 2 or pix.height < 30 * 2:
|
||
continue
|
||
|
||
filename = f"page{page_num + 1}_table{table_index + 1}.png"
|
||
pix.save(str(images_dest / filename))
|
||
extracted += 1
|
||
|
||
# Table caption 上下都搜(学术论文惯例:Table 标题在上方,但实际排版各异)
|
||
caption_text = _extract_caption_text(
|
||
page, table_rect, page_height, search_both=True,
|
||
)
|
||
label = _identify_label(caption_text)
|
||
|
||
item = {
|
||
"filename": filename,
|
||
"page": page_num + 1,
|
||
"y0": y0,
|
||
"caption_text": caption_text,
|
||
"label": label,
|
||
}
|
||
|
||
# 关键:caption 标注为 Figure → 重分类为图片
|
||
if _is_figure_caption(caption_text):
|
||
image_items.append(item)
|
||
else:
|
||
table_items.append(item)
|
||
|
||
doc.close()
|
||
|
||
# ── 第二遍:矫正 find_tables 的误判 ──
|
||
# 如果表格与同页的图片高度重叠(复合图表的子区域),且 caption 不含 "Table",
|
||
# 则重分类为图片,归入邻近图片的 label
|
||
for t_item in table_items[:]:
|
||
t_page = t_item["page"]
|
||
t_y0 = t_item["y0"]
|
||
same_page_images = [i for i in image_items if i["page"] == t_page]
|
||
if not same_page_images:
|
||
continue
|
||
# 检查是否有重叠的图片
|
||
nearby = [
|
||
i for i in same_page_images
|
||
if abs(i["y0"] - t_y0) < 50
|
||
]
|
||
if nearby and not (t_item["caption_text"] and _TABLE_CAPTION_RE.search(t_item["caption_text"])):
|
||
# 重分类为图片,继承邻近图片的 label
|
||
neighbor_label = nearby[0].get("label")
|
||
t_item["label"] = neighbor_label
|
||
image_items.append(t_item)
|
||
table_items.remove(t_item)
|
||
|
||
# ── 第三遍:按 (page, y0) 排序 → 序号匹配兜底 ──
|
||
image_items.sort(key=lambda it: (it["page"], it["y0"]))
|
||
table_items.sort(key=lambda it: (it["page"], it["y0"]))
|
||
|
||
# 统计已通过 caption 确认的 Figure/Table 编号,避免序号重复分配
|
||
used_figure_nums: set[int] = set()
|
||
used_table_nums: set[int] = set()
|
||
for item in image_items:
|
||
if item["label"]:
|
||
m = _FIGURE_CAPTION_RE.search(item["label"])
|
||
if m:
|
||
used_figure_nums.add(int(m.group(1)))
|
||
for item in table_items:
|
||
if item["label"]:
|
||
m = _TABLE_CAPTION_RE.search(item["label"])
|
||
if m:
|
||
used_table_nums.add(int(m.group(1)))
|
||
|
||
# 为未识别编号的图片分配序号(跳过已占用的编号)
|
||
next_fig = 1
|
||
for item in image_items:
|
||
if item["label"] is None:
|
||
while next_fig in used_figure_nums:
|
||
next_fig += 1
|
||
item["label"] = f"Figure {next_fig}"
|
||
used_figure_nums.add(next_fig)
|
||
|
||
next_tbl = 1
|
||
for item in table_items:
|
||
if item["label"] is None:
|
||
while next_tbl in used_table_nums:
|
||
next_tbl += 1
|
||
item["label"] = f"Table {next_tbl}"
|
||
used_table_nums.add(next_tbl)
|
||
|
||
# ── 第三遍:构建 manifest ──
|
||
manifest: dict[str, dict] = {}
|
||
for item in image_items:
|
||
manifest[item["filename"]] = {
|
||
"page": item["page"],
|
||
"type": "image",
|
||
"label": item["label"],
|
||
"caption_text": item.get("caption_text"),
|
||
"figures": [item["label"]],
|
||
}
|
||
for item in table_items:
|
||
manifest[item["filename"]] = {
|
||
"page": item["page"],
|
||
"type": "table",
|
||
"label": item["label"],
|
||
"caption_text": item.get("caption_text"),
|
||
"tables": [item["label"]],
|
||
}
|
||
|
||
# 保存 manifest
|
||
manifest_path = images_dest / "manifest.json"
|
||
manifest_path.write_text(
|
||
json.dumps(manifest, ensure_ascii=False, indent=2)
|
||
)
|
||
|
||
captioned = sum(
|
||
1 for it in image_items + table_items if it["caption_text"]
|
||
)
|
||
label_matched = sum(
|
||
1 for it in image_items + table_items
|
||
if it["caption_text"] and _identify_label(it["caption_text"])
|
||
)
|
||
|
||
if extracted > 0:
|
||
logger.info(
|
||
"Extracted %d items from PDF for %s "
|
||
"(%d images, %d tables, %d with captions, %d label-matched)",
|
||
extracted, arxiv_id,
|
||
len(image_items), len(table_items), captioned, label_matched,
|
||
)
|
||
|
||
return extracted
|
||
|
||
|
||
def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
|
||
"""根据 summary 中的 figures 字段过滤提取的图片/表格。
|
||
|
||
用 manifest.json 中的 label 匹配,保留被 AI 总结引用的图片。
|
||
"""
|
||
if not figures:
|
||
return 0
|
||
|
||
images_dir = paper_dir(arxiv_id) / "images"
|
||
manifest_path = images_dir / "manifest.json"
|
||
|
||
if not images_dir.exists() or not manifest_path.exists():
|
||
return 0
|
||
|
||
all_files = [f for f in images_dir.iterdir() if f.suffix == ".png"]
|
||
if not all_files:
|
||
return 0
|
||
|
||
manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||
|
||
# 收集 summary 中引用的所有 Figure/Table ID(归一化)
|
||
referenced_ids: set[str] = set()
|
||
for fig in figures:
|
||
fig_id = fig.get("id", "")
|
||
m = re.match(r'(?:Fig\.?|Figure)\s*(\d+)', fig_id, re.IGNORECASE)
|
||
if m:
|
||
referenced_ids.add(f"Figure {m.group(1)}")
|
||
m2 = re.match(r'Table\s*(\d+)', fig_id, re.IGNORECASE)
|
||
if m2:
|
||
referenced_ids.add(f"Table {m2.group(1)}")
|
||
|
||
if not referenced_ids:
|
||
logger.warning("No valid figure/table IDs in summary for %s", arxiv_id)
|
||
return len(all_files)
|
||
|
||
# 根据 manifest 的 label 字段匹配
|
||
keep_filenames: set[str] = set()
|
||
for filename, info in manifest.items():
|
||
label = info.get("label", "")
|
||
if label in referenced_ids:
|
||
keep_filenames.add(filename)
|
||
continue
|
||
for ref in info.get("figures", []) + info.get("tables", []):
|
||
if ref in referenced_ids:
|
||
keep_filenames.add(filename)
|
||
break
|
||
|
||
if not keep_filenames:
|
||
logger.warning(
|
||
"No manifest matches for %s (refs=%s), keeping all",
|
||
arxiv_id, referenced_ids,
|
||
)
|
||
return len(all_files)
|
||
|
||
removed = 0
|
||
for f in all_files:
|
||
if f.name not in keep_filenames:
|
||
f.unlink()
|
||
removed += 1
|
||
|
||
kept = len(all_files) - removed
|
||
logger.info(
|
||
"Filtered images for %s: kept %d, removed %d (refs=%s)",
|
||
arxiv_id, kept, removed, referenced_ids,
|
||
)
|
||
return kept
|