diff --git a/app/routes/pages.py b/app/routes/pages.py index 950e5de..bc3411e 100644 --- a/app/routes/pages.py +++ b/app/routes/pages.py @@ -15,7 +15,13 @@ from sqlalchemy.orm import Session, joinedload from app.config import settings from app.database import get_db from app.models import PAPER_FULL_LOAD, Paper -from app.utils import PAPERS_DIR, safe_json_loads, templates, today_str, latest_paper_date +from app.utils import ( + PAPERS_DIR, + safe_json_loads, + templates, + today_str, + latest_paper_date, +) logger = logging.getLogger(__name__) @@ -52,15 +58,9 @@ def day_page(date_str: str, request: Request, db: Session = Depends(get_db)): .all() ) - dates_raw = ( - db.execute( - select(Paper.paper_date) - .distinct() - .order_by(Paper.paper_date.desc()) - .limit(30) - ) - .all() - ) + dates_raw = db.execute( + select(Paper.paper_date).distinct().order_by(Paper.paper_date.desc()).limit(30) + ).all() available_dates = [ d[0].isoformat() if isinstance(d[0], date) else str(d[0]) for d in dates_raw ] @@ -140,11 +140,7 @@ def paper_detail(arxiv_id: str, request: Request, db: Session = Depends(get_db)) table_figures.append(fig) elif not is_table and section == "method" and fig.get("image_url"): method_figures.append(fig) - elif ( - not is_table - and section == "results" - and fig.get("image_url") - ): + elif not is_table and section == "results" and fig.get("image_url"): results_figures.append(fig) else: gallery_figures.append(fig) @@ -330,16 +326,18 @@ def _link_figures_with_images( # 按类型分流:Figure vs Table fig_type_unmatched = [f for f in unmatched if _is_figure_type(f.get("id", ""))] - table_type_unmatched = [f for f in unmatched if not _is_figure_type(f.get("id", ""))] + table_type_unmatched = [ + f for f in unmatched if not _is_figure_type(f.get("id", "")) + ] # 提取的图片按类型分流,按文件名中的编号排序 def _sort_key(name: str) -> tuple[int, int]: - # 新格式:figure_1.png, table_1.png - m = re.search(r'(?:figure|table)_(\d+)', name) + # 新格式:figure_1.jpg, table_1.jpg + m = re.search(r"(?:figure|table)_(\d+)", name) if m: return (0, int(m.group(1))) - # 旧格式:page2_img1.png, page5_table1.png - m2 = re.search(r'page(\d+)_(?:img|table)(\d+)', name) + # 旧格式:page2_img1.png, page5_table1.png, figure_1.png + m2 = re.search(r"page(\d+)_(?:img|table)(\d+)", name) if m2: return (int(m2.group(1)), int(m2.group(2))) return (0, 0) diff --git a/app/services/pdf_image_extractor.py b/app/services/pdf_image_extractor.py index f8758ef..384171c 100644 --- a/app/services/pdf_image_extractor.py +++ b/app/services/pdf_image_extractor.py @@ -39,6 +39,8 @@ _TABLE_SIDE_PADDING = 60 # 正文行距的 ~1.5 倍 ≈ 空白间隙阈值(学术论文紧密排版,30pt 太宽松) _CONTENT_GAP_THRESHOLD = 20 +# 密集表格数据块后的过渡阈值:表格块之后的段落间距常只有 12-18pt +_TABLE_DATA_GAP_THRESHOLD = 12 # ── Caption 正则 ─────────────────────────────────────────────────────── @@ -48,11 +50,11 @@ _CONTENT_GAP_THRESHOLD = 20 # "Figure 1: Title" / "Figure 1. Title" / "Figure 1 Title"(无标点,空格分隔) # 第三种需要后续紧跟大写字母(排除 "Figure 1 shows..." 等正文引用) _CAPTION_RE = re.compile( - r"^(?:Fig\.?|Figure)\s+(\d+)\s*(?:[:\.]\s*|\s+(?=[A-Z]))", + r"^(?:Fig\.?|Figure)\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))", re.IGNORECASE, ) _TABLE_CAPTION_RE = re.compile( - r"^Table\s+(\d+)\s*(?:[:\.]\s*|\s+(?=[A-Z]))", + r"^Table\s+(\d+)\s*(?:[:\.]\s*|\s+(?=(?-i:[A-Z])))", re.IGNORECASE, ) @@ -163,7 +165,8 @@ def _find_figure_top(page, caption: dict) -> float: """向上扫描页面,找到 Figure 的上边界。 策略: - 1. 优先用嵌入图片定位(绝大多数 figure 包含嵌入图片,图片边界即 figure 边界) + 1. 优先用嵌入图片定位 — 收集 caption 上方所有相关图片 bbox, + 按 Y 轴聚类后取最大簇的最小 y 作为上界(处理 subfigure 组合图) 2. 无图片时回退到文本块间隙检测(处理纯矢量图如 TikZ/matplotlib PDF) """ caption_y = caption["caption_y0"] @@ -184,8 +187,9 @@ def _find_figure_top(page, caption: dict) -> float: _caption_cutoff = by0 break - # ── 策略 1:嵌入图片定位(覆盖绝大多数 figure) ── - topmost_image_y: float | None = None + # ── 策略 1:嵌入图片聚类定位 ── + # 收集 caption 上方搜索范围内所有与 caption 水平区域重叠的图片 + image_tops: list[float] = [] for img_info in page.get_image_info(): bbox = img_info.get("bbox") if bbox is None: @@ -194,15 +198,36 @@ def _find_figure_top(page, caption: dict) -> float: ix0, iy0, ix1, iy1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1 else: ix0, iy0, ix1, iy1 = bbox[0], bbox[1], bbox[2], bbox[3] - if iy1 <= caption_y and iy1 > caption_y - _FIGURE_MAX_HEIGHT: - if ix1 > cx0 and ix0 < cx1: - if _caption_cutoff is not None and iy0 < _caption_cutoff: - continue # 属于上方另一个 figure - if topmost_image_y is None or iy0 < topmost_image_y: - topmost_image_y = iy0 - if topmost_image_y is not None: - figure_top = topmost_image_y + # 图片底部必须在 caption 上方、且在搜索范围内 + if not (iy1 <= caption_y and iy1 > caption_y - _FIGURE_MAX_HEIGHT): + continue + # 图片水平范围与 caption 所在列有重叠 + if not (ix1 > cx0 and ix0 < cx1): + continue + # 跳过属于上方另一个 figure 的图片 + if _caption_cutoff is not None and iy0 < _caption_cutoff: + continue + # 跳过极小图标(宽度或高度 <15pt,通常是 logo/符号) + if (ix1 - ix0) < 15 or (iy1 - iy0) < 15: + continue + + image_tops.append(iy0) + + if image_tops: + # 聚类:将 Y 轴接近的图片视为同一组(subfigure),最大簇的最小 y 即图上界 + image_tops.sort() + # 用简单单遍聚类:相邻图片 top 差 < 最大高度的 40% 视为同簇 + cluster_gap = _FIGURE_MAX_HEIGHT * 0.4 + clusters: list[list[float]] = [[image_tops[0]]] + for yt in image_tops[1:]: + if yt - clusters[-1][-1] < cluster_gap: + clusters[-1].append(yt) + else: + clusters.append([yt]) + # 取最大簇(图片数最多的)的最小 y + biggest = max(clusters, key=len) + figure_top = min(biggest) else: # ── 策略 2:文本块间隙检测(纯矢量图) ── above_blocks: list[tuple[float, float, float, float]] = [] @@ -240,6 +265,37 @@ def _find_figure_top(page, caption: dict) -> float: return max(0, figure_top) +def _find_figure_horizontal( + page, caption: dict, top: float, bottom: float +) -> tuple[float, float]: + """确定 Figure 的水平裁剪范围。 + + 取 caption 宽度和图片实际宽度的并集,避免截断比 caption 更宽的图。 + """ + pw = caption["page_width"] + x0 = caption["caption_x0"] + x1 = caption["caption_x1"] + + # 收集裁剪区域内所有嵌入图片的水平范围 + col_x0, col_x1 = _estimate_column_x(caption) + for img_info in page.get_image_info(): + bbox = img_info.get("bbox") + if bbox is None: + continue + if hasattr(bbox, "x0"): + ix0, iy0, ix1, iy1 = bbox.x0, bbox.y0, bbox.x1, bbox.y1 + else: + ix0, iy0, ix1, iy1 = bbox[0], bbox[1], bbox[2], bbox[3] + # 图片在裁剪区域内且在 caption 所在列 + if iy0 < bottom and iy1 > top and ix1 > col_x0 and ix0 < col_x1: + if (ix1 - ix0) < 15: + continue # 跳过小图标 + x0 = min(x0, ix0) + x1 = max(x1, ix1) + + return max(0, x0 - _REGION_SIDE_PADDING), min(pw, x1 + _REGION_SIDE_PADDING) + + def _find_table_region(page, caption: dict) -> tuple[float, float, float, float]: """向下扫描页面,找到 Table 的下边界和水平范围。 @@ -247,82 +303,238 @@ def _find_table_region(page, caption: dict) -> tuple[float, float, float, float] 上边界由调用方根据 caption 位置确定。 策略: - 1. 收集 caption 下方的文本块(表格内容是文本) - 2. 找到连续内容区域的底部(遇到大间隙时停止) - 3. 同时检测表格内容的水平范围(表格通常比 caption 宽) + 1. 用 page.find_tables() 收集 caption 下方所有相邻的表格段,合并为一个完整区域 + (学术论文表格常被拆成表头行 + 数据行等多个 find_tables 段) + 2. 未命中时回退到文本块间隙检测 """ - blocks = page.get_text("blocks") caption_y = caption["caption_y1"] # caption 底部作为扫描起点 caption_x0 = caption["caption_x0"] caption_x1 = caption["caption_x1"] - page_height = caption["page_height"] page_width = caption["page_width"] - # 估计 caption 所在列的水平边界,避免双栏论文跨列抓取 - col_x0, col_x1 = _estimate_column_x(caption) - search_x0 = max(col_x0, caption_x0 - _TABLE_SIDE_PADDING) - search_x1 = min(col_x1, caption_x1 + _TABLE_SIDE_PADDING) + # ── 策略 1: find_tables() 结构化检测 + 合并相邻段 ── + try: + tables = page.find_tables() + except Exception: + tables = None - below_blocks: list[tuple[float, float, float, float]] = [] - for b in blocks: - if len(b) < 5: - continue + if tables and tables.tables: + # 确定 caption 所在栏的范围(防止双栏论文中跨栏收集) + col_x0, col_x1 = _estimate_column_x(caption) + + # 收集 caption 下方附近且在同一栏内的表格段 bbox + segments: list[tuple[float, float, float, float]] = [] + for t in tables.tables: + tb = t.bbox + if isinstance(tb, (list, tuple)): + tx0, ty0, tx1, ty1 = ( + float(tb[0]), + float(tb[1]), + float(tb[2]), + float(tb[3]), + ) + else: + tx0, ty0, tx1, ty1 = ( + float(tb.x0), + float(tb.y0), + float(tb.x1), + float(tb.y1), + ) + + # 表格段上边在 caption 底部附近,且与 caption 同栏 + if ( + ty0 >= caption_y - 5 + and ty0 < caption_y + 200 + and tx1 > col_x0 + and tx0 < col_x1 + ): + segments.append((tx0, ty0, tx1, ty1)) + + if segments: + # 按 y 排序,合并相邻段(gap < 30pt 视为同一表格的连续部分) + segments.sort(key=lambda s: s[1]) + merged: list[tuple[float, float, float, float]] = [segments[0]] + for seg in segments[1:]: + prev = merged[-1] + gap = seg[1] - prev[3] # 当前段 top - 上一段 bottom + if gap < 30: + # 合并:取并集范围 + merged[-1] = ( + min(prev[0], seg[0]), + min(prev[1], seg[1]), + max(prev[2], seg[2]), + max(prev[3], seg[3]), + ) + else: + merged.append(seg) + + # 取第一个合并段(最靠近 caption 的完整表格) + final = merged[0] + tx0, ty0, tx1, ty1 = final + + # 限制最大高度 + if ty1 - caption_y > _TABLE_MAX_HEIGHT: + ty1 = caption_y + _TABLE_MAX_HEIGHT + x0 = max(0, min(caption_x0, tx0) - _REGION_SIDE_PADDING) + x1 = min(page_width, max(caption_x1, tx1) + _REGION_SIDE_PADDING) + logger.debug( + "Table detected by find_tables() (%d segments merged): " + "(%.0f,%.0f)-(%.0f,%.0f)", + len(segments), + x0, + caption_y, + x1, + ty1, + ) + return (x0, caption["caption_y0"], ty1, x1) + + # ── 策略 2: 回退到文本块间隙检测 ── + x0, t_top, t_bottom, x1 = _find_table_region_by_blocks(page, caption) + return (x0, t_top, t_bottom, x1) + + +def _scan_blocks_direction( + blocks: list, + start_y: float, + col_x0: float, + col_x1: float, + direction: int, + max_range: float, +) -> list[tuple[float, float, float, float]]: + """从 start_y 向上(direction=-1)或向下(direction=1)扫描文本块。 + + 收集间隙连续的块,遇到 stop 信号(caption / section header)或大间隙即停。 + 用 current_top/current_bottom 追踪连通区域边界,正确处理 y 重叠块。 + + Returns: + 收集到的块列表 [(x0, y0, x1, y1), ...] + """ + # 过滤在扫描范围内的块 + if direction > 0: # 向下 + candidates = [ + b + for b in blocks + if len(b) >= 5 + and b[1] > start_y + and b[1] < start_y + max_range + and b[2] > col_x0 + and b[0] < col_x1 + ] + candidates.sort(key=lambda b: b[1]) # 按 y0 升序 + else: # 向上 + candidates = [ + b + for b in blocks + if len(b) >= 5 + and b[3] <= start_y + and b[1] > start_y - max_range + and b[2] > col_x0 + and b[0] < col_x1 + ] + candidates.sort(key=lambda b: b[3], reverse=True) # 按 y1 降序(底部离 start_y 最近的在前) + + if not candidates: + return [] + + # 从 start_y 开始,追踪连通区域边界 + connected: list[tuple[float, float, float, float]] = [] + boundary = start_y # 当前连通区域离 start_y 最近端的 y 坐标 + prev_was_dense_table = False + + for b in candidates: bx0, by0, bx1, by1 = b[0], b[1], b[2], b[3] - if by0 > caption_y and by0 < caption_y + _TABLE_MAX_HEIGHT: - if bx1 > search_x0 and bx0 < search_x1: - # 双栏论文:排除跨列正文段落(宽度 >> 列宽,起点在另一列) - # 表格行起点在列内或列边界附近;正文段落起点在另一列(bx0 远小于 col_x0) - if col_x0 > 0 and bx0 < col_x0 - _TABLE_SIDE_PADDING: - continue - # 停止信号:遇到下一个 caption 或 section header 立即停止 - text = str(b[4]).strip() - first_line = text.split("\n")[0].strip() - if _CAPTION_STOP_RE.match(first_line) or _SECTION_STOP_RE.match( - first_line - ): - break - below_blocks.append((bx0, by0, bx1, by1)) + text = str(b[4]).strip() + first_line = text.split("\n")[0].strip() - if not below_blocks: - # 没有内容 → 使用默认高度和 caption 宽度 - return ( - max(0, caption_x0 - _REGION_SIDE_PADDING), - min(page_height, caption_y + _TABLE_MIN_HEIGHT), - min(page_width, caption_x1 + _REGION_SIDE_PADDING), + # stop 信号 + if _CAPTION_STOP_RE.match(first_line) or _SECTION_STOP_RE.match(first_line): + break + + # 检查当前块是否与连通区域相连(间隙 < 阈值) + if direction > 0: + gap = by0 - boundary + else: + gap = boundary - by1 + + # 密集表格数据块后使用更低的间隙阈值 + threshold = ( + _TABLE_DATA_GAP_THRESHOLD + if prev_was_dense_table + else _CONTENT_GAP_THRESHOLD + ) + if gap > threshold: + break + + connected.append((bx0, by0, bx1, by1)) + + # 更新连通区域边界 + if direction > 0: + boundary = by1 # 向下扩展 + else: + boundary = min(boundary, by0) # 向上扩展 + + # 判断当前块是否为密集表格数据(行密度高) + lines = [l for l in text.split("\n") if l.strip()] + block_height = by1 - by0 + prev_was_dense_table = ( + len(lines) >= 4 + and block_height > 0 + and len(lines) / block_height >= 0.08 ) - # ── 找到连续内容区域的底部 ── - below_blocks.sort(key=lambda b: b[1]) # 按 y 升序 + return connected - prev_y = caption_y - bottom = below_blocks[-1][3] + 5 # 最后一块的底部 + margin - for b in below_blocks: - gap = b[1] - prev_y # b[1] = by0 - if gap > _CONTENT_GAP_THRESHOLD: - bottom = prev_y + 5 - break - prev_y = b[3] # b[3] = by1 +def _find_table_region_by_blocks( + page, caption: dict +) -> tuple[float, float, float]: + """文本块间隙检测 — 作为 find_tables() 的 fallback。 - # 限制最大高度 - if bottom - caption_y > _TABLE_MAX_HEIGHT: - bottom = caption_y + _TABLE_MAX_HEIGHT + 向下扫描找表格下边界,向上扫描找表格上边界(处理 caption 在数据下方)。 + 使用 _scan_blocks_direction 统一双向扫描逻辑。 + """ + blocks = page.get_text("blocks") + caption_y0 = caption["caption_y0"] + caption_y1 = caption["caption_y1"] + caption_x0 = caption["caption_x0"] + caption_x1 = caption["caption_x1"] + page_width = caption["page_width"] + page_height = caption["page_height"] - # ── 检测表格内容的水平范围 ── - # 只用 gap 之前的 block 计算水平范围(gap 之后的 block 属于正文,可能更宽) - table_blocks = [b for b in below_blocks if b[1] < bottom] - if not table_blocks: - table_blocks = below_blocks[:1] # 至少用第一个 block - content_x0 = min(caption_x0, min(b[0] for b in table_blocks)) - content_x1 = max(caption_x1, max(b[2] for b in table_blocks)) + col_x0, col_x1 = _estimate_column_x(caption) + + # 向下扫描 + below = _scan_blocks_direction( + blocks, caption_y1, col_x0, col_x1, direction=1, max_range=_TABLE_MAX_HEIGHT + ) + # 向上扫描 + above = _scan_blocks_direction( + blocks, caption_y0, col_x0, col_x1, direction=-1, max_range=_TABLE_MAX_HEIGHT + ) + + # 确定上下边界 + scan_top = min(b[1] for b in above) if above else caption_y0 + scan_bottom = max(b[3] for b in below) if below else caption_y1 + + top = scan_top + bottom = scan_bottom + 5 # 底部 padding + + if bottom - top > _TABLE_MAX_HEIGHT: + bottom = top + _TABLE_MAX_HEIGHT + + # 水平范围:caption + 所有纳入块 + all_blocks = above + below + if all_blocks: + content_x0 = min(caption_x0, min(b[0] for b in all_blocks)) + content_x1 = max(caption_x1, max(b[2] for b in all_blocks)) + else: + content_x0 = caption_x0 + content_x1 = caption_x1 - # 添加边距,不超出页面 - # 使用较小 padding,避免将相邻列内容(如同页另一列的 Figure)带入截图; - # 同时不限制列边界 — 双栏论文中 caption 可能跨列起始 x0 = max(0, content_x0 - _REGION_SIDE_PADDING) x1 = min(page_width, content_x1 + _REGION_SIDE_PADDING) - return (x0, bottom, x1) + return (x0, top, bottom, x1) def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int: @@ -349,9 +561,10 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int: images_dest = paper_dir(arxiv_id) / "images" images_dest.mkdir(parents=True, exist_ok=True) - # 清理上次提取的旧图片,避免残留 - for old_file in images_dest.glob("*.png"): - old_file.unlink() + # 清理上次提取的旧图片,避免残留(同时清理 .png 和 .jpg) + for old_file in images_dest.iterdir(): + if old_file.suffix.lower() in (".png", ".jpg", ".jpeg"): + old_file.unlink() if (images_dest / "manifest.json").exists(): (images_dest / "manifest.json").unlink() @@ -379,7 +592,6 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int: for cap in unique_captions: page = doc[cap["page_num"]] - pw = cap["page_width"] if cap["type"] == "figure": # Figure: caption 上方是图 → 向上找图的上边界 @@ -387,10 +599,8 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int: # 上方多留 5pt 边距,确保图框边框、装饰线等不被截断 top = max(0, top - 5) bottom = cap["caption_y1"] + 5 # 包含 caption - # 水平范围:caption 宽度 + 边距(图和 caption 通常等宽) - # 但也要考虑图内容的实际宽度 - x0 = max(0, cap["caption_x0"] - _REGION_SIDE_PADDING) - x1 = min(pw, cap["caption_x1"] + _REGION_SIDE_PADDING) + # 水平范围:取 caption 宽度和图片实际宽度的并集 + x0, x1 = _find_figure_horizontal(page, cap, top, bottom) height = bottom - top if height < _FIGURE_MIN_HEIGHT: @@ -400,9 +610,9 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int: continue else: - # Table: caption 下方是表格 → 向下找表格的下边界和水平范围 - x0, bottom, x1 = _find_table_region(page, cap) - top = max(0, cap["caption_y0"] - 3) # 包含 caption,上边留少许 margin + # Table: 找表格区域(find_tables() → 块级 fallback,双向扫描) + x0, tbl_top, bottom, x1 = _find_table_region(page, cap) + top = max(0, tbl_top - 5) # 包含 caption 及上方数据,留 5pt margin height = bottom - top if height < _TABLE_MIN_HEIGHT: @@ -420,8 +630,11 @@ def extract_images_from_pdf(arxiv_id: str, pdf_path: Path | None = None) -> int: logger.debug("Failed to render %s region for %s", cap["label"], arxiv_id) continue - filename = f"{cap['label'].replace(' ', '_').lower()}.png" - pix.save(str(images_dest / filename)) + # 保存为 JPEG(比 PNG 小 5-10 倍,适合网络传输) + filename = f"{cap['label'].replace(' ', '_').lower()}.jpg" + jpeg_path = images_dest / filename + jpeg_bytes = pix.tobytes("jpeg") + jpeg_path.write_bytes(jpeg_bytes) extracted += 1 cap_preview = cap["caption_text"][:200] if cap["caption_text"] else "" @@ -477,7 +690,9 @@ def filter_images_by_summary(arxiv_id: str, figures: list[dict]) -> int: if not images_dir.exists() or not manifest_path.exists(): return 0 - all_files = [f for f in images_dir.iterdir() if f.suffix == ".png"] + all_files = [ + f for f in images_dir.iterdir() if f.suffix.lower() in (".png", ".jpg", ".jpeg") + ] if not all_files: return 0