"""PDF 图片与表格提取 — 从 PDF 中提取嵌入图片和表格截图。 策略: 1. 提取 PDF 中嵌入的图片(图表、插图等) 2. 检测表格区域,渲染为截图 3. 同时搜索页面中的 Figure/Table 标注,记录到 manifest 4. 过滤掉过小的图片 5. 保存到 data/papers/{arxiv_id}/images/ """ from __future__ import annotations import json import logging import re from pathlib import Path from app.services.pdf_downloader import paper_dir logger = logging.getLogger(__name__) # 最小面积阈值(像素),小于此值的图片视为图标/装饰 _MIN_AREA = 10_000 # ~100x100 _MIN_DIM = 80 # Figure/Table 标注与图片/表格的最大垂直距离(点) _MAX_LABEL_DISTANCE = 120 # Figure/Table 标注的正则 _FIGURE_RE = re.compile(r'\b(?:Fig\.?|Figure)\s*(\d+)\b', re.IGNORECASE) _TABLE_RE = re.compile(r'\bTable\s*(\d+)\b', re.IGNORECASE) def _find_nearby_labels( rects: list, labels: dict[str, list[tuple[int, float]]], page_num: int ) -> list[str]: """查找与给定矩形区域在位置上接近的 Figure/Table 标注。 匹配逻辑:标注的垂直位置 (y) 需在图片/表格的上下 _MAX_LABEL_DISTANCE 点范围内。 """ matched: list[str] = [] for rect in rects: if isinstance(rect, (list, tuple)): y_min, y_max = rect[1], rect[3] else: y_min, y_max = rect.y0, rect.y1 for label_key, positions in labels.items(): for label_page, label_y in positions: if label_page == page_num: # 标注在图片/表格上方或下方的距离 distance = min(abs(label_y - y_min), abs(label_y - y_max)) if distance <= _MAX_LABEL_DISTANCE: if label_key not in matched: matched.append(label_key) return matched def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int: """从 PDF 提取嵌入图片和表格截图,同时生成 manifest。 Args: arxiv_id: 论文 ID pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf Returns: 提取的图片+表格数量 """ import pymupdf if pdf_path is None: pdf_path = Path("data/tmp") / arxiv_id / "paper.pdf" if not pdf_path.exists(): logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path) return 0 images_dest = paper_dir(arxiv_id) / "images" images_dest.mkdir(parents=True, exist_ok=True) doc = pymupdf.open(str(pdf_path)) extracted = 0 seen_hashes: set[int] = set() # 扫描每页的 Figure/Table 标注位置 # figure_labels: {key: [(page_num, y_center)]} — 记录标注在页面中的垂直位置 figure_labels: dict[str, list[tuple[int, float]]] = {} table_labels: dict[str, list[tuple[int, float]]] = {} for page_num in range(len(doc)): page = doc[page_num] text_dict = page.get_text("dict") for block in text_dict.get("blocks", []): if block.get("type") != 0: # 只看文本块 continue block_text = "" for line in block.get("lines", []): for span in line.get("spans", []): block_text += span.get("text", "") for m in _FIGURE_RE.finditer(block_text): key = f"Figure {m.group(1)}" bbox = block.get("bbox", [0, 0, 0, 0]) y_center = (bbox[1] + bbox[3]) / 2 figure_labels.setdefault(key, []).append((page_num, y_center)) for m in _TABLE_RE.finditer(block_text): key = f"Table {m.group(1)}" bbox = block.get("bbox", [0, 0, 0, 0]) y_center = (bbox[1] + bbox[3]) / 2 table_labels.setdefault(key, []).append((page_num, y_center)) # 记录每个提取文件的元信息 manifest: dict[str, dict] = {} for page_num in range(len(doc)): page = doc[page_num] # ── 1. 提取嵌入图片 ── image_list = page.get_images(full=True) for img_index, img_info in enumerate(image_list): xref = img_info[0] try: pix = pymupdf.Pixmap(doc, xref) except Exception: continue if pix.width < _MIN_DIM or pix.height < _MIN_DIM: continue if pix.width * pix.height < _MIN_AREA: continue img_hash = hash(pix.tobytes()[:1024]) if img_hash in seen_hashes: continue seen_hashes.add(img_hash) if pix.n >= 5: try: pix = pymupdf.Pixmap(pymupdf.csRGB, pix) except Exception: continue filename = f"page{page_num + 1}_img{img_index + 1}.png" pix.save(str(images_dest / filename)) extracted += 1 logger.debug("Image: %s (%dx%d)", filename, pix.width, pix.height) # 查找该图片位置附近的 Figure 标注 img_rects = page.get_image_rects(xref) matched = _find_nearby_labels(img_rects, figure_labels, page_num) manifest[filename] = {"page": page_num + 1, "type": "image", "figures": matched} # ── 2. 提取表格截图 ── try: tables = page.find_tables() except Exception: tables = None if tables and tables.tables: for table_index, table in enumerate(tables.tables): bbox = table.bbox if not bbox: continue margin = 5 if isinstance(bbox, (list, tuple)): x0, y0, x1, y1 = bbox else: x0, y0, x1, y1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1 clip_rect = pymupdf.Rect(x0 - margin, y0 - margin, x1 + margin, y1 + margin) zoom = 2 mat = pymupdf.Matrix(zoom, zoom) try: pix = page.get_pixmap(matrix=mat, clip=clip_rect) except Exception: continue if pix.width < _MIN_DIM * 2 or pix.height < 30 * 2: continue filename = f"page{page_num + 1}_table{table_index + 1}.png" pix.save(str(images_dest / filename)) extracted += 1 logger.debug("Table: %s (%dx%d)", filename, pix.width, pix.height) # 查找该表格位置附近的 Table 标注 table_rect = pymupdf.Rect(x0, y0, x1, y1) matched = _find_nearby_labels([table_rect], table_labels, page_num) manifest[filename] = {"page": page_num + 1, "type": "table", "tables": matched} doc.close() # 保存 manifest manifest_path = images_dest / "manifest.json" manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2)) if extracted > 0: logger.info("Extracted %d images+tables from PDF for %s", extracted, arxiv_id) return extracted def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int: """根据 summary 中的 figures 字段过滤提取的图片/表格。 用 manifest.json 匹配,不需要 PDF 文件。 """ if not figures: return 0 images_dir = paper_dir(arxiv_id) / "images" manifest_path = images_dir / "manifest.json" if not images_dir.exists() or not manifest_path.exists(): return 0 all_files = [f for f in images_dir.iterdir() if f.suffix == ".png"] if not all_files: return 0 manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8")) # 收集 summary 中引用的所有 Figure/Table ID(归一化) referenced_ids: set[str] = set() for fig in figures: fig_id = fig.get("id", "") m = re.match(r'(?:Fig\.?|Figure)\s*(\d+)', fig_id, re.IGNORECASE) if m: referenced_ids.add(f"Figure {m.group(1)}") m2 = re.match(r'Table\s*(\d+)', fig_id, re.IGNORECASE) if m2: referenced_ids.add(f"Table {m2.group(1)}") if not referenced_ids: logger.warning("No valid figure/table IDs in summary for %s", arxiv_id) return len(all_files) # 根据 manifest 判断每个文件是否被引用 keep_filenames: set[str] = set() for filename, info in manifest.items(): file_refs = info.get("figures", []) + info.get("tables", []) for ref in file_refs: if ref in referenced_ids: keep_filenames.add(filename) break if not keep_filenames: logger.warning( "No manifest matches for %s (refs=%s), keeping all", arxiv_id, referenced_ids, ) return len(all_files) removed = 0 for f in all_files: if f.name not in keep_filenames: f.unlink() removed += 1 kept = len(all_files) - removed logger.info("Filtered images for %s: kept %d, removed %d (refs=%s)", arxiv_id, kept, removed, referenced_ids) return kept