"""PDF 图片与表格提取 — 两阶段流水线。 Phase 1: DocLayout-YOLO 检测 figure/table 区域 → 渲染为 JPEG(通用标签) Phase 2: 用 LLM summary 的 figures[].id 在 PDF 中搜索定位 → 匹配到 box → 重命名 相比旧方案(正则匹配 caption): - 不再依赖正则,用 LLM 输出的 ID 直接搜索 PDF 文本 - page.search_for() 精确搜索 + 空间距离过滤,避免正文引用误匹配 - 通用标签兜底,LLM 没提到的图表不会被丢弃 """ from __future__ import annotations import json import logging import re from pathlib import Path import pymupdf from app.services.layout_detector import LayoutBox, detect_page_layout from app.services.pdf_downloader import paper_dir from app.utils import PAPERS_DIR, TMP_DIR logger = logging.getLogger(__name__) # 截图区域的外边距(单位: pt) _REGION_PADDING = 5 # 渲染倍率(3x 保证清晰度) _RENDER_ZOOM = 3 # 相邻 box 聚类间距(单位: pt)— 同一 figure/table 的碎片间距通常 < 15pt _CLUSTER_GAP = 15 # 最小 bbox 面积(单位: pt²)— 过滤 icon/logo 等微小误检 _MIN_BOX_AREA = 2000 # Phase 2: 搜索文本到 box 的最大匹配距离(单位: pt) _LABEL_MATCH_DISTANCE = 100 # ── Box 聚类 ───────────────────────────────────────────────────────── class _BoxCluster: """合并后的布局区域(由一个或多个相邻 LayoutBox 组成)。""" __slots__ = ("x0", "y0", "x1", "y1", "boxclass") def __init__(self, boxes: list): self.x0 = min(b.x0 for b in boxes) self.y0 = min(b.y0 for b in boxes) self.x1 = max(b.x1 for b in boxes) self.y1 = max(b.y1 for b in boxes) raw = boxes[0].boxclass self.boxclass = "table" if raw == "table-fallback" else raw def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]: """将相邻的同类型 box 合并为聚类。""" if not boxes: return [] n = len(boxes) parent = list(range(n)) def find(x: int) -> int: while parent[x] != x: parent[x] = parent[parent[x]] x = parent[x] return x def union(a: int, b: int) -> None: ra, rb = find(a), find(b) if ra != rb: parent[ra] = rb for i in range(n): bi = boxes[i] for j in range(i + 1, n): bj = boxes[j] if bi.boxclass != bj.boxclass: continue h_gap = max(0.0, max(bi.x0, bj.x0) - min(bi.x1, bj.x1)) v_gap = max(0.0, max(bi.y0, bj.y0) - min(bi.y1, bj.y1)) h_overlap = bi.x1 > bj.x0 - gap and bj.x1 > bi.x0 - gap v_overlap = bi.y1 > bj.y0 - gap and bj.y1 > bi.y0 - gap if (h_gap <= gap and v_overlap) or (v_gap <= gap and h_overlap): union(i, j) groups: dict[int, list] = {} for i in range(n): groups.setdefault(find(i), []).append(boxes[i]) return [_BoxCluster(members) for members in groups.values()] # ── Phase 1: 检测 + 渲染 ────────────────────────────────────────────── def _render_box( page, box: _BoxCluster, images_dest: Path, filename: str, cap_type: str, page_num: int, ) -> bool: """渲染单个 box 区域并保存 JPEG,成功返回 True。""" page_width = page.rect.width clip = pymupdf.Rect( max(0, box.x0 - _REGION_PADDING), max(0, box.y0 - _REGION_PADDING), min(page_width, box.x1 + _REGION_PADDING), box.y1 + _REGION_PADDING, ) mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM) try: pix = page.get_pixmap(matrix=mat, clip=clip) except Exception: return False (images_dest / filename).write_bytes(pix.tobytes("jpeg")) return True def _process_page( doc, page_idx: int, page_boxes: list[LayoutBox], images_dest: Path, manifest: dict, seen_labels: set, arxiv_id: str, ) -> int: """处理单页:检测 → 聚类 → 渲染,全部用通用标签。""" page = doc[page_idx] page_num = page_idx + 1 fig_counter = 0 tbl_counter = 0 # 收集本页的 table/picture box(跳过极小区域) raw_boxes = [] for box in page_boxes: if box.boxclass not in ("table", "table-fallback", "picture"): continue w = box.x1 - box.x0 h = box.y1 - box.y0 if w < 20 or h < 20 or w * h < _MIN_BOX_AREA: continue raw_boxes.append(box) if not raw_boxes: return 0 # 聚类:将同一 figure/table 的碎片 box 合并 clusters = _cluster_boxes(raw_boxes) extracted = 0 for cluster in clusters: cap_type = "figure" if cluster.boxclass == "picture" else "table" if cap_type == "figure": fig_counter += 1 label = f"Figure (p{page_num}-{fig_counter})" else: tbl_counter += 1 label = f"Table (p{page_num}-{tbl_counter})" if label in seen_labels: continue seen_labels.add(label) filename = f"{label.replace(' ', '_').lower()}.jpg" if not _render_box(page, cluster, images_dest, filename, cap_type, page_num): continue manifest[filename] = { "page": page_num, "type": cap_type, "label": label, "box": [ round(float(cluster.x0), 1), round(float(cluster.y0), 1), round(float(cluster.x1), 1), round(float(cluster.y1), 1), ], } extracted += 1 return extracted # ── Phase 1 核心入口 ─────────────────────────────────────────────────── def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int: """Phase 1: 从 PDF 提取 Figure/Table 截图,生成通用标签的 manifest。 Args: arxiv_id: 论文 ID pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf Returns: 提取的图片数量 """ if pdf_path is None: pdf_path = TMP_DIR / arxiv_id / "paper.pdf" if not pdf_path.exists(): logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path) return 0 images_dest = paper_dir(arxiv_id) / "images" images_dest.mkdir(parents=True, exist_ok=True) # 清理上次提取的旧图片 for old_file in images_dest.iterdir(): if old_file.suffix.lower() in (".png", ".jpg", ".jpeg"): old_file.unlink() if (images_dest / "manifest.json").exists(): (images_dest / "manifest.json").unlink() with pymupdf.open(str(pdf_path)) as doc: extracted = 0 manifest: dict[str, dict] = {} seen_labels: set[str] = set() for page_idx in range(doc.page_count): try: page_boxes = detect_page_layout(doc[page_idx]) extracted += _process_page( doc, page_idx, page_boxes, images_dest=images_dest, manifest=manifest, seen_labels=seen_labels, arxiv_id=arxiv_id, ) except Exception: logger.warning( "Failed to process page %d for %s", page_idx + 1, arxiv_id, exc_info=True, ) continue # 保存 manifest manifest_path = images_dest / "manifest.json" manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2)) if extracted > 0: logger.info( "Extracted %d figure/table screenshots from PDF for %s", extracted, arxiv_id, ) return extracted # ── Phase 2: 用 summary 的 figures ID 定位并重命名 ───────────────────── def _distance_text_to_box(rect: pymupdf.Rect, box: list[float]) -> float | None: """计算搜索到的文本 rect 到 box 的距离。超出阈值返回 None。 判断逻辑:rect 中心与 box 的垂直距离 + 水平重叠检查。 """ rect_cx = (rect.x0 + rect.x1) / 2 rect_cy = (rect.y0 + rect.y1) / 2 bx0, by0, bx1, by1 = box # 水平重叠:rect 中心在 box 水平范围内(或接近) if not (bx0 - 20 <= rect_cx <= bx1 + 20): return None # 垂直距离 if rect_cy < by0: dist = by0 - rect_cy elif rect_cy > by1: dist = rect_cy - by1 else: dist = 0 return dist if dist <= _LABEL_MATCH_DISTANCE else None def _search_variants(fig_id: str) -> list[str]: """为 figure/table ID 生成搜索变体。 "Figure 1" → ["Figure 1", "Fig. 1", "Fig 1"] "Fig. 1" → ["Fig. 1", "Figure 1", "Fig 1"] "Table A1" → ["Table A1"] """ variants = [fig_id] m = re.match(r"(Fig\.?|Figure)\s+(\d+.*)", fig_id, re.IGNORECASE) if m: num_part = m.group(2) variants.extend( [ f"Figure {num_part}", f"Fig. {num_part}", f"Fig {num_part}", ] ) # 去重保序 seen = set() result = [] for v in variants: if v not in seen: seen.add(v) result.append(v) return result def label_images_by_summary( arxiv_id: str, figures: list[dict], pdf_path: Path | None = None, ) -> int: """Phase 2: 用 summary 的 figures ID 在 PDF 中搜索定位,重命名图片。 对 summary 中的每个 figure/table ID: 1. page.search_for(id) 在所有页面搜索文本位置 2. 计算搜索位置与 manifest 中 box 坐标的距离 3. 最近匹配 → 重命名文件、更新 manifest Args: arxiv_id: 论文 ID figures: summary 的 figures 列表,每项含 id/caption/description 等 pdf_path: PDF 路径 Returns: 成功重命名的图片数量 """ if not figures: return 0 if pdf_path is None: pdf_path = TMP_DIR / arxiv_id / "paper.pdf" if not pdf_path.exists(): return 0 images_dest = paper_dir(arxiv_id) / "images" manifest_path = images_dest / "manifest.json" if not manifest_path.exists(): return 0 manifest: dict[str, dict] = json.loads(manifest_path.read_text(encoding="utf-8")) if not manifest: return 0 # 构建候选列表:只对通用标签的条目做匹配 candidates: dict[str, dict] = {} # filename → {page, box, ...} for fname, info in manifest.items(): if "(p" in info.get("label", ""): candidates[fname] = info if not candidates: return 0 with pymupdf.open(str(pdf_path)) as doc: # 收集所有匹配候选:(fig_id, fig_index, filename, distance) matches: list[tuple[str, int, str, float]] = [] for fig_idx, fig in enumerate(figures): fig_id = fig.get("id", "") if not fig_id: continue # 生成搜索变体:Figure 1 / Fig. 1 / Fig 1 等 search_terms = _search_variants(fig_id) # 在所有页面搜索该文本(含变体) search_hits: list[tuple[int, pymupdf.Rect]] = [] # (page_num_1based, Rect) for page_idx in range(doc.page_count): page = doc[page_idx] seen_rects: set[tuple[float, float]] = set() for term in search_terms: for r in page.search_for(term): key = (round(r.x0, 1), round(r.y0, 1)) if key not in seen_rects: seen_rects.add(key) search_hits.append((page_idx + 1, r)) if not search_hits: continue # 对每个候选 manifest 条目,找最近的搜索命中 for fname, info in candidates.items(): box = info.get("box") if not box: continue manifest_page = info.get("page", 0) best_dist: float | None = None for hit_page, rect in search_hits: # 只匹配同页面 if hit_page != manifest_page: continue dist = _distance_text_to_box(rect, box) if dist is not None and (best_dist is None or dist < best_dist): best_dist = dist if best_dist is not None: matches.append((fig_id, fig_idx, fname, best_dist)) if not matches: logger.info("No label matches for %s", arxiv_id) return 0 # 去冲突:按距离排序,每个 fig_id 和每个 filename 只匹配一次 matches.sort(key=lambda x: x[3]) used_fig_ids: set[int] = set() used_filenames: set[str] = set() renames: list[tuple[str, str, str]] = [] # (old_fname, new_fname, fig_id) for fig_id, fig_idx, fname, dist in matches: if fig_idx in used_fig_ids or fname in used_filenames: continue used_fig_ids.add(fig_idx) used_filenames.add(fname) new_fname = f"{fig_id.replace(' ', '_').lower()}.jpg" renames.append((fname, new_fname, fig_id)) # 执行重命名 labeled = 0 new_manifest: dict[str, dict] = {} for fname, info in manifest.items(): if fname in used_filenames: continue # 未匹配的保持原样 new_manifest[fname] = info for old_fname, new_fname, fig_id in renames: old_path = images_dest / old_fname new_path = images_dest / new_fname if not old_path.exists(): continue # 搬运 manifest 信息 info = manifest[old_fname].copy() cap_type = info.get("type", "figure") # 读取 caption 文本(从 figures 列表) caption_text = "" for fig in figures: if fig.get("id") == fig_id: caption_text = fig.get("caption", "") break info["label"] = fig_id info["caption_text"] = caption_text[:200] if caption_text else "" info.setdefault("figures" if cap_type == "figure" else "tables", []).append( fig_id ) # 重命名文件 if new_fname != old_fname: old_path.rename(new_path) new_manifest[new_fname] = info labeled += 1 # 写回 manifest manifest_path.write_text(json.dumps(new_manifest, ensure_ascii=False, indent=2)) logger.info( "Labeled %d/%d images for %s using summary figures", labeled, len(manifest), arxiv_id, ) return labeled # ── Figure ↔ Image 关联 ──────────────────────────────────────────────── def _normalize_figure_id(raw_id: str) -> str: """归一化 Figure/Table ID:'Figure 1'/'Fig.1' → 'Figure 1'。""" m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE) if m: return f"Figure {m.group(1)}" m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE) if m2: return f"Table {m2.group(1)}" return raw_id def _is_figure_type(fig_id: str) -> bool: """判断是否为 Figure 类型(非 Table)。""" return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE) def _image_sort_key(name: str) -> tuple[int, int]: """按文件名中的编号排序提取的图片。""" # 新格式:figure_1.jpg, table_1.jpg m = re.search(r"(?:figure|table)_(\d+)", name) if m: return (0, int(m.group(1))) return (0, 0) def link_figures_with_images( figures: list[dict], images: list[dict], arxiv_id: str ) -> list[dict]: """将 summary figures 元数据与提取的图片文件关联。 策略: 1. 优先用 manifest.json 的 label 做 ID 精确匹配 2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图 """ if not figures or not images: return figures manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json" # ── 策略 1:manifest ID 精确匹配 ── id_to_url: dict[str, str] = {} if manifest_path.exists(): try: manifest = json.loads(manifest_path.read_text(encoding="utf-8")) except (ValueError, TypeError): manifest = {} for filename, info in manifest.items(): url = f"/papers/{arxiv_id}/images/{filename}" # 优先用 label 字段(新格式) label = info.get("label", "") if label: id_to_url[label] = url # 也兼容 figures/tables 列表(旧格式) for fig_id in info.get("figures", []) + info.get("tables", []): if fig_id not in id_to_url: id_to_url[fig_id] = url for fig in figures: raw_id = fig.get("id", "") normalized = _normalize_figure_id(raw_id) if normalized in id_to_url: fig["image_url"] = id_to_url[normalized] # ── 策略 2:序号兜底(manifest 匹配不到时) ── unmatched = [f for f in figures if not f.get("image_url")] if not unmatched: return figures # 按类型分流:Figure vs Table fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))] table_type_unmatched = [ f for f in unmatched if not _is_figure_type(f.get("id", "")) ] # 提取的图片按类型分流,按文件名中的编号排序 fig_images = sorted( [img for img in images if "table" not in img["name"].lower()], key=lambda img: _image_sort_key(img["name"]), ) table_images = sorted( [img for img in images if "table" in img["name"].lower()], key=lambda img: _image_sort_key(img["name"]), ) for i, fig in enumerate(fig_type_unmatched): if i < len(fig_images): fig["image_url"] = fig_images[i]["url"] for i, fig in enumerate(table_type_unmatched): if i < len(table_images): fig["image_url"] = table_images[i]["url"] return figures