refactor: restructure services and add image/pdf extraction utilities
- Add image_extractor, pdf_downloader, pi_client, trends services - Add shared utils module - Refactor summarizer, embedder, routes for cleaner separation - Update tests to match new service structure
This commit is contained in:
@@ -16,13 +16,10 @@ from app.models import (
|
||||
Paper,
|
||||
TaskLock,
|
||||
)
|
||||
from app.utils import PAPERS_DIR, TMP_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DATA_DIR = Path("data")
|
||||
_TMP_DIR = _DATA_DIR / "tmp"
|
||||
_PAPERS_DIR = _DATA_DIR / "papers"
|
||||
|
||||
# 临时文件最大保留时间(小时)
|
||||
_MAX_TMP_AGE_HOURS = 24
|
||||
|
||||
@@ -39,7 +36,7 @@ def cleanup_tmp(max_age_hours: int = _MAX_TMP_AGE_HOURS) -> dict:
|
||||
Returns:
|
||||
清理统计 {"scanned": int, "removed": int, "errors": list[str]}
|
||||
"""
|
||||
if not _TMP_DIR.exists():
|
||||
if not TMP_DIR.exists():
|
||||
return {"scanned": 0, "removed": 0, "errors": []}
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
@@ -48,7 +45,7 @@ def cleanup_tmp(max_age_hours: int = _MAX_TMP_AGE_HOURS) -> dict:
|
||||
removed = 0
|
||||
errors: list[str] = []
|
||||
|
||||
for entry in _TMP_DIR.iterdir():
|
||||
for entry in TMP_DIR.iterdir():
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
scanned += 1
|
||||
@@ -147,13 +144,13 @@ async def delete_papers_by_date_range(
|
||||
logger.warning("Failed to delete %s from ChromaDB", arxiv_id, exc_info=True)
|
||||
|
||||
# 2. 删除本地文件 data/papers/{arxiv_id}/
|
||||
paper_dir = _PAPERS_DIR / arxiv_id
|
||||
paper_dir = PAPERS_DIR / arxiv_id
|
||||
if paper_dir.exists():
|
||||
shutil.rmtree(paper_dir)
|
||||
logger.debug("Removed paper dir: %s", paper_dir)
|
||||
|
||||
# 3. 删除临时文件 data/tmp/{arxiv_id}/
|
||||
tmp_dir = _TMP_DIR / arxiv_id
|
||||
tmp_dir = TMP_DIR / arxiv_id
|
||||
if tmp_dir.exists():
|
||||
shutil.rmtree(tmp_dir)
|
||||
logger.debug("Removed tmp dir: %s", tmp_dir)
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.models import (
|
||||
PaperTag,
|
||||
SummaryStatus,
|
||||
)
|
||||
from app.utils import make_http_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,15 +35,7 @@ async def fetch_daily(target_date: str, top_n: int | None = None) -> list[dict]:
|
||||
url = f"{settings.HF_API_BASE}/daily_papers"
|
||||
params = {"date": target_date}
|
||||
|
||||
transport = None
|
||||
if settings.http_proxy:
|
||||
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=settings.HTTP_TIMEOUT_SECONDS,
|
||||
headers={"User-Agent": settings.HTTP_USER_AGENT},
|
||||
transport=transport,
|
||||
) as client:
|
||||
async with make_http_client() as client:
|
||||
for attempt in range(1, settings.HTTP_MAX_RETRIES + 1):
|
||||
try:
|
||||
logger.info("Fetching HF Daily Papers: date=%s attempt=%d", target_date, attempt)
|
||||
|
||||
+70
-54
@@ -5,8 +5,6 @@ from __future__ import annotations
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.config import settings
|
||||
@@ -14,66 +12,82 @@ from app.models import Paper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── 单例客户端和 collection ─────────────────────────────────────────────
|
||||
_client = None
|
||||
_collection = None
|
||||
|
||||
# ── ChromaDB 管理器(替代全局可变状态)──────────────────────────────────
|
||||
|
||||
|
||||
def _chroma_dir() -> Path:
|
||||
return Path(settings.CHROMA_DIR)
|
||||
class ChromaManager:
|
||||
"""封装 ChromaDB 客户端和 collection 的生命周期。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._client = None
|
||||
self._collection = None
|
||||
|
||||
def init(self) -> None:
|
||||
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
|
||||
if not settings.CHROMA_ENABLED:
|
||||
logger.debug("ChromaDB disabled, skip init")
|
||||
return
|
||||
|
||||
if self._client is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
|
||||
chroma_path = Path(settings.CHROMA_DIR)
|
||||
chroma_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._client = chromadb.PersistentClient(path=str(chroma_path))
|
||||
self._collection = self._get_or_create_collection()
|
||||
logger.info("ChromaDB initialized at %s", chroma_path)
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize ChromaDB")
|
||||
self._client = None
|
||||
self._collection = None
|
||||
|
||||
def _get_or_create_collection(self):
|
||||
"""获取或创建 papers_embeddings collection。"""
|
||||
try:
|
||||
col = self._client.get_collection("papers_embeddings")
|
||||
logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count())
|
||||
return col
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
col = self._client.create_collection(
|
||||
name="papers_embeddings",
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
)
|
||||
logger.info("ChromaDB collection 'papers_embeddings' created")
|
||||
return col
|
||||
|
||||
def get_collection(self):
|
||||
"""返回当前 collection,未初始化则自动初始化。"""
|
||||
if not settings.CHROMA_ENABLED:
|
||||
return None
|
||||
if self._collection is None:
|
||||
self.init()
|
||||
return self._collection
|
||||
|
||||
def reset(self) -> None:
|
||||
"""重置状态(供测试使用)。"""
|
||||
self._client = None
|
||||
self._collection = None
|
||||
|
||||
|
||||
# 模块级单例
|
||||
_chroma = ChromaManager()
|
||||
|
||||
|
||||
def init_chroma() -> None:
|
||||
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
|
||||
global _client, _collection
|
||||
if not settings.CHROMA_ENABLED:
|
||||
logger.debug("ChromaDB disabled, skip init")
|
||||
return
|
||||
|
||||
if _client is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
|
||||
chroma_path = _chroma_dir()
|
||||
chroma_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_client = chromadb.PersistentClient(path=str(chroma_path))
|
||||
_collection = _get_or_create_collection()
|
||||
logger.info("ChromaDB initialized at %s", chroma_path)
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize ChromaDB")
|
||||
_client = None
|
||||
_collection = None
|
||||
|
||||
|
||||
def _get_or_create_collection():
|
||||
"""获取或创建 papers_embeddings collection,维度不匹配时记录日志并跳过。"""
|
||||
import chromadb
|
||||
|
||||
try:
|
||||
col = _client.get_collection("papers_embeddings")
|
||||
logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count())
|
||||
return col
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
col = _client.create_collection(
|
||||
name="papers_embeddings",
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
)
|
||||
logger.info("ChromaDB collection 'papers_embeddings' created")
|
||||
return col
|
||||
"""初始化 ChromaDB(供 lifespan 调用)。"""
|
||||
_chroma.init()
|
||||
|
||||
|
||||
def get_collection():
|
||||
"""返回当前 collection,未初始化则返回 None。"""
|
||||
if not settings.CHROMA_ENABLED:
|
||||
return None
|
||||
if _collection is None:
|
||||
init_chroma()
|
||||
return _collection
|
||||
return _chroma.get_collection()
|
||||
|
||||
|
||||
# ── Embedding API 调用 ──────────────────────────────────────────────────
|
||||
@@ -90,6 +104,8 @@ def _get_embedding(text: str) -> list[float] | None:
|
||||
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
|
||||
return None
|
||||
|
||||
from app.utils import make_http_client
|
||||
|
||||
url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if settings.EMBED_API_KEY:
|
||||
@@ -101,7 +117,7 @@ def _get_embedding(text: str) -> list[float] | None:
|
||||
}
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=settings.HTTP_TIMEOUT_SECONDS) as client:
|
||||
with make_http_client(sync=True) as client:
|
||||
resp = client.post(url, json=payload, headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
"""LaTeX 图片提取 — 从 arXiv 源码中扫描 \\includegraphics 并提取图片文件。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from app.services.pdf_downloader import download_source_zip, paper_dir, tmp_dir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_INCLUDEGRAPHICS_RE = re.compile(
|
||||
r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
|
||||
)
|
||||
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
|
||||
|
||||
|
||||
async def extract_images_from_source(arxiv_id: str) -> int:
|
||||
"""从 LaTeX 源码中提取图片文件。
|
||||
|
||||
流程:
|
||||
1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
|
||||
2. 扫描 .tex 文件中的 \\includegraphics
|
||||
3. 复制图片到 data/papers/{arxiv_id}/images/
|
||||
4. 清理源码临时文件
|
||||
|
||||
Returns:
|
||||
提取的图片数量
|
||||
"""
|
||||
tmp_source = tmp_dir(arxiv_id) / "source"
|
||||
images_dest = paper_dir(arxiv_id) / "images"
|
||||
|
||||
try:
|
||||
# 下载源码 zip(如果还没下载)
|
||||
if not tmp_source.exists():
|
||||
source_url = f"https://arxiv.org/e-print/{arxiv_id}"
|
||||
await download_source_zip(arxiv_id, source_url, tmp_source)
|
||||
|
||||
if not tmp_source.exists():
|
||||
return 0
|
||||
|
||||
# 扫描 .tex 文件,收集图片路径
|
||||
image_paths: set[str] = set()
|
||||
for tex_file in tmp_source.rglob("*.tex"):
|
||||
try:
|
||||
content = tex_file.read_text(encoding="utf-8", errors="replace")
|
||||
for match in _INCLUDEGRAPHICS_RE.finditer(content):
|
||||
img_path = match.group(1).strip()
|
||||
image_paths.add(img_path)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not image_paths:
|
||||
return 0
|
||||
|
||||
# 查找并复制图片
|
||||
images_dest.mkdir(parents=True, exist_ok=True)
|
||||
copied = 0
|
||||
for img_rel in image_paths:
|
||||
# 尝试在源码目录中找到文件
|
||||
for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
|
||||
candidate = tmp_source / (img_rel + ext)
|
||||
if candidate.is_file():
|
||||
dest_name = candidate.name
|
||||
# 避免文件名冲突
|
||||
dest = images_dest / dest_name
|
||||
if dest.exists():
|
||||
stem = dest.stem
|
||||
suffix = dest.suffix
|
||||
dest = images_dest / f"{stem}_{copied}{suffix}"
|
||||
shutil.copy2(candidate, dest)
|
||||
copied += 1
|
||||
break
|
||||
|
||||
if copied > 0:
|
||||
logger.info("Extracted %d images from source for %s", copied, arxiv_id)
|
||||
return copied
|
||||
|
||||
except Exception:
|
||||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||||
return 0
|
||||
@@ -0,0 +1,105 @@
|
||||
"""PDF 下载与源码下载 — 从 arXiv 下载论文 PDF 和 LaTeX 源码包。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
from app.utils import PAPERS_DIR, TMP_DIR, make_http_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── 自定义异常 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class PdfDownloadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# ── 路径工具 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def paper_dir(arxiv_id: str) -> Path:
|
||||
return PAPERS_DIR / arxiv_id
|
||||
|
||||
|
||||
def tmp_dir(arxiv_id: str) -> Path:
|
||||
return TMP_DIR / arxiv_id
|
||||
|
||||
|
||||
# ── PDF 下载 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
|
||||
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
|
||||
if not pdf_url:
|
||||
raise PdfDownloadError(f"no pdf_url for {arxiv_id}")
|
||||
|
||||
dest_dir = tmp_dir(arxiv_id)
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest = dest_dir / "paper.pdf"
|
||||
|
||||
try:
|
||||
async with make_http_client(follow_redirects=True) as client:
|
||||
resp = await client.get(pdf_url)
|
||||
resp.raise_for_status()
|
||||
dest.write_bytes(resp.content)
|
||||
except Exception as exc:
|
||||
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
|
||||
|
||||
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
|
||||
return dest
|
||||
|
||||
|
||||
# ── 源码下载 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def download_source_zip(arxiv_id: str, source_url: str, dest_dir: Path) -> None:
|
||||
"""下载 arXiv 源码并解压。"""
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
zip_path = tmp_dir(arxiv_id) / "source.zip"
|
||||
|
||||
try:
|
||||
async with make_http_client(follow_redirects=True) as client:
|
||||
resp = await client.get(source_url)
|
||||
resp.raise_for_status()
|
||||
zip_path.write_bytes(resp.content)
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
|
||||
return
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
zf.extractall(dest_dir)
|
||||
logger.debug("Extracted source for %s", arxiv_id)
|
||||
except zipfile.BadZipFile:
|
||||
# 可能是 tar.gz
|
||||
import tarfile
|
||||
try:
|
||||
with tarfile.open(zip_path, "r:*") as tf:
|
||||
tf.extractall(dest_dir, filter="data")
|
||||
logger.debug("Extracted source (tar) for %s", arxiv_id)
|
||||
except Exception:
|
||||
logger.warning("Cannot extract source for %s", arxiv_id)
|
||||
except Exception:
|
||||
logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
|
||||
finally:
|
||||
if zip_path.exists():
|
||||
zip_path.unlink()
|
||||
|
||||
|
||||
# ── 临时文件清理 ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def cleanup_tmp(arxiv_id: str) -> None:
|
||||
"""清理 data/tmp/{arxiv_id}/ 目录。"""
|
||||
td = tmp_dir(arxiv_id)
|
||||
if td.exists():
|
||||
try:
|
||||
shutil.rmtree(td)
|
||||
logger.debug("Cleaned tmp: %s", arxiv_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
|
||||
@@ -0,0 +1,160 @@
|
||||
"""pi CLI 调用与 JSON 提取 — 调用 pi 生成总结,从输出中提取结构化 JSON。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── 自定义异常 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class PiTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PiProcessError(Exception):
|
||||
def __init__(self, returncode: int, stderr: str):
|
||||
self.returncode = returncode
|
||||
self.stderr = stderr
|
||||
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
|
||||
|
||||
|
||||
class JsonNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# ── meta.json ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def write_meta_json(paper) -> Path:
|
||||
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
|
||||
from app.services.pdf_downloader import paper_dir
|
||||
|
||||
d = paper_dir(paper.arxiv_id)
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
meta_path = d / "meta.json"
|
||||
|
||||
authors = [a.name for a in paper.authors]
|
||||
tags = [t.tag for t in paper.tags]
|
||||
meta = {
|
||||
"arxiv_id": paper.arxiv_id,
|
||||
"title_en": paper.title_en,
|
||||
"abstract": paper.abstract or "",
|
||||
"published_at": paper.published_at.isoformat() if paper.published_at else None,
|
||||
"authors": authors,
|
||||
"tags": tags,
|
||||
"upvotes": paper.upvotes,
|
||||
}
|
||||
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
return meta_path
|
||||
|
||||
|
||||
# ── pi CLI 调用 ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def call_pi(meta_path: Path, pdf_path: Path) -> str:
|
||||
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
|
||||
arxiv_id = meta_path.parent.name
|
||||
cmd = [
|
||||
settings.PI_BIN,
|
||||
"-p",
|
||||
"--no-tools",
|
||||
"--skill",
|
||||
settings.SUMMARY_SKILL,
|
||||
"请深度解读以下论文,并按指定 JSON schema 输出:",
|
||||
f"@{meta_path}",
|
||||
f"@{pdf_path}",
|
||||
]
|
||||
logger.info("Calling pi for %s", arxiv_id)
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
proc.communicate(),
|
||||
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
raise PiTimeoutError(
|
||||
f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s"
|
||||
)
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
|
||||
|
||||
return stdout.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
# ── JSON 提取 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def extract_json(raw_output: str) -> dict:
|
||||
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
|
||||
# 策略 1:整体直接解析
|
||||
stripped = raw_output.strip()
|
||||
try:
|
||||
result = json.loads(stripped)
|
||||
if isinstance(result, dict) and "title_zh" in result:
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 策略 2:提取 ```json ... ``` 代码块
|
||||
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
|
||||
for match in fence_pattern.finditer(raw_output):
|
||||
try:
|
||||
result = json.loads(match.group(1).strip())
|
||||
if isinstance(result, dict) and "title_zh" in result:
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# 策略 3:匹配包含 title_zh 的最大 {...} 块
|
||||
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
|
||||
for match in brace_pattern.finditer(raw_output):
|
||||
try:
|
||||
return json.loads(match.group(0))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# 更宽松:找到最大的 { ... } 平衡块
|
||||
best = None
|
||||
best_len = 0
|
||||
for i, ch in enumerate(raw_output):
|
||||
if ch != "{":
|
||||
continue
|
||||
depth = 0
|
||||
for j in range(i, len(raw_output)):
|
||||
if raw_output[j] == "{":
|
||||
depth += 1
|
||||
elif raw_output[j] == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
candidate = raw_output[i : j + 1]
|
||||
if len(candidate) > best_len:
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
if isinstance(parsed, dict):
|
||||
best = parsed
|
||||
best_len = len(candidate)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
break
|
||||
|
||||
if best is not None:
|
||||
return best
|
||||
|
||||
raise JsonNotFoundError("no JSON object found in pi output")
|
||||
+39
-362
@@ -1,18 +1,15 @@
|
||||
"""AI 总结服务 — 调用 pi CLI 生成论文中文结构化总结。"""
|
||||
"""AI 总结编排服务 — 协调 PDF 下载、pi CLI 调用、JSON 校验、DB 写入、语义索引。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.config import settings
|
||||
@@ -25,216 +22,31 @@ from app.models import (
|
||||
SummaryStatus,
|
||||
TaskLock,
|
||||
)
|
||||
from app.services.image_extractor import extract_images_from_source
|
||||
from app.services.pdf_downloader import (
|
||||
PdfDownloadError,
|
||||
cleanup_tmp,
|
||||
download_pdf,
|
||||
paper_dir,
|
||||
)
|
||||
from app.services.pi_client import (
|
||||
JsonNotFoundError,
|
||||
PiProcessError,
|
||||
PiTimeoutError,
|
||||
call_pi,
|
||||
extract_json,
|
||||
write_meta_json,
|
||||
)
|
||||
from app.services.schemas import (
|
||||
SummarySchema,
|
||||
assess_quality,
|
||||
classify_validation_error,
|
||||
flatten_for_db,
|
||||
)
|
||||
from app.utils import PAPERS_DIR, release_lock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── 自定义异常 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class PdfDownloadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PiTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PiProcessError(Exception):
|
||||
def __init__(self, returncode: int, stderr: str):
|
||||
self.returncode = returncode
|
||||
self.stderr = stderr
|
||||
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
|
||||
|
||||
|
||||
class JsonNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# ── 路径工具 ────────────────────────────────────────────────────────────
|
||||
|
||||
_DATA_DIR = Path("data")
|
||||
_PAPERS_DIR = _DATA_DIR / "papers"
|
||||
_TMP_DIR = _DATA_DIR / "tmp"
|
||||
|
||||
|
||||
def _paper_dir(arxiv_id: str) -> Path:
|
||||
return _PAPERS_DIR / arxiv_id
|
||||
|
||||
|
||||
def _tmp_dir(arxiv_id: str) -> Path:
|
||||
return _TMP_DIR / arxiv_id
|
||||
|
||||
|
||||
# ── PDF 下载 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _download_pdf(arxiv_id: str, pdf_url: str) -> Path:
|
||||
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
|
||||
if not pdf_url:
|
||||
raise PdfDownloadError(f"no pdf_url for {arxiv_id}")
|
||||
|
||||
tmp = _tmp_dir(arxiv_id)
|
||||
tmp.mkdir(parents=True, exist_ok=True)
|
||||
dest = tmp / "paper.pdf"
|
||||
|
||||
transport = None
|
||||
if settings.http_proxy:
|
||||
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=settings.HTTP_TIMEOUT_SECONDS,
|
||||
headers={"User-Agent": settings.HTTP_USER_AGENT},
|
||||
transport=transport,
|
||||
follow_redirects=True,
|
||||
) as client:
|
||||
resp = await client.get(pdf_url)
|
||||
resp.raise_for_status()
|
||||
dest.write_bytes(resp.content)
|
||||
except Exception as exc:
|
||||
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
|
||||
|
||||
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
|
||||
return dest
|
||||
|
||||
|
||||
# ── meta.json ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _write_meta_json(paper: Paper) -> Path:
|
||||
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
|
||||
d = _paper_dir(paper.arxiv_id)
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
meta_path = d / "meta.json"
|
||||
|
||||
authors = [a.name for a in paper.authors]
|
||||
tags = [t.tag for t in paper.tags]
|
||||
meta = {
|
||||
"arxiv_id": paper.arxiv_id,
|
||||
"title_en": paper.title_en,
|
||||
"abstract": paper.abstract or "",
|
||||
"published_at": paper.published_at.isoformat() if paper.published_at else None,
|
||||
"authors": authors,
|
||||
"tags": tags,
|
||||
"upvotes": paper.upvotes,
|
||||
}
|
||||
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
return meta_path
|
||||
|
||||
|
||||
# ── pi CLI 调用 ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _call_pi(meta_path: Path, pdf_path: Path) -> str:
|
||||
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
|
||||
cmd = [
|
||||
settings.PI_BIN,
|
||||
"-p",
|
||||
"--no-tools",
|
||||
"--skill",
|
||||
settings.SUMMARY_SKILL,
|
||||
"请深度解读以下论文,并按指定 JSON schema 输出:",
|
||||
f"@{meta_path}",
|
||||
f"@{pdf_path}",
|
||||
]
|
||||
logger.info("Calling pi: %s %s", paper_id_from_path(meta_path), " ".join(cmd[:4]))
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
proc.communicate(),
|
||||
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
raise PiTimeoutError(
|
||||
f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s"
|
||||
)
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
|
||||
|
||||
return stdout.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def paper_id_from_path(meta_path: Path) -> str:
|
||||
"""从 meta.json 路径反推 arxiv_id。"""
|
||||
return meta_path.parent.name
|
||||
|
||||
|
||||
# ── JSON 提取 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _extract_json(raw_output: str) -> dict:
|
||||
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
|
||||
# 策略 1:整体直接解析
|
||||
stripped = raw_output.strip()
|
||||
try:
|
||||
result = json.loads(stripped)
|
||||
if isinstance(result, dict) and "title_zh" in result:
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 策略 2:提取 ```json ... ``` 代码块
|
||||
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
|
||||
for match in fence_pattern.finditer(raw_output):
|
||||
try:
|
||||
result = json.loads(match.group(1).strip())
|
||||
if isinstance(result, dict) and "title_zh" in result:
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# 策略 3:匹配包含 title_zh 的最大 {...} 块
|
||||
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
|
||||
# 先尝试一层嵌套;如果没命中再用更宽松的策略
|
||||
for match in brace_pattern.finditer(raw_output):
|
||||
try:
|
||||
return json.loads(match.group(0))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# 更宽松:找到最大的 { ... } 平衡块
|
||||
best = None
|
||||
best_len = 0
|
||||
for i, ch in enumerate(raw_output):
|
||||
if ch != "{":
|
||||
continue
|
||||
depth = 0
|
||||
for j in range(i, len(raw_output)):
|
||||
if raw_output[j] == "{":
|
||||
depth += 1
|
||||
elif raw_output[j] == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
candidate = raw_output[i : j + 1]
|
||||
if len(candidate) > best_len:
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
if isinstance(parsed, dict):
|
||||
best = parsed
|
||||
best_len = len(candidate)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
break
|
||||
|
||||
if best is not None:
|
||||
return best
|
||||
|
||||
raise JsonNotFoundError("no JSON object found in pi output")
|
||||
|
||||
|
||||
# ── 错误分类 ────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -284,6 +96,8 @@ def _update_summary_in_db(
|
||||
raw_output: str,
|
||||
) -> None:
|
||||
"""将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。"""
|
||||
from sqlalchemy import text
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# 1. paper_summaries:upsert
|
||||
@@ -298,9 +112,9 @@ def _update_summary_in_db(
|
||||
# 2. papers 表
|
||||
paper.title_zh = schema.title_zh
|
||||
paper.summary_quality = quality
|
||||
paper_dir = _paper_dir(paper.arxiv_id)
|
||||
paper.summary_path = str(paper_dir / "summary.json")
|
||||
paper.raw_output_path = str(paper_dir / "raw_output.txt")
|
||||
p_dir = paper_dir(paper.arxiv_id)
|
||||
paper.summary_path = str(p_dir / "summary.json")
|
||||
paper.raw_output_path = str(p_dir / "raw_output.txt")
|
||||
|
||||
# 3. AI 标签
|
||||
existing_tag_names = {t.tag for t in paper.tags}
|
||||
@@ -332,7 +146,7 @@ def _update_summary_in_db(
|
||||
|
||||
def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
|
||||
"""保存 summary.json 和 raw_output.txt。"""
|
||||
d = _paper_dir(arxiv_id)
|
||||
d = paper_dir(arxiv_id)
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
(d / "summary.json").write_text(
|
||||
schema.model_dump_json(ensure_ascii=False, indent=2),
|
||||
@@ -343,143 +157,11 @@ def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
|
||||
|
||||
def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None:
|
||||
"""仅保存 raw_output.txt(失败时)。"""
|
||||
d = _paper_dir(arxiv_id)
|
||||
d = paper_dir(arxiv_id)
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
|
||||
|
||||
|
||||
def _cleanup_tmp(arxiv_id: str) -> None:
|
||||
"""清理 data/tmp/{arxiv_id}/ 目录。"""
|
||||
tmp = _tmp_dir(arxiv_id)
|
||||
if tmp.exists():
|
||||
try:
|
||||
shutil.rmtree(tmp)
|
||||
logger.debug("Cleaned tmp: %s", arxiv_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
|
||||
|
||||
|
||||
# ── LaTeX 图片提取(Phase 5)───────────────────────────────────────────
|
||||
|
||||
_INCLUDEGRAPHICS_RE = re.compile(
|
||||
r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
|
||||
)
|
||||
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
|
||||
|
||||
|
||||
async def _extract_images_from_source(arxiv_id: str, tmp_source: Path | None = None) -> int:
|
||||
"""从 LaTeX 源码中提取图片文件。
|
||||
|
||||
流程:
|
||||
1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
|
||||
2. 扫描 .tex 文件中的 \\includegraphics
|
||||
3. 复制图片到 data/papers/{arxiv_id}/images/
|
||||
4. 清理源码临时文件
|
||||
|
||||
Returns:
|
||||
提取的图片数量
|
||||
"""
|
||||
tmp_source = _tmp_dir(arxiv_id) / "source"
|
||||
images_dest = _paper_dir(arxiv_id) / "images"
|
||||
|
||||
try:
|
||||
# 下载源码 zip(如果还没下载)
|
||||
if not tmp_source.exists():
|
||||
source_url = f"https://arxiv.org/e-print/{arxiv_id}"
|
||||
await _download_source_zip(arxiv_id, source_url, tmp_source)
|
||||
|
||||
if not tmp_source.exists():
|
||||
return 0
|
||||
|
||||
# 扫描 .tex 文件,收集图片路径
|
||||
image_paths: set[str] = set()
|
||||
for tex_file in tmp_source.rglob("*.tex"):
|
||||
try:
|
||||
content = tex_file.read_text(encoding="utf-8", errors="replace")
|
||||
for match in _INCLUDEGRAPHICS_RE.finditer(content):
|
||||
img_path = match.group(1).strip()
|
||||
image_paths.add(img_path)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not image_paths:
|
||||
return 0
|
||||
|
||||
# 查找并复制图片
|
||||
images_dest.mkdir(parents=True, exist_ok=True)
|
||||
copied = 0
|
||||
for img_rel in image_paths:
|
||||
# 尝试在源码目录中找到文件
|
||||
for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
|
||||
candidate = tmp_source / (img_rel + ext)
|
||||
if candidate.is_file():
|
||||
dest_name = candidate.name
|
||||
# 避免文件名冲突
|
||||
dest = images_dest / dest_name
|
||||
if dest.exists():
|
||||
stem = dest.stem
|
||||
suffix = dest.suffix
|
||||
dest = images_dest / f"{stem}_{copied}{suffix}"
|
||||
shutil.copy2(candidate, dest)
|
||||
copied += 1
|
||||
break
|
||||
|
||||
if copied > 0:
|
||||
logger.info("Extracted %d images from source for %s", copied, arxiv_id)
|
||||
return copied
|
||||
|
||||
except Exception:
|
||||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||||
return 0
|
||||
|
||||
|
||||
async def _download_source_zip(
|
||||
arxiv_id: str, source_url: str, dest_dir: Path
|
||||
) -> None:
|
||||
"""下载 arXiv 源码并解压。"""
|
||||
import zipfile
|
||||
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
zip_path = _tmp_dir(arxiv_id) / "source.zip"
|
||||
|
||||
transport = None
|
||||
if settings.http_proxy:
|
||||
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=settings.HTTP_TIMEOUT_SECONDS,
|
||||
headers={"User-Agent": settings.HTTP_USER_AGENT},
|
||||
transport=transport,
|
||||
follow_redirects=True,
|
||||
) as client:
|
||||
resp = await client.get(source_url)
|
||||
resp.raise_for_status()
|
||||
zip_path.write_bytes(resp.content)
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
|
||||
return
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
zf.extractall(dest_dir)
|
||||
logger.debug("Extracted source for %s", arxiv_id)
|
||||
except zipfile.BadZipFile:
|
||||
# 可能是 tar.gz
|
||||
import tarfile
|
||||
try:
|
||||
with tarfile.open(zip_path, "r:*") as tf:
|
||||
tf.extractall(dest_dir)
|
||||
logger.debug("Extracted source (tar) for %s", arxiv_id)
|
||||
except Exception:
|
||||
logger.warning("Cannot extract source for %s", arxiv_id)
|
||||
except Exception:
|
||||
logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
|
||||
finally:
|
||||
if zip_path.exists():
|
||||
zip_path.unlink()
|
||||
|
||||
|
||||
# ── 单篇总结 ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -491,6 +173,8 @@ async def summarize_one(
|
||||
force: bool = False,
|
||||
) -> dict:
|
||||
"""总结单篇论文的完整流程。"""
|
||||
import asyncio
|
||||
|
||||
arxiv_id = paper.arxiv_id
|
||||
|
||||
# 获取或创建 summary_status
|
||||
@@ -520,6 +204,8 @@ async def summarize_one(
|
||||
|
||||
async def _do_summarize_one(db: Session, paper: Paper) -> dict:
|
||||
"""实际的单篇总结执行(在 semaphore 保护下)。"""
|
||||
import asyncio
|
||||
|
||||
arxiv_id = paper.arxiv_id
|
||||
status = paper.summary_status
|
||||
now = datetime.now(timezone.utc)
|
||||
@@ -532,16 +218,16 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
|
||||
raw_output = ""
|
||||
try:
|
||||
# 写 meta.json
|
||||
meta_path = _write_meta_json(paper)
|
||||
meta_path = write_meta_json(paper)
|
||||
|
||||
# 下载 PDF
|
||||
await _download_pdf(arxiv_id, paper.pdf_url)
|
||||
await download_pdf(arxiv_id, paper.pdf_url)
|
||||
|
||||
# 调用 pi
|
||||
raw_output = await _call_pi(meta_path, _tmp_dir(arxiv_id) / "paper.pdf")
|
||||
raw_output = await call_pi(meta_path, Path("data/tmp") / arxiv_id / "paper.pdf")
|
||||
|
||||
# 提取 JSON
|
||||
json_data = _extract_json(raw_output)
|
||||
json_data = extract_json(raw_output)
|
||||
|
||||
# Pydantic 校验
|
||||
schema = SummarySchema.model_validate(json_data)
|
||||
@@ -564,7 +250,7 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
|
||||
|
||||
# Phase 5: LaTeX 图片提取(可选增强,失败不影响总结)
|
||||
try:
|
||||
await _extract_images_from_source(arxiv_id)
|
||||
await extract_images_from_source(arxiv_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||||
|
||||
@@ -625,7 +311,7 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
|
||||
}
|
||||
|
||||
finally:
|
||||
_cleanup_tmp(arxiv_id)
|
||||
cleanup_tmp(arxiv_id)
|
||||
|
||||
|
||||
# ── 单篇入口 ────────────────────────────────────────────────────────────
|
||||
@@ -690,6 +376,8 @@ async def summarize_batch(
|
||||
|
||||
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# TaskLock 防重入
|
||||
@@ -741,7 +429,7 @@ async def summarize_batch(
|
||||
log_entry.papers_found = 0
|
||||
log_entry.papers_new = 0
|
||||
log_entry.completed_at = datetime.now(timezone.utc)
|
||||
_release_lock(db, lock)
|
||||
release_lock(db, lock)
|
||||
return {"status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0}
|
||||
|
||||
# 并发控制
|
||||
@@ -813,15 +501,4 @@ async def summarize_batch(
|
||||
return {"status": "failed", "error": str(exc)}
|
||||
|
||||
finally:
|
||||
_release_lock(db, lock)
|
||||
|
||||
|
||||
def _release_lock(db: Session, lock: TaskLock) -> None:
|
||||
"""释放 TaskLock。"""
|
||||
try:
|
||||
lock.status = "finished"
|
||||
lock.released_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
logger.warning("Failed to release summarize lock", exc_info=True)
|
||||
release_lock(db, lock)
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
"""趋势统计服务 — 按日论文数量、热门标签、Upvotes 分布、总结完成率。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, timedelta
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
def get_trends_data(db: Session) -> dict:
|
||||
"""从 DB 聚合趋势数据。"""
|
||||
thirty_days_ago = (date.today() - timedelta(days=30)).isoformat()
|
||||
|
||||
# 1. 按日论文数量(近 30 天)
|
||||
daily_rows = db.execute(text("""
|
||||
SELECT paper_date, COUNT(*) as cnt
|
||||
FROM papers
|
||||
WHERE paper_date >= :start_date
|
||||
GROUP BY paper_date
|
||||
ORDER BY paper_date ASC
|
||||
"""), {"start_date": thirty_days_ago}).fetchall()
|
||||
daily_counts = [
|
||||
{"date": str(row[0]), "count": row[1]}
|
||||
for row in daily_rows
|
||||
]
|
||||
|
||||
# 2. 热门标签 Top 20
|
||||
tag_rows = db.execute(text("""
|
||||
SELECT tag, COUNT(*) as cnt
|
||||
FROM paper_tags
|
||||
GROUP BY tag
|
||||
ORDER BY cnt DESC
|
||||
LIMIT 20
|
||||
""")).fetchall()
|
||||
top_tags = [
|
||||
{"tag": row[0], "count": row[1]}
|
||||
for row in tag_rows
|
||||
]
|
||||
|
||||
# 3. Upvotes 分布
|
||||
upvote_rows = db.execute(text("""
|
||||
SELECT
|
||||
CASE
|
||||
WHEN upvotes >= 100 THEN '100+'
|
||||
WHEN upvotes >= 50 THEN '50-99'
|
||||
WHEN upvotes >= 20 THEN '20-49'
|
||||
WHEN upvotes >= 10 THEN '10-19'
|
||||
WHEN upvotes >= 5 THEN '5-9'
|
||||
ELSE '0-4'
|
||||
END as bucket,
|
||||
COUNT(*) as cnt
|
||||
FROM papers
|
||||
GROUP BY bucket
|
||||
ORDER BY MIN(upvotes) DESC
|
||||
""")).fetchall()
|
||||
upvotes_dist = [
|
||||
{"range": row[0], "count": row[1]}
|
||||
for row in upvote_rows
|
||||
]
|
||||
|
||||
# 4. 总结完成率
|
||||
summary_rows = db.execute(text("""
|
||||
SELECT
|
||||
COALESCE(ss.status, 'none') as status,
|
||||
COUNT(*) as cnt
|
||||
FROM papers p
|
||||
LEFT JOIN summary_status ss ON ss.paper_id = p.id
|
||||
GROUP BY status
|
||||
""")).fetchall()
|
||||
summary_completion = [
|
||||
{"status": row[0], "count": row[1]}
|
||||
for row in summary_rows
|
||||
]
|
||||
|
||||
return {
|
||||
"daily_counts": daily_counts,
|
||||
"top_tags": top_tags,
|
||||
"upvotes_dist": upvotes_dist,
|
||||
"summary_completion": summary_completion,
|
||||
}
|
||||
@@ -1,12 +1,13 @@
|
||||
"""用户数据服务 — 收藏、阅读状态、个人笔记。无账号体系,数据写入本地 SQLite。"""
|
||||
"""用户数据服务 — 收藏、阅读状态、个人笔记、阅读列表查询。无账号体系,数据写入本地 SQLite。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.models import Paper, UserBookmark, UserNote, UserReadingStatus
|
||||
from app.models import Paper, PaperTag, UserBookmark, UserNote, UserReadingStatus
|
||||
|
||||
# ── 收藏 ──────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -113,3 +114,47 @@ def save_note(db: Session, arxiv_id: str, content: str) -> dict:
|
||||
"content": content,
|
||||
"updated_at": now.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ── 阅读列表 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def query_reading_list(
|
||||
db: Session,
|
||||
filter_type: str,
|
||||
tag: str | None,
|
||||
) -> list[Paper]:
|
||||
"""根据筛选条件查询阅读列表。"""
|
||||
# 基础:有任意用户数据的论文
|
||||
base = db.query(Paper).filter(
|
||||
or_(
|
||||
Paper.bookmark.has(),
|
||||
Paper.reading_status.has(),
|
||||
Paper.note.has(),
|
||||
)
|
||||
)
|
||||
|
||||
# 应用筛选
|
||||
if filter_type == "has_note":
|
||||
base = base.filter(Paper.note.has())
|
||||
elif filter_type in ("unread", "skimmed", "read_summary", "read_full"):
|
||||
base = base.filter(
|
||||
Paper.reading_status.has(UserReadingStatus.status == filter_type)
|
||||
)
|
||||
|
||||
# 应用标签
|
||||
if tag:
|
||||
base = base.filter(Paper.tags.any(PaperTag.tag == tag))
|
||||
|
||||
return (
|
||||
base.options(
|
||||
joinedload(Paper.authors),
|
||||
joinedload(Paper.tags),
|
||||
joinedload(Paper.summary_status),
|
||||
joinedload(Paper.bookmark),
|
||||
joinedload(Paper.reading_status),
|
||||
joinedload(Paper.note),
|
||||
)
|
||||
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user