"""PDF 图片与表格提取 — 基于 caption 定位的页面区域截图。 核心思路:学术论文排版极其规整,Figure caption 在图下方,Table caption 在表格上方。 因此反过来:先找 caption 文字 → 向上/向下截取页面区域 → 渲染为 PNG。 优势(相比提取嵌入位图): - 复合图表不会被拆成碎片(整块截取) - 矢量图也能截取(页面渲染包含一切) - 不依赖 find_tables()(纯文本匹配 caption) """ from __future__ import annotations import json import logging import re from pathlib import Path from app.services.pdf_downloader import paper_dir from app.utils import TMP_DIR logger = logging.getLogger(__name__) # ── 截取区域参数 ─────────────────────────────────────────────────────── # Figure: caption 上方搜索图的范围(点) _FIGURE_MAX_HEIGHT = 450 # 最大向上搜索范围 _FIGURE_MIN_HEIGHT = 50 # 最小有效截图高度 _FIGURE_DEFAULT_HEIGHT = 280 # 上方未找到内容块时的默认图高度 # Table: caption 下方搜索表格的范围 _TABLE_MAX_HEIGHT = 500 # 最大向下搜索范围 _TABLE_MIN_HEIGHT = 30 # caption 左右扩展(双栏论文中 caption 可能比表格窄) _REGION_SIDE_PADDING = 10 # 表格通常比 caption 文字宽,使用更大的水平扩展 _TABLE_SIDE_PADDING = 60 # 正文行距的 ~1.5 倍 ≈ 空白间隙阈值(学术论文紧密排版,30pt 太宽松) _CONTENT_GAP_THRESHOLD = 20 # 密集表格数据块后的过渡阈值:表格块之后的段落间距常只有 12-18pt _TABLE_DATA_GAP_THRESHOLD = 12 # ── Caption 正则 ─────────────────────────────────────────────────────── # 要求以 Figure/Table 开头(避免匹配正文中的 "see Figure 3" 等) # 支持三种 caption 格式: # "Figure 1: Title" / "Figure 1. Title" / "Figure 1 Title"(无标点,空格分隔) # 第三种需要后续紧跟大写字母(排除 "Figure 1 shows..." 等正文引用) _CAPTION_RE = re.compile( r"^(?:Fig\.?|Figure)\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))", re.IGNORECASE, ) _TABLE_CAPTION_RE = re.compile( r"^Table\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))", re.IGNORECASE, ) # ── 停止信号:表格边界检测遇到以下内容时立即停止 ── # 下一个 Figure/Table caption(如 "Table 2:" "Figure 3:" "Figure 4 Title") _CAPTION_STOP_RE = re.compile( r"^(?:Table|Fig\.?|Figure)\s+\d+\s*(?:[:\.]\s*|\s+[A-Z])", re.IGNORECASE, ) # Section header(如 "6.2 Evolution" "D.1 Dependency" "7 Conclusion") _SECTION_STOP_RE = re.compile( r"^(\d{1,2}(?:\.\d+)?\s+[A-Z][a-z]|[A-Z]\.\d+\s+[A-Z][a-z])" ) def _estimate_column_x(caption: dict) -> tuple[float, float]: """估计 caption 所在列的水平边界(col_x0, col_x1)。 双栏论文中 caption 宽度远小于页面宽度,据此判断左右列。 单栏或跨栏 caption(宽度 >65% 页宽)返回整页宽度。 caption 居中对齐(中心接近页面中线)时按跨栏处理,使用宽范围。 """ pw = caption["page_width"] caption_w = caption["caption_x1"] - caption["caption_x0"] # caption 宽度 >65% 页宽 → 单栏或跨栏 if caption_w > pw * 0.65: return 0, pw cx = (caption["caption_x0"] + caption["caption_x1"]) / 2 # caption 居中(中心距页面中线 <8%)→ 可能是跨栏表格,使用宽范围 if abs(cx - pw / 2) / pw < 0.08: return ( max(0, caption["caption_x0"] - _TABLE_SIDE_PADDING * 2), min(pw, caption["caption_x1"] + _TABLE_SIDE_PADDING * 2), ) if cx < pw / 2: return 0, pw / 2 else: return pw / 2, pw def _find_captions(doc) -> list[dict]: """扫描整个文档,找到所有 Figure/Table caption 的位置和信息。""" captions = [] for page_num in range(len(doc)): page = doc[page_num] page_width = page.rect.width page_height = page.rect.height blocks = page.get_text("blocks") for block in blocks: if len(block) < 5: continue text = str(block[4]).strip() if not text: continue bx0, by0, bx1, by1 = block[0], block[1], block[2], block[3] # 只取 block 第一行做匹配(避免 block 包含多段文字干扰) first_line = text.split("\n")[0].strip() m = _CAPTION_RE.match(first_line) if m: captions.append( { "type": "figure", "num": int(m.group(1)), "label": f"Figure {m.group(1)}", "page_num": page_num, "caption_y0": by0, "caption_y1": by1, "caption_x0": bx0, "caption_x1": bx1, "caption_text": text, "page_width": page_width, "page_height": page_height, } ) continue m = _TABLE_CAPTION_RE.match(first_line) if m: captions.append( { "type": "table", "num": int(m.group(1)), "label": f"Table {m.group(1)}", "page_num": page_num, "caption_y0": by0, "caption_y1": by1, "caption_x0": bx0, "caption_x1": bx1, "caption_text": text, "page_width": page_width, "page_height": page_height, } ) return captions def _find_figure_top(page, caption: dict) -> float: """向上扫描页面,找到 Figure 的上边界。 策略: 1. 优先用嵌入图片定位 — 收集 caption 上方所有相关图片 bbox, 按 Y 轴聚类后取最大簇的最小 y 作为上界(处理 subfigure 组合图) 2. 无图片时回退到文本块间隙检测(处理纯矢量图如 TikZ/matplotlib PDF) """ caption_y = caption["caption_y0"] col_x0, col_x1 = _estimate_column_x(caption) cx0 = max(col_x0, caption["caption_x0"] - _REGION_SIDE_PADDING) cx1 = min(col_x1, caption["caption_x1"] + _REGION_SIDE_PADDING) # 同页上方最近的 Figure/Table caption(多 figure 同页时截断) _caption_cutoff: float | None = None for b in page.get_text("blocks"): if len(b) < 5: continue by0, by1 = b[1], b[3] if by1 >= caption_y or by1 <= caption_y - _FIGURE_MAX_HEIGHT: continue first_line = str(b[4]).strip().split("\n")[0].strip() if _CAPTION_STOP_RE.match(first_line): _caption_cutoff = by0 break # ── 策略 1:嵌入图片聚类定位 ── # 收集 caption 上方搜索范围内所有与 caption 水平区域重叠的图片 image_tops: list[float] = [] for img_info in page.get_image_info(): bbox = img_info.get("bbox") if bbox is None: continue if hasattr(bbox, "x0"): ix0, iy0, ix1, iy1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1 else: ix0, iy0, ix1, iy1 = bbox[0], bbox[1], bbox[2], bbox[3] # 图片底部必须在 caption 上方、且在搜索范围内 if not (iy1 <= caption_y and iy1 > caption_y - _FIGURE_MAX_HEIGHT): continue # 图片水平范围与 caption 所在列有重叠 if not (ix1 > cx0 and ix0 < cx1): continue # 跳过属于上方另一个 figure 的图片 if _caption_cutoff is not None and iy0 < _caption_cutoff: continue # 跳过极小图标(宽度或高度 <15pt,通常是 logo/符号) if (ix1 - ix0) < 15 or (iy1 - iy0) < 15: continue image_tops.append(iy0) if image_tops: # 聚类:将 Y 轴接近的图片视为同一组(subfigure),最大簇的最小 y 即图上界 image_tops.sort() # 用简单单遍聚类:相邻图片 top 差 < 最大高度的 40% 视为同簇 cluster_gap = _FIGURE_MAX_HEIGHT * 0.4 clusters: list[list[float]] = [[image_tops[0]]] for yt in image_tops[1:]: if yt - clusters[-1][-1] < cluster_gap: clusters[-1].append(yt) else: clusters.append([yt]) # 取最大簇(图片数最多的)的最小 y biggest = max(clusters, key=len) figure_top = min(biggest) else: # ── 策略 2:文本块间隙检测(纯矢量图) ── above_blocks: list[tuple[float, float, float, float]] = [] for b in page.get_text("blocks"): if len(b) < 5: continue bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3] if by1 <= caption_y and by1 > caption_y - _FIGURE_MAX_HEIGHT: if bx1 > cx0 and bx0 < cx1: if col_x0 > 0 and bx0 < col_x0 - _REGION_SIDE_PADDING * 2: continue above_blocks.append((bx0, by0, bx1, by1)) if not above_blocks: return max(0, caption_y - _FIGURE_DEFAULT_HEIGHT) above_blocks.sort(key=lambda b: b[1], reverse=True) prev_bottom = caption_y for b in above_blocks: if prev_bottom - b[3] > _CONTENT_GAP_THRESHOLD: figure_top = prev_bottom - 5 break prev_bottom = b[1] else: figure_top = above_blocks[-1][1] # 同页 caption 截断 if _caption_cutoff is not None: figure_top = max(figure_top, _caption_cutoff + 5) # 限制最大高度 if caption_y - figure_top > _FIGURE_MAX_HEIGHT: figure_top = caption_y - _FIGURE_MAX_HEIGHT return max(0, figure_top) def _find_figure_horizontal( page, caption: dict, top: float, bottom: float ) -> tuple[float, float]: """确定 Figure 的水平裁剪范围。 取 caption 宽度和图片实际宽度的并集,避免截断比 caption 更宽的图。 """ pw = caption["page_width"] x0 = caption["caption_x0"] x1 = caption["caption_x1"] # 收集裁剪区域内所有嵌入图片的水平范围 col_x0, col_x1 = _estimate_column_x(caption) for img_info in page.get_image_info(): bbox = img_info.get("bbox") if bbox is None: continue if hasattr(bbox, "x0"): ix0, iy0, ix1, iy1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1 else: ix0, iy0, ix1, iy1 = bbox[0], bbox[1], bbox[2], bbox[3] # 图片在裁剪区域内且在 caption 所在列 if iy0 < bottom and iy1 > top and ix1 > col_x0 and ix0 < col_x1: if (ix1 - ix0) < 15: continue # 跳过小图标 x0 = min(x0, ix0) x1 = max(x1, ix1) return max(0, x0 - _REGION_SIDE_PADDING), min(pw, x1 + _REGION_SIDE_PADDING) def _find_table_region(page, caption: dict) -> tuple[float, float, float, float]: """向下扫描页面,找到 Table 的下边界和水平范围。 返回: (x0, bottom, x1) — 裁剪区域的左、下、右边界。 上边界由调用方根据 caption 位置确定。 策略: 1. 用 page.find_tables() 收集 caption 下方所有相邻的表格段,合并为一个完整区域 (学术论文表格常被拆成表头行 + 数据行等多个 find_tables 段) 2. 未命中时回退到文本块间隙检测 """ caption_y = caption["caption_y1"] # caption 底部作为扫描起点 caption_x0 = caption["caption_x0"] caption_x1 = caption["caption_x1"] page_width = caption["page_width"] # ── 策略 1: find_tables() 结构化检测 + 合并相邻段 ── try: tables = page.find_tables() except Exception: tables = None if tables and tables.tables: # 确定 caption 所在栏的范围(防止双栏论文中跨栏收集) col_x0, col_x1 = _estimate_column_x(caption) # 收集 caption 下方附近且在同一栏内的表格段 bbox segments: list[tuple[float, float, float, float]] = [] for t in tables.tables: tb = t.bbox if isinstance(tb, (list, tuple)): tx0, ty0, tx1, ty1 = ( float(tb[0]), float(tb[1]), float(tb[2]), float(tb[3]), ) else: tx0, ty0, tx1, ty1 = ( float(tb.x0), float(tb.y0), float(tb.x1), float(tb.y1), ) # 表格段上边在 caption 底部附近,且与 caption 同栏 if ( ty0 >= caption_y - 5 and ty0 < caption_y + 200 and tx1 > col_x0 and tx0 < col_x1 ): segments.append((tx0, ty0, tx1, ty1)) if segments: # 按 y 排序,合并相邻段(gap < 30pt 视为同一表格的连续部分) segments.sort(key=lambda s: s[1]) merged: list[tuple[float, float, float, float]] = [segments[0]] for seg in segments[1:]: prev = merged[-1] gap = seg[1] - prev[3] # 当前段 top - 上一段 bottom if gap < 30: # 合并:取并集范围 merged[-1] = ( min(prev[0], seg[0]), min(prev[1], seg[1]), max(prev[2], seg[2]), max(prev[3], seg[3]), ) else: merged.append(seg) # 取第一个合并段(最靠近 caption 的完整表格) final = merged[0] tx0, ty0, tx1, ty1 = final # 限制最大高度 if ty1 - caption_y > _TABLE_MAX_HEIGHT: ty1 = caption_y + _TABLE_MAX_HEIGHT x0 = max(0, min(caption_x0, tx0) - _REGION_SIDE_PADDING) x1 = min(page_width, max(caption_x1, tx1) + _REGION_SIDE_PADDING) logger.debug( "Table detected by find_tables() (%d segments merged): " "(%.0f,%.0f)-(%.0f,%.0f)", len(segments), x0, caption_y, x1, ty1, ) return (x0, caption["caption_y0"], ty1, x1) # ── 策略 2: 回退到文本块间隙检测 ── x0, t_top, t_bottom, x1 = _find_table_region_by_blocks(page, caption) return (x0, t_top, t_bottom, x1) def _scan_blocks_direction( blocks: list, start_y: float, col_x0: float, col_x1: float, direction: int, max_range: float, ) -> list[tuple[float, float, float, float]]: """从 start_y 向上(direction=-1)或向下(direction=1)扫描文本块。 收集间隙连续的块,遇到 stop 信号(caption / section header)或大间隙即停。 用 current_top/current_bottom 追踪连通区域边界,正确处理 y 重叠块。 Returns: 收集到的块列表 [(x0, y0, x1, y1), ...] """ # 过滤在扫描范围内的块 if direction > 0: # 向下 candidates = [ b for b in blocks if len(b) >= 5 and b[1] > start_y and b[1] < start_y + max_range and b[2] > col_x0 and b[0] < col_x1 ] candidates.sort(key=lambda b: b[1]) # 按 y0 升序 else: # 向上 candidates = [ b for b in blocks if len(b) >= 5 and b[3] <= start_y and b[1] > start_y - max_range and b[2] > col_x0 and b[0] < col_x1 ] candidates.sort(key=lambda b: b[3], reverse=True) # 按 y1 降序(底部离 start_y 最近的在前) if not candidates: return [] # 从 start_y 开始,追踪连通区域边界 connected: list[tuple[float, float, float, float]] = [] boundary = start_y # 当前连通区域离 start_y 最近端的 y 坐标 prev_was_dense_table = False for b in candidates: bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3] text = str(b[4]).strip() first_line = text.split("\n")[0].strip() # stop 信号 if _CAPTION_STOP_RE.match(first_line) or _SECTION_STOP_RE.match(first_line): break # 检查当前块是否与连通区域相连(间隙 < 阈值) if direction > 0: gap = by0 - boundary else: gap = boundary - by1 # 密集表格数据块后使用更低的间隙阈值 threshold = ( _TABLE_DATA_GAP_THRESHOLD if prev_was_dense_table else _CONTENT_GAP_THRESHOLD ) if gap > threshold: break connected.append((bx0, by0, bx1, by1)) # 更新连通区域边界 if direction > 0: boundary = by1 # 向下扩展 else: boundary = min(boundary, by0) # 向上扩展 # 判断当前块是否为密集表格数据(行密度高) lines = [l for l in text.split("\n") if l.strip()] block_height = by1 - by0 prev_was_dense_table = ( len(lines) >= 4 and block_height > 0 and len(lines) / block_height >= 0.08 ) return connected def _find_table_region_by_blocks( page, caption: dict ) -> tuple[float, float, float]: """文本块间隙检测 — 作为 find_tables() 的 fallback。 向下扫描找表格下边界,向上扫描找表格上边界(处理 caption 在数据下方)。 使用 _scan_blocks_direction 统一双向扫描逻辑。 """ blocks = page.get_text("blocks") caption_y0 = caption["caption_y0"] caption_y1 = caption["caption_y1"] caption_x0 = caption["caption_x0"] caption_x1 = caption["caption_x1"] page_width = caption["page_width"] page_height = caption["page_height"] col_x0, col_x1 = _estimate_column_x(caption) # 向下扫描 below = _scan_blocks_direction( blocks, caption_y1, col_x0, col_x1, direction=1, max_range=_TABLE_MAX_HEIGHT ) # 向上扫描 above = _scan_blocks_direction( blocks, caption_y0, col_x0, col_x1, direction=-1, max_range=_TABLE_MAX_HEIGHT ) # 确定上下边界 scan_top = min(b[1] for b in above) if above else caption_y0 scan_bottom = max(b[3] for b in below) if below else caption_y1 top = scan_top bottom = scan_bottom + 5 # 底部 padding if bottom - top > _TABLE_MAX_HEIGHT: bottom = top + _TABLE_MAX_HEIGHT # 水平范围:caption + 所有纳入块 all_blocks = above + below if all_blocks: content_x0 = min(caption_x0, min(b[0] for b in all_blocks)) content_x1 = max(caption_x1, max(b[2] for b in all_blocks)) else: content_x0 = caption_x0 content_x1 = caption_x1 x0 = max(0, content_x0 - _REGION_SIDE_PADDING) x1 = min(page_width, content_x1 + _REGION_SIDE_PADDING) return (x0, top, bottom, x1) def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int: """从 PDF 提取 Figure/Table 截图,生成 manifest。 策略:找 caption → 定位区域 → 渲染页面截图。 Args: arxiv_id: 论文 ID pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf Returns: 提取的图片数量 """ import pymupdf if pdf_path is None: pdf_path = TMP_DIR / arxiv_id / "paper.pdf" if not pdf_path.exists(): logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path) return 0 images_dest = paper_dir(arxiv_id) / "images" images_dest.mkdir(parents=True, exist_ok=True) # 清理上次提取的旧图片,避免残留(同时清理 .png 和 .jpg) for old_file in images_dest.iterdir(): if old_file.suffix.lower() in (".png", ".jpg", ".jpeg"): old_file.unlink() if (images_dest / "manifest.json").exists(): (images_dest / "manifest.json").unlink() doc = pymupdf.open(str(pdf_path)) captions = _find_captions(doc) if not captions: logger.info("No Figure/Table captions found in PDF for %s", arxiv_id) doc.close() return 0 # 去重:同一页同一 label 可能匹配到多个 block(如正文引用 "Figure 7") # 保留每个 (type, num) 的第一个匹配(即真正的 caption) seen_labels: dict[str, dict] = {} for cap in captions: key = cap["label"] if key not in seen_labels: seen_labels[key] = cap unique_captions = list(seen_labels.values()) extracted = 0 manifest: dict[str, dict] = {} zoom = 3 # 3x 渲染,保证清晰度 for cap in unique_captions: page = doc[cap["page_num"]] if cap["type"] == "figure": # Figure: caption 上方是图 → 向上找图的上边界 top = _find_figure_top(page, cap) # 上方多留 5pt 边距,确保图框边框、装饰线等不被截断 top = max(0, top - 5) bottom = cap["caption_y1"] + 5 # 包含 caption # 水平范围:取 caption 宽度和图片实际宽度的并集 x0, x1 = _find_figure_horizontal(page, cap, top, bottom) height = bottom - top if height < _FIGURE_MIN_HEIGHT: logger.debug( "Figure %s too small (%.0fpt), skipping", cap["label"], height ) continue else: # Table: 找表格区域(find_tables() → 块级 fallback,双向扫描) x0, tbl_top, bottom, x1 = _find_table_region(page, cap) top = max(0, tbl_top - 5) # 包含 caption 及上方数据,留 5pt margin height = bottom - top if height < _TABLE_MIN_HEIGHT: logger.debug( "Table %s too small (%.0fpt), skipping", cap["label"], height ) continue # 渲染截取 clip = pymupdf.Rect(x0, top, x1, bottom) mat = pymupdf.Matrix(zoom, zoom) try: pix = page.get_pixmap(matrix=mat, clip=clip) except Exception: logger.debug("Failed to render %s region for %s", cap["label"], arxiv_id) continue # 保存为 JPEG(比 PNG 小 5-10 倍,适合网络传输) filename = f"{cap['label'].replace(' ', '_').lower()}.jpg" jpeg_path = images_dest / filename jpeg_bytes = pix.tobytes("jpeg") jpeg_path.write_bytes(jpeg_bytes) extracted += 1 cap_preview = cap["caption_text"][:200] if cap["caption_text"] else "" manifest[filename] = { "page": cap["page_num"] + 1, "type": cap["type"], "label": cap["label"], "caption_text": cap_preview, "figures" if cap["type"] == "figure" else "tables": [cap["label"]], } logger.debug( "Rendered %s: page %d, region (%.0f,%.0f)-(%.0f,%.0f) h=%.0fpt → %s", cap["label"], cap["page_num"] + 1, x0, top, x1, bottom, height, filename, ) doc.close() # 保存 manifest manifest_path = images_dest / "manifest.json" manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2)) if extracted > 0: logger.info( "Extracted %d figure/table screenshots from PDF for %s " "(from %d captions found, %d unique)", extracted, arxiv_id, len(captions), len(unique_captions), ) return extracted def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int: """根据 summary 中的 figures 字段过滤提取的图片/表格。 用 manifest.json 中的 label 匹配,保留被 AI 总结引用的图片。 """ if not figures: return 0 images_dir = paper_dir(arxiv_id) / "images" manifest_path = images_dir / "manifest.json" if not images_dir.exists() or not manifest_path.exists(): return 0 all_files = [ f for f in images_dir.iterdir() if f.suffix.lower() in (".png", ".jpg", ".jpeg") ] if not all_files: return 0 manifest: dict = json.loads(manifest_path.read_text(encoding="utf-8")) # 收集 summary 中引用的所有 Figure/Table ID(归一化) referenced_ids: set[str] = set() for fig in figures: fig_id = fig.get("id", "") m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", fig_id, re.IGNORECASE) if m: referenced_ids.add(f"Figure {m.group(1)}") m2 = re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE) if m2: referenced_ids.add(f"Table {m2.group(1)}") if not referenced_ids: logger.warning("No valid figure/table IDs in summary for %s", arxiv_id) return len(all_files) # 根据 manifest 的 label 字段匹配 keep_filenames: set[str] = set() for filename, info in manifest.items(): label = info.get("label", "") if label in referenced_ids: keep_filenames.add(filename) continue for ref in info.get("figures", []) + info.get("tables", []): if ref in referenced_ids: keep_filenames.add(filename) break if not keep_filenames: logger.warning( "No manifest matches for %s (refs=%s), keeping all", arxiv_id, referenced_ids, ) return len(all_files) removed = 0 for f in all_files: if f.name not in keep_filenames: f.unlink() removed += 1 kept = len(all_files) - removed logger.info( "Filtered images for %s: kept %d, removed %d (refs=%s)", arxiv_id, kept, removed, referenced_ids, ) return kept