refactor: restructure services and add image/pdf extraction utilities

- Add image_extractor, pdf_downloader, pi_client, trends services
- Add shared utils module
- Refactor summarizer, embedder, routes for cleaner separation
- Update tests to match new service structure
This commit is contained in:
2026-06-06 00:00:55 +08:00
parent ba9afa212c
commit 85c4cfb9e8
22 changed files with 843 additions and 780 deletions
+39 -362
View File
@@ -1,18 +1,15 @@
"""AI 总结服务 — 调用 pi CLI 生成论文中文结构化总结"""
"""AI 总结编排服务 — 协调 PDF 下载、pi CLI 调用、JSON 校验、DB 写入、语义索引"""
from __future__ import annotations
import asyncio
import json
import logging
import re
import shutil
from datetime import datetime, timezone
from pathlib import Path
import httpx
from pydantic import ValidationError
from sqlalchemy import select, text
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from app.config import settings
@@ -25,216 +22,31 @@ from app.models import (
SummaryStatus,
TaskLock,
)
from app.services.image_extractor import extract_images_from_source
from app.services.pdf_downloader import (
PdfDownloadError,
cleanup_tmp,
download_pdf,
paper_dir,
)
from app.services.pi_client import (
JsonNotFoundError,
PiProcessError,
PiTimeoutError,
call_pi,
extract_json,
write_meta_json,
)
from app.services.schemas import (
SummarySchema,
assess_quality,
classify_validation_error,
flatten_for_db,
)
from app.utils import PAPERS_DIR, release_lock
logger = logging.getLogger(__name__)
# ── 自定义异常 ──────────────────────────────────────────────────────────
class PdfDownloadError(Exception):
pass
class PiTimeoutError(Exception):
pass
class PiProcessError(Exception):
def __init__(self, returncode: int, stderr: str):
self.returncode = returncode
self.stderr = stderr
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
class JsonNotFoundError(Exception):
pass
# ── 路径工具 ────────────────────────────────────────────────────────────
_DATA_DIR = Path("data")
_PAPERS_DIR = _DATA_DIR / "papers"
_TMP_DIR = _DATA_DIR / "tmp"
def _paper_dir(arxiv_id: str) -> Path:
return _PAPERS_DIR / arxiv_id
def _tmp_dir(arxiv_id: str) -> Path:
return _TMP_DIR / arxiv_id
# ── PDF 下载 ────────────────────────────────────────────────────────────
async def _download_pdf(arxiv_id: str, pdf_url: str) -> Path:
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
if not pdf_url:
raise PdfDownloadError(f"no pdf_url for {arxiv_id}")
tmp = _tmp_dir(arxiv_id)
tmp.mkdir(parents=True, exist_ok=True)
dest = tmp / "paper.pdf"
transport = None
if settings.http_proxy:
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
try:
async with httpx.AsyncClient(
timeout=settings.HTTP_TIMEOUT_SECONDS,
headers={"User-Agent": settings.HTTP_USER_AGENT},
transport=transport,
follow_redirects=True,
) as client:
resp = await client.get(pdf_url)
resp.raise_for_status()
dest.write_bytes(resp.content)
except Exception as exc:
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
return dest
# ── meta.json ───────────────────────────────────────────────────────────
def _write_meta_json(paper: Paper) -> Path:
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
d = _paper_dir(paper.arxiv_id)
d.mkdir(parents=True, exist_ok=True)
meta_path = d / "meta.json"
authors = [a.name for a in paper.authors]
tags = [t.tag for t in paper.tags]
meta = {
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"abstract": paper.abstract or "",
"published_at": paper.published_at.isoformat() if paper.published_at else None,
"authors": authors,
"tags": tags,
"upvotes": paper.upvotes,
}
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
return meta_path
# ── pi CLI 调用 ────────────────────────────────────────────────────────
async def _call_pi(meta_path: Path, pdf_path: Path) -> str:
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
cmd = [
settings.PI_BIN,
"-p",
"--no-tools",
"--skill",
settings.SUMMARY_SKILL,
"请深度解读以下论文,并按指定 JSON schema 输出:",
f"@{meta_path}",
f"@{pdf_path}",
]
logger.info("Calling pi: %s %s", paper_id_from_path(meta_path), " ".join(cmd[:4]))
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(),
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise PiTimeoutError(
f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s"
)
if proc.returncode != 0:
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
return stdout.decode("utf-8", errors="replace")
def paper_id_from_path(meta_path: Path) -> str:
"""从 meta.json 路径反推 arxiv_id。"""
return meta_path.parent.name
# ── JSON 提取 ──────────────────────────────────────────────────────────
def _extract_json(raw_output: str) -> dict:
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
# 策略 1:整体直接解析
stripped = raw_output.strip()
try:
result = json.loads(stripped)
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
pass
# 策略 2:提取 ```json ... ``` 代码块
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
for match in fence_pattern.finditer(raw_output):
try:
result = json.loads(match.group(1).strip())
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
continue
# 策略 3:匹配包含 title_zh 的最大 {...} 块
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
# 先尝试一层嵌套;如果没命中再用更宽松的策略
for match in brace_pattern.finditer(raw_output):
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
continue
# 更宽松:找到最大的 { ... } 平衡块
best = None
best_len = 0
for i, ch in enumerate(raw_output):
if ch != "{":
continue
depth = 0
for j in range(i, len(raw_output)):
if raw_output[j] == "{":
depth += 1
elif raw_output[j] == "}":
depth -= 1
if depth == 0:
candidate = raw_output[i : j + 1]
if len(candidate) > best_len:
try:
parsed = json.loads(candidate)
if isinstance(parsed, dict):
best = parsed
best_len = len(candidate)
except json.JSONDecodeError:
pass
break
if best is not None:
return best
raise JsonNotFoundError("no JSON object found in pi output")
# ── 错误分类 ────────────────────────────────────────────────────────────
@@ -284,6 +96,8 @@ def _update_summary_in_db(
raw_output: str,
) -> None:
"""将校验后的总结写入 DBpaper_summaries + papers + paper_tags + FTS5。"""
from sqlalchemy import text
now = datetime.now(timezone.utc)
# 1. paper_summariesupsert
@@ -298,9 +112,9 @@ def _update_summary_in_db(
# 2. papers 表
paper.title_zh = schema.title_zh
paper.summary_quality = quality
paper_dir = _paper_dir(paper.arxiv_id)
paper.summary_path = str(paper_dir / "summary.json")
paper.raw_output_path = str(paper_dir / "raw_output.txt")
p_dir = paper_dir(paper.arxiv_id)
paper.summary_path = str(p_dir / "summary.json")
paper.raw_output_path = str(p_dir / "raw_output.txt")
# 3. AI 标签
existing_tag_names = {t.tag for t in paper.tags}
@@ -332,7 +146,7 @@ def _update_summary_in_db(
def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
"""保存 summary.json 和 raw_output.txt。"""
d = _paper_dir(arxiv_id)
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
(d / "summary.json").write_text(
schema.model_dump_json(ensure_ascii=False, indent=2),
@@ -343,143 +157,11 @@ def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None:
"""仅保存 raw_output.txt(失败时)。"""
d = _paper_dir(arxiv_id)
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
def _cleanup_tmp(arxiv_id: str) -> None:
"""清理 data/tmp/{arxiv_id}/ 目录。"""
tmp = _tmp_dir(arxiv_id)
if tmp.exists():
try:
shutil.rmtree(tmp)
logger.debug("Cleaned tmp: %s", arxiv_id)
except Exception:
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
# ── LaTeX 图片提取(Phase 5)───────────────────────────────────────────
_INCLUDEGRAPHICS_RE = re.compile(
r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
)
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
async def _extract_images_from_source(arxiv_id: str, tmp_source: Path | None = None) -> int:
"""从 LaTeX 源码中提取图片文件。
流程:
1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
2. 扫描 .tex 文件中的 \\includegraphics
3. 复制图片到 data/papers/{arxiv_id}/images/
4. 清理源码临时文件
Returns:
提取的图片数量
"""
tmp_source = _tmp_dir(arxiv_id) / "source"
images_dest = _paper_dir(arxiv_id) / "images"
try:
# 下载源码 zip(如果还没下载)
if not tmp_source.exists():
source_url = f"https://arxiv.org/e-print/{arxiv_id}"
await _download_source_zip(arxiv_id, source_url, tmp_source)
if not tmp_source.exists():
return 0
# 扫描 .tex 文件,收集图片路径
image_paths: set[str] = set()
for tex_file in tmp_source.rglob("*.tex"):
try:
content = tex_file.read_text(encoding="utf-8", errors="replace")
for match in _INCLUDEGRAPHICS_RE.finditer(content):
img_path = match.group(1).strip()
image_paths.add(img_path)
except Exception:
continue
if not image_paths:
return 0
# 查找并复制图片
images_dest.mkdir(parents=True, exist_ok=True)
copied = 0
for img_rel in image_paths:
# 尝试在源码目录中找到文件
for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
candidate = tmp_source / (img_rel + ext)
if candidate.is_file():
dest_name = candidate.name
# 避免文件名冲突
dest = images_dest / dest_name
if dest.exists():
stem = dest.stem
suffix = dest.suffix
dest = images_dest / f"{stem}_{copied}{suffix}"
shutil.copy2(candidate, dest)
copied += 1
break
if copied > 0:
logger.info("Extracted %d images from source for %s", copied, arxiv_id)
return copied
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
return 0
async def _download_source_zip(
arxiv_id: str, source_url: str, dest_dir: Path
) -> None:
"""下载 arXiv 源码并解压。"""
import zipfile
dest_dir.mkdir(parents=True, exist_ok=True)
zip_path = _tmp_dir(arxiv_id) / "source.zip"
transport = None
if settings.http_proxy:
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
try:
async with httpx.AsyncClient(
timeout=settings.HTTP_TIMEOUT_SECONDS,
headers={"User-Agent": settings.HTTP_USER_AGENT},
transport=transport,
follow_redirects=True,
) as client:
resp = await client.get(source_url)
resp.raise_for_status()
zip_path.write_bytes(resp.content)
except Exception as exc:
logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
return
try:
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(dest_dir)
logger.debug("Extracted source for %s", arxiv_id)
except zipfile.BadZipFile:
# 可能是 tar.gz
import tarfile
try:
with tarfile.open(zip_path, "r:*") as tf:
tf.extractall(dest_dir)
logger.debug("Extracted source (tar) for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
finally:
if zip_path.exists():
zip_path.unlink()
# ── 单篇总结 ────────────────────────────────────────────────────────────
@@ -491,6 +173,8 @@ async def summarize_one(
force: bool = False,
) -> dict:
"""总结单篇论文的完整流程。"""
import asyncio
arxiv_id = paper.arxiv_id
# 获取或创建 summary_status
@@ -520,6 +204,8 @@ async def summarize_one(
async def _do_summarize_one(db: Session, paper: Paper) -> dict:
"""实际的单篇总结执行(在 semaphore 保护下)。"""
import asyncio
arxiv_id = paper.arxiv_id
status = paper.summary_status
now = datetime.now(timezone.utc)
@@ -532,16 +218,16 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
raw_output = ""
try:
# 写 meta.json
meta_path = _write_meta_json(paper)
meta_path = write_meta_json(paper)
# 下载 PDF
await _download_pdf(arxiv_id, paper.pdf_url)
await download_pdf(arxiv_id, paper.pdf_url)
# 调用 pi
raw_output = await _call_pi(meta_path, _tmp_dir(arxiv_id) / "paper.pdf")
raw_output = await call_pi(meta_path, Path("data/tmp") / arxiv_id / "paper.pdf")
# 提取 JSON
json_data = _extract_json(raw_output)
json_data = extract_json(raw_output)
# Pydantic 校验
schema = SummarySchema.model_validate(json_data)
@@ -564,7 +250,7 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
# Phase 5: LaTeX 图片提取(可选增强,失败不影响总结)
try:
await _extract_images_from_source(arxiv_id)
await extract_images_from_source(arxiv_id)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
@@ -625,7 +311,7 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
}
finally:
_cleanup_tmp(arxiv_id)
cleanup_tmp(arxiv_id)
# ── 单篇入口 ────────────────────────────────────────────────────────────
@@ -690,6 +376,8 @@ async def summarize_batch(
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
import asyncio
now = datetime.now(timezone.utc)
# TaskLock 防重入
@@ -741,7 +429,7 @@ async def summarize_batch(
log_entry.papers_found = 0
log_entry.papers_new = 0
log_entry.completed_at = datetime.now(timezone.utc)
_release_lock(db, lock)
release_lock(db, lock)
return {"status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0}
# 并发控制
@@ -813,15 +501,4 @@ async def summarize_batch(
return {"status": "failed", "error": str(exc)}
finally:
_release_lock(db, lock)
def _release_lock(db: Session, lock: TaskLock) -> None:
"""释放 TaskLock。"""
try:
lock.status = "finished"
lock.released_at = datetime.now(timezone.utc)
db.commit()
except Exception:
db.rollback()
logger.warning("Failed to release summarize lock", exc_info=True)
release_lock(db, lock)