refactor: restructure services and add image/pdf extraction utilities

- Add image_extractor, pdf_downloader, pi_client, trends services
- Add shared utils module
- Refactor summarizer, embedder, routes for cleaner separation
- Update tests to match new service structure
This commit is contained in:
2026-06-06 00:00:55 +08:00
parent ba9afa212c
commit 85c4cfb9e8
22 changed files with 843 additions and 780 deletions
+5 -8
View File
@@ -16,13 +16,10 @@ from app.models import (
Paper,
TaskLock,
)
from app.utils import PAPERS_DIR, TMP_DIR
logger = logging.getLogger(__name__)
_DATA_DIR = Path("data")
_TMP_DIR = _DATA_DIR / "tmp"
_PAPERS_DIR = _DATA_DIR / "papers"
# 临时文件最大保留时间(小时)
_MAX_TMP_AGE_HOURS = 24
@@ -39,7 +36,7 @@ def cleanup_tmp(max_age_hours: int = _MAX_TMP_AGE_HOURS) -> dict:
Returns:
清理统计 {"scanned": int, "removed": int, "errors": list[str]}
"""
if not _TMP_DIR.exists():
if not TMP_DIR.exists():
return {"scanned": 0, "removed": 0, "errors": []}
now = datetime.now(timezone.utc)
@@ -48,7 +45,7 @@ def cleanup_tmp(max_age_hours: int = _MAX_TMP_AGE_HOURS) -> dict:
removed = 0
errors: list[str] = []
for entry in _TMP_DIR.iterdir():
for entry in TMP_DIR.iterdir():
if not entry.is_dir():
continue
scanned += 1
@@ -147,13 +144,13 @@ async def delete_papers_by_date_range(
logger.warning("Failed to delete %s from ChromaDB", arxiv_id, exc_info=True)
# 2. 删除本地文件 data/papers/{arxiv_id}/
paper_dir = _PAPERS_DIR / arxiv_id
paper_dir = PAPERS_DIR / arxiv_id
if paper_dir.exists():
shutil.rmtree(paper_dir)
logger.debug("Removed paper dir: %s", paper_dir)
# 3. 删除临时文件 data/tmp/{arxiv_id}/
tmp_dir = _TMP_DIR / arxiv_id
tmp_dir = TMP_DIR / arxiv_id
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
logger.debug("Removed tmp dir: %s", tmp_dir)
+2 -9
View File
@@ -16,6 +16,7 @@ from app.models import (
PaperTag,
SummaryStatus,
)
from app.utils import make_http_client
logger = logging.getLogger(__name__)
@@ -34,15 +35,7 @@ async def fetch_daily(target_date: str, top_n: int | None = None) -> list[dict]:
url = f"{settings.HF_API_BASE}/daily_papers"
params = {"date": target_date}
transport = None
if settings.http_proxy:
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
async with httpx.AsyncClient(
timeout=settings.HTTP_TIMEOUT_SECONDS,
headers={"User-Agent": settings.HTTP_USER_AGENT},
transport=transport,
) as client:
async with make_http_client() as client:
for attempt in range(1, settings.HTTP_MAX_RETRIES + 1):
try:
logger.info("Fetching HF Daily Papers: date=%s attempt=%d", target_date, attempt)
+70 -54
View File
@@ -5,8 +5,6 @@ from __future__ import annotations
import logging
from pathlib import Path
import httpx
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from app.config import settings
@@ -14,66 +12,82 @@ from app.models import Paper
logger = logging.getLogger(__name__)
# ── 单例客户端和 collection ─────────────────────────────────────────────
_client = None
_collection = None
# ── ChromaDB 管理器(替代全局可变状态)──────────────────────────────────
def _chroma_dir() -> Path:
return Path(settings.CHROMA_DIR)
class ChromaManager:
"""封装 ChromaDB 客户端和 collection 的生命周期。"""
def __init__(self) -> None:
self._client = None
self._collection = None
def init(self) -> None:
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
if not settings.CHROMA_ENABLED:
logger.debug("ChromaDB disabled, skip init")
return
if self._client is not None:
return
try:
import chromadb
chroma_path = Path(settings.CHROMA_DIR)
chroma_path.mkdir(parents=True, exist_ok=True)
self._client = chromadb.PersistentClient(path=str(chroma_path))
self._collection = self._get_or_create_collection()
logger.info("ChromaDB initialized at %s", chroma_path)
except Exception:
logger.exception("Failed to initialize ChromaDB")
self._client = None
self._collection = None
def _get_or_create_collection(self):
"""获取或创建 papers_embeddings collection。"""
try:
col = self._client.get_collection("papers_embeddings")
logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count())
return col
except Exception:
pass
col = self._client.create_collection(
name="papers_embeddings",
metadata={"hnsw:space": "cosine"},
)
logger.info("ChromaDB collection 'papers_embeddings' created")
return col
def get_collection(self):
"""返回当前 collection,未初始化则自动初始化。"""
if not settings.CHROMA_ENABLED:
return None
if self._collection is None:
self.init()
return self._collection
def reset(self) -> None:
"""重置状态(供测试使用)。"""
self._client = None
self._collection = None
# 模块级单例
_chroma = ChromaManager()
def init_chroma() -> None:
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection"""
global _client, _collection
if not settings.CHROMA_ENABLED:
logger.debug("ChromaDB disabled, skip init")
return
if _client is not None:
return
try:
import chromadb
chroma_path = _chroma_dir()
chroma_path.mkdir(parents=True, exist_ok=True)
_client = chromadb.PersistentClient(path=str(chroma_path))
_collection = _get_or_create_collection()
logger.info("ChromaDB initialized at %s", chroma_path)
except Exception:
logger.exception("Failed to initialize ChromaDB")
_client = None
_collection = None
def _get_or_create_collection():
"""获取或创建 papers_embeddings collection,维度不匹配时记录日志并跳过。"""
import chromadb
try:
col = _client.get_collection("papers_embeddings")
logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count())
return col
except Exception:
pass
col = _client.create_collection(
name="papers_embeddings",
metadata={"hnsw:space": "cosine"},
)
logger.info("ChromaDB collection 'papers_embeddings' created")
return col
"""初始化 ChromaDB(供 lifespan 调用)"""
_chroma.init()
def get_collection():
"""返回当前 collection,未初始化则返回 None。"""
if not settings.CHROMA_ENABLED:
return None
if _collection is None:
init_chroma()
return _collection
return _chroma.get_collection()
# ── Embedding API 调用 ──────────────────────────────────────────────────
@@ -90,6 +104,8 @@ def _get_embedding(text: str) -> list[float] | None:
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
return None
from app.utils import make_http_client
url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings"
headers = {"Content-Type": "application/json"}
if settings.EMBED_API_KEY:
@@ -101,7 +117,7 @@ def _get_embedding(text: str) -> list[float] | None:
}
try:
with httpx.Client(timeout=settings.HTTP_TIMEOUT_SECONDS) as client:
with make_http_client(sync=True) as client:
resp = client.post(url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
+83
View File
@@ -0,0 +1,83 @@
"""LaTeX 图片提取 — 从 arXiv 源码中扫描 \\includegraphics 并提取图片文件。"""
from __future__ import annotations
import logging
import re
import shutil
from pathlib import Path
from app.services.pdf_downloader import download_source_zip, paper_dir, tmp_dir
logger = logging.getLogger(__name__)
_INCLUDEGRAPHICS_RE = re.compile(
r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
)
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
async def extract_images_from_source(arxiv_id: str) -> int:
"""从 LaTeX 源码中提取图片文件。
流程:
1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
2. 扫描 .tex 文件中的 \\includegraphics
3. 复制图片到 data/papers/{arxiv_id}/images/
4. 清理源码临时文件
Returns:
提取的图片数量
"""
tmp_source = tmp_dir(arxiv_id) / "source"
images_dest = paper_dir(arxiv_id) / "images"
try:
# 下载源码 zip(如果还没下载)
if not tmp_source.exists():
source_url = f"https://arxiv.org/e-print/{arxiv_id}"
await download_source_zip(arxiv_id, source_url, tmp_source)
if not tmp_source.exists():
return 0
# 扫描 .tex 文件,收集图片路径
image_paths: set[str] = set()
for tex_file in tmp_source.rglob("*.tex"):
try:
content = tex_file.read_text(encoding="utf-8", errors="replace")
for match in _INCLUDEGRAPHICS_RE.finditer(content):
img_path = match.group(1).strip()
image_paths.add(img_path)
except Exception:
continue
if not image_paths:
return 0
# 查找并复制图片
images_dest.mkdir(parents=True, exist_ok=True)
copied = 0
for img_rel in image_paths:
# 尝试在源码目录中找到文件
for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
candidate = tmp_source / (img_rel + ext)
if candidate.is_file():
dest_name = candidate.name
# 避免文件名冲突
dest = images_dest / dest_name
if dest.exists():
stem = dest.stem
suffix = dest.suffix
dest = images_dest / f"{stem}_{copied}{suffix}"
shutil.copy2(candidate, dest)
copied += 1
break
if copied > 0:
logger.info("Extracted %d images from source for %s", copied, arxiv_id)
return copied
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
return 0
+105
View File
@@ -0,0 +1,105 @@
"""PDF 下载与源码下载 — 从 arXiv 下载论文 PDF 和 LaTeX 源码包。"""
from __future__ import annotations
import logging
import shutil
import zipfile
from pathlib import Path
from app.utils import PAPERS_DIR, TMP_DIR, make_http_client
logger = logging.getLogger(__name__)
# ── 自定义异常 ──────────────────────────────────────────────────────────
class PdfDownloadError(Exception):
pass
# ── 路径工具 ────────────────────────────────────────────────────────────
def paper_dir(arxiv_id: str) -> Path:
return PAPERS_DIR / arxiv_id
def tmp_dir(arxiv_id: str) -> Path:
return TMP_DIR / arxiv_id
# ── PDF 下载 ────────────────────────────────────────────────────────────
async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
if not pdf_url:
raise PdfDownloadError(f"no pdf_url for {arxiv_id}")
dest_dir = tmp_dir(arxiv_id)
dest_dir.mkdir(parents=True, exist_ok=True)
dest = dest_dir / "paper.pdf"
try:
async with make_http_client(follow_redirects=True) as client:
resp = await client.get(pdf_url)
resp.raise_for_status()
dest.write_bytes(resp.content)
except Exception as exc:
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
return dest
# ── 源码下载 ────────────────────────────────────────────────────────────
async def download_source_zip(arxiv_id: str, source_url: str, dest_dir: Path) -> None:
"""下载 arXiv 源码并解压。"""
dest_dir.mkdir(parents=True, exist_ok=True)
zip_path = tmp_dir(arxiv_id) / "source.zip"
try:
async with make_http_client(follow_redirects=True) as client:
resp = await client.get(source_url)
resp.raise_for_status()
zip_path.write_bytes(resp.content)
except Exception as exc:
logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
return
try:
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(dest_dir)
logger.debug("Extracted source for %s", arxiv_id)
except zipfile.BadZipFile:
# 可能是 tar.gz
import tarfile
try:
with tarfile.open(zip_path, "r:*") as tf:
tf.extractall(dest_dir, filter="data")
logger.debug("Extracted source (tar) for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
finally:
if zip_path.exists():
zip_path.unlink()
# ── 临时文件清理 ────────────────────────────────────────────────────────
def cleanup_tmp(arxiv_id: str) -> None:
"""清理 data/tmp/{arxiv_id}/ 目录。"""
td = tmp_dir(arxiv_id)
if td.exists():
try:
shutil.rmtree(td)
logger.debug("Cleaned tmp: %s", arxiv_id)
except Exception:
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
+160
View File
@@ -0,0 +1,160 @@
"""pi CLI 调用与 JSON 提取 — 调用 pi 生成总结,从输出中提取结构化 JSON。"""
from __future__ import annotations
import asyncio
import json
import logging
import re
from pathlib import Path
from app.config import settings
logger = logging.getLogger(__name__)
# ── 自定义异常 ──────────────────────────────────────────────────────────
class PiTimeoutError(Exception):
pass
class PiProcessError(Exception):
def __init__(self, returncode: int, stderr: str):
self.returncode = returncode
self.stderr = stderr
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
class JsonNotFoundError(Exception):
pass
# ── meta.json ───────────────────────────────────────────────────────────
def write_meta_json(paper) -> Path:
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
from app.services.pdf_downloader import paper_dir
d = paper_dir(paper.arxiv_id)
d.mkdir(parents=True, exist_ok=True)
meta_path = d / "meta.json"
authors = [a.name for a in paper.authors]
tags = [t.tag for t in paper.tags]
meta = {
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"abstract": paper.abstract or "",
"published_at": paper.published_at.isoformat() if paper.published_at else None,
"authors": authors,
"tags": tags,
"upvotes": paper.upvotes,
}
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
return meta_path
# ── pi CLI 调用 ────────────────────────────────────────────────────────
async def call_pi(meta_path: Path, pdf_path: Path) -> str:
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
arxiv_id = meta_path.parent.name
cmd = [
settings.PI_BIN,
"-p",
"--no-tools",
"--skill",
settings.SUMMARY_SKILL,
"请深度解读以下论文,并按指定 JSON schema 输出:",
f"@{meta_path}",
f"@{pdf_path}",
]
logger.info("Calling pi for %s", arxiv_id)
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(),
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise PiTimeoutError(
f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s"
)
if proc.returncode != 0:
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
return stdout.decode("utf-8", errors="replace")
# ── JSON 提取 ──────────────────────────────────────────────────────────
def extract_json(raw_output: str) -> dict:
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
# 策略 1:整体直接解析
stripped = raw_output.strip()
try:
result = json.loads(stripped)
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
pass
# 策略 2:提取 ```json ... ``` 代码块
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
for match in fence_pattern.finditer(raw_output):
try:
result = json.loads(match.group(1).strip())
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
continue
# 策略 3:匹配包含 title_zh 的最大 {...} 块
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
for match in brace_pattern.finditer(raw_output):
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
continue
# 更宽松:找到最大的 { ... } 平衡块
best = None
best_len = 0
for i, ch in enumerate(raw_output):
if ch != "{":
continue
depth = 0
for j in range(i, len(raw_output)):
if raw_output[j] == "{":
depth += 1
elif raw_output[j] == "}":
depth -= 1
if depth == 0:
candidate = raw_output[i : j + 1]
if len(candidate) > best_len:
try:
parsed = json.loads(candidate)
if isinstance(parsed, dict):
best = parsed
best_len = len(candidate)
except json.JSONDecodeError:
pass
break
if best is not None:
return best
raise JsonNotFoundError("no JSON object found in pi output")
+39 -362
View File
@@ -1,18 +1,15 @@
"""AI 总结服务 — 调用 pi CLI 生成论文中文结构化总结"""
"""AI 总结编排服务 — 协调 PDF 下载、pi CLI 调用、JSON 校验、DB 写入、语义索引"""
from __future__ import annotations
import asyncio
import json
import logging
import re
import shutil
from datetime import datetime, timezone
from pathlib import Path
import httpx
from pydantic import ValidationError
from sqlalchemy import select, text
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from app.config import settings
@@ -25,216 +22,31 @@ from app.models import (
SummaryStatus,
TaskLock,
)
from app.services.image_extractor import extract_images_from_source
from app.services.pdf_downloader import (
PdfDownloadError,
cleanup_tmp,
download_pdf,
paper_dir,
)
from app.services.pi_client import (
JsonNotFoundError,
PiProcessError,
PiTimeoutError,
call_pi,
extract_json,
write_meta_json,
)
from app.services.schemas import (
SummarySchema,
assess_quality,
classify_validation_error,
flatten_for_db,
)
from app.utils import PAPERS_DIR, release_lock
logger = logging.getLogger(__name__)
# ── 自定义异常 ──────────────────────────────────────────────────────────
class PdfDownloadError(Exception):
pass
class PiTimeoutError(Exception):
pass
class PiProcessError(Exception):
def __init__(self, returncode: int, stderr: str):
self.returncode = returncode
self.stderr = stderr
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
class JsonNotFoundError(Exception):
pass
# ── 路径工具 ────────────────────────────────────────────────────────────
_DATA_DIR = Path("data")
_PAPERS_DIR = _DATA_DIR / "papers"
_TMP_DIR = _DATA_DIR / "tmp"
def _paper_dir(arxiv_id: str) -> Path:
return _PAPERS_DIR / arxiv_id
def _tmp_dir(arxiv_id: str) -> Path:
return _TMP_DIR / arxiv_id
# ── PDF 下载 ────────────────────────────────────────────────────────────
async def _download_pdf(arxiv_id: str, pdf_url: str) -> Path:
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
if not pdf_url:
raise PdfDownloadError(f"no pdf_url for {arxiv_id}")
tmp = _tmp_dir(arxiv_id)
tmp.mkdir(parents=True, exist_ok=True)
dest = tmp / "paper.pdf"
transport = None
if settings.http_proxy:
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
try:
async with httpx.AsyncClient(
timeout=settings.HTTP_TIMEOUT_SECONDS,
headers={"User-Agent": settings.HTTP_USER_AGENT},
transport=transport,
follow_redirects=True,
) as client:
resp = await client.get(pdf_url)
resp.raise_for_status()
dest.write_bytes(resp.content)
except Exception as exc:
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
return dest
# ── meta.json ───────────────────────────────────────────────────────────
def _write_meta_json(paper: Paper) -> Path:
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
d = _paper_dir(paper.arxiv_id)
d.mkdir(parents=True, exist_ok=True)
meta_path = d / "meta.json"
authors = [a.name for a in paper.authors]
tags = [t.tag for t in paper.tags]
meta = {
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"abstract": paper.abstract or "",
"published_at": paper.published_at.isoformat() if paper.published_at else None,
"authors": authors,
"tags": tags,
"upvotes": paper.upvotes,
}
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
return meta_path
# ── pi CLI 调用 ────────────────────────────────────────────────────────
async def _call_pi(meta_path: Path, pdf_path: Path) -> str:
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
cmd = [
settings.PI_BIN,
"-p",
"--no-tools",
"--skill",
settings.SUMMARY_SKILL,
"请深度解读以下论文,并按指定 JSON schema 输出:",
f"@{meta_path}",
f"@{pdf_path}",
]
logger.info("Calling pi: %s %s", paper_id_from_path(meta_path), " ".join(cmd[:4]))
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(),
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise PiTimeoutError(
f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s"
)
if proc.returncode != 0:
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
return stdout.decode("utf-8", errors="replace")
def paper_id_from_path(meta_path: Path) -> str:
"""从 meta.json 路径反推 arxiv_id。"""
return meta_path.parent.name
# ── JSON 提取 ──────────────────────────────────────────────────────────
def _extract_json(raw_output: str) -> dict:
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
# 策略 1:整体直接解析
stripped = raw_output.strip()
try:
result = json.loads(stripped)
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
pass
# 策略 2:提取 ```json ... ``` 代码块
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
for match in fence_pattern.finditer(raw_output):
try:
result = json.loads(match.group(1).strip())
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
continue
# 策略 3:匹配包含 title_zh 的最大 {...} 块
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
# 先尝试一层嵌套;如果没命中再用更宽松的策略
for match in brace_pattern.finditer(raw_output):
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
continue
# 更宽松:找到最大的 { ... } 平衡块
best = None
best_len = 0
for i, ch in enumerate(raw_output):
if ch != "{":
continue
depth = 0
for j in range(i, len(raw_output)):
if raw_output[j] == "{":
depth += 1
elif raw_output[j] == "}":
depth -= 1
if depth == 0:
candidate = raw_output[i : j + 1]
if len(candidate) > best_len:
try:
parsed = json.loads(candidate)
if isinstance(parsed, dict):
best = parsed
best_len = len(candidate)
except json.JSONDecodeError:
pass
break
if best is not None:
return best
raise JsonNotFoundError("no JSON object found in pi output")
# ── 错误分类 ────────────────────────────────────────────────────────────
@@ -284,6 +96,8 @@ def _update_summary_in_db(
raw_output: str,
) -> None:
"""将校验后的总结写入 DBpaper_summaries + papers + paper_tags + FTS5。"""
from sqlalchemy import text
now = datetime.now(timezone.utc)
# 1. paper_summariesupsert
@@ -298,9 +112,9 @@ def _update_summary_in_db(
# 2. papers 表
paper.title_zh = schema.title_zh
paper.summary_quality = quality
paper_dir = _paper_dir(paper.arxiv_id)
paper.summary_path = str(paper_dir / "summary.json")
paper.raw_output_path = str(paper_dir / "raw_output.txt")
p_dir = paper_dir(paper.arxiv_id)
paper.summary_path = str(p_dir / "summary.json")
paper.raw_output_path = str(p_dir / "raw_output.txt")
# 3. AI 标签
existing_tag_names = {t.tag for t in paper.tags}
@@ -332,7 +146,7 @@ def _update_summary_in_db(
def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
"""保存 summary.json 和 raw_output.txt。"""
d = _paper_dir(arxiv_id)
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
(d / "summary.json").write_text(
schema.model_dump_json(ensure_ascii=False, indent=2),
@@ -343,143 +157,11 @@ def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None:
"""仅保存 raw_output.txt(失败时)。"""
d = _paper_dir(arxiv_id)
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
def _cleanup_tmp(arxiv_id: str) -> None:
"""清理 data/tmp/{arxiv_id}/ 目录。"""
tmp = _tmp_dir(arxiv_id)
if tmp.exists():
try:
shutil.rmtree(tmp)
logger.debug("Cleaned tmp: %s", arxiv_id)
except Exception:
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
# ── LaTeX 图片提取(Phase 5)───────────────────────────────────────────
_INCLUDEGRAPHICS_RE = re.compile(
r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
)
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
async def _extract_images_from_source(arxiv_id: str, tmp_source: Path | None = None) -> int:
"""从 LaTeX 源码中提取图片文件。
流程:
1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
2. 扫描 .tex 文件中的 \\includegraphics
3. 复制图片到 data/papers/{arxiv_id}/images/
4. 清理源码临时文件
Returns:
提取的图片数量
"""
tmp_source = _tmp_dir(arxiv_id) / "source"
images_dest = _paper_dir(arxiv_id) / "images"
try:
# 下载源码 zip(如果还没下载)
if not tmp_source.exists():
source_url = f"https://arxiv.org/e-print/{arxiv_id}"
await _download_source_zip(arxiv_id, source_url, tmp_source)
if not tmp_source.exists():
return 0
# 扫描 .tex 文件,收集图片路径
image_paths: set[str] = set()
for tex_file in tmp_source.rglob("*.tex"):
try:
content = tex_file.read_text(encoding="utf-8", errors="replace")
for match in _INCLUDEGRAPHICS_RE.finditer(content):
img_path = match.group(1).strip()
image_paths.add(img_path)
except Exception:
continue
if not image_paths:
return 0
# 查找并复制图片
images_dest.mkdir(parents=True, exist_ok=True)
copied = 0
for img_rel in image_paths:
# 尝试在源码目录中找到文件
for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
candidate = tmp_source / (img_rel + ext)
if candidate.is_file():
dest_name = candidate.name
# 避免文件名冲突
dest = images_dest / dest_name
if dest.exists():
stem = dest.stem
suffix = dest.suffix
dest = images_dest / f"{stem}_{copied}{suffix}"
shutil.copy2(candidate, dest)
copied += 1
break
if copied > 0:
logger.info("Extracted %d images from source for %s", copied, arxiv_id)
return copied
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
return 0
async def _download_source_zip(
arxiv_id: str, source_url: str, dest_dir: Path
) -> None:
"""下载 arXiv 源码并解压。"""
import zipfile
dest_dir.mkdir(parents=True, exist_ok=True)
zip_path = _tmp_dir(arxiv_id) / "source.zip"
transport = None
if settings.http_proxy:
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
try:
async with httpx.AsyncClient(
timeout=settings.HTTP_TIMEOUT_SECONDS,
headers={"User-Agent": settings.HTTP_USER_AGENT},
transport=transport,
follow_redirects=True,
) as client:
resp = await client.get(source_url)
resp.raise_for_status()
zip_path.write_bytes(resp.content)
except Exception as exc:
logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
return
try:
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(dest_dir)
logger.debug("Extracted source for %s", arxiv_id)
except zipfile.BadZipFile:
# 可能是 tar.gz
import tarfile
try:
with tarfile.open(zip_path, "r:*") as tf:
tf.extractall(dest_dir)
logger.debug("Extracted source (tar) for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
finally:
if zip_path.exists():
zip_path.unlink()
# ── 单篇总结 ────────────────────────────────────────────────────────────
@@ -491,6 +173,8 @@ async def summarize_one(
force: bool = False,
) -> dict:
"""总结单篇论文的完整流程。"""
import asyncio
arxiv_id = paper.arxiv_id
# 获取或创建 summary_status
@@ -520,6 +204,8 @@ async def summarize_one(
async def _do_summarize_one(db: Session, paper: Paper) -> dict:
"""实际的单篇总结执行(在 semaphore 保护下)。"""
import asyncio
arxiv_id = paper.arxiv_id
status = paper.summary_status
now = datetime.now(timezone.utc)
@@ -532,16 +218,16 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
raw_output = ""
try:
# 写 meta.json
meta_path = _write_meta_json(paper)
meta_path = write_meta_json(paper)
# 下载 PDF
await _download_pdf(arxiv_id, paper.pdf_url)
await download_pdf(arxiv_id, paper.pdf_url)
# 调用 pi
raw_output = await _call_pi(meta_path, _tmp_dir(arxiv_id) / "paper.pdf")
raw_output = await call_pi(meta_path, Path("data/tmp") / arxiv_id / "paper.pdf")
# 提取 JSON
json_data = _extract_json(raw_output)
json_data = extract_json(raw_output)
# Pydantic 校验
schema = SummarySchema.model_validate(json_data)
@@ -564,7 +250,7 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
# Phase 5: LaTeX 图片提取(可选增强,失败不影响总结)
try:
await _extract_images_from_source(arxiv_id)
await extract_images_from_source(arxiv_id)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
@@ -625,7 +311,7 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
}
finally:
_cleanup_tmp(arxiv_id)
cleanup_tmp(arxiv_id)
# ── 单篇入口 ────────────────────────────────────────────────────────────
@@ -690,6 +376,8 @@ async def summarize_batch(
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
import asyncio
now = datetime.now(timezone.utc)
# TaskLock 防重入
@@ -741,7 +429,7 @@ async def summarize_batch(
log_entry.papers_found = 0
log_entry.papers_new = 0
log_entry.completed_at = datetime.now(timezone.utc)
_release_lock(db, lock)
release_lock(db, lock)
return {"status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0}
# 并发控制
@@ -813,15 +501,4 @@ async def summarize_batch(
return {"status": "failed", "error": str(exc)}
finally:
_release_lock(db, lock)
def _release_lock(db: Session, lock: TaskLock) -> None:
"""释放 TaskLock。"""
try:
lock.status = "finished"
lock.released_at = datetime.now(timezone.utc)
db.commit()
except Exception:
db.rollback()
logger.warning("Failed to release summarize lock", exc_info=True)
release_lock(db, lock)
+81
View File
@@ -0,0 +1,81 @@
"""趋势统计服务 — 按日论文数量、热门标签、Upvotes 分布、总结完成率。"""
from __future__ import annotations
from datetime import date, timedelta
from sqlalchemy import text
from sqlalchemy.orm import Session
def get_trends_data(db: Session) -> dict:
"""从 DB 聚合趋势数据。"""
thirty_days_ago = (date.today() - timedelta(days=30)).isoformat()
# 1. 按日论文数量(近 30 天)
daily_rows = db.execute(text("""
SELECT paper_date, COUNT(*) as cnt
FROM papers
WHERE paper_date >= :start_date
GROUP BY paper_date
ORDER BY paper_date ASC
"""), {"start_date": thirty_days_ago}).fetchall()
daily_counts = [
{"date": str(row[0]), "count": row[1]}
for row in daily_rows
]
# 2. 热门标签 Top 20
tag_rows = db.execute(text("""
SELECT tag, COUNT(*) as cnt
FROM paper_tags
GROUP BY tag
ORDER BY cnt DESC
LIMIT 20
""")).fetchall()
top_tags = [
{"tag": row[0], "count": row[1]}
for row in tag_rows
]
# 3. Upvotes 分布
upvote_rows = db.execute(text("""
SELECT
CASE
WHEN upvotes >= 100 THEN '100+'
WHEN upvotes >= 50 THEN '50-99'
WHEN upvotes >= 20 THEN '20-49'
WHEN upvotes >= 10 THEN '10-19'
WHEN upvotes >= 5 THEN '5-9'
ELSE '0-4'
END as bucket,
COUNT(*) as cnt
FROM papers
GROUP BY bucket
ORDER BY MIN(upvotes) DESC
""")).fetchall()
upvotes_dist = [
{"range": row[0], "count": row[1]}
for row in upvote_rows
]
# 4. 总结完成率
summary_rows = db.execute(text("""
SELECT
COALESCE(ss.status, 'none') as status,
COUNT(*) as cnt
FROM papers p
LEFT JOIN summary_status ss ON ss.paper_id = p.id
GROUP BY status
""")).fetchall()
summary_completion = [
{"status": row[0], "count": row[1]}
for row in summary_rows
]
return {
"daily_counts": daily_counts,
"top_tags": top_tags,
"upvotes_dist": upvotes_dist,
"summary_completion": summary_completion,
}
+48 -3
View File
@@ -1,12 +1,13 @@
"""用户数据服务 — 收藏、阅读状态、个人笔记。无账号体系,数据写入本地 SQLite。"""
"""用户数据服务 — 收藏、阅读状态、个人笔记、阅读列表查询。无账号体系,数据写入本地 SQLite。"""
from __future__ import annotations
from datetime import datetime, timezone
from sqlalchemy.orm import Session
from sqlalchemy import or_
from sqlalchemy.orm import Session, joinedload
from app.models import Paper, UserBookmark, UserNote, UserReadingStatus
from app.models import Paper, PaperTag, UserBookmark, UserNote, UserReadingStatus
# ── 收藏 ──────────────────────────────────────────────────────────────
@@ -113,3 +114,47 @@ def save_note(db: Session, arxiv_id: str, content: str) -> dict:
"content": content,
"updated_at": now.isoformat(),
}
# ── 阅读列表 ──────────────────────────────────────────────────────────
def query_reading_list(
db: Session,
filter_type: str,
tag: str | None,
) -> list[Paper]:
"""根据筛选条件查询阅读列表。"""
# 基础:有任意用户数据的论文
base = db.query(Paper).filter(
or_(
Paper.bookmark.has(),
Paper.reading_status.has(),
Paper.note.has(),
)
)
# 应用筛选
if filter_type == "has_note":
base = base.filter(Paper.note.has())
elif filter_type in ("unread", "skimmed", "read_summary", "read_full"):
base = base.filter(
Paper.reading_status.has(UserReadingStatus.status == filter_type)
)
# 应用标签
if tag:
base = base.filter(Paper.tags.any(PaperTag.tag == tag))
return (
base.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
joinedload(Paper.bookmark),
joinedload(Paper.reading_status),
joinedload(Paper.note),
)
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
.all()
)