refactor: restructure services and add image/pdf extraction utilities

- Add image_extractor, pdf_downloader, pi_client, trends services
- Add shared utils module
- Refactor summarizer, embedder, routes for cleaner separation
- Update tests to match new service structure
This commit is contained in:
2026-06-06 00:00:55 +08:00
parent ba9afa212c
commit 85c4cfb9e8
22 changed files with 843 additions and 780 deletions
+33 -1
View File
@@ -1,6 +1,6 @@
"""数据库引擎、会话工厂、初始化。"""
from sqlalchemy import event, create_engine
from sqlalchemy import event, create_engine, text
from sqlalchemy.orm import DeclarativeBase, sessionmaker
from app.config import settings
@@ -10,6 +10,27 @@ class Base(DeclarativeBase):
pass
# ── FTS5 和索引 DDL(与 ORM 模型分开管理)───────────────────────────────
FTS5_CREATE_SQL = """
CREATE VIRTUAL TABLE IF NOT EXISTS papers_fts USING fts5(
title_en,
title_zh,
abstract,
authors,
tags,
summary_text,
tokenize='unicode61'
);
"""
FTS5_TRIGGER_INDEX = """
-- partial index for task_locks running
CREATE UNIQUE INDEX IF NOT EXISTS uq_task_locks_running
ON task_locks(task, lock_key) WHERE status = 'running';
"""
def _make_engine():
"""创建 SQLite 引擎,启用 foreign_keys。"""
engine = create_engine(
@@ -39,3 +60,14 @@ def get_db():
yield db
finally:
db.close()
def init_db(engine):
"""创建所有 ORM 表 + FTS5 虚拟表。"""
from app.models import Base # noqa: F811 — 避免循环导入,延迟导入
Base.metadata.create_all(engine)
with engine.connect() as conn:
conn.execute(text(FTS5_CREATE_SQL))
conn.execute(text(FTS5_TRIGGER_INDEX))
conn.commit()
+21 -20
View File
@@ -2,14 +2,13 @@
import logging
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from starlette.staticfiles import StaticFiles as StarletteStaticFiles
from app.config import settings
from app.database import engine
from app.models import init_db
from app.database import engine, init_db
from app.routes.admin import router as admin_router
from app.routes.compare import router as compare_router
from app.routes.pages import router as pages_router
@@ -24,11 +23,30 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理:启动与关闭。"""
# ── startup ──
from app.services.scheduler import start_scheduler
from app.services.embedder import init_chroma
start_scheduler()
init_chroma()
yield
# ── shutdown ──
from app.services.scheduler import stop_scheduler
stop_scheduler()
def create_app() -> FastAPI:
app = FastAPI(
title="HF Daily Papers",
description="HuggingFace Daily Papers — 中文论文导览站",
version="0.1.0",
lifespan=lifespan,
)
# 确保数据目录存在
@@ -65,23 +83,6 @@ def create_app() -> FastAPI:
app.include_router(trends_router)
app.include_router(compare_router)
# 调度器(Phase 4
@app.on_event("startup")
async def _start_scheduler():
from app.services.scheduler import start_scheduler
start_scheduler()
# Phase 5: 初始化 ChromaDB
@app.on_event("startup")
async def _init_chroma():
from app.services.embedder import init_chroma
init_chroma()
@app.on_event("shutdown")
async def _stop_scheduler():
from app.services.scheduler import stop_scheduler
stop_scheduler()
return app
+1 -31
View File
@@ -1,4 +1,4 @@
"""SQLAlchemy ORM 模型 — papers, authors, tags, summaries, FTS5, logs, locks, user data"""
"""SQLAlchemy ORM 模型 — papers, authors, tags, summaries, user data, logs, locks。"""
from datetime import date, datetime
@@ -13,7 +13,6 @@ from sqlalchemy import (
String,
Text,
UniqueConstraint,
text,
)
from sqlalchemy.orm import relationship
@@ -204,32 +203,3 @@ class DataDeleteJob(Base):
error = Column(Text)
started_at = Column(DateTime, nullable=False)
completed_at = Column(DateTime)
# ── FTS5 索引初始化 SQL(普通虚拟表,由应用层维护)──────────────────────
FTS5_CREATE_SQL = """
CREATE VIRTUAL TABLE IF NOT EXISTS papers_fts USING fts5(
title_en,
title_zh,
abstract,
authors,
tags,
summary_text,
tokenize='unicode61'
);
"""
FTS5_TRIGGER_INDEX = """
-- partial index for task_locks running
CREATE UNIQUE INDEX IF NOT EXISTS uq_task_locks_running
ON task_locks(task, lock_key) WHERE status = 'running';
"""
def init_db(engine):
"""创建所有 ORM 表 + FTS5 虚拟表。"""
Base.metadata.create_all(engine)
with engine.connect() as conn:
conn.execute(text(FTS5_CREATE_SQL))
conn.execute(text(FTS5_TRIGGER_INDEX))
conn.commit()
+9 -30
View File
@@ -6,7 +6,6 @@ from datetime import date, datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -17,10 +16,10 @@ from app.models import CrawlLog, DataDeleteJob, TaskLock
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
from app.services.crawler import crawl_daily
from app.services.summarizer import summarize_batch, summarize_single
from app.utils import release_lock, templates, today_str
router = APIRouter(prefix="/admin", tags=["admin"])
security = HTTPBearer()
templates = Jinja2Templates(directory="app/templates")
async def verify_admin(
@@ -32,7 +31,7 @@ async def verify_admin(
return credentials.credentials
# ── 请求模型 ──────────────────────────────────────────────────────────────
# ── 请求模型 ──────────────────────────────────────────────────────────
class DeleteRequest(BaseModel):
@@ -49,7 +48,7 @@ class DeleteRequest(BaseModel):
return v
# ── 抓取 ──────────────────────────────────────────────────────────────────
# ── 抓取 ──────────────────────────────────────────────────────────────
@router.post("/crawl")
@@ -59,12 +58,7 @@ async def admin_crawl(
date: str | None = Query(None, description="YYYY-MM-DD,默认今天"),
):
"""手动抓取指定日期,默认今天。"""
# 计算 target_date
from zoneinfo import ZoneInfo
tz = ZoneInfo(settings.APP_TIMEZONE)
today = datetime.now(tz).strftime("%Y-%m-%d")
target_date = date or today
target_date = date or today_str()
# TaskLock 防重入
now = datetime.now(timezone.utc)
@@ -88,10 +82,10 @@ async def admin_crawl(
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
finally:
_release_lock(db, lock)
release_lock(db, lock)
# ── 总结 ──────────────────────────────────────────────────────────────────
# ── 总结 ──────────────────────────────────────────────────────────────
@router.post("/summarize")
@@ -119,7 +113,7 @@ async def admin_summarize_single(
return result
# ── 清理 ──────────────────────────────────────────────────────────────────
# ── 清理 ──────────────────────────────────────────────────────────────
@router.post("/cleanup")
@@ -155,7 +149,7 @@ async def admin_cleanup(
raise HTTPException(status_code=500, detail=str(exc))
# ── 删除 ──────────────────────────────────────────────────────────────────
# ── 删除 ──────────────────────────────────────────────────────────────
@router.post("/delete")
@@ -177,7 +171,7 @@ async def admin_delete(
return result
# ── 日志 ──────────────────────────────────────────────────────────────────
# ── 日志 ──────────────────────────────────────────────────────────────
@router.get("/logs")
@@ -189,7 +183,6 @@ async def admin_logs(
per_page: int = Query(20, ge=1, le=100),
):
"""查看任务日志(CrawlLog + DataDeleteJob)。"""
# 查询 crawl_logs
crawl_logs = (
db.execute(
select(CrawlLog)
@@ -201,7 +194,6 @@ async def admin_logs(
.all()
)
# 查询 delete_jobs
delete_jobs = (
db.execute(
select(DataDeleteJob)
@@ -223,16 +215,3 @@ async def admin_logs(
"per_page": per_page,
},
)
# ── 工具函数 ──────────────────────────────────────────────────────────────
def _release_lock(db: Session, lock: TaskLock) -> None:
"""释放 TaskLock。"""
try:
lock.status = "finished"
lock.released_at = datetime.now(timezone.utc)
db.commit()
except Exception:
db.rollback()
+1 -6
View File
@@ -2,19 +2,14 @@
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session, joinedload
from app.database import get_db
from app.models import Paper
logger = logging.getLogger(__name__)
from app.utils import templates
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
@router.get("/compare")
+6 -20
View File
@@ -3,34 +3,27 @@
from __future__ import annotations
import logging
from datetime import date, datetime, timedelta
from datetime import date, timedelta
from pathlib import Path
from zoneinfo import ZoneInfo
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import RedirectResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.database import get_db
from app.models import Paper
from app.utils import templates, today_str
logger = logging.getLogger(__name__)
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
def _today() -> str:
tz = ZoneInfo(settings.APP_TIMEZONE)
return datetime.now(tz).strftime("%Y-%m-%d")
@router.get("/")
def index(request: Request):
"""重定向到 /day/{today}"""
return RedirectResponse(url=f"/day/{_today()}")
return RedirectResponse(url=f"/day/{today_str()}")
@router.get("/day/{date_str}")
@@ -43,7 +36,7 @@ def day_page(date_str: str, request: Request, db: Session = Depends(get_db)):
prev_day = (target - timedelta(days=1)).isoformat()
next_day = (target + timedelta(days=1)).isoformat()
today_str = _today()
today = today_str()
papers = (
db.query(Paper)
@@ -74,7 +67,7 @@ def day_page(date_str: str, request: Request, db: Session = Depends(get_db)):
"current_date": date_str,
"prev_day": prev_day,
"next_day": next_day,
"today": today_str,
"today": today,
"available_dates": available_dates,
"page_title": f"{date_str} 论文列表",
},
@@ -145,17 +138,10 @@ def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict
if not settings.CHROMA_ENABLED:
return []
try:
from app.services.embedder import search_similar
# 用论文的 arxiv_id 从 ChromaDB 查询
col = None
try:
from app.services.embedder import get_collection
col = get_collection()
except Exception:
return []
col = get_collection()
if col is None:
return []
+7 -61
View File
@@ -2,14 +2,11 @@
from __future__ import annotations
import math
from datetime import date, datetime, timedelta, timezone
from zoneinfo import ZoneInfo
from datetime import date, timedelta
from xml.sax.saxutils import escape
from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import Response
from fastapi.templating import Jinja2Templates
from sqlalchemy import text
from sqlalchemy.orm import Session, joinedload
@@ -17,9 +14,10 @@ from app.config import settings
from app.database import get_db
from app.models import Paper, PaperTag, UserReadingStatus
from app.services.searcher import get_all_tags, search_papers
from app.services.user_data import query_reading_list
from app.utils import templates, today_str
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
# ── 搜索页 ────────────────────────────────────────────────────────────
@@ -56,7 +54,7 @@ def search_page(
"total_pages": result["total_pages"],
"all_tags": all_tags,
"page_title": f"搜索: {q}" if q else "搜索",
"today": _today_str(),
"today": today_str(),
},
)
@@ -114,7 +112,7 @@ def reading_list_page(
db: Session = Depends(get_db),
):
"""阅读列表页面。"""
papers = _query_reading_list(db, filter, tag or None)
papers = query_reading_list(db, filter, tag or None)
all_tags = get_all_tags(db)
return templates.TemplateResponse(
@@ -126,54 +124,11 @@ def reading_list_page(
"current_tag": tag,
"all_tags": all_tags,
"page_title": "阅读列表",
"today": _today_str(),
"today": today_str(),
},
)
def _query_reading_list(
db: Session,
filter_type: str,
tag: str | None,
) -> list[Paper]:
"""根据筛选条件查询阅读列表。"""
from sqlalchemy import or_
# 基础:有任意用户数据的论文
base = db.query(Paper).filter(
or_(
Paper.bookmark.has(),
Paper.reading_status.has(),
Paper.note.has(),
)
)
# 应用筛选
if filter_type == "has_note":
base = base.filter(Paper.note.has())
elif filter_type in ("unread", "skimmed", "read_summary", "read_full"):
base = base.filter(
Paper.reading_status.has(UserReadingStatus.status == filter_type)
)
# 应用标签
if tag:
base = base.filter(Paper.tags.any(PaperTag.tag == tag))
return (
base.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
joinedload(Paper.bookmark),
joinedload(Paper.reading_status),
joinedload(Paper.note),
)
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
.all()
)
# ── RSS Feed ──────────────────────────────────────────────────────────
@@ -216,7 +171,7 @@ def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> st
lines.append(f" <title>{escape(channel_title)}</title>")
lines.append(f" <link>{escape(base_url)}</link>")
lines.append(" <description>HuggingFace Daily Papers — 中文论文导览站</description>")
lines.append(f" <language>zh-CN</language>")
lines.append(" <language>zh-CN</language>")
for paper in papers:
title_text = paper.title_zh or paper.title_en
@@ -245,12 +200,3 @@ def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> st
lines.append(" </channel>")
lines.append("</rss>")
return "\n".join(lines)
# ── 工具函数 ──────────────────────────────────────────────────────────
def _today_str() -> str:
"""当前日期字符串(按 APP_TIMEZONE)。"""
tz = ZoneInfo(settings.APP_TIMEZONE)
return datetime.now(tz).strftime("%Y-%m-%d")
+5 -92
View File
@@ -2,34 +2,27 @@
from __future__ import annotations
import logging
from datetime import date, timedelta
from fastapi import APIRouter, Depends, Request
from fastapi.templating import Jinja2Templates
from sqlalchemy import func, text
from sqlalchemy.orm import Session
from app.config import settings
from app.database import get_db
logger = logging.getLogger(__name__)
from app.services.trends import get_trends_data
from app.utils import templates, today_str
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
@router.get("/trends")
def trends_page(request: Request, db: Session = Depends(get_db)):
"""趋势看板页面。"""
stats = _get_trends_data(db)
stats = get_trends_data(db)
return templates.TemplateResponse(
request,
"trends.html",
{
"page_title": "趋势看板",
"stats": stats,
"today": _today_str(),
"today": today_str(),
},
)
@@ -37,84 +30,4 @@ def trends_page(request: Request, db: Session = Depends(get_db)):
@router.get("/api/stats/trends")
def trends_api(db: Session = Depends(get_db)):
"""趋势数据 JSON API。"""
return _get_trends_data(db)
def _get_trends_data(db: Session) -> dict:
"""从 DB 聚合趋势数据。"""
thirty_days_ago = (date.today() - timedelta(days=30)).isoformat()
# 1. 按日论文数量(近 30 天)
daily_rows = db.execute(text("""
SELECT paper_date, COUNT(*) as cnt
FROM papers
WHERE paper_date >= :start_date
GROUP BY paper_date
ORDER BY paper_date ASC
"""), {"start_date": thirty_days_ago}).fetchall()
daily_counts = [
{"date": str(row[0]), "count": row[1]}
for row in daily_rows
]
# 2. 热门标签 Top 20
tag_rows = db.execute(text("""
SELECT tag, COUNT(*) as cnt
FROM paper_tags
GROUP BY tag
ORDER BY cnt DESC
LIMIT 20
""")).fetchall()
top_tags = [
{"tag": row[0], "count": row[1]}
for row in tag_rows
]
# 3. Upvotes 分布
upvote_rows = db.execute(text("""
SELECT
CASE
WHEN upvotes >= 100 THEN '100+'
WHEN upvotes >= 50 THEN '50-99'
WHEN upvotes >= 20 THEN '20-49'
WHEN upvotes >= 10 THEN '10-19'
WHEN upvotes >= 5 THEN '5-9'
ELSE '0-4'
END as bucket,
COUNT(*) as cnt
FROM papers
GROUP BY bucket
ORDER BY MIN(upvotes) DESC
""")).fetchall()
upvotes_dist = [
{"range": row[0], "count": row[1]}
for row in upvote_rows
]
# 4. 总结完成率
summary_rows = db.execute(text("""
SELECT
COALESCE(ss.status, 'none') as status,
COUNT(*) as cnt
FROM papers p
LEFT JOIN summary_status ss ON ss.paper_id = p.id
GROUP BY status
""")).fetchall()
summary_completion = [
{"status": row[0], "count": row[1]}
for row in summary_rows
]
return {
"daily_counts": daily_counts,
"top_tags": top_tags,
"upvotes_dist": upvotes_dist,
"summary_completion": summary_completion,
}
def _today_str() -> str:
from datetime import datetime
from zoneinfo import ZoneInfo
tz = ZoneInfo(settings.APP_TIMEZONE)
return datetime.now(tz).strftime("%Y-%m-%d")
return get_trends_data(db)
+5 -8
View File
@@ -16,13 +16,10 @@ from app.models import (
Paper,
TaskLock,
)
from app.utils import PAPERS_DIR, TMP_DIR
logger = logging.getLogger(__name__)
_DATA_DIR = Path("data")
_TMP_DIR = _DATA_DIR / "tmp"
_PAPERS_DIR = _DATA_DIR / "papers"
# 临时文件最大保留时间(小时)
_MAX_TMP_AGE_HOURS = 24
@@ -39,7 +36,7 @@ def cleanup_tmp(max_age_hours: int = _MAX_TMP_AGE_HOURS) -> dict:
Returns:
清理统计 {"scanned": int, "removed": int, "errors": list[str]}
"""
if not _TMP_DIR.exists():
if not TMP_DIR.exists():
return {"scanned": 0, "removed": 0, "errors": []}
now = datetime.now(timezone.utc)
@@ -48,7 +45,7 @@ def cleanup_tmp(max_age_hours: int = _MAX_TMP_AGE_HOURS) -> dict:
removed = 0
errors: list[str] = []
for entry in _TMP_DIR.iterdir():
for entry in TMP_DIR.iterdir():
if not entry.is_dir():
continue
scanned += 1
@@ -147,13 +144,13 @@ async def delete_papers_by_date_range(
logger.warning("Failed to delete %s from ChromaDB", arxiv_id, exc_info=True)
# 2. 删除本地文件 data/papers/{arxiv_id}/
paper_dir = _PAPERS_DIR / arxiv_id
paper_dir = PAPERS_DIR / arxiv_id
if paper_dir.exists():
shutil.rmtree(paper_dir)
logger.debug("Removed paper dir: %s", paper_dir)
# 3. 删除临时文件 data/tmp/{arxiv_id}/
tmp_dir = _TMP_DIR / arxiv_id
tmp_dir = TMP_DIR / arxiv_id
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
logger.debug("Removed tmp dir: %s", tmp_dir)
+2 -9
View File
@@ -16,6 +16,7 @@ from app.models import (
PaperTag,
SummaryStatus,
)
from app.utils import make_http_client
logger = logging.getLogger(__name__)
@@ -34,15 +35,7 @@ async def fetch_daily(target_date: str, top_n: int | None = None) -> list[dict]:
url = f"{settings.HF_API_BASE}/daily_papers"
params = {"date": target_date}
transport = None
if settings.http_proxy:
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
async with httpx.AsyncClient(
timeout=settings.HTTP_TIMEOUT_SECONDS,
headers={"User-Agent": settings.HTTP_USER_AGENT},
transport=transport,
) as client:
async with make_http_client() as client:
for attempt in range(1, settings.HTTP_MAX_RETRIES + 1):
try:
logger.info("Fetching HF Daily Papers: date=%s attempt=%d", target_date, attempt)
+44 -28
View File
@@ -5,8 +5,6 @@ from __future__ import annotations
import logging
from pathlib import Path
import httpx
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from app.config import settings
@@ -14,66 +12,82 @@ from app.models import Paper
logger = logging.getLogger(__name__)
# ── 单例客户端和 collection ─────────────────────────────────────────────
_client = None
_collection = None
# ── ChromaDB 管理器(替代全局可变状态)──────────────────────────────────
def _chroma_dir() -> Path:
return Path(settings.CHROMA_DIR)
class ChromaManager:
"""封装 ChromaDB 客户端和 collection 的生命周期。"""
def __init__(self) -> None:
self._client = None
self._collection = None
def init_chroma() -> None:
def init(self) -> None:
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
global _client, _collection
if not settings.CHROMA_ENABLED:
logger.debug("ChromaDB disabled, skip init")
return
if _client is not None:
if self._client is not None:
return
try:
import chromadb
chroma_path = _chroma_dir()
chroma_path = Path(settings.CHROMA_DIR)
chroma_path.mkdir(parents=True, exist_ok=True)
_client = chromadb.PersistentClient(path=str(chroma_path))
_collection = _get_or_create_collection()
self._client = chromadb.PersistentClient(path=str(chroma_path))
self._collection = self._get_or_create_collection()
logger.info("ChromaDB initialized at %s", chroma_path)
except Exception:
logger.exception("Failed to initialize ChromaDB")
_client = None
_collection = None
def _get_or_create_collection():
"""获取或创建 papers_embeddings collection,维度不匹配时记录日志并跳过。"""
import chromadb
self._client = None
self._collection = None
def _get_or_create_collection(self):
"""获取或创建 papers_embeddings collection。"""
try:
col = _client.get_collection("papers_embeddings")
col = self._client.get_collection("papers_embeddings")
logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count())
return col
except Exception:
pass
col = _client.create_collection(
col = self._client.create_collection(
name="papers_embeddings",
metadata={"hnsw:space": "cosine"},
)
logger.info("ChromaDB collection 'papers_embeddings' created")
return col
def get_collection(self):
"""返回当前 collection,未初始化则自动初始化。"""
if not settings.CHROMA_ENABLED:
return None
if self._collection is None:
self.init()
return self._collection
def reset(self) -> None:
"""重置状态(供测试使用)。"""
self._client = None
self._collection = None
# 模块级单例
_chroma = ChromaManager()
def init_chroma() -> None:
"""初始化 ChromaDB(供 lifespan 调用)。"""
_chroma.init()
def get_collection():
"""返回当前 collection,未初始化则返回 None。"""
if not settings.CHROMA_ENABLED:
return None
if _collection is None:
init_chroma()
return _collection
return _chroma.get_collection()
# ── Embedding API 调用 ──────────────────────────────────────────────────
@@ -90,6 +104,8 @@ def _get_embedding(text: str) -> list[float] | None:
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
return None
from app.utils import make_http_client
url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings"
headers = {"Content-Type": "application/json"}
if settings.EMBED_API_KEY:
@@ -101,7 +117,7 @@ def _get_embedding(text: str) -> list[float] | None:
}
try:
with httpx.Client(timeout=settings.HTTP_TIMEOUT_SECONDS) as client:
with make_http_client(sync=True) as client:
resp = client.post(url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
+83
View File
@@ -0,0 +1,83 @@
"""LaTeX 图片提取 — 从 arXiv 源码中扫描 \\includegraphics 并提取图片文件。"""
from __future__ import annotations
import logging
import re
import shutil
from pathlib import Path
from app.services.pdf_downloader import download_source_zip, paper_dir, tmp_dir
logger = logging.getLogger(__name__)
_INCLUDEGRAPHICS_RE = re.compile(
r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
)
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
async def extract_images_from_source(arxiv_id: str) -> int:
"""从 LaTeX 源码中提取图片文件。
流程:
1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
2. 扫描 .tex 文件中的 \\includegraphics
3. 复制图片到 data/papers/{arxiv_id}/images/
4. 清理源码临时文件
Returns:
提取的图片数量
"""
tmp_source = tmp_dir(arxiv_id) / "source"
images_dest = paper_dir(arxiv_id) / "images"
try:
# 下载源码 zip(如果还没下载)
if not tmp_source.exists():
source_url = f"https://arxiv.org/e-print/{arxiv_id}"
await download_source_zip(arxiv_id, source_url, tmp_source)
if not tmp_source.exists():
return 0
# 扫描 .tex 文件,收集图片路径
image_paths: set[str] = set()
for tex_file in tmp_source.rglob("*.tex"):
try:
content = tex_file.read_text(encoding="utf-8", errors="replace")
for match in _INCLUDEGRAPHICS_RE.finditer(content):
img_path = match.group(1).strip()
image_paths.add(img_path)
except Exception:
continue
if not image_paths:
return 0
# 查找并复制图片
images_dest.mkdir(parents=True, exist_ok=True)
copied = 0
for img_rel in image_paths:
# 尝试在源码目录中找到文件
for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
candidate = tmp_source / (img_rel + ext)
if candidate.is_file():
dest_name = candidate.name
# 避免文件名冲突
dest = images_dest / dest_name
if dest.exists():
stem = dest.stem
suffix = dest.suffix
dest = images_dest / f"{stem}_{copied}{suffix}"
shutil.copy2(candidate, dest)
copied += 1
break
if copied > 0:
logger.info("Extracted %d images from source for %s", copied, arxiv_id)
return copied
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
return 0
+105
View File
@@ -0,0 +1,105 @@
"""PDF 下载与源码下载 — 从 arXiv 下载论文 PDF 和 LaTeX 源码包。"""
from __future__ import annotations
import logging
import shutil
import zipfile
from pathlib import Path
from app.utils import PAPERS_DIR, TMP_DIR, make_http_client
logger = logging.getLogger(__name__)
# ── 自定义异常 ──────────────────────────────────────────────────────────
class PdfDownloadError(Exception):
pass
# ── 路径工具 ────────────────────────────────────────────────────────────
def paper_dir(arxiv_id: str) -> Path:
return PAPERS_DIR / arxiv_id
def tmp_dir(arxiv_id: str) -> Path:
return TMP_DIR / arxiv_id
# ── PDF 下载 ────────────────────────────────────────────────────────────
async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
if not pdf_url:
raise PdfDownloadError(f"no pdf_url for {arxiv_id}")
dest_dir = tmp_dir(arxiv_id)
dest_dir.mkdir(parents=True, exist_ok=True)
dest = dest_dir / "paper.pdf"
try:
async with make_http_client(follow_redirects=True) as client:
resp = await client.get(pdf_url)
resp.raise_for_status()
dest.write_bytes(resp.content)
except Exception as exc:
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
return dest
# ── 源码下载 ────────────────────────────────────────────────────────────
async def download_source_zip(arxiv_id: str, source_url: str, dest_dir: Path) -> None:
"""下载 arXiv 源码并解压。"""
dest_dir.mkdir(parents=True, exist_ok=True)
zip_path = tmp_dir(arxiv_id) / "source.zip"
try:
async with make_http_client(follow_redirects=True) as client:
resp = await client.get(source_url)
resp.raise_for_status()
zip_path.write_bytes(resp.content)
except Exception as exc:
logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
return
try:
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(dest_dir)
logger.debug("Extracted source for %s", arxiv_id)
except zipfile.BadZipFile:
# 可能是 tar.gz
import tarfile
try:
with tarfile.open(zip_path, "r:*") as tf:
tf.extractall(dest_dir, filter="data")
logger.debug("Extracted source (tar) for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
finally:
if zip_path.exists():
zip_path.unlink()
# ── 临时文件清理 ────────────────────────────────────────────────────────
def cleanup_tmp(arxiv_id: str) -> None:
"""清理 data/tmp/{arxiv_id}/ 目录。"""
td = tmp_dir(arxiv_id)
if td.exists():
try:
shutil.rmtree(td)
logger.debug("Cleaned tmp: %s", arxiv_id)
except Exception:
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
+160
View File
@@ -0,0 +1,160 @@
"""pi CLI 调用与 JSON 提取 — 调用 pi 生成总结,从输出中提取结构化 JSON。"""
from __future__ import annotations
import asyncio
import json
import logging
import re
from pathlib import Path
from app.config import settings
logger = logging.getLogger(__name__)
# ── 自定义异常 ──────────────────────────────────────────────────────────
class PiTimeoutError(Exception):
pass
class PiProcessError(Exception):
def __init__(self, returncode: int, stderr: str):
self.returncode = returncode
self.stderr = stderr
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
class JsonNotFoundError(Exception):
pass
# ── meta.json ───────────────────────────────────────────────────────────
def write_meta_json(paper) -> Path:
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
from app.services.pdf_downloader import paper_dir
d = paper_dir(paper.arxiv_id)
d.mkdir(parents=True, exist_ok=True)
meta_path = d / "meta.json"
authors = [a.name for a in paper.authors]
tags = [t.tag for t in paper.tags]
meta = {
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"abstract": paper.abstract or "",
"published_at": paper.published_at.isoformat() if paper.published_at else None,
"authors": authors,
"tags": tags,
"upvotes": paper.upvotes,
}
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
return meta_path
# ── pi CLI 调用 ────────────────────────────────────────────────────────
async def call_pi(meta_path: Path, pdf_path: Path) -> str:
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
arxiv_id = meta_path.parent.name
cmd = [
settings.PI_BIN,
"-p",
"--no-tools",
"--skill",
settings.SUMMARY_SKILL,
"请深度解读以下论文,并按指定 JSON schema 输出:",
f"@{meta_path}",
f"@{pdf_path}",
]
logger.info("Calling pi for %s", arxiv_id)
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(),
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise PiTimeoutError(
f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s"
)
if proc.returncode != 0:
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
return stdout.decode("utf-8", errors="replace")
# ── JSON 提取 ──────────────────────────────────────────────────────────
def extract_json(raw_output: str) -> dict:
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
# 策略 1:整体直接解析
stripped = raw_output.strip()
try:
result = json.loads(stripped)
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
pass
# 策略 2:提取 ```json ... ``` 代码块
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
for match in fence_pattern.finditer(raw_output):
try:
result = json.loads(match.group(1).strip())
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
continue
# 策略 3:匹配包含 title_zh 的最大 {...} 块
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
for match in brace_pattern.finditer(raw_output):
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
continue
# 更宽松:找到最大的 { ... } 平衡块
best = None
best_len = 0
for i, ch in enumerate(raw_output):
if ch != "{":
continue
depth = 0
for j in range(i, len(raw_output)):
if raw_output[j] == "{":
depth += 1
elif raw_output[j] == "}":
depth -= 1
if depth == 0:
candidate = raw_output[i : j + 1]
if len(candidate) > best_len:
try:
parsed = json.loads(candidate)
if isinstance(parsed, dict):
best = parsed
best_len = len(candidate)
except json.JSONDecodeError:
pass
break
if best is not None:
return best
raise JsonNotFoundError("no JSON object found in pi output")
+39 -362
View File
@@ -1,18 +1,15 @@
"""AI 总结服务 — 调用 pi CLI 生成论文中文结构化总结"""
"""AI 总结编排服务 — 协调 PDF 下载、pi CLI 调用、JSON 校验、DB 写入、语义索引"""
from __future__ import annotations
import asyncio
import json
import logging
import re
import shutil
from datetime import datetime, timezone
from pathlib import Path
import httpx
from pydantic import ValidationError
from sqlalchemy import select, text
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from app.config import settings
@@ -25,216 +22,31 @@ from app.models import (
SummaryStatus,
TaskLock,
)
from app.services.image_extractor import extract_images_from_source
from app.services.pdf_downloader import (
PdfDownloadError,
cleanup_tmp,
download_pdf,
paper_dir,
)
from app.services.pi_client import (
JsonNotFoundError,
PiProcessError,
PiTimeoutError,
call_pi,
extract_json,
write_meta_json,
)
from app.services.schemas import (
SummarySchema,
assess_quality,
classify_validation_error,
flatten_for_db,
)
from app.utils import PAPERS_DIR, release_lock
logger = logging.getLogger(__name__)
# ── 自定义异常 ──────────────────────────────────────────────────────────
class PdfDownloadError(Exception):
pass
class PiTimeoutError(Exception):
pass
class PiProcessError(Exception):
def __init__(self, returncode: int, stderr: str):
self.returncode = returncode
self.stderr = stderr
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
class JsonNotFoundError(Exception):
pass
# ── 路径工具 ────────────────────────────────────────────────────────────
_DATA_DIR = Path("data")
_PAPERS_DIR = _DATA_DIR / "papers"
_TMP_DIR = _DATA_DIR / "tmp"
def _paper_dir(arxiv_id: str) -> Path:
return _PAPERS_DIR / arxiv_id
def _tmp_dir(arxiv_id: str) -> Path:
return _TMP_DIR / arxiv_id
# ── PDF 下载 ────────────────────────────────────────────────────────────
async def _download_pdf(arxiv_id: str, pdf_url: str) -> Path:
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
if not pdf_url:
raise PdfDownloadError(f"no pdf_url for {arxiv_id}")
tmp = _tmp_dir(arxiv_id)
tmp.mkdir(parents=True, exist_ok=True)
dest = tmp / "paper.pdf"
transport = None
if settings.http_proxy:
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
try:
async with httpx.AsyncClient(
timeout=settings.HTTP_TIMEOUT_SECONDS,
headers={"User-Agent": settings.HTTP_USER_AGENT},
transport=transport,
follow_redirects=True,
) as client:
resp = await client.get(pdf_url)
resp.raise_for_status()
dest.write_bytes(resp.content)
except Exception as exc:
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
return dest
# ── meta.json ───────────────────────────────────────────────────────────
def _write_meta_json(paper: Paper) -> Path:
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
d = _paper_dir(paper.arxiv_id)
d.mkdir(parents=True, exist_ok=True)
meta_path = d / "meta.json"
authors = [a.name for a in paper.authors]
tags = [t.tag for t in paper.tags]
meta = {
"arxiv_id": paper.arxiv_id,
"title_en": paper.title_en,
"abstract": paper.abstract or "",
"published_at": paper.published_at.isoformat() if paper.published_at else None,
"authors": authors,
"tags": tags,
"upvotes": paper.upvotes,
}
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
return meta_path
# ── pi CLI 调用 ────────────────────────────────────────────────────────
async def _call_pi(meta_path: Path, pdf_path: Path) -> str:
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
cmd = [
settings.PI_BIN,
"-p",
"--no-tools",
"--skill",
settings.SUMMARY_SKILL,
"请深度解读以下论文,并按指定 JSON schema 输出:",
f"@{meta_path}",
f"@{pdf_path}",
]
logger.info("Calling pi: %s %s", paper_id_from_path(meta_path), " ".join(cmd[:4]))
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(),
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise PiTimeoutError(
f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s"
)
if proc.returncode != 0:
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
return stdout.decode("utf-8", errors="replace")
def paper_id_from_path(meta_path: Path) -> str:
"""从 meta.json 路径反推 arxiv_id。"""
return meta_path.parent.name
# ── JSON 提取 ──────────────────────────────────────────────────────────
def _extract_json(raw_output: str) -> dict:
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
# 策略 1:整体直接解析
stripped = raw_output.strip()
try:
result = json.loads(stripped)
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
pass
# 策略 2:提取 ```json ... ``` 代码块
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
for match in fence_pattern.finditer(raw_output):
try:
result = json.loads(match.group(1).strip())
if isinstance(result, dict) and "title_zh" in result:
return result
except json.JSONDecodeError:
continue
# 策略 3:匹配包含 title_zh 的最大 {...} 块
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
# 先尝试一层嵌套;如果没命中再用更宽松的策略
for match in brace_pattern.finditer(raw_output):
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
continue
# 更宽松:找到最大的 { ... } 平衡块
best = None
best_len = 0
for i, ch in enumerate(raw_output):
if ch != "{":
continue
depth = 0
for j in range(i, len(raw_output)):
if raw_output[j] == "{":
depth += 1
elif raw_output[j] == "}":
depth -= 1
if depth == 0:
candidate = raw_output[i : j + 1]
if len(candidate) > best_len:
try:
parsed = json.loads(candidate)
if isinstance(parsed, dict):
best = parsed
best_len = len(candidate)
except json.JSONDecodeError:
pass
break
if best is not None:
return best
raise JsonNotFoundError("no JSON object found in pi output")
# ── 错误分类 ────────────────────────────────────────────────────────────
@@ -284,6 +96,8 @@ def _update_summary_in_db(
raw_output: str,
) -> None:
"""将校验后的总结写入 DBpaper_summaries + papers + paper_tags + FTS5。"""
from sqlalchemy import text
now = datetime.now(timezone.utc)
# 1. paper_summariesupsert
@@ -298,9 +112,9 @@ def _update_summary_in_db(
# 2. papers 表
paper.title_zh = schema.title_zh
paper.summary_quality = quality
paper_dir = _paper_dir(paper.arxiv_id)
paper.summary_path = str(paper_dir / "summary.json")
paper.raw_output_path = str(paper_dir / "raw_output.txt")
p_dir = paper_dir(paper.arxiv_id)
paper.summary_path = str(p_dir / "summary.json")
paper.raw_output_path = str(p_dir / "raw_output.txt")
# 3. AI 标签
existing_tag_names = {t.tag for t in paper.tags}
@@ -332,7 +146,7 @@ def _update_summary_in_db(
def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
"""保存 summary.json 和 raw_output.txt。"""
d = _paper_dir(arxiv_id)
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
(d / "summary.json").write_text(
schema.model_dump_json(ensure_ascii=False, indent=2),
@@ -343,143 +157,11 @@ def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None:
"""仅保存 raw_output.txt(失败时)。"""
d = _paper_dir(arxiv_id)
d = paper_dir(arxiv_id)
d.mkdir(parents=True, exist_ok=True)
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
def _cleanup_tmp(arxiv_id: str) -> None:
"""清理 data/tmp/{arxiv_id}/ 目录。"""
tmp = _tmp_dir(arxiv_id)
if tmp.exists():
try:
shutil.rmtree(tmp)
logger.debug("Cleaned tmp: %s", arxiv_id)
except Exception:
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
# ── LaTeX 图片提取(Phase 5)───────────────────────────────────────────
_INCLUDEGRAPHICS_RE = re.compile(
r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
)
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
async def _extract_images_from_source(arxiv_id: str, tmp_source: Path | None = None) -> int:
"""从 LaTeX 源码中提取图片文件。
流程:
1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
2. 扫描 .tex 文件中的 \\includegraphics
3. 复制图片到 data/papers/{arxiv_id}/images/
4. 清理源码临时文件
Returns:
提取的图片数量
"""
tmp_source = _tmp_dir(arxiv_id) / "source"
images_dest = _paper_dir(arxiv_id) / "images"
try:
# 下载源码 zip(如果还没下载)
if not tmp_source.exists():
source_url = f"https://arxiv.org/e-print/{arxiv_id}"
await _download_source_zip(arxiv_id, source_url, tmp_source)
if not tmp_source.exists():
return 0
# 扫描 .tex 文件,收集图片路径
image_paths: set[str] = set()
for tex_file in tmp_source.rglob("*.tex"):
try:
content = tex_file.read_text(encoding="utf-8", errors="replace")
for match in _INCLUDEGRAPHICS_RE.finditer(content):
img_path = match.group(1).strip()
image_paths.add(img_path)
except Exception:
continue
if not image_paths:
return 0
# 查找并复制图片
images_dest.mkdir(parents=True, exist_ok=True)
copied = 0
for img_rel in image_paths:
# 尝试在源码目录中找到文件
for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
candidate = tmp_source / (img_rel + ext)
if candidate.is_file():
dest_name = candidate.name
# 避免文件名冲突
dest = images_dest / dest_name
if dest.exists():
stem = dest.stem
suffix = dest.suffix
dest = images_dest / f"{stem}_{copied}{suffix}"
shutil.copy2(candidate, dest)
copied += 1
break
if copied > 0:
logger.info("Extracted %d images from source for %s", copied, arxiv_id)
return copied
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
return 0
async def _download_source_zip(
arxiv_id: str, source_url: str, dest_dir: Path
) -> None:
"""下载 arXiv 源码并解压。"""
import zipfile
dest_dir.mkdir(parents=True, exist_ok=True)
zip_path = _tmp_dir(arxiv_id) / "source.zip"
transport = None
if settings.http_proxy:
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
try:
async with httpx.AsyncClient(
timeout=settings.HTTP_TIMEOUT_SECONDS,
headers={"User-Agent": settings.HTTP_USER_AGENT},
transport=transport,
follow_redirects=True,
) as client:
resp = await client.get(source_url)
resp.raise_for_status()
zip_path.write_bytes(resp.content)
except Exception as exc:
logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
return
try:
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(dest_dir)
logger.debug("Extracted source for %s", arxiv_id)
except zipfile.BadZipFile:
# 可能是 tar.gz
import tarfile
try:
with tarfile.open(zip_path, "r:*") as tf:
tf.extractall(dest_dir)
logger.debug("Extracted source (tar) for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id)
except Exception:
logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
finally:
if zip_path.exists():
zip_path.unlink()
# ── 单篇总结 ────────────────────────────────────────────────────────────
@@ -491,6 +173,8 @@ async def summarize_one(
force: bool = False,
) -> dict:
"""总结单篇论文的完整流程。"""
import asyncio
arxiv_id = paper.arxiv_id
# 获取或创建 summary_status
@@ -520,6 +204,8 @@ async def summarize_one(
async def _do_summarize_one(db: Session, paper: Paper) -> dict:
"""实际的单篇总结执行(在 semaphore 保护下)。"""
import asyncio
arxiv_id = paper.arxiv_id
status = paper.summary_status
now = datetime.now(timezone.utc)
@@ -532,16 +218,16 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
raw_output = ""
try:
# 写 meta.json
meta_path = _write_meta_json(paper)
meta_path = write_meta_json(paper)
# 下载 PDF
await _download_pdf(arxiv_id, paper.pdf_url)
await download_pdf(arxiv_id, paper.pdf_url)
# 调用 pi
raw_output = await _call_pi(meta_path, _tmp_dir(arxiv_id) / "paper.pdf")
raw_output = await call_pi(meta_path, Path("data/tmp") / arxiv_id / "paper.pdf")
# 提取 JSON
json_data = _extract_json(raw_output)
json_data = extract_json(raw_output)
# Pydantic 校验
schema = SummarySchema.model_validate(json_data)
@@ -564,7 +250,7 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
# Phase 5: LaTeX 图片提取(可选增强,失败不影响总结)
try:
await _extract_images_from_source(arxiv_id)
await extract_images_from_source(arxiv_id)
except Exception:
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
@@ -625,7 +311,7 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
}
finally:
_cleanup_tmp(arxiv_id)
cleanup_tmp(arxiv_id)
# ── 单篇入口 ────────────────────────────────────────────────────────────
@@ -690,6 +376,8 @@ async def summarize_batch(
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
"""
import asyncio
now = datetime.now(timezone.utc)
# TaskLock 防重入
@@ -741,7 +429,7 @@ async def summarize_batch(
log_entry.papers_found = 0
log_entry.papers_new = 0
log_entry.completed_at = datetime.now(timezone.utc)
_release_lock(db, lock)
release_lock(db, lock)
return {"status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0}
# 并发控制
@@ -813,15 +501,4 @@ async def summarize_batch(
return {"status": "failed", "error": str(exc)}
finally:
_release_lock(db, lock)
def _release_lock(db: Session, lock: TaskLock) -> None:
"""释放 TaskLock。"""
try:
lock.status = "finished"
lock.released_at = datetime.now(timezone.utc)
db.commit()
except Exception:
db.rollback()
logger.warning("Failed to release summarize lock", exc_info=True)
release_lock(db, lock)
+81
View File
@@ -0,0 +1,81 @@
"""趋势统计服务 — 按日论文数量、热门标签、Upvotes 分布、总结完成率。"""
from __future__ import annotations
from datetime import date, timedelta
from sqlalchemy import text
from sqlalchemy.orm import Session
def get_trends_data(db: Session) -> dict:
"""从 DB 聚合趋势数据。"""
thirty_days_ago = (date.today() - timedelta(days=30)).isoformat()
# 1. 按日论文数量(近 30 天)
daily_rows = db.execute(text("""
SELECT paper_date, COUNT(*) as cnt
FROM papers
WHERE paper_date >= :start_date
GROUP BY paper_date
ORDER BY paper_date ASC
"""), {"start_date": thirty_days_ago}).fetchall()
daily_counts = [
{"date": str(row[0]), "count": row[1]}
for row in daily_rows
]
# 2. 热门标签 Top 20
tag_rows = db.execute(text("""
SELECT tag, COUNT(*) as cnt
FROM paper_tags
GROUP BY tag
ORDER BY cnt DESC
LIMIT 20
""")).fetchall()
top_tags = [
{"tag": row[0], "count": row[1]}
for row in tag_rows
]
# 3. Upvotes 分布
upvote_rows = db.execute(text("""
SELECT
CASE
WHEN upvotes >= 100 THEN '100+'
WHEN upvotes >= 50 THEN '50-99'
WHEN upvotes >= 20 THEN '20-49'
WHEN upvotes >= 10 THEN '10-19'
WHEN upvotes >= 5 THEN '5-9'
ELSE '0-4'
END as bucket,
COUNT(*) as cnt
FROM papers
GROUP BY bucket
ORDER BY MIN(upvotes) DESC
""")).fetchall()
upvotes_dist = [
{"range": row[0], "count": row[1]}
for row in upvote_rows
]
# 4. 总结完成率
summary_rows = db.execute(text("""
SELECT
COALESCE(ss.status, 'none') as status,
COUNT(*) as cnt
FROM papers p
LEFT JOIN summary_status ss ON ss.paper_id = p.id
GROUP BY status
""")).fetchall()
summary_completion = [
{"status": row[0], "count": row[1]}
for row in summary_rows
]
return {
"daily_counts": daily_counts,
"top_tags": top_tags,
"upvotes_dist": upvotes_dist,
"summary_completion": summary_completion,
}
+48 -3
View File
@@ -1,12 +1,13 @@
"""用户数据服务 — 收藏、阅读状态、个人笔记。无账号体系,数据写入本地 SQLite。"""
"""用户数据服务 — 收藏、阅读状态、个人笔记、阅读列表查询。无账号体系,数据写入本地 SQLite。"""
from __future__ import annotations
from datetime import datetime, timezone
from sqlalchemy.orm import Session
from sqlalchemy import or_
from sqlalchemy.orm import Session, joinedload
from app.models import Paper, UserBookmark, UserNote, UserReadingStatus
from app.models import Paper, PaperTag, UserBookmark, UserNote, UserReadingStatus
# ── 收藏 ──────────────────────────────────────────────────────────────
@@ -113,3 +114,47 @@ def save_note(db: Session, arxiv_id: str, content: str) -> dict:
"content": content,
"updated_at": now.isoformat(),
}
# ── 阅读列表 ──────────────────────────────────────────────────────────
def query_reading_list(
db: Session,
filter_type: str,
tag: str | None,
) -> list[Paper]:
"""根据筛选条件查询阅读列表。"""
# 基础:有任意用户数据的论文
base = db.query(Paper).filter(
or_(
Paper.bookmark.has(),
Paper.reading_status.has(),
Paper.note.has(),
)
)
# 应用筛选
if filter_type == "has_note":
base = base.filter(Paper.note.has())
elif filter_type in ("unread", "skimmed", "read_summary", "read_full"):
base = base.filter(
Paper.reading_status.has(UserReadingStatus.status == filter_type)
)
# 应用标签
if tag:
base = base.filter(Paper.tags.any(PaperTag.tag == tag))
return (
base.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
joinedload(Paper.bookmark),
joinedload(Paper.reading_status),
joinedload(Paper.note),
)
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
.all()
)
+73
View File
@@ -0,0 +1,73 @@
"""公共工具 — 消除各模块间的重复代码。"""
from __future__ import annotations
from datetime import datetime, timezone
from pathlib import Path
from zoneinfo import ZoneInfo
import httpx
from fastapi.templating import Jinja2Templates
from app.config import settings
# ── 路径常量 ──────────────────────────────────────────────────────────
DATA_DIR = Path("data")
PAPERS_DIR = DATA_DIR / "papers"
TMP_DIR = DATA_DIR / "tmp"
# ── 模板单例 ──────────────────────────────────────────────────────────
templates = Jinja2Templates(directory="app/templates")
# ── 时区工具 ──────────────────────────────────────────────────────────
def today_str() -> str:
"""当前日期字符串(按 APP_TIMEZONE)。"""
tz = ZoneInfo(settings.APP_TIMEZONE)
return datetime.now(tz).strftime("%Y-%m-%d")
# ── 锁释放 ────────────────────────────────────────────────────────────
def release_lock(db, lock) -> None:
"""释放 TaskLock。"""
try:
lock.status = "finished"
lock.released_at = datetime.now(timezone.utc)
db.commit()
except Exception:
db.rollback()
# ── HTTP 客户端工厂 ───────────────────────────────────────────────────
def make_http_client(*, sync: bool = False, follow_redirects: bool = False, **kwargs) -> httpx.AsyncClient | httpx.Client:
"""创建带 proxy 和默认配置的 httpx 客户端。
Args:
sync: True 返回同步 ClientFalse 返回 AsyncClient
follow_redirects: 是否跟随重定向
**kwargs: 覆盖默认参数
"""
defaults: dict = {
"timeout": settings.HTTP_TIMEOUT_SECONDS,
"headers": {"User-Agent": settings.HTTP_USER_AGENT},
"follow_redirects": follow_redirects,
}
if settings.http_proxy:
defaults["transport"] = (
httpx.HTTPTransport(proxy=settings.http_proxy)
if sync
else httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
)
defaults.update(kwargs)
if sync:
return httpx.Client(**defaults)
return httpx.AsyncClient(**defaults)
+1 -1
View File
@@ -15,13 +15,13 @@ from sqlalchemy.pool import StaticPool
from app.database import get_db
from app.main import create_app
from app.database import init_db
from app.models import (
Paper,
PaperAuthor,
PaperSummary,
PaperTag,
SummaryStatus,
init_db,
)
+5 -5
View File
@@ -141,7 +141,7 @@ class TestCleanupTmp:
old_mtime = time.time() - 25 * 3600
os.utime(old_dir, (old_mtime, old_mtime))
monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir)
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
@@ -158,7 +158,7 @@ class TestCleanupTmp:
recent_dir.mkdir()
(recent_dir / "paper.pdf").write_text("fake pdf")
monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir)
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
@@ -168,7 +168,7 @@ class TestCleanupTmp:
def test_cleanup_empty_dir(self, tmp_path, monkeypatch):
"""data/tmp/ 不存在时安全返回。"""
monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_path / "nonexistent")
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_path / "nonexistent")
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
assert result["scanned"] == 0
@@ -187,7 +187,7 @@ class TestCleanupTmp:
recent_dir = tmp_dir / "2401.new"
recent_dir.mkdir()
monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir)
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
from app.services.cleaner import cleanup_tmp
result = cleanup_tmp()
@@ -318,7 +318,7 @@ class TestDeletePapersByDateRange:
(papers_dir / "2401.10001").mkdir()
(papers_dir / "2401.10001" / "meta.json").write_text("{}")
monkeypatch.setattr("app.services.cleaner._PAPERS_DIR", papers_dir)
monkeypatch.setattr("app.services.cleaner.PAPERS_DIR", papers_dir)
result = await delete_papers_by_date_range(
db_session,
+42 -39
View File
@@ -125,10 +125,9 @@ class TestEmbedderInit:
"""CHROMA_ENABLED=false 时不初始化。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
emb._chroma.reset()
emb.init_chroma()
assert emb._client is None
assert emb._chroma._client is None
def test_chroma_init_success(self, monkeypatch, tmp_path):
"""CHROMA_ENABLED=true 时初始化成功。"""
@@ -136,23 +135,20 @@ class TestEmbedderInit:
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
import app.services.embedder as emb
emb._client = None
emb._collection = None
emb._chroma.reset()
emb.init_chroma()
assert emb._client is not None
assert emb._collection is not None
assert emb._chroma._client is not None
assert emb._chroma._collection is not None
# 清理
emb._client = None
emb._collection = None
emb._chroma.reset()
def test_get_collection_returns_none_when_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 get_collection 返回 None。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
emb._chroma.reset()
assert emb.get_collection() is None
@@ -163,8 +159,7 @@ class TestEmbedderIndexing:
"""CHROMA_ENABLED=false 时 index_paper 返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
emb._chroma.reset()
assert emb.index_paper("test-id") is False
def test_index_paper_no_api_config(self, monkeypatch, tmp_path):
@@ -175,22 +170,19 @@ class TestEmbedderIndexing:
monkeypatch.setattr(settings, "EMBED_MODEL", "")
import app.services.embedder as emb
emb._client = None
emb._collection = None
emb._chroma.reset()
emb.init_chroma()
result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"})
assert result is False
emb._client = None
emb._collection = None
emb._chroma.reset()
def test_index_batch_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 index_batch 返回全失败。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
emb._chroma.reset()
result = emb.index_batch(["a", "b"])
assert result["success"] == 0
assert result["failed"] == 2
@@ -206,16 +198,14 @@ class TestEmbedderIndexing:
"""CHROMA_ENABLED=false 时 delete_paper 返回 False。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
emb._chroma.reset()
assert emb.delete_paper("test-id") is False
def test_search_similar_disabled(self, monkeypatch):
"""CHROMA_ENABLED=false 时 search_similar 返回空列表。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
emb._chroma.reset()
assert emb.search_similar("test query") == []
@@ -427,8 +417,8 @@ class TestTrendsDashboard:
from unittest.mock import patch as upatch
import app.routes.trends as trends_mod
# monkeypatch _get_trends_data 中的 date.today
with upatch("app.routes.trends.date") as mock_date:
# monkeypatch get_trends_data 中的 date.today
with upatch("app.services.trends.date") as mock_date:
mock_date.today.return_value = date(2024, 1, 20)
mock_date.side_effect = lambda *a, **kw: date(*a, **kw)
@@ -528,15 +518,17 @@ class TestImageExtraction:
@pytest.mark.asyncio
async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path):
"""源码目录不存在时返回 0。"""
monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x)
from app.services.summarizer import _extract_images_from_source
result = await _extract_images_from_source("2401.99999")
monkeypatch.setattr("app.services.pdf_downloader.tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.pdf_downloader.paper_dir", lambda x: tmp_path / "papers" / x)
from app.services.image_extractor import extract_images_from_source
result = await extract_images_from_source("2401.99999")
assert result == 0
@pytest.mark.asyncio
async def test_extract_images_from_tex(self, monkeypatch, tmp_path):
"""从 .tex 文件中提取图片。"""
from app.services.image_extractor import extract_images_from_source
tmp_source = tmp_path / "tmp" / "2401.00001" / "source"
tmp_source.mkdir(parents=True)
@@ -559,11 +551,16 @@ class TestImageExtraction:
(tmp_source / "main.tex").write_text(tex_content)
papers_dir = tmp_path / "papers" / "2401.00001"
monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x)
monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
from app.services.summarizer import _extract_images_from_source
result = await _extract_images_from_source("2401.00001")
# Mock download_source_zip to avoid real network call (source dir already exists)
async def _noop_download(*args, **kwargs):
pass
monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
result = await extract_images_from_source("2401.00001")
assert result == 2
dest_images = papers_dir / "images"
@@ -574,15 +571,22 @@ class TestImageExtraction:
@pytest.mark.asyncio
async def test_extract_images_empty_tex(self, monkeypatch, tmp_path):
""".tex 文件无图片时返回 0。"""
from app.services.image_extractor import extract_images_from_source
tmp_source = tmp_path / "tmp" / "2401.00002" / "source"
tmp_source.mkdir(parents=True)
(tmp_source / "main.tex").write_text(r"\documentclass{article}\begin{document}Hello\end{document}")
monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x)
monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
from app.services.summarizer import _extract_images_from_source
result = await _extract_images_from_source("2401.00002")
# Mock download_source_zip to avoid real network call
async def _noop_download(*args, **kwargs):
pass
monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
result = await extract_images_from_source("2401.00002")
assert result == 0
@@ -644,8 +648,7 @@ class TestGracefulDegradation:
"""CHROMA 关闭时删除论文正常工作。"""
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
import app.services.embedder as emb
emb._client = None
emb._collection = None
emb._chroma.reset()
from app.services.cleaner import delete_papers_by_date_range
result = await delete_papers_by_date_range(
+45 -37
View File
@@ -27,14 +27,7 @@ from app.services.schemas import (
flatten_for_db,
)
from app.services.summarizer import (
JsonNotFoundError,
PdfDownloadError,
PiProcessError,
PiTimeoutError,
_call_pi,
_classify_error,
_cleanup_tmp,
_extract_json,
_save_files,
_save_raw_output_only,
_update_summary_in_db,
@@ -42,6 +35,17 @@ from app.services.summarizer import (
summarize_one,
summarize_single,
)
from app.services.pi_client import (
JsonNotFoundError,
PiProcessError,
PiTimeoutError,
call_pi as _call_pi,
extract_json as _extract_json,
)
from app.services.pdf_downloader import (
PdfDownloadError,
cleanup_tmp as _cleanup_tmp,
)
# ═══════════════════════════════════════════════════════════════════════
@@ -287,7 +291,7 @@ class TestFileOperations:
def test_save_files(self, tmp_path, sample_summary_dict):
schema = SummarySchema.model_validate(sample_summary_dict)
with patch("app.services.summarizer._PAPERS_DIR", tmp_path):
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
_save_files("2401.12345", schema, "raw output text")
paper_dir = tmp_path / "2401.12345"
@@ -297,7 +301,7 @@ class TestFileOperations:
assert saved["title_zh"] == "测试论文中文标题"
def test_save_raw_output_only(self, tmp_path):
with patch("app.services.summarizer._PAPERS_DIR", tmp_path):
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
_save_raw_output_only("2401.12345", "raw output")
paper_dir = tmp_path / "2401.12345"
assert (paper_dir / "raw_output.txt").exists()
@@ -307,13 +311,13 @@ class TestFileOperations:
tmp_paper = tmp_path / "2401.12345"
tmp_paper.mkdir()
(tmp_paper / "paper.pdf").write_bytes(b"%PDF-fake")
with patch("app.services.summarizer._TMP_DIR", tmp_path):
with patch("app.services.pdf_downloader.TMP_DIR", tmp_path):
_cleanup_tmp("2401.12345")
assert not tmp_paper.exists()
def test_cleanup_tmp_nonexistent(self, tmp_path):
"""清理不存在的目录不报错。"""
with patch("app.services.summarizer._TMP_DIR", tmp_path):
with patch("app.services.pdf_downloader.TMP_DIR", tmp_path):
_cleanup_tmp("nonexistent") # 不抛异常
@@ -329,9 +333,11 @@ class TestSummarizeOneFlow:
def _patch_paths(self, tmp_path):
"""将 data 目录重定向到 tmp_path。"""
with (
patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"),
patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"),
patch("app.services.summarizer._DATA_DIR", tmp_path),
patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
patch("app.utils.TMP_DIR", tmp_path / "tmp"),
):
yield
@@ -341,8 +347,8 @@ class TestSummarizeOneFlow:
):
"""pending → processing → done 全流程。"""
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
result = await summarize_one(db_session, sample_paper)
@@ -374,7 +380,7 @@ class TestSummarizeOneFlow:
"""PDF 下载失败 → error_type=pdf_download_failedtmp 被清理。"""
with (
patch(
"app.services.summarizer._download_pdf",
"app.services.summarizer.download_pdf",
new_callable=AsyncMock,
side_effect=PdfDownloadError("network error"),
),
@@ -392,9 +398,9 @@ class TestSummarizeOneFlow:
async def test_pi_timeout(self, db_session, sample_paper, _patch_paths):
"""pi 超时 → timeout 错误,retry_count 递增。"""
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer._call_pi",
"app.services.summarizer.call_pi",
new_callable=AsyncMock,
side_effect=PiTimeoutError("timeout after 300s"),
),
@@ -409,9 +415,9 @@ class TestSummarizeOneFlow:
async def test_json_not_found(self, db_session, sample_paper, _patch_paths):
"""pi 输出无 JSON → json_not_found。"""
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer._call_pi",
"app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value="No JSON in this output at all.",
),
@@ -436,9 +442,9 @@ class TestSummarizeOneFlow:
bad_output = f"```json\n{bad_json}\n```"
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer._call_pi",
"app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value=bad_output,
),
@@ -464,9 +470,9 @@ class TestSummarizeOneFlow:
):
"""失败时仍保存 raw_output.txt。"""
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch(
"app.services.summarizer._call_pi",
"app.services.summarizer.call_pi",
new_callable=AsyncMock,
return_value="Some output without JSON",
),
@@ -483,8 +489,8 @@ class TestSummarizeOneFlow:
):
"""成功后清理 tmp 目录。"""
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
await summarize_one(db_session, sample_paper)
@@ -498,7 +504,7 @@ class TestSummarizeOneFlow:
"""失败后也清理 tmp 目录。"""
with (
patch(
"app.services.summarizer._download_pdf",
"app.services.summarizer.download_pdf",
new_callable=AsyncMock,
side_effect=PdfDownloadError("fail"),
),
@@ -529,9 +535,11 @@ class TestBatchSummarize:
@pytest.fixture
def _patch_paths(self, tmp_path):
with (
patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"),
patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"),
patch("app.services.summarizer._DATA_DIR", tmp_path),
patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
patch("app.utils.TMP_DIR", tmp_path / "tmp"),
):
yield
@@ -561,8 +569,8 @@ class TestBatchSummarize:
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
result = await summarize_batch(
db_session, _session_factory=_TestSession
@@ -612,8 +620,8 @@ class TestBatchSummarize:
return mock_pi_output
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer._call_pi", side_effect=_mock_call_pi),
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.call_pi", side_effect=_mock_call_pi),
):
result = await summarize_batch(
db_session, _session_factory=_TestSession
@@ -646,8 +654,8 @@ class TestBatchSummarize:
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
with (
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
):
await summarize_batch(
db_session, _session_factory=_TestSession