90fe705e8f
- 核心变更: - app/services/layout_detector.py: 重写布局检测器,从 PicoDet-S_layout_3cls 迁移到 DocLayout-YOLO (DocStructBench, imgsz=1024) - 支持多设备推理 (CPU/CUDA/DirectML/OpenVINO 等),自动探测最优设备 - 预处理改为 letterbox (保比例缩放+灰边 padding),坐标还原使用 (model_coord - padding) / ratio 公式 - 后处理解析 YOLOv10 end-to-end 输出 [N,6]=[x1,y1,x2,y2,conf,cls] - 类别映射改为按 class name 动态匹配 (figure/figure_group→picture, table/table_group→table) - 新增文件: - scripts/export_doclayout_yolo_onnx.py: DocLayout-YOLO ONNX 导出脚本 (独立 venv 运行) - tests/test_layout_detector.py: 布局检测器完整测试 (35 个用例) - 配置更新: - .env.example: 更新布局检测配置 (新增 LAYOUT_IMGSZ, LAYOUT_DEVICE, LAYOUT_DEVICE_ID) - app/config.py: Settings 类对应字段 - pyproject.toml: 新增 export 依赖组 (torch, doclayout-yolo, onnx 等) - 删除旧文件: - scripts/export_picodet_onnx.py: 旧 PicoDet 导出脚本 - 文档更新: - README.md: 更新环境变量说明 - 相关服务注释更新 (pdf_image_extractor.py, summary_persister.py, reextract_images.py) 此重构遵循项目初期开发阶段规范,大胆调整数据模型,无需向后兼容。
579 lines
18 KiB
Python
579 lines
18 KiB
Python
"""PDF 图片与表格提取 — 两阶段流水线。
|
||
|
||
Phase 1: DocLayout-YOLO 检测 figure/table 区域 → 渲染为 JPEG(通用标签)
|
||
Phase 2: 用 LLM summary 的 figures[].id 在 PDF 中搜索定位 → 匹配到 box → 重命名
|
||
|
||
相比旧方案(正则匹配 caption):
|
||
- 不再依赖正则,用 LLM 输出的 ID 直接搜索 PDF 文本
|
||
- page.search_for() 精确搜索 + 空间距离过滤,避免正文引用误匹配
|
||
- 通用标签兜底,LLM 没提到的图表不会被丢弃
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
from pathlib import Path
|
||
|
||
import pymupdf
|
||
|
||
from app.services.layout_detector import LayoutBox, detect_page_layout
|
||
from app.services.pdf_downloader import paper_dir
|
||
from app.utils import PAPERS_DIR, TMP_DIR
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 截图区域的外边距(单位: pt)
|
||
_REGION_PADDING = 5
|
||
# 渲染倍率(3x 保证清晰度)
|
||
_RENDER_ZOOM = 3
|
||
# 相邻 box 聚类间距(单位: pt)— 同一 figure/table 的碎片间距通常 < 15pt
|
||
_CLUSTER_GAP = 15
|
||
# 最小 bbox 面积(单位: pt²)— 过滤 icon/logo 等微小误检
|
||
_MIN_BOX_AREA = 2000
|
||
# Phase 2: 搜索文本到 box 的最大匹配距离(单位: pt)
|
||
_LABEL_MATCH_DISTANCE = 100
|
||
|
||
|
||
# ── Box 聚类 ─────────────────────────────────────────────────────────
|
||
|
||
|
||
class _BoxCluster:
|
||
"""合并后的布局区域(由一个或多个相邻 LayoutBox 组成)。"""
|
||
|
||
__slots__ = ("x0", "y0", "x1", "y1", "boxclass")
|
||
|
||
def __init__(self, boxes: list):
|
||
self.x0 = min(b.x0 for b in boxes)
|
||
self.y0 = min(b.y0 for b in boxes)
|
||
self.x1 = max(b.x1 for b in boxes)
|
||
self.y1 = max(b.y1 for b in boxes)
|
||
raw = boxes[0].boxclass
|
||
self.boxclass = "table" if raw == "table-fallback" else raw
|
||
|
||
|
||
def _cluster_boxes(boxes: list, gap: float = _CLUSTER_GAP) -> list[_BoxCluster]:
|
||
"""将相邻的同类型 box 合并为聚类。"""
|
||
if not boxes:
|
||
return []
|
||
|
||
n = len(boxes)
|
||
parent = list(range(n))
|
||
|
||
def find(x: int) -> int:
|
||
while parent[x] != x:
|
||
parent[x] = parent[parent[x]]
|
||
x = parent[x]
|
||
return x
|
||
|
||
def union(a: int, b: int) -> None:
|
||
ra, rb = find(a), find(b)
|
||
if ra != rb:
|
||
parent[ra] = rb
|
||
|
||
for i in range(n):
|
||
bi = boxes[i]
|
||
for j in range(i + 1, n):
|
||
bj = boxes[j]
|
||
if bi.boxclass != bj.boxclass:
|
||
continue
|
||
h_gap = max(0.0, max(bi.x0, bj.x0) - min(bi.x1, bj.x1))
|
||
v_gap = max(0.0, max(bi.y0, bj.y0) - min(bi.y1, bj.y1))
|
||
h_overlap = bi.x1 > bj.x0 - gap and bj.x1 > bi.x0 - gap
|
||
v_overlap = bi.y1 > bj.y0 - gap and bj.y1 > bi.y0 - gap
|
||
if (h_gap <= gap and v_overlap) or (v_gap <= gap and h_overlap):
|
||
union(i, j)
|
||
|
||
groups: dict[int, list] = {}
|
||
for i in range(n):
|
||
groups.setdefault(find(i), []).append(boxes[i])
|
||
|
||
return [_BoxCluster(members) for members in groups.values()]
|
||
|
||
|
||
# ── Phase 1: 检测 + 渲染 ──────────────────────────────────────────────
|
||
|
||
|
||
def _render_box(
|
||
page,
|
||
box: _BoxCluster,
|
||
images_dest: Path,
|
||
filename: str,
|
||
cap_type: str,
|
||
page_num: int,
|
||
) -> bool:
|
||
"""渲染单个 box 区域并保存 JPEG,成功返回 True。"""
|
||
page_width = page.rect.width
|
||
clip = pymupdf.Rect(
|
||
max(0, box.x0 - _REGION_PADDING),
|
||
max(0, box.y0 - _REGION_PADDING),
|
||
min(page_width, box.x1 + _REGION_PADDING),
|
||
box.y1 + _REGION_PADDING,
|
||
)
|
||
mat = pymupdf.Matrix(_RENDER_ZOOM, _RENDER_ZOOM)
|
||
try:
|
||
pix = page.get_pixmap(matrix=mat, clip=clip)
|
||
except Exception:
|
||
return False
|
||
|
||
(images_dest / filename).write_bytes(pix.tobytes("jpeg"))
|
||
return True
|
||
|
||
|
||
def _process_page(
|
||
doc,
|
||
page_idx: int,
|
||
page_boxes: list[LayoutBox],
|
||
images_dest: Path,
|
||
manifest: dict,
|
||
seen_labels: set,
|
||
arxiv_id: str,
|
||
) -> int:
|
||
"""处理单页:检测 → 聚类 → 渲染,全部用通用标签。"""
|
||
page = doc[page_idx]
|
||
page_num = page_idx + 1
|
||
fig_counter = 0
|
||
tbl_counter = 0
|
||
|
||
# 收集本页的 table/picture box(跳过极小区域)
|
||
raw_boxes = []
|
||
for box in page_boxes:
|
||
if box.boxclass not in ("table", "table-fallback", "picture"):
|
||
continue
|
||
w = box.x1 - box.x0
|
||
h = box.y1 - box.y0
|
||
if w < 20 or h < 20 or w * h < _MIN_BOX_AREA:
|
||
continue
|
||
raw_boxes.append(box)
|
||
|
||
if not raw_boxes:
|
||
return 0
|
||
|
||
# 聚类:将同一 figure/table 的碎片 box 合并
|
||
clusters = _cluster_boxes(raw_boxes)
|
||
|
||
extracted = 0
|
||
for cluster in clusters:
|
||
cap_type = "figure" if cluster.boxclass == "picture" else "table"
|
||
|
||
if cap_type == "figure":
|
||
fig_counter += 1
|
||
label = f"Figure (p{page_num}-{fig_counter})"
|
||
else:
|
||
tbl_counter += 1
|
||
label = f"Table (p{page_num}-{tbl_counter})"
|
||
|
||
if label in seen_labels:
|
||
continue
|
||
seen_labels.add(label)
|
||
|
||
filename = f"{label.replace(' ', '_').lower()}.jpg"
|
||
if not _render_box(page, cluster, images_dest, filename, cap_type, page_num):
|
||
continue
|
||
|
||
manifest[filename] = {
|
||
"page": page_num,
|
||
"type": cap_type,
|
||
"label": label,
|
||
"box": [
|
||
round(float(cluster.x0), 1),
|
||
round(float(cluster.y0), 1),
|
||
round(float(cluster.x1), 1),
|
||
round(float(cluster.y1), 1),
|
||
],
|
||
}
|
||
extracted += 1
|
||
|
||
return extracted
|
||
|
||
|
||
# ── Phase 1 核心入口 ───────────────────────────────────────────────────
|
||
|
||
|
||
def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int:
|
||
"""Phase 1: 从 PDF 提取 Figure/Table 截图,生成通用标签的 manifest。
|
||
|
||
Args:
|
||
arxiv_id: 论文 ID
|
||
pdf_path: PDF 路径,默认 data/tmp/{arxiv_id}/paper.pdf
|
||
|
||
Returns:
|
||
提取的图片数量
|
||
"""
|
||
if pdf_path is None:
|
||
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
|
||
|
||
if not pdf_path.exists():
|
||
logger.warning("PDF not found for %s: %s", arxiv_id, pdf_path)
|
||
return 0
|
||
|
||
images_dest = paper_dir(arxiv_id) / "images"
|
||
images_dest.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 清理上次提取的旧图片
|
||
for old_file in images_dest.iterdir():
|
||
if old_file.suffix.lower() in (".png", ".jpg", ".jpeg"):
|
||
old_file.unlink()
|
||
if (images_dest / "manifest.json").exists():
|
||
(images_dest / "manifest.json").unlink()
|
||
|
||
with pymupdf.open(str(pdf_path)) as doc:
|
||
extracted = 0
|
||
manifest: dict[str, dict] = {}
|
||
seen_labels: set[str] = set()
|
||
|
||
for page_idx in range(doc.page_count):
|
||
try:
|
||
page_boxes = detect_page_layout(doc[page_idx])
|
||
extracted += _process_page(
|
||
doc,
|
||
page_idx,
|
||
page_boxes,
|
||
images_dest=images_dest,
|
||
manifest=manifest,
|
||
seen_labels=seen_labels,
|
||
arxiv_id=arxiv_id,
|
||
)
|
||
except Exception:
|
||
logger.warning(
|
||
"Failed to process page %d for %s",
|
||
page_idx + 1,
|
||
arxiv_id,
|
||
exc_info=True,
|
||
)
|
||
continue
|
||
|
||
# 保存 manifest
|
||
manifest_path = images_dest / "manifest.json"
|
||
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2))
|
||
|
||
if extracted > 0:
|
||
logger.info(
|
||
"Extracted %d figure/table screenshots from PDF for %s",
|
||
extracted,
|
||
arxiv_id,
|
||
)
|
||
|
||
return extracted
|
||
|
||
|
||
# ── Phase 2: 用 summary 的 figures ID 定位并重命名 ─────────────────────
|
||
|
||
|
||
def _distance_text_to_box(rect: pymupdf.Rect, box: list[float]) -> float | None:
|
||
"""计算搜索到的文本 rect 到 box 的距离。超出阈值返回 None。
|
||
|
||
判断逻辑:rect 中心与 box 的垂直距离 + 水平重叠检查。
|
||
"""
|
||
rect_cx = (rect.x0 + rect.x1) / 2
|
||
rect_cy = (rect.y0 + rect.y1) / 2
|
||
bx0, by0, bx1, by1 = box
|
||
|
||
# 水平重叠:rect 中心在 box 水平范围内(或接近)
|
||
if not (bx0 - 20 <= rect_cx <= bx1 + 20):
|
||
return None
|
||
|
||
# 垂直距离
|
||
if rect_cy < by0:
|
||
dist = by0 - rect_cy
|
||
elif rect_cy > by1:
|
||
dist = rect_cy - by1
|
||
else:
|
||
dist = 0
|
||
|
||
return dist if dist <= _LABEL_MATCH_DISTANCE else None
|
||
|
||
|
||
def _search_variants(fig_id: str) -> list[str]:
|
||
"""为 figure/table ID 生成搜索变体。
|
||
|
||
"Figure 1" → ["Figure 1", "Fig. 1", "Fig 1"]
|
||
"Fig. 1" → ["Fig. 1", "Figure 1", "Fig 1"]
|
||
"Table A1" → ["Table A1"]
|
||
"""
|
||
variants = [fig_id]
|
||
|
||
m = re.match(r"(Fig\.?|Figure)\s+(\d+.*)", fig_id, re.IGNORECASE)
|
||
if m:
|
||
num_part = m.group(2)
|
||
variants.extend(
|
||
[
|
||
f"Figure {num_part}",
|
||
f"Fig. {num_part}",
|
||
f"Fig {num_part}",
|
||
]
|
||
)
|
||
|
||
# 去重保序
|
||
seen = set()
|
||
result = []
|
||
for v in variants:
|
||
if v not in seen:
|
||
seen.add(v)
|
||
result.append(v)
|
||
return result
|
||
|
||
|
||
def label_images_by_summary(
|
||
arxiv_id: str,
|
||
figures: list[dict],
|
||
pdf_path: Path | None = None,
|
||
) -> int:
|
||
"""Phase 2: 用 summary 的 figures ID 在 PDF 中搜索定位,重命名图片。
|
||
|
||
对 summary 中的每个 figure/table ID:
|
||
1. page.search_for(id) 在所有页面搜索文本位置
|
||
2. 计算搜索位置与 manifest 中 box 坐标的距离
|
||
3. 最近匹配 → 重命名文件、更新 manifest
|
||
|
||
Args:
|
||
arxiv_id: 论文 ID
|
||
figures: summary 的 figures 列表,每项含 id/caption/description 等
|
||
pdf_path: PDF 路径
|
||
|
||
Returns:
|
||
成功重命名的图片数量
|
||
"""
|
||
if not figures:
|
||
return 0
|
||
|
||
if pdf_path is None:
|
||
pdf_path = TMP_DIR / arxiv_id / "paper.pdf"
|
||
if not pdf_path.exists():
|
||
return 0
|
||
|
||
images_dest = paper_dir(arxiv_id) / "images"
|
||
manifest_path = images_dest / "manifest.json"
|
||
if not manifest_path.exists():
|
||
return 0
|
||
|
||
manifest: dict[str, dict] = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||
if not manifest:
|
||
return 0
|
||
|
||
# 构建候选列表:只对通用标签的条目做匹配
|
||
candidates: dict[str, dict] = {} # filename → {page, box, ...}
|
||
for fname, info in manifest.items():
|
||
if "(p" in info.get("label", ""):
|
||
candidates[fname] = info
|
||
|
||
if not candidates:
|
||
return 0
|
||
|
||
with pymupdf.open(str(pdf_path)) as doc:
|
||
# 收集所有匹配候选:(fig_id, fig_index, filename, distance)
|
||
matches: list[tuple[str, int, str, float]] = []
|
||
|
||
for fig_idx, fig in enumerate(figures):
|
||
fig_id = fig.get("id", "")
|
||
if not fig_id:
|
||
continue
|
||
|
||
# 生成搜索变体:Figure 1 / Fig. 1 / Fig 1 等
|
||
search_terms = _search_variants(fig_id)
|
||
|
||
# 在所有页面搜索该文本(含变体)
|
||
search_hits: list[tuple[int, pymupdf.Rect]] = [] # (page_num_1based, Rect)
|
||
for page_idx in range(doc.page_count):
|
||
page = doc[page_idx]
|
||
seen_rects: set[tuple[float, float]] = set()
|
||
for term in search_terms:
|
||
for r in page.search_for(term):
|
||
key = (round(r.x0, 1), round(r.y0, 1))
|
||
if key not in seen_rects:
|
||
seen_rects.add(key)
|
||
search_hits.append((page_idx + 1, r))
|
||
|
||
if not search_hits:
|
||
continue
|
||
|
||
# 对每个候选 manifest 条目,找最近的搜索命中
|
||
for fname, info in candidates.items():
|
||
box = info.get("box")
|
||
if not box:
|
||
continue
|
||
manifest_page = info.get("page", 0)
|
||
|
||
best_dist: float | None = None
|
||
for hit_page, rect in search_hits:
|
||
# 只匹配同页面
|
||
if hit_page != manifest_page:
|
||
continue
|
||
dist = _distance_text_to_box(rect, box)
|
||
if dist is not None and (best_dist is None or dist < best_dist):
|
||
best_dist = dist
|
||
|
||
if best_dist is not None:
|
||
matches.append((fig_id, fig_idx, fname, best_dist))
|
||
|
||
if not matches:
|
||
logger.info("No label matches for %s", arxiv_id)
|
||
return 0
|
||
|
||
# 去冲突:按距离排序,每个 fig_id 和每个 filename 只匹配一次
|
||
matches.sort(key=lambda x: x[3])
|
||
used_fig_ids: set[int] = set()
|
||
used_filenames: set[str] = set()
|
||
renames: list[tuple[str, str, str]] = [] # (old_fname, new_fname, fig_id)
|
||
|
||
for fig_id, fig_idx, fname, dist in matches:
|
||
if fig_idx in used_fig_ids or fname in used_filenames:
|
||
continue
|
||
used_fig_ids.add(fig_idx)
|
||
used_filenames.add(fname)
|
||
new_fname = f"{fig_id.replace(' ', '_').lower()}.jpg"
|
||
renames.append((fname, new_fname, fig_id))
|
||
|
||
# 执行重命名
|
||
labeled = 0
|
||
new_manifest: dict[str, dict] = {}
|
||
|
||
for fname, info in manifest.items():
|
||
if fname in used_filenames:
|
||
continue
|
||
# 未匹配的保持原样
|
||
new_manifest[fname] = info
|
||
|
||
for old_fname, new_fname, fig_id in renames:
|
||
old_path = images_dest / old_fname
|
||
new_path = images_dest / new_fname
|
||
if not old_path.exists():
|
||
continue
|
||
|
||
# 搬运 manifest 信息
|
||
info = manifest[old_fname].copy()
|
||
cap_type = info.get("type", "figure")
|
||
|
||
# 读取 caption 文本(从 figures 列表)
|
||
caption_text = ""
|
||
for fig in figures:
|
||
if fig.get("id") == fig_id:
|
||
caption_text = fig.get("caption", "")
|
||
break
|
||
|
||
info["label"] = fig_id
|
||
info["caption_text"] = caption_text[:200] if caption_text else ""
|
||
info.setdefault("figures" if cap_type == "figure" else "tables", []).append(
|
||
fig_id
|
||
)
|
||
|
||
# 重命名文件
|
||
if new_fname != old_fname:
|
||
old_path.rename(new_path)
|
||
new_manifest[new_fname] = info
|
||
labeled += 1
|
||
|
||
# 写回 manifest
|
||
manifest_path.write_text(json.dumps(new_manifest, ensure_ascii=False, indent=2))
|
||
|
||
logger.info(
|
||
"Labeled %d/%d images for %s using summary figures",
|
||
labeled,
|
||
len(manifest),
|
||
arxiv_id,
|
||
)
|
||
return labeled
|
||
|
||
|
||
# ── Figure ↔ Image 关联 ────────────────────────────────────────────────
|
||
|
||
|
||
def _normalize_figure_id(raw_id: str) -> str:
|
||
"""归一化 Figure/Table ID:'Figure 1'/'Fig.1' → 'Figure 1'。"""
|
||
m = re.match(r"(?:Fig\.?|Figure)\s*(\d+)", raw_id, re.IGNORECASE)
|
||
if m:
|
||
return f"Figure {m.group(1)}"
|
||
m2 = re.match(r"Table\s*(\d+)", raw_id, re.IGNORECASE)
|
||
if m2:
|
||
return f"Table {m2.group(1)}"
|
||
return raw_id
|
||
|
||
|
||
def _is_figure_type(fig_id: str) -> bool:
|
||
"""判断是否为 Figure 类型(非 Table)。"""
|
||
return not re.match(r"Table\s*(\d+)", fig_id, re.IGNORECASE)
|
||
|
||
|
||
def _image_sort_key(name: str) -> tuple[int, int]:
|
||
"""按文件名中的编号排序提取的图片。"""
|
||
# 新格式:figure_1.jpg, table_1.jpg
|
||
m = re.search(r"(?:figure|table)_(\d+)", name)
|
||
if m:
|
||
return (0, int(m.group(1)))
|
||
# 旧格式:page2_img1.png, page5_table1.png, figure_1.png
|
||
m2 = re.search(r"page(\d+)_(?:img|table)(\d+)", name)
|
||
if m2:
|
||
return (int(m2.group(1)), int(m2.group(2)))
|
||
return (0, 0)
|
||
|
||
|
||
def link_figures_with_images(
|
||
figures: list[dict], images: list[dict], arxiv_id: str
|
||
) -> list[dict]:
|
||
"""将 summary figures 元数据与提取的图片文件关联。
|
||
|
||
策略:
|
||
1. 优先用 manifest.json 的 label 做 ID 精确匹配
|
||
2. 未匹配的 figure 用序号兜底:第 N 个 Figure → 第 N 张提取图
|
||
"""
|
||
if not figures or not images:
|
||
return figures
|
||
|
||
manifest_path = PAPERS_DIR / arxiv_id / "images" / "manifest.json"
|
||
|
||
# ── 策略 1:manifest ID 精确匹配 ──
|
||
id_to_url: dict[str, str] = {}
|
||
if manifest_path.exists():
|
||
try:
|
||
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||
except (ValueError, TypeError):
|
||
manifest = {}
|
||
for filename, info in manifest.items():
|
||
url = f"/papers/{arxiv_id}/images/{filename}"
|
||
# 优先用 label 字段(新格式)
|
||
label = info.get("label", "")
|
||
if label:
|
||
id_to_url[label] = url
|
||
# 也兼容 figures/tables 列表(旧格式)
|
||
for fig_id in info.get("figures", []) + info.get("tables", []):
|
||
if fig_id not in id_to_url:
|
||
id_to_url[fig_id] = url
|
||
|
||
for fig in figures:
|
||
raw_id = fig.get("id", "")
|
||
normalized = _normalize_figure_id(raw_id)
|
||
if normalized in id_to_url:
|
||
fig["image_url"] = id_to_url[normalized]
|
||
|
||
# ── 策略 2:序号兜底(manifest 匹配不到时) ──
|
||
unmatched = [f for f in figures if not f.get("image_url")]
|
||
if not unmatched:
|
||
return figures
|
||
|
||
# 按类型分流:Figure vs Table
|
||
fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))]
|
||
table_type_unmatched = [
|
||
f for f in unmatched if not _is_figure_type(f.get("id", ""))
|
||
]
|
||
|
||
# 提取的图片按类型分流,按文件名中的编号排序
|
||
fig_images = sorted(
|
||
[img for img in images if "table" not in img["name"].lower()],
|
||
key=lambda img: _image_sort_key(img["name"]),
|
||
)
|
||
table_images = sorted(
|
||
[img for img in images if "table" in img["name"].lower()],
|
||
key=lambda img: _image_sort_key(img["name"]),
|
||
)
|
||
|
||
for i, fig in enumerate(fig_type_unmatched):
|
||
if i < len(fig_images):
|
||
fig["image_url"] = fig_images[i]["url"]
|
||
|
||
for i, fig in enumerate(table_type_unmatched):
|
||
if i < len(table_images):
|
||
fig["image_url"] = table_images[i]["url"]
|
||
|
||
return figures
|