"""PDF 图片与表格提取 — 基于 caption 定位的页面区域截图。 核心思路:学术论文排版极其规整,Figure caption 在图下方,Table caption 在表格上方。 因此反过来:先找 caption 文字 → 向上/向下截取页面区域 → 渲染为 PNG。 优势(相比提取嵌入位图): - 复合图表不会被拆成碎片(整块截取) - 矢量图也能截取(页面渲染包含一切) - 不依赖 find_tables()(纯文本匹配 caption) """ from __future__ import annotations import json import logging import re from pathlib import Path from app.services.pdf_downloader import paper_dir from app.utils import TMP_DIR logger = logging.getLogger(__name__) # ── 截取区域参数 ─────────────────────────────────────────────────────── # Figure: caption 上方搜索图的范围(点) _FIGURE_MAX_HEIGHT = 450 # 最大向上搜索范围 _FIGURE_MIN_HEIGHT = 50 # 最小有效截图高度 _FIGURE_DEFAULT_HEIGHT = 280 # 上方未找到内容块时的默认图高度 # Table: caption 下方搜索表格的范围 _TABLE_MAX_HEIGHT = 500 # 最大向下搜索范围 _TABLE_MIN_HEIGHT = 30 # caption 左右扩展(双栏论文中 caption 可能比表格窄) _REGION_SIDE_PADDING = 10 # 表格通常比 caption 文字宽,使用更大的水平扩展 _TABLE_SIDE_PADDING = 60 # 正文行距的 ~1.5 倍 ≈ 空白间隙阈值(学术论文紧密排版,30pt 太宽松) _CONTENT_GAP_THRESHOLD = 20 # ── Caption 正则 ─────────────────────────────────────────────────────── # 要求以 Figure/Table 开头(避免匹配正文中的 "see Figure 3" 等) # 支持三种 caption 格式: # "Figure 1: Title" / "Figure 1. Title" / "Figure 1 Title"(无标点,空格分隔) # 第三种需要后续紧跟大写字母(排除 "Figure 1 shows..." 等正文引用) _CAPTION_RE = re.compile( r"^(?:Fig\.?|Figure)\s+(\d+)\s*(?:[:\.]\s*|\s+(?=[A-Z]))", re.IGNORECASE, ) _TABLE_CAPTION_RE = re.compile( r"^Table\s+(\d+)\s*(?:[:\.]\s*|\s+(?=[A-Z]))", re.IGNORECASE, ) # ── 停止信号:表格边界检测遇到以下内容时立即停止 ── # 下一个 Figure/Table caption(如 "Table 2:" "Figure 3:" "Figure 4 Title") _CAPTION_STOP_RE = re.compile( r"^(?:Table|Fig\.?|Figure)\s+\d+\s*(?:[:\.]\s*|\s+[A-Z])", re.IGNORECASE, ) # Section header(如 "6.2 Evolution" "D.1 Dependency" "7 Conclusion") _SECTION_STOP_RE = re.compile( r"^(\d{1,2}(?:\.\d+)?\s+[A-Z][a-z]|[A-Z]\.\d+\s+[A-Z][a-z])" ) def _estimate_column_x(caption: dict) -> tuple[float, float]: """估计 caption 所在列的水平边界(col_x0, col_x1)。 双栏论文中 caption 宽度远小于页面宽度,据此判断左右列。 单栏或跨栏 caption(宽度 >65% 页宽)返回整页宽度。 caption 居中对齐(中心接近页面中线)时按跨栏处理,使用宽范围。 """ pw = caption["page_width"] caption_w = caption["caption_x1"] - caption["caption_x0"] # caption 宽度 >65% 页宽 → 单栏或跨栏 if caption_w > pw * 0.65: return 0, pw cx = (caption["caption_x0"] + caption["caption_x1"]) / 2 # caption 居中(中心距页面中线 <8%)→ 可能是跨栏表格,使用宽范围 if abs(cx - pw / 2) / pw < 0.08: return ( max(0, caption["caption_x0"] - _TABLE_SIDE_PADDING * 2), min(pw, caption["caption_x1"] + _TABLE_SIDE_PADDING * 2), ) if cx < pw / 2: return 0, pw / 2 else: return pw / 2, pw def _find_captions(doc) -> list[dict]: """扫描整个文档,找到所有 Figure/Table caption 的位置和信息。""" captions = [] for page_num in range(len(doc)): page = doc[page_num] page_width = page.rect.width page_height = page.rect.height blocks = page.get_text("blocks") for block in blocks: if len(block) < 5: continue text = str(block[4]).strip() if not text: continue bx0, by0, bx1, by1 = block[0], block[1], block[2], block[3] # 只取 block 第一行做匹配(避免 block 包含多段文字干扰) first_line = text.split("\n")[0].strip() m = _CAPTION_RE.match(first_line) if m: captions.append( { "type": "figure", "num": int(m.group(1)), "label": f"Figure {m.group(1)}", "page_num": page_num, "caption_y0": by0, "caption_y1": by1, "caption_x0": bx0, "caption_x1": bx1, "caption_text": text, "page_width": page_width, "page_height": page_height, } ) continue m = _TABLE_CAPTION_RE.match(first_line) if m: captions.append( { "type": "table", "num": int(m.group(1)), "label": f"Table {m.group(1)}", "page_num": page_num, "caption_y0": by0, "caption_y1": by1, "caption_x0": bx0, "caption_x1": bx1, "caption_text": text, "page_width": page_width, "page_height": page_height, } ) return captions def _find_figure_top(page, caption: dict) -> float: """向上扫描页面,找到 Figure 的上边界。 策略: 1. 优先用嵌入图片定位(绝大多数 figure 包含嵌入图片,图片边界即 figure 边界) 2. 无图片时回退到文本块间隙检测(处理纯矢量图如 TikZ/matplotlib PDF) """ caption_y = caption["caption_y0"] col_x0, col_x1 = _estimate_column_x(caption) cx0 = max(col_x0, caption["caption_x0"] - _REGION_SIDE_PADDING) cx1 = min(col_x1, caption["caption_x1"] + _REGION_SIDE_PADDING) # 同页上方最近的 Figure/Table caption(多 figure 同页时截断) _caption_cutoff: float | None = None for b in page.get_text("blocks"): if len(b) < 5: continue by0, by1 = b[1], b[3] if by1 >= caption_y or by1 <= caption_y - _FIGURE_MAX_HEIGHT: continue first_line = str(b[4]).strip().split("\n")[0].strip() if _CAPTION_STOP_RE.match(first_line): _caption_cutoff = by0 break # ── 策略 1:嵌入图片定位(覆盖绝大多数 figure) ── topmost_image_y: float | None = None for img_info in page.get_image_info(): bbox = img_info.get("bbox") if bbox is None: continue if hasattr(bbox, "x0"): ix0, iy0, ix1, iy1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1 else: ix0, iy0, ix1, iy1 = bbox[0], bbox[1], bbox[2], bbox[3] if iy1 <= caption_y and iy1 > caption_y - _FIGURE_MAX_HEIGHT: if ix1 > cx0 and ix0 < cx1: if _caption_cutoff is not None and iy0 < _caption_cutoff: continue # 属于上方另一个 figure if topmost_image_y is None or iy0 < topmost_image_y: topmost_image_y = iy0 if topmost_image_y is not None: figure_top = topmost_image_y else: # ── 策略 2:文本块间隙检测(纯矢量图) ── above_blocks: list[tuple[float, float, float, float]] = [] for b in page.get_text("blocks"): if len(b) < 5: continue bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3] if by1 <= caption_y and by1 > caption_y - _FIGURE_MAX_HEIGHT: if bx1 > cx0 and bx0 < cx1: if col_x0 > 0 and bx0 < col_x0 - _REGION_SIDE_PADDING * 2: continue above_blocks.append((bx0, by0, bx1, by1)) if not above_blocks: return max(0, caption_y - _FIGURE_DEFAULT_HEIGHT) above_blocks.sort(key=lambda b: b[1], reverse=True) prev_bottom = caption_y for b in above_blocks: if prev_bottom - b[3] > _CONTENT_GAP_THRESHOLD: figure_top = prev_bottom - 5 break prev_bottom = b[1] else: figure_top = above_blocks[-1][1] # 同页 caption 截断 if _caption_cutoff is not None: figure_top = max(figure_top, _caption_cutoff + 5) # 限制最大高度 if caption_y - figure_top > _FIGURE_MAX_HEIGHT: figure_top = caption_y - _FIGURE_MAX_HEIGHT return max(0, figure_top) def _find_table_region(page, caption: dict) -> tuple[float, float, float, float]: """向下扫描页面,找到 Table 的下边界和水平范围。 返回: (x0, bottom, x1) — 裁剪区域的左、下、右边界。 上边界由调用方根据 caption 位置确定。 策略: 1. 收集 caption 下方的文本块(表格内容是文本) 2. 找到连续内容区域的底部(遇到大间隙时停止) 3. 同时检测表格内容的水平范围(表格通常比 caption 宽) """ blocks = page.get_text("blocks") caption_y = caption["caption_y1"] # caption 底部作为扫描起点 caption_x0 = caption["caption_x0"] caption_x1 = caption["caption_x1"] page_height = caption["page_height"] page_width = caption["page_width"] # 估计 caption 所在列的水平边界,避免双栏论文跨列抓取 col_x0, col_x1 = _estimate_column_x(caption) search_x0 = max(col_x0, caption_x0 - _TABLE_SIDE_PADDING) search_x1 = min(col_x1, caption_x1 + _TABLE_SIDE_PADDING) below_blocks: list[tuple[float, float, float, float]] = [] for b in blocks: if len(b) < 5: continue bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3] if by0 > caption_y and by0 < caption_y + _TABLE_MAX_HEIGHT: if bx1 > search_x0 and bx0 < search_x1: # 双栏论文:排除跨列正文段落(宽度 >> 列宽,起点在另一列) # 表格行起点在列内或列边界附近;正文段落起点在另一列(bx0 远小于 col_x0) if col_x0 > 0 and bx0 < col_x0 - _TABLE_SIDE_PADDING: continue # 停止信号:遇到下一个 caption 或 section header 立即停止 text = str(b[4]).strip() first_line = text.split("\n")[0].strip() if _CAPTION_STOP_RE.match(first_line) or _SECTION_STOP_RE.match( first_line ): break below_blocks.append((bx0, by0, bx1, by1)) if not below_blocks: # 没有内容 → 使用默认高度和 caption 宽度 return ( max(0, caption_x0 - _REGION_SIDE_PADDING), min(page_height, caption_y + _TABLE_MIN_HEIGHT), min(page_width, caption_x1 + _REGION_SIDE_PADDING), ) # ── 找到连续内容区域的底部 ── below_blocks.sort(key=lambda b: b[1]) # 按 y 升序 prev_y = caption_y bottom = below_blocks[-1][3] + 5 # 最后一块的底部 + margin for b in below_blocks: gap = b[1] - prev_y # b[1] = by0 if gap > _CONTENT_GAP_THRESHOLD: bottom = prev_y + 5 break prev_y = b[3] # b[3] = by1 # 限制最大高度 if bottom - caption_y > _TABLE_MAX_HEIGHT: bottom = caption_y + _TABLE_MAX_HEIGHT # ── 检测表格内容的水平范围 ── # 只用 gap 之前的 block 计算水平范围(gap 之后的 block 属于正文,可能更宽) table_blocks = [b for b in below_blocks if b[1] < bottom] if not table_blocks: table_blocks = below_blocks[:1] # 至少用第一个 block content_x0 = min(caption_x0, min(b[0] for b in table_blocks)) content_x1 = max(caption_x1, max(b[2] for b in table_blocks)) # 添加边距,不超出页面 # 使用较小 padding,避免将相邻列内容(如同页另一列的 Figure)带入截图; # 同时不限制列边界 — 双栏论文中 caption 可能跨列起始 x0 = max(0, content_x0 - _REGION_SIDE_PADDING) x1 = min(page_width, content_x1 + _REGION_SIDE_PADDING) return (x0, bottom, x1) def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int: """从 PDF 提取 Figure/Table 截图,生成 manifest。 策略:找 caption → 定位区域 → 渲染页面截图。 Args: arxiv_id: 论文 ID pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf Returns: 提取的图片数量 """ import pymupdf if pdf_path is None: pdf_path = TMP_DIR / arxiv_id / "paper.pdf" if not pdf_path.exists(): logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path) return 0 images_dest = paper_dir(arxiv_id) / "images" images_dest.mkdir(parents=True, exist_ok=True) # 清理上次提取的旧图片,避免残留 for old_file in images_dest.glob("*.png"): old_file.unlink() if (images_dest / "manifest.json").exists(): (images_dest / "manifest.json").unlink() doc = pymupdf.open(str(pdf_path)) captions = _find_captions(doc) if not captions: logger.info("No Figure/Table captions found in PDF for %s", arxiv_id) doc.close() return 0 # 去重:同一页同一 label 可能匹配到多个 block(如正文引用 "Figure 7") # 保留每个 (type, num) 的第一个匹配(即真正的 caption) seen_labels: dict[str, dict] = {} for cap in captions: key = cap["label"] if key not in seen_labels: seen_labels[key] = cap unique_captions = list(seen_labels.values()) extracted = 0 manifest: dict[str, dict] = {} zoom = 3 # 3x 渲染,保证清晰度 for cap in unique_captions: page = doc[cap["page_num"]] pw = cap["page_width"] if cap["type"] == "figure": # Figure: caption 上方是图 → 向上找图的上边界 top = _find_figure_top(page, cap) # 上方多留 5pt 边距,确保图框边框、装饰线等不被截断 top = max(0, top - 5) bottom = cap["caption_y1"] + 5 # 包含 caption # 水平范围:caption 宽度 + 边距(图和 caption 通常等宽) # 但也要考虑图内容的实际宽度 x0 = max(0, cap["caption_x0"] - _REGION_SIDE_PADDING) x1 = min(pw, cap["caption_x1"] + _REGION_SIDE_PADDING) height = bottom - top if height < _FIGURE_MIN_HEIGHT: logger.debug( "Figure %s too small (%.0fpt), skipping", cap["label"], height ) continue else: # Table: caption 下方是表格 → 向下找表格的下边界和水平范围 x0, bottom, x1 = _find_table_region(page, cap) top = max(0, cap["caption_y0"] - 3) # 包含 caption,上边留少许 margin height = bottom - top if height < _TABLE_MIN_HEIGHT: logger.debug( "Table %s too small (%.0fpt), skipping", cap["label"], height ) continue # 渲染截取 clip = pymupdf.Rect(x0, top, x1, bottom) mat = pymupdf.Matrix(zoom, zoom) try: pix = page.get_pixmap(matrix=mat, clip=clip) except Exception: logger.debug("Failed to render %s region for %s", cap["label"], arxiv_id) continue filename = f"{cap['label'].replace(' ', '_').lower()}.png" pix.save(str(images_dest / filename)) extracted += 1 cap_preview = cap["caption_text"][:200] if cap["caption_text"] else "" manifest[filename] = { "page": cap["page_num"] + 1, "type": cap["type"], "label": cap["label"], "caption_text": cap_preview, "figures" if cap["type"] == "figure" else "tables": [cap["label"]], } logger.debug( "Rendered %s: page %d, region (%.0f,%.0f)-(%.0f,%.0f) h=%.0fpt → %s", cap["label"], cap["page_num"] + 1, x0, top, x1, bottom, height, filename, ) doc.close() # 保存 manifest manifest_path = images_dest / "manifest.json" manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2)) if extracted > 0: logger.info( "Extracted %d figure/table screenshots from PDF for %s " "(from %d captions found, %d unique)", extracted, arxiv_id, len(captions), len(unique_captions), ) return extracted def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int: """根据 summary 中的 figures 字段过滤提取的图片/表格。 用 manifest.json 中的 label 匹配,保留被 AI 总结引用的图片。 """ if not figures: return 0 images_dir = paper_dir(arxiv_id) / "images" manifest_path = images_dir / "manifest.json" if not images_dir.exists() or not manifest_path.exists(): return 0 all_files = [f for f in images_dir.iterdir() if f.suffix == ".png"] if not all_files: return 0 manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8")) # 收集 summary 中引用的所有 Figure/Table ID(归一化) referenced_ids: set[str] = set() for fig in figures: fig_id = fig.get("id", "") m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", fig_id, re.IGNORECASE) if m: referenced_ids.add(f"Figure {m.group(1)}") m2 = re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE) if m2: referenced_ids.add(f"Table {m2.group(1)}") if not referenced_ids: logger.warning("No valid figure/table IDs in summary for %s", arxiv_id) return len(all_files) # 根据 manifest 的 label 字段匹配 keep_filenames: set[str] = set() for filename, info in manifest.items(): label = info.get("label", "") if label in referenced_ids: keep_filenames.add(filename) continue for ref in info.get("figures", []) + info.get("tables", []): if ref in referenced_ids: keep_filenames.add(filename) break if not keep_filenames: logger.warning( "No manifest matches for %s (refs=%s), keeping all", arxiv_id, referenced_ids, ) return len(all_files) removed = 0 for f in all_files: if f.name not in keep_filenames: f.unlink() removed += 1 kept = len(all_files) - removed logger.info( "Filtered images for %s: kept %d, removed %d (refs=%s)", arxiv_id, kept, removed, referenced_ids, ) return kept