Files
daily-paper/app/services/pdf_image_extractor.py
T
Rain-Bus 0d293422ac feat: enhance UI, refactor services, improve templates and tests
- Replace image_extractor with pdf_image_extractor service
- Enhance pi_client with expanded API capabilities
- Improve summarizer service with additional features
- Update admin routes with more endpoints
- Add login page template
- Enhance detail page with comprehensive layout
- Improve search and trends pages
- Update base template with additional elements
- Refactor tests for better coverage
- Add validate_summary script
- Update project configuration and dependencies
2026-06-07 19:38:58 +08:00

262 lines
9.0 KiB
Python

"""PDF 图片与表格提取 — 从 PDF 中提取嵌入图片和表格截图。
策略:
1. 提取 PDF 中嵌入的图片(图表、插图等)
2. 检测表格区域,渲染为截图
3. 同时搜索页面中的 Figure/Table 标注,记录到 manifest
4. 过滤掉过小的图片
5. 保存到 data/papers/{arxiv_id}/images/
"""
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
from app.services.pdf_downloader import paper_dir
logger = logging.getLogger(__name__)
# 最小面积阈值(像素),小于此值的图片视为图标/装饰
_MIN_AREA = 10_000 # ~100x100
_MIN_DIM = 80
# Figure/Table 标注与图片/表格的最大垂直距离(点)
_MAX_LABEL_DISTANCE = 120
# Figure/Table 标注的正则
_FIGURE_RE = re.compile(r'\b(?:Fig\.?|Figure)\s*(\d+)\b', re.IGNORECASE)
_TABLE_RE = re.compile(r'\bTable\s*(\d+)\b', re.IGNORECASE)
def _find_nearby_labels(
rects: list, labels: dict[str, list[tuple[int, float]]], page_num: int
) -> list[str]:
"""查找与给定矩形区域在位置上接近的 Figure/Table 标注。
匹配逻辑:标注的垂直位置 (y) 需在图片/表格的上下 _MAX_LABEL_DISTANCE 点范围内。
"""
matched: list[str] = []
for rect in rects:
if isinstance(rect, (list, tuple)):
y_min, y_max = rect[1], rect[3]
else:
y_min, y_max = rect.y0, rect.y1
for label_key, positions in labels.items():
for label_page, label_y in positions:
if label_page == page_num:
# 标注在图片/表格上方或下方的距离
distance = min(abs(label_y - y_min), abs(label_y - y_max))
if distance <= _MAX_LABEL_DISTANCE:
if label_key not in matched:
matched.append(label_key)
return matched
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
"""从 PDF 提取嵌入图片和表格截图,同时生成 manifest。
Args:
arxiv_id: 论文 ID
pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf
Returns:
提取的图片+表格数量
"""
import pymupdf
if pdf_path is None:
pdf_path = Path("data/tmp") / arxiv_id / "paper.pdf"
if not pdf_path.exists():
logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path)
return 0
images_dest = paper_dir(arxiv_id) / "images"
images_dest.mkdir(parents=True, exist_ok=True)
doc = pymupdf.open(str(pdf_path))
extracted = 0
seen_hashes: set[int] = set()
# 扫描每页的 Figure/Table 标注位置
# figure_labels: {key: [(page_num, y_center)]} — 记录标注在页面中的垂直位置
figure_labels: dict[str, list[tuple[int, float]]] = {}
table_labels: dict[str, list[tuple[int, float]]] = {}
for page_num in range(len(doc)):
page = doc[page_num]
text_dict = page.get_text("dict")
for block in text_dict.get("blocks", []):
if block.get("type") != 0: # 只看文本块
continue
block_text = ""
for line in block.get("lines", []):
for span in line.get("spans", []):
block_text += span.get("text", "")
for m in _FIGURE_RE.finditer(block_text):
key = f"Figure {m.group(1)}"
bbox = block.get("bbox", [0, 0, 0, 0])
y_center = (bbox[1] + bbox[3]) / 2
figure_labels.setdefault(key, []).append((page_num, y_center))
for m in _TABLE_RE.finditer(block_text):
key = f"Table {m.group(1)}"
bbox = block.get("bbox", [0, 0, 0, 0])
y_center = (bbox[1] + bbox[3]) / 2
table_labels.setdefault(key, []).append((page_num, y_center))
# 记录每个提取文件的元信息
manifest: dict[str, dict] = {}
for page_num in range(len(doc)):
page = doc[page_num]
# ── 1. 提取嵌入图片 ──
image_list = page.get_images(full=True)
for img_index, img_info in enumerate(image_list):
xref = img_info[0]
try:
pix = pymupdf.Pixmap(doc, xref)
except Exception:
continue
if pix.width < _MIN_DIM or pix.height < _MIN_DIM:
continue
if pix.width * pix.height < _MIN_AREA:
continue
img_hash = hash(pix.tobytes()[:1024])
if img_hash in seen_hashes:
continue
seen_hashes.add(img_hash)
if pix.n >= 5:
try:
pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
except Exception:
continue
filename = f"page{page_num + 1}_img{img_index + 1}.png"
pix.save(str(images_dest / filename))
extracted += 1
logger.debug("Image: %s (%dx%d)", filename, pix.width, pix.height)
# 查找该图片位置附近的 Figure 标注
img_rects = page.get_image_rects(xref)
matched = _find_nearby_labels(img_rects, figure_labels, page_num)
manifest[filename] = {"page": page_num + 1, "type": "image", "figures": matched}
# ── 2. 提取表格截图 ──
try:
tables = page.find_tables()
except Exception:
tables = None
if tables and tables.tables:
for table_index, table in enumerate(tables.tables):
bbox = table.bbox
if not bbox:
continue
margin = 5
if isinstance(bbox, (list, tuple)):
x0, y0, x1, y1 = bbox
else:
x0, y0, x1, y1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1
clip_rect = pymupdf.Rect(x0 - margin, y0 - margin, x1 + margin, y1 + margin)
zoom = 2
mat = pymupdf.Matrix(zoom, zoom)
try:
pix = page.get_pixmap(matrix=mat, clip=clip_rect)
except Exception:
continue
if pix.width < _MIN_DIM * 2 or pix.height < 30 * 2:
continue
filename = f"page{page_num + 1}_table{table_index + 1}.png"
pix.save(str(images_dest / filename))
extracted += 1
logger.debug("Table: %s (%dx%d)", filename, pix.width, pix.height)
# 查找该表格位置附近的 Table 标注
table_rect = pymupdf.Rect(x0, y0, x1, y1)
matched = _find_nearby_labels([table_rect], table_labels, page_num)
manifest[filename] = {"page": page_num + 1, "type": "table", "tables": matched}
doc.close()
# 保存 manifest
manifest_path = images_dest / "manifest.json"
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))
if extracted > 0:
logger.info("Extracted %d images+tables from PDF for %s", extracted, arxiv_id)
return extracted
def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int:
"""根据 summary 中的 figures 字段过滤提取的图片/表格。
用 manifest.json 匹配,不需要 PDF 文件。
"""
if not figures:
return 0
images_dir = paper_dir(arxiv_id) / "images"
manifest_path = images_dir / "manifest.json"
if not images_dir.exists() or not manifest_path.exists():
return 0
all_files = [f for f in images_dir.iterdir() if f.suffix == ".png"]
if not all_files:
return 0
manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8"))
# 收集 summary 中引用的所有 Figure/Table ID(归一化)
referenced_ids: set[str] = set()
for fig in figures:
fig_id = fig.get("id", "")
m = re.match(r'(?:Fig\.?|Figure)\s*(\d+)', fig_id, re.IGNORECASE)
if m:
referenced_ids.add(f"Figure {m.group(1)}")
m2 = re.match(r'Table\s*(\d+)', fig_id, re.IGNORECASE)
if m2:
referenced_ids.add(f"Table {m2.group(1)}")
if not referenced_ids:
logger.warning("No valid figure/table IDs in summary for %s", arxiv_id)
return len(all_files)
# 根据 manifest 判断每个文件是否被引用
keep_filenames: set[str] = set()
for filename, info in manifest.items():
file_refs = info.get("figures", []) + info.get("tables", [])
for ref in file_refs:
if ref in referenced_ids:
keep_filenames.add(filename)
break
if not keep_filenames:
logger.warning(
"No manifest matches for %s (refs=%s), keeping all",
arxiv_id, referenced_ids,
)
return len(all_files)
removed = 0
for f in all_files:
if f.name not in keep_filenames:
f.unlink()
removed += 1
kept = len(all_files) - removed
logger.info("Filtered images for %s: kept %d, removed %d (refs=%s)", arxiv_id, kept, removed, referenced_ids)
return kept