refactor: restructure services and add image/pdf extraction utilities
- Add image_extractor, pdf_downloader, pi_client, trends services - Add shared utils module - Refactor summarizer, embedder, routes for cleaner separation - Update tests to match new service structure
This commit is contained in:
+33
-1
@@ -1,6 +1,6 @@
|
|||||||
"""数据库引擎、会话工厂、初始化。"""
|
"""数据库引擎、会话工厂、初始化。"""
|
||||||
|
|
||||||
from sqlalchemy import event, create_engine
|
from sqlalchemy import event, create_engine, text
|
||||||
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
@@ -10,6 +10,27 @@ class Base(DeclarativeBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ── FTS5 和索引 DDL(与 ORM 模型分开管理)───────────────────────────────
|
||||||
|
|
||||||
|
FTS5_CREATE_SQL = """
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS papers_fts USING fts5(
|
||||||
|
title_en,
|
||||||
|
title_zh,
|
||||||
|
abstract,
|
||||||
|
authors,
|
||||||
|
tags,
|
||||||
|
summary_text,
|
||||||
|
tokenize='unicode61'
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
FTS5_TRIGGER_INDEX = """
|
||||||
|
-- partial index for task_locks running
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS uq_task_locks_running
|
||||||
|
ON task_locks(task, lock_key) WHERE status = 'running';
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _make_engine():
|
def _make_engine():
|
||||||
"""创建 SQLite 引擎,启用 foreign_keys。"""
|
"""创建 SQLite 引擎,启用 foreign_keys。"""
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
@@ -39,3 +60,14 @@ def get_db():
|
|||||||
yield db
|
yield db
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def init_db(engine):
|
||||||
|
"""创建所有 ORM 表 + FTS5 虚拟表。"""
|
||||||
|
from app.models import Base # noqa: F811 — 避免循环导入,延迟导入
|
||||||
|
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(text(FTS5_CREATE_SQL))
|
||||||
|
conn.execute(text(FTS5_TRIGGER_INDEX))
|
||||||
|
conn.commit()
|
||||||
|
|||||||
+21
-20
@@ -2,14 +2,13 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from starlette.staticfiles import StaticFiles as StarletteStaticFiles
|
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database import engine
|
from app.database import engine, init_db
|
||||||
from app.models import init_db
|
|
||||||
from app.routes.admin import router as admin_router
|
from app.routes.admin import router as admin_router
|
||||||
from app.routes.compare import router as compare_router
|
from app.routes.compare import router as compare_router
|
||||||
from app.routes.pages import router as pages_router
|
from app.routes.pages import router as pages_router
|
||||||
@@ -24,11 +23,30 @@ logging.basicConfig(
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理:启动与关闭。"""
|
||||||
|
# ── startup ──
|
||||||
|
from app.services.scheduler import start_scheduler
|
||||||
|
from app.services.embedder import init_chroma
|
||||||
|
|
||||||
|
start_scheduler()
|
||||||
|
init_chroma()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# ── shutdown ──
|
||||||
|
from app.services.scheduler import stop_scheduler
|
||||||
|
|
||||||
|
stop_scheduler()
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="HF Daily Papers",
|
title="HF Daily Papers",
|
||||||
description="HuggingFace Daily Papers — 中文论文导览站",
|
description="HuggingFace Daily Papers — 中文论文导览站",
|
||||||
version="0.1.0",
|
version="0.1.0",
|
||||||
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 确保数据目录存在
|
# 确保数据目录存在
|
||||||
@@ -65,23 +83,6 @@ def create_app() -> FastAPI:
|
|||||||
app.include_router(trends_router)
|
app.include_router(trends_router)
|
||||||
app.include_router(compare_router)
|
app.include_router(compare_router)
|
||||||
|
|
||||||
# 调度器(Phase 4)
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def _start_scheduler():
|
|
||||||
from app.services.scheduler import start_scheduler
|
|
||||||
start_scheduler()
|
|
||||||
|
|
||||||
# Phase 5: 初始化 ChromaDB
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def _init_chroma():
|
|
||||||
from app.services.embedder import init_chroma
|
|
||||||
init_chroma()
|
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
|
||||||
async def _stop_scheduler():
|
|
||||||
from app.services.scheduler import stop_scheduler
|
|
||||||
stop_scheduler()
|
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+1
-31
@@ -1,4 +1,4 @@
|
|||||||
"""SQLAlchemy ORM 模型 — papers, authors, tags, summaries, FTS5, logs, locks, user data。"""
|
"""SQLAlchemy ORM 模型 — papers, authors, tags, summaries, user data, logs, locks。"""
|
||||||
|
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
|
|
||||||
@@ -13,7 +13,6 @@ from sqlalchemy import (
|
|||||||
String,
|
String,
|
||||||
Text,
|
Text,
|
||||||
UniqueConstraint,
|
UniqueConstraint,
|
||||||
text,
|
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
@@ -204,32 +203,3 @@ class DataDeleteJob(Base):
|
|||||||
error = Column(Text)
|
error = Column(Text)
|
||||||
started_at = Column(DateTime, nullable=False)
|
started_at = Column(DateTime, nullable=False)
|
||||||
completed_at = Column(DateTime)
|
completed_at = Column(DateTime)
|
||||||
|
|
||||||
|
|
||||||
# ── FTS5 索引初始化 SQL(普通虚拟表,由应用层维护)──────────────────────
|
|
||||||
FTS5_CREATE_SQL = """
|
|
||||||
CREATE VIRTUAL TABLE IF NOT EXISTS papers_fts USING fts5(
|
|
||||||
title_en,
|
|
||||||
title_zh,
|
|
||||||
abstract,
|
|
||||||
authors,
|
|
||||||
tags,
|
|
||||||
summary_text,
|
|
||||||
tokenize='unicode61'
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
|
|
||||||
FTS5_TRIGGER_INDEX = """
|
|
||||||
-- partial index for task_locks running
|
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS uq_task_locks_running
|
|
||||||
ON task_locks(task, lock_key) WHERE status = 'running';
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def init_db(engine):
|
|
||||||
"""创建所有 ORM 表 + FTS5 虚拟表。"""
|
|
||||||
Base.metadata.create_all(engine)
|
|
||||||
with engine.connect() as conn:
|
|
||||||
conn.execute(text(FTS5_CREATE_SQL))
|
|
||||||
conn.execute(text(FTS5_TRIGGER_INDEX))
|
|
||||||
conn.commit()
|
|
||||||
|
|||||||
+9
-30
@@ -6,7 +6,6 @@ from datetime import date, datetime, timezone
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
from fastapi.templating import Jinja2Templates
|
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -17,10 +16,10 @@ from app.models import CrawlLog, DataDeleteJob, TaskLock
|
|||||||
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
|
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
|
||||||
from app.services.crawler import crawl_daily
|
from app.services.crawler import crawl_daily
|
||||||
from app.services.summarizer import summarize_batch, summarize_single
|
from app.services.summarizer import summarize_batch, summarize_single
|
||||||
|
from app.utils import release_lock, templates, today_str
|
||||||
|
|
||||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||||
security = HTTPBearer()
|
security = HTTPBearer()
|
||||||
templates = Jinja2Templates(directory="app/templates")
|
|
||||||
|
|
||||||
|
|
||||||
async def verify_admin(
|
async def verify_admin(
|
||||||
@@ -32,7 +31,7 @@ async def verify_admin(
|
|||||||
return credentials.credentials
|
return credentials.credentials
|
||||||
|
|
||||||
|
|
||||||
# ── 请求模型 ──────────────────────────────────────────────────────────────
|
# ── 请求模型 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class DeleteRequest(BaseModel):
|
class DeleteRequest(BaseModel):
|
||||||
@@ -49,7 +48,7 @@ class DeleteRequest(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
# ── 抓取 ──────────────────────────────────────────────────────────────────
|
# ── 抓取 ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.post("/crawl")
|
@router.post("/crawl")
|
||||||
@@ -59,12 +58,7 @@ async def admin_crawl(
|
|||||||
date: str | None = Query(None, description="YYYY-MM-DD,默认今天"),
|
date: str | None = Query(None, description="YYYY-MM-DD,默认今天"),
|
||||||
):
|
):
|
||||||
"""手动抓取指定日期,默认今天。"""
|
"""手动抓取指定日期,默认今天。"""
|
||||||
# 计算 target_date
|
target_date = date or today_str()
|
||||||
from zoneinfo import ZoneInfo
|
|
||||||
|
|
||||||
tz = ZoneInfo(settings.APP_TIMEZONE)
|
|
||||||
today = datetime.now(tz).strftime("%Y-%m-%d")
|
|
||||||
target_date = date or today
|
|
||||||
|
|
||||||
# TaskLock 防重入
|
# TaskLock 防重入
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
@@ -88,10 +82,10 @@ async def admin_crawl(
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise HTTPException(status_code=500, detail=str(exc))
|
raise HTTPException(status_code=500, detail=str(exc))
|
||||||
finally:
|
finally:
|
||||||
_release_lock(db, lock)
|
release_lock(db, lock)
|
||||||
|
|
||||||
|
|
||||||
# ── 总结 ──────────────────────────────────────────────────────────────────
|
# ── 总结 ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.post("/summarize")
|
@router.post("/summarize")
|
||||||
@@ -119,7 +113,7 @@ async def admin_summarize_single(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# ── 清理 ──────────────────────────────────────────────────────────────────
|
# ── 清理 ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.post("/cleanup")
|
@router.post("/cleanup")
|
||||||
@@ -155,7 +149,7 @@ async def admin_cleanup(
|
|||||||
raise HTTPException(status_code=500, detail=str(exc))
|
raise HTTPException(status_code=500, detail=str(exc))
|
||||||
|
|
||||||
|
|
||||||
# ── 删除 ──────────────────────────────────────────────────────────────────
|
# ── 删除 ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.post("/delete")
|
@router.post("/delete")
|
||||||
@@ -177,7 +171,7 @@ async def admin_delete(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# ── 日志 ──────────────────────────────────────────────────────────────────
|
# ── 日志 ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.get("/logs")
|
@router.get("/logs")
|
||||||
@@ -189,7 +183,6 @@ async def admin_logs(
|
|||||||
per_page: int = Query(20, ge=1, le=100),
|
per_page: int = Query(20, ge=1, le=100),
|
||||||
):
|
):
|
||||||
"""查看任务日志(CrawlLog + DataDeleteJob)。"""
|
"""查看任务日志(CrawlLog + DataDeleteJob)。"""
|
||||||
# 查询 crawl_logs
|
|
||||||
crawl_logs = (
|
crawl_logs = (
|
||||||
db.execute(
|
db.execute(
|
||||||
select(CrawlLog)
|
select(CrawlLog)
|
||||||
@@ -201,7 +194,6 @@ async def admin_logs(
|
|||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 查询 delete_jobs
|
|
||||||
delete_jobs = (
|
delete_jobs = (
|
||||||
db.execute(
|
db.execute(
|
||||||
select(DataDeleteJob)
|
select(DataDeleteJob)
|
||||||
@@ -223,16 +215,3 @@ async def admin_logs(
|
|||||||
"per_page": per_page,
|
"per_page": per_page,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── 工具函数 ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _release_lock(db: Session, lock: TaskLock) -> None:
|
|
||||||
"""释放 TaskLock。"""
|
|
||||||
try:
|
|
||||||
lock.status = "finished"
|
|
||||||
lock.released_at = datetime.now(timezone.utc)
|
|
||||||
db.commit()
|
|
||||||
except Exception:
|
|
||||||
db.rollback()
|
|
||||||
|
|||||||
@@ -2,19 +2,14 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
from fastapi.templating import Jinja2Templates
|
|
||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.models import Paper
|
from app.models import Paper
|
||||||
|
from app.utils import templates
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
templates = Jinja2Templates(directory="app/templates")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/compare")
|
@router.get("/compare")
|
||||||
|
|||||||
+7
-21
@@ -3,34 +3,27 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import date, datetime, timedelta
|
from datetime import date, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from zoneinfo import ZoneInfo
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
|
||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.models import Paper
|
from app.models import Paper
|
||||||
|
from app.utils import templates, today_str
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
templates = Jinja2Templates(directory="app/templates")
|
|
||||||
|
|
||||||
|
|
||||||
def _today() -> str:
|
|
||||||
tz = ZoneInfo(settings.APP_TIMEZONE)
|
|
||||||
return datetime.now(tz).strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
def index(request: Request):
|
def index(request: Request):
|
||||||
"""重定向到 /day/{today}。"""
|
"""重定向到 /day/{today}。"""
|
||||||
return RedirectResponse(url=f"/day/{_today()}")
|
return RedirectResponse(url=f"/day/{today_str()}")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/day/{date_str}")
|
@router.get("/day/{date_str}")
|
||||||
@@ -43,7 +36,7 @@ def day_page(date_str: str, request: Request, db: Session = Depends(get_db)):
|
|||||||
|
|
||||||
prev_day = (target - timedelta(days=1)).isoformat()
|
prev_day = (target - timedelta(days=1)).isoformat()
|
||||||
next_day = (target + timedelta(days=1)).isoformat()
|
next_day = (target + timedelta(days=1)).isoformat()
|
||||||
today_str = _today()
|
today = today_str()
|
||||||
|
|
||||||
papers = (
|
papers = (
|
||||||
db.query(Paper)
|
db.query(Paper)
|
||||||
@@ -74,7 +67,7 @@ def day_page(date_str: str, request: Request, db: Session = Depends(get_db)):
|
|||||||
"current_date": date_str,
|
"current_date": date_str,
|
||||||
"prev_day": prev_day,
|
"prev_day": prev_day,
|
||||||
"next_day": next_day,
|
"next_day": next_day,
|
||||||
"today": today_str,
|
"today": today,
|
||||||
"available_dates": available_dates,
|
"available_dates": available_dates,
|
||||||
"page_title": f"{date_str} 论文列表",
|
"page_title": f"{date_str} 论文列表",
|
||||||
},
|
},
|
||||||
@@ -146,16 +139,9 @@ def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from app.services.embedder import search_similar
|
from app.services.embedder import get_collection
|
||||||
|
|
||||||
# 用论文的 arxiv_id 从 ChromaDB 查询
|
|
||||||
col = None
|
|
||||||
try:
|
|
||||||
from app.services.embedder import get_collection
|
|
||||||
col = get_collection()
|
|
||||||
except Exception:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
col = get_collection()
|
||||||
if col is None:
|
if col is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|||||||
+7
-61
@@ -2,14 +2,11 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
from datetime import date, timedelta
|
||||||
from datetime import date, datetime, timedelta, timezone
|
|
||||||
from zoneinfo import ZoneInfo
|
|
||||||
from xml.sax.saxutils import escape
|
from xml.sax.saxutils import escape
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, Request
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
from fastapi.templating import Jinja2Templates
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
@@ -17,9 +14,10 @@ from app.config import settings
|
|||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.models import Paper, PaperTag, UserReadingStatus
|
from app.models import Paper, PaperTag, UserReadingStatus
|
||||||
from app.services.searcher import get_all_tags, search_papers
|
from app.services.searcher import get_all_tags, search_papers
|
||||||
|
from app.services.user_data import query_reading_list
|
||||||
|
from app.utils import templates, today_str
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
templates = Jinja2Templates(directory="app/templates")
|
|
||||||
|
|
||||||
|
|
||||||
# ── 搜索页 ────────────────────────────────────────────────────────────
|
# ── 搜索页 ────────────────────────────────────────────────────────────
|
||||||
@@ -56,7 +54,7 @@ def search_page(
|
|||||||
"total_pages": result["total_pages"],
|
"total_pages": result["total_pages"],
|
||||||
"all_tags": all_tags,
|
"all_tags": all_tags,
|
||||||
"page_title": f"搜索: {q}" if q else "搜索",
|
"page_title": f"搜索: {q}" if q else "搜索",
|
||||||
"today": _today_str(),
|
"today": today_str(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -114,7 +112,7 @@ def reading_list_page(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""阅读列表页面。"""
|
"""阅读列表页面。"""
|
||||||
papers = _query_reading_list(db, filter, tag or None)
|
papers = query_reading_list(db, filter, tag or None)
|
||||||
all_tags = get_all_tags(db)
|
all_tags = get_all_tags(db)
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
@@ -126,54 +124,11 @@ def reading_list_page(
|
|||||||
"current_tag": tag,
|
"current_tag": tag,
|
||||||
"all_tags": all_tags,
|
"all_tags": all_tags,
|
||||||
"page_title": "阅读列表",
|
"page_title": "阅读列表",
|
||||||
"today": _today_str(),
|
"today": today_str(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _query_reading_list(
|
|
||||||
db: Session,
|
|
||||||
filter_type: str,
|
|
||||||
tag: str | None,
|
|
||||||
) -> list[Paper]:
|
|
||||||
"""根据筛选条件查询阅读列表。"""
|
|
||||||
from sqlalchemy import or_
|
|
||||||
|
|
||||||
# 基础:有任意用户数据的论文
|
|
||||||
base = db.query(Paper).filter(
|
|
||||||
or_(
|
|
||||||
Paper.bookmark.has(),
|
|
||||||
Paper.reading_status.has(),
|
|
||||||
Paper.note.has(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 应用筛选
|
|
||||||
if filter_type == "has_note":
|
|
||||||
base = base.filter(Paper.note.has())
|
|
||||||
elif filter_type in ("unread", "skimmed", "read_summary", "read_full"):
|
|
||||||
base = base.filter(
|
|
||||||
Paper.reading_status.has(UserReadingStatus.status == filter_type)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 应用标签
|
|
||||||
if tag:
|
|
||||||
base = base.filter(Paper.tags.any(PaperTag.tag == tag))
|
|
||||||
|
|
||||||
return (
|
|
||||||
base.options(
|
|
||||||
joinedload(Paper.authors),
|
|
||||||
joinedload(Paper.tags),
|
|
||||||
joinedload(Paper.summary_status),
|
|
||||||
joinedload(Paper.bookmark),
|
|
||||||
joinedload(Paper.reading_status),
|
|
||||||
joinedload(Paper.note),
|
|
||||||
)
|
|
||||||
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── RSS Feed ──────────────────────────────────────────────────────────
|
# ── RSS Feed ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -216,7 +171,7 @@ def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> st
|
|||||||
lines.append(f" <title>{escape(channel_title)}</title>")
|
lines.append(f" <title>{escape(channel_title)}</title>")
|
||||||
lines.append(f" <link>{escape(base_url)}</link>")
|
lines.append(f" <link>{escape(base_url)}</link>")
|
||||||
lines.append(" <description>HuggingFace Daily Papers — 中文论文导览站</description>")
|
lines.append(" <description>HuggingFace Daily Papers — 中文论文导览站</description>")
|
||||||
lines.append(f" <language>zh-CN</language>")
|
lines.append(" <language>zh-CN</language>")
|
||||||
|
|
||||||
for paper in papers:
|
for paper in papers:
|
||||||
title_text = paper.title_zh or paper.title_en
|
title_text = paper.title_zh or paper.title_en
|
||||||
@@ -245,12 +200,3 @@ def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> st
|
|||||||
lines.append(" </channel>")
|
lines.append(" </channel>")
|
||||||
lines.append("</rss>")
|
lines.append("</rss>")
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
# ── 工具函数 ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _today_str() -> str:
|
|
||||||
"""当前日期字符串(按 APP_TIMEZONE)。"""
|
|
||||||
tz = ZoneInfo(settings.APP_TIMEZONE)
|
|
||||||
return datetime.now(tz).strftime("%Y-%m-%d")
|
|
||||||
|
|||||||
+5
-92
@@ -2,34 +2,27 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import date, timedelta
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Request
|
from fastapi import APIRouter, Depends, Request
|
||||||
from fastapi.templating import Jinja2Templates
|
|
||||||
from sqlalchemy import func, text
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import settings
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
from app.services.trends import get_trends_data
|
||||||
logger = logging.getLogger(__name__)
|
from app.utils import templates, today_str
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
templates = Jinja2Templates(directory="app/templates")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/trends")
|
@router.get("/trends")
|
||||||
def trends_page(request: Request, db: Session = Depends(get_db)):
|
def trends_page(request: Request, db: Session = Depends(get_db)):
|
||||||
"""趋势看板页面。"""
|
"""趋势看板页面。"""
|
||||||
stats = _get_trends_data(db)
|
stats = get_trends_data(db)
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
"trends.html",
|
"trends.html",
|
||||||
{
|
{
|
||||||
"page_title": "趋势看板",
|
"page_title": "趋势看板",
|
||||||
"stats": stats,
|
"stats": stats,
|
||||||
"today": _today_str(),
|
"today": today_str(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -37,84 +30,4 @@ def trends_page(request: Request, db: Session = Depends(get_db)):
|
|||||||
@router.get("/api/stats/trends")
|
@router.get("/api/stats/trends")
|
||||||
def trends_api(db: Session = Depends(get_db)):
|
def trends_api(db: Session = Depends(get_db)):
|
||||||
"""趋势数据 JSON API。"""
|
"""趋势数据 JSON API。"""
|
||||||
return _get_trends_data(db)
|
return get_trends_data(db)
|
||||||
|
|
||||||
|
|
||||||
def _get_trends_data(db: Session) -> dict:
|
|
||||||
"""从 DB 聚合趋势数据。"""
|
|
||||||
thirty_days_ago = (date.today() - timedelta(days=30)).isoformat()
|
|
||||||
|
|
||||||
# 1. 按日论文数量(近 30 天)
|
|
||||||
daily_rows = db.execute(text("""
|
|
||||||
SELECT paper_date, COUNT(*) as cnt
|
|
||||||
FROM papers
|
|
||||||
WHERE paper_date >= :start_date
|
|
||||||
GROUP BY paper_date
|
|
||||||
ORDER BY paper_date ASC
|
|
||||||
"""), {"start_date": thirty_days_ago}).fetchall()
|
|
||||||
daily_counts = [
|
|
||||||
{"date": str(row[0]), "count": row[1]}
|
|
||||||
for row in daily_rows
|
|
||||||
]
|
|
||||||
|
|
||||||
# 2. 热门标签 Top 20
|
|
||||||
tag_rows = db.execute(text("""
|
|
||||||
SELECT tag, COUNT(*) as cnt
|
|
||||||
FROM paper_tags
|
|
||||||
GROUP BY tag
|
|
||||||
ORDER BY cnt DESC
|
|
||||||
LIMIT 20
|
|
||||||
""")).fetchall()
|
|
||||||
top_tags = [
|
|
||||||
{"tag": row[0], "count": row[1]}
|
|
||||||
for row in tag_rows
|
|
||||||
]
|
|
||||||
|
|
||||||
# 3. Upvotes 分布
|
|
||||||
upvote_rows = db.execute(text("""
|
|
||||||
SELECT
|
|
||||||
CASE
|
|
||||||
WHEN upvotes >= 100 THEN '100+'
|
|
||||||
WHEN upvotes >= 50 THEN '50-99'
|
|
||||||
WHEN upvotes >= 20 THEN '20-49'
|
|
||||||
WHEN upvotes >= 10 THEN '10-19'
|
|
||||||
WHEN upvotes >= 5 THEN '5-9'
|
|
||||||
ELSE '0-4'
|
|
||||||
END as bucket,
|
|
||||||
COUNT(*) as cnt
|
|
||||||
FROM papers
|
|
||||||
GROUP BY bucket
|
|
||||||
ORDER BY MIN(upvotes) DESC
|
|
||||||
""")).fetchall()
|
|
||||||
upvotes_dist = [
|
|
||||||
{"range": row[0], "count": row[1]}
|
|
||||||
for row in upvote_rows
|
|
||||||
]
|
|
||||||
|
|
||||||
# 4. 总结完成率
|
|
||||||
summary_rows = db.execute(text("""
|
|
||||||
SELECT
|
|
||||||
COALESCE(ss.status, 'none') as status,
|
|
||||||
COUNT(*) as cnt
|
|
||||||
FROM papers p
|
|
||||||
LEFT JOIN summary_status ss ON ss.paper_id = p.id
|
|
||||||
GROUP BY status
|
|
||||||
""")).fetchall()
|
|
||||||
summary_completion = [
|
|
||||||
{"status": row[0], "count": row[1]}
|
|
||||||
for row in summary_rows
|
|
||||||
]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"daily_counts": daily_counts,
|
|
||||||
"top_tags": top_tags,
|
|
||||||
"upvotes_dist": upvotes_dist,
|
|
||||||
"summary_completion": summary_completion,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _today_str() -> str:
|
|
||||||
from datetime import datetime
|
|
||||||
from zoneinfo import ZoneInfo
|
|
||||||
tz = ZoneInfo(settings.APP_TIMEZONE)
|
|
||||||
return datetime.now(tz).strftime("%Y-%m-%d")
|
|
||||||
|
|||||||
@@ -16,13 +16,10 @@ from app.models import (
|
|||||||
Paper,
|
Paper,
|
||||||
TaskLock,
|
TaskLock,
|
||||||
)
|
)
|
||||||
|
from app.utils import PAPERS_DIR, TMP_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_DATA_DIR = Path("data")
|
|
||||||
_TMP_DIR = _DATA_DIR / "tmp"
|
|
||||||
_PAPERS_DIR = _DATA_DIR / "papers"
|
|
||||||
|
|
||||||
# 临时文件最大保留时间(小时)
|
# 临时文件最大保留时间(小时)
|
||||||
_MAX_TMP_AGE_HOURS = 24
|
_MAX_TMP_AGE_HOURS = 24
|
||||||
|
|
||||||
@@ -39,7 +36,7 @@ def cleanup_tmp(max_age_hours: int = _MAX_TMP_AGE_HOURS) -> dict:
|
|||||||
Returns:
|
Returns:
|
||||||
清理统计 {"scanned": int, "removed": int, "errors": list[str]}
|
清理统计 {"scanned": int, "removed": int, "errors": list[str]}
|
||||||
"""
|
"""
|
||||||
if not _TMP_DIR.exists():
|
if not TMP_DIR.exists():
|
||||||
return {"scanned": 0, "removed": 0, "errors": []}
|
return {"scanned": 0, "removed": 0, "errors": []}
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
@@ -48,7 +45,7 @@ def cleanup_tmp(max_age_hours: int = _MAX_TMP_AGE_HOURS) -> dict:
|
|||||||
removed = 0
|
removed = 0
|
||||||
errors: list[str] = []
|
errors: list[str] = []
|
||||||
|
|
||||||
for entry in _TMP_DIR.iterdir():
|
for entry in TMP_DIR.iterdir():
|
||||||
if not entry.is_dir():
|
if not entry.is_dir():
|
||||||
continue
|
continue
|
||||||
scanned += 1
|
scanned += 1
|
||||||
@@ -147,13 +144,13 @@ async def delete_papers_by_date_range(
|
|||||||
logger.warning("Failed to delete %s from ChromaDB", arxiv_id, exc_info=True)
|
logger.warning("Failed to delete %s from ChromaDB", arxiv_id, exc_info=True)
|
||||||
|
|
||||||
# 2. 删除本地文件 data/papers/{arxiv_id}/
|
# 2. 删除本地文件 data/papers/{arxiv_id}/
|
||||||
paper_dir = _PAPERS_DIR / arxiv_id
|
paper_dir = PAPERS_DIR / arxiv_id
|
||||||
if paper_dir.exists():
|
if paper_dir.exists():
|
||||||
shutil.rmtree(paper_dir)
|
shutil.rmtree(paper_dir)
|
||||||
logger.debug("Removed paper dir: %s", paper_dir)
|
logger.debug("Removed paper dir: %s", paper_dir)
|
||||||
|
|
||||||
# 3. 删除临时文件 data/tmp/{arxiv_id}/
|
# 3. 删除临时文件 data/tmp/{arxiv_id}/
|
||||||
tmp_dir = _TMP_DIR / arxiv_id
|
tmp_dir = TMP_DIR / arxiv_id
|
||||||
if tmp_dir.exists():
|
if tmp_dir.exists():
|
||||||
shutil.rmtree(tmp_dir)
|
shutil.rmtree(tmp_dir)
|
||||||
logger.debug("Removed tmp dir: %s", tmp_dir)
|
logger.debug("Removed tmp dir: %s", tmp_dir)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from app.models import (
|
|||||||
PaperTag,
|
PaperTag,
|
||||||
SummaryStatus,
|
SummaryStatus,
|
||||||
)
|
)
|
||||||
|
from app.utils import make_http_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -34,15 +35,7 @@ async def fetch_daily(target_date: str, top_n: int | None = None) -> list[dict]:
|
|||||||
url = f"{settings.HF_API_BASE}/daily_papers"
|
url = f"{settings.HF_API_BASE}/daily_papers"
|
||||||
params = {"date": target_date}
|
params = {"date": target_date}
|
||||||
|
|
||||||
transport = None
|
async with make_http_client() as client:
|
||||||
if settings.http_proxy:
|
|
||||||
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(
|
|
||||||
timeout=settings.HTTP_TIMEOUT_SECONDS,
|
|
||||||
headers={"User-Agent": settings.HTTP_USER_AGENT},
|
|
||||||
transport=transport,
|
|
||||||
) as client:
|
|
||||||
for attempt in range(1, settings.HTTP_MAX_RETRIES + 1):
|
for attempt in range(1, settings.HTTP_MAX_RETRIES + 1):
|
||||||
try:
|
try:
|
||||||
logger.info("Fetching HF Daily Papers: date=%s attempt=%d", target_date, attempt)
|
logger.info("Fetching HF Daily Papers: date=%s attempt=%d", target_date, attempt)
|
||||||
|
|||||||
+70
-54
@@ -5,8 +5,6 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import httpx
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
@@ -14,66 +12,82 @@ from app.models import Paper
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ── 单例客户端和 collection ─────────────────────────────────────────────
|
|
||||||
_client = None
|
# ── ChromaDB 管理器(替代全局可变状态)──────────────────────────────────
|
||||||
_collection = None
|
|
||||||
|
|
||||||
|
|
||||||
def _chroma_dir() -> Path:
|
class ChromaManager:
|
||||||
return Path(settings.CHROMA_DIR)
|
"""封装 ChromaDB 客户端和 collection 的生命周期。"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._client = None
|
||||||
|
self._collection = None
|
||||||
|
|
||||||
|
def init(self) -> None:
|
||||||
|
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
|
||||||
|
if not settings.CHROMA_ENABLED:
|
||||||
|
logger.debug("ChromaDB disabled, skip init")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._client is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
chroma_path = Path(settings.CHROMA_DIR)
|
||||||
|
chroma_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self._client = chromadb.PersistentClient(path=str(chroma_path))
|
||||||
|
self._collection = self._get_or_create_collection()
|
||||||
|
logger.info("ChromaDB initialized at %s", chroma_path)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to initialize ChromaDB")
|
||||||
|
self._client = None
|
||||||
|
self._collection = None
|
||||||
|
|
||||||
|
def _get_or_create_collection(self):
|
||||||
|
"""获取或创建 papers_embeddings collection。"""
|
||||||
|
try:
|
||||||
|
col = self._client.get_collection("papers_embeddings")
|
||||||
|
logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count())
|
||||||
|
return col
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
col = self._client.create_collection(
|
||||||
|
name="papers_embeddings",
|
||||||
|
metadata={"hnsw:space": "cosine"},
|
||||||
|
)
|
||||||
|
logger.info("ChromaDB collection 'papers_embeddings' created")
|
||||||
|
return col
|
||||||
|
|
||||||
|
def get_collection(self):
|
||||||
|
"""返回当前 collection,未初始化则自动初始化。"""
|
||||||
|
if not settings.CHROMA_ENABLED:
|
||||||
|
return None
|
||||||
|
if self._collection is None:
|
||||||
|
self.init()
|
||||||
|
return self._collection
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""重置状态(供测试使用)。"""
|
||||||
|
self._client = None
|
||||||
|
self._collection = None
|
||||||
|
|
||||||
|
|
||||||
|
# 模块级单例
|
||||||
|
_chroma = ChromaManager()
|
||||||
|
|
||||||
|
|
||||||
def init_chroma() -> None:
|
def init_chroma() -> None:
|
||||||
"""CHROMA_ENABLED=true 时初始化 ChromaDB 持久客户端和 collection。"""
|
"""初始化 ChromaDB(供 lifespan 调用)。"""
|
||||||
global _client, _collection
|
_chroma.init()
|
||||||
if not settings.CHROMA_ENABLED:
|
|
||||||
logger.debug("ChromaDB disabled, skip init")
|
|
||||||
return
|
|
||||||
|
|
||||||
if _client is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
import chromadb
|
|
||||||
|
|
||||||
chroma_path = _chroma_dir()
|
|
||||||
chroma_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
_client = chromadb.PersistentClient(path=str(chroma_path))
|
|
||||||
_collection = _get_or_create_collection()
|
|
||||||
logger.info("ChromaDB initialized at %s", chroma_path)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to initialize ChromaDB")
|
|
||||||
_client = None
|
|
||||||
_collection = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_or_create_collection():
|
|
||||||
"""获取或创建 papers_embeddings collection,维度不匹配时记录日志并跳过。"""
|
|
||||||
import chromadb
|
|
||||||
|
|
||||||
try:
|
|
||||||
col = _client.get_collection("papers_embeddings")
|
|
||||||
logger.info("ChromaDB collection 'papers_embeddings' loaded, count=%d", col.count())
|
|
||||||
return col
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
col = _client.create_collection(
|
|
||||||
name="papers_embeddings",
|
|
||||||
metadata={"hnsw:space": "cosine"},
|
|
||||||
)
|
|
||||||
logger.info("ChromaDB collection 'papers_embeddings' created")
|
|
||||||
return col
|
|
||||||
|
|
||||||
|
|
||||||
def get_collection():
|
def get_collection():
|
||||||
"""返回当前 collection,未初始化则返回 None。"""
|
"""返回当前 collection,未初始化则返回 None。"""
|
||||||
if not settings.CHROMA_ENABLED:
|
return _chroma.get_collection()
|
||||||
return None
|
|
||||||
if _collection is None:
|
|
||||||
init_chroma()
|
|
||||||
return _collection
|
|
||||||
|
|
||||||
|
|
||||||
# ── Embedding API 调用 ──────────────────────────────────────────────────
|
# ── Embedding API 调用 ──────────────────────────────────────────────────
|
||||||
@@ -90,6 +104,8 @@ def _get_embedding(text: str) -> list[float] | None:
|
|||||||
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
|
logger.warning("EMBED_API_BASE or EMBED_MODEL not configured, skip embedding")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
from app.utils import make_http_client
|
||||||
|
|
||||||
url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings"
|
url = f"{settings.EMBED_API_BASE.rstrip('/')}/v1/embeddings"
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
if settings.EMBED_API_KEY:
|
if settings.EMBED_API_KEY:
|
||||||
@@ -101,7 +117,7 @@ def _get_embedding(text: str) -> list[float] | None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with httpx.Client(timeout=settings.HTTP_TIMEOUT_SECONDS) as client:
|
with make_http_client(sync=True) as client:
|
||||||
resp = client.post(url, json=payload, headers=headers)
|
resp = client.post(url, json=payload, headers=headers)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|||||||
@@ -0,0 +1,83 @@
|
|||||||
|
"""LaTeX 图片提取 — 从 arXiv 源码中扫描 \\includegraphics 并提取图片文件。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from app.services.pdf_downloader import download_source_zip, paper_dir, tmp_dir
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_INCLUDEGRAPHICS_RE = re.compile(
|
||||||
|
r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
|
||||||
|
)
|
||||||
|
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_images_from_source(arxiv_id: str) -> int:
|
||||||
|
"""从 LaTeX 源码中提取图片文件。
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
|
||||||
|
2. 扫描 .tex 文件中的 \\includegraphics
|
||||||
|
3. 复制图片到 data/papers/{arxiv_id}/images/
|
||||||
|
4. 清理源码临时文件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
提取的图片数量
|
||||||
|
"""
|
||||||
|
tmp_source = tmp_dir(arxiv_id) / "source"
|
||||||
|
images_dest = paper_dir(arxiv_id) / "images"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 下载源码 zip(如果还没下载)
|
||||||
|
if not tmp_source.exists():
|
||||||
|
source_url = f"https://arxiv.org/e-print/{arxiv_id}"
|
||||||
|
await download_source_zip(arxiv_id, source_url, tmp_source)
|
||||||
|
|
||||||
|
if not tmp_source.exists():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 扫描 .tex 文件,收集图片路径
|
||||||
|
image_paths: set[str] = set()
|
||||||
|
for tex_file in tmp_source.rglob("*.tex"):
|
||||||
|
try:
|
||||||
|
content = tex_file.read_text(encoding="utf-8", errors="replace")
|
||||||
|
for match in _INCLUDEGRAPHICS_RE.finditer(content):
|
||||||
|
img_path = match.group(1).strip()
|
||||||
|
image_paths.add(img_path)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not image_paths:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 查找并复制图片
|
||||||
|
images_dest.mkdir(parents=True, exist_ok=True)
|
||||||
|
copied = 0
|
||||||
|
for img_rel in image_paths:
|
||||||
|
# 尝试在源码目录中找到文件
|
||||||
|
for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
|
||||||
|
candidate = tmp_source / (img_rel + ext)
|
||||||
|
if candidate.is_file():
|
||||||
|
dest_name = candidate.name
|
||||||
|
# 避免文件名冲突
|
||||||
|
dest = images_dest / dest_name
|
||||||
|
if dest.exists():
|
||||||
|
stem = dest.stem
|
||||||
|
suffix = dest.suffix
|
||||||
|
dest = images_dest / f"{stem}_{copied}{suffix}"
|
||||||
|
shutil.copy2(candidate, dest)
|
||||||
|
copied += 1
|
||||||
|
break
|
||||||
|
|
||||||
|
if copied > 0:
|
||||||
|
logger.info("Extracted %d images from source for %s", copied, arxiv_id)
|
||||||
|
return copied
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||||||
|
return 0
|
||||||
@@ -0,0 +1,105 @@
|
|||||||
|
"""PDF 下载与源码下载 — 从 arXiv 下载论文 PDF 和 LaTeX 源码包。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from app.utils import PAPERS_DIR, TMP_DIR, make_http_client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 自定义异常 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class PdfDownloadError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ── 路径工具 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def paper_dir(arxiv_id: str) -> Path:
|
||||||
|
return PAPERS_DIR / arxiv_id
|
||||||
|
|
||||||
|
|
||||||
|
def tmp_dir(arxiv_id: str) -> Path:
|
||||||
|
return TMP_DIR / arxiv_id
|
||||||
|
|
||||||
|
|
||||||
|
# ── PDF 下载 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def download_pdf(arxiv_id: str, pdf_url: str) -> Path:
|
||||||
|
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
|
||||||
|
if not pdf_url:
|
||||||
|
raise PdfDownloadError(f"no pdf_url for {arxiv_id}")
|
||||||
|
|
||||||
|
dest_dir = tmp_dir(arxiv_id)
|
||||||
|
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
dest = dest_dir / "paper.pdf"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with make_http_client(follow_redirects=True) as client:
|
||||||
|
resp = await client.get(pdf_url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
dest.write_bytes(resp.content)
|
||||||
|
except Exception as exc:
|
||||||
|
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
|
||||||
|
|
||||||
|
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
|
||||||
|
return dest
|
||||||
|
|
||||||
|
|
||||||
|
# ── 源码下载 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def download_source_zip(arxiv_id: str, source_url: str, dest_dir: Path) -> None:
|
||||||
|
"""下载 arXiv 源码并解压。"""
|
||||||
|
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
zip_path = tmp_dir(arxiv_id) / "source.zip"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with make_http_client(follow_redirects=True) as client:
|
||||||
|
resp = await client.get(source_url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
zip_path.write_bytes(resp.content)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||||
|
zf.extractall(dest_dir)
|
||||||
|
logger.debug("Extracted source for %s", arxiv_id)
|
||||||
|
except zipfile.BadZipFile:
|
||||||
|
# 可能是 tar.gz
|
||||||
|
import tarfile
|
||||||
|
try:
|
||||||
|
with tarfile.open(zip_path, "r:*") as tf:
|
||||||
|
tf.extractall(dest_dir, filter="data")
|
||||||
|
logger.debug("Extracted source (tar) for %s", arxiv_id)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Cannot extract source for %s", arxiv_id)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
|
||||||
|
finally:
|
||||||
|
if zip_path.exists():
|
||||||
|
zip_path.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
# ── 临时文件清理 ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_tmp(arxiv_id: str) -> None:
|
||||||
|
"""清理 data/tmp/{arxiv_id}/ 目录。"""
|
||||||
|
td = tmp_dir(arxiv_id)
|
||||||
|
if td.exists():
|
||||||
|
try:
|
||||||
|
shutil.rmtree(td)
|
||||||
|
logger.debug("Cleaned tmp: %s", arxiv_id)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
|
||||||
@@ -0,0 +1,160 @@
|
|||||||
|
"""pi CLI 调用与 JSON 提取 — 调用 pi 生成总结,从输出中提取结构化 JSON。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 自定义异常 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class PiTimeoutError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PiProcessError(Exception):
|
||||||
|
def __init__(self, returncode: int, stderr: str):
|
||||||
|
self.returncode = returncode
|
||||||
|
self.stderr = stderr
|
||||||
|
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
|
||||||
|
|
||||||
|
|
||||||
|
class JsonNotFoundError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ── meta.json ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def write_meta_json(paper) -> Path:
|
||||||
|
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
|
||||||
|
from app.services.pdf_downloader import paper_dir
|
||||||
|
|
||||||
|
d = paper_dir(paper.arxiv_id)
|
||||||
|
d.mkdir(parents=True, exist_ok=True)
|
||||||
|
meta_path = d / "meta.json"
|
||||||
|
|
||||||
|
authors = [a.name for a in paper.authors]
|
||||||
|
tags = [t.tag for t in paper.tags]
|
||||||
|
meta = {
|
||||||
|
"arxiv_id": paper.arxiv_id,
|
||||||
|
"title_en": paper.title_en,
|
||||||
|
"abstract": paper.abstract or "",
|
||||||
|
"published_at": paper.published_at.isoformat() if paper.published_at else None,
|
||||||
|
"authors": authors,
|
||||||
|
"tags": tags,
|
||||||
|
"upvotes": paper.upvotes,
|
||||||
|
}
|
||||||
|
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
return meta_path
|
||||||
|
|
||||||
|
|
||||||
|
# ── pi CLI 调用 ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def call_pi(meta_path: Path, pdf_path: Path) -> str:
|
||||||
|
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
|
||||||
|
arxiv_id = meta_path.parent.name
|
||||||
|
cmd = [
|
||||||
|
settings.PI_BIN,
|
||||||
|
"-p",
|
||||||
|
"--no-tools",
|
||||||
|
"--skill",
|
||||||
|
settings.SUMMARY_SKILL,
|
||||||
|
"请深度解读以下论文,并按指定 JSON schema 输出:",
|
||||||
|
f"@{meta_path}",
|
||||||
|
f"@{pdf_path}",
|
||||||
|
]
|
||||||
|
logger.info("Calling pi for %s", arxiv_id)
|
||||||
|
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*cmd,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
stdout, stderr = await asyncio.wait_for(
|
||||||
|
proc.communicate(),
|
||||||
|
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
proc.kill()
|
||||||
|
await proc.wait()
|
||||||
|
raise PiTimeoutError(
|
||||||
|
f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
if proc.returncode != 0:
|
||||||
|
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
|
||||||
|
|
||||||
|
return stdout.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
|
||||||
|
# ── JSON 提取 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def extract_json(raw_output: str) -> dict:
|
||||||
|
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
|
||||||
|
# 策略 1:整体直接解析
|
||||||
|
stripped = raw_output.strip()
|
||||||
|
try:
|
||||||
|
result = json.loads(stripped)
|
||||||
|
if isinstance(result, dict) and "title_zh" in result:
|
||||||
|
return result
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 策略 2:提取 ```json ... ``` 代码块
|
||||||
|
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
|
||||||
|
for match in fence_pattern.finditer(raw_output):
|
||||||
|
try:
|
||||||
|
result = json.loads(match.group(1).strip())
|
||||||
|
if isinstance(result, dict) and "title_zh" in result:
|
||||||
|
return result
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 策略 3:匹配包含 title_zh 的最大 {...} 块
|
||||||
|
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
|
||||||
|
for match in brace_pattern.finditer(raw_output):
|
||||||
|
try:
|
||||||
|
return json.loads(match.group(0))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 更宽松:找到最大的 { ... } 平衡块
|
||||||
|
best = None
|
||||||
|
best_len = 0
|
||||||
|
for i, ch in enumerate(raw_output):
|
||||||
|
if ch != "{":
|
||||||
|
continue
|
||||||
|
depth = 0
|
||||||
|
for j in range(i, len(raw_output)):
|
||||||
|
if raw_output[j] == "{":
|
||||||
|
depth += 1
|
||||||
|
elif raw_output[j] == "}":
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
candidate = raw_output[i : j + 1]
|
||||||
|
if len(candidate) > best_len:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(candidate)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
best = parsed
|
||||||
|
best_len = len(candidate)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
break
|
||||||
|
|
||||||
|
if best is not None:
|
||||||
|
return best
|
||||||
|
|
||||||
|
raise JsonNotFoundError("no JSON object found in pi output")
|
||||||
+39
-362
@@ -1,18 +1,15 @@
|
|||||||
"""AI 总结服务 — 调用 pi CLI 生成论文中文结构化总结。"""
|
"""AI 总结编排服务 — 协调 PDF 下载、pi CLI 调用、JSON 校验、DB 写入、语义索引。"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
import shutil
|
import shutil
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import httpx
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from sqlalchemy import select, text
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
@@ -25,216 +22,31 @@ from app.models import (
|
|||||||
SummaryStatus,
|
SummaryStatus,
|
||||||
TaskLock,
|
TaskLock,
|
||||||
)
|
)
|
||||||
|
from app.services.image_extractor import extract_images_from_source
|
||||||
|
from app.services.pdf_downloader import (
|
||||||
|
PdfDownloadError,
|
||||||
|
cleanup_tmp,
|
||||||
|
download_pdf,
|
||||||
|
paper_dir,
|
||||||
|
)
|
||||||
|
from app.services.pi_client import (
|
||||||
|
JsonNotFoundError,
|
||||||
|
PiProcessError,
|
||||||
|
PiTimeoutError,
|
||||||
|
call_pi,
|
||||||
|
extract_json,
|
||||||
|
write_meta_json,
|
||||||
|
)
|
||||||
from app.services.schemas import (
|
from app.services.schemas import (
|
||||||
SummarySchema,
|
SummarySchema,
|
||||||
assess_quality,
|
assess_quality,
|
||||||
classify_validation_error,
|
classify_validation_error,
|
||||||
flatten_for_db,
|
flatten_for_db,
|
||||||
)
|
)
|
||||||
|
from app.utils import PAPERS_DIR, release_lock
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ── 自定义异常 ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class PdfDownloadError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PiTimeoutError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PiProcessError(Exception):
|
|
||||||
def __init__(self, returncode: int, stderr: str):
|
|
||||||
self.returncode = returncode
|
|
||||||
self.stderr = stderr
|
|
||||||
super().__init__(f"pi exited with code {returncode}: {stderr[:500]}")
|
|
||||||
|
|
||||||
|
|
||||||
class JsonNotFoundError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# ── 路径工具 ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_DATA_DIR = Path("data")
|
|
||||||
_PAPERS_DIR = _DATA_DIR / "papers"
|
|
||||||
_TMP_DIR = _DATA_DIR / "tmp"
|
|
||||||
|
|
||||||
|
|
||||||
def _paper_dir(arxiv_id: str) -> Path:
|
|
||||||
return _PAPERS_DIR / arxiv_id
|
|
||||||
|
|
||||||
|
|
||||||
def _tmp_dir(arxiv_id: str) -> Path:
|
|
||||||
return _TMP_DIR / arxiv_id
|
|
||||||
|
|
||||||
|
|
||||||
# ── PDF 下载 ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _download_pdf(arxiv_id: str, pdf_url: str) -> Path:
|
|
||||||
"""下载 PDF 到 data/tmp/{arxiv_id}/paper.pdf。"""
|
|
||||||
if not pdf_url:
|
|
||||||
raise PdfDownloadError(f"no pdf_url for {arxiv_id}")
|
|
||||||
|
|
||||||
tmp = _tmp_dir(arxiv_id)
|
|
||||||
tmp.mkdir(parents=True, exist_ok=True)
|
|
||||||
dest = tmp / "paper.pdf"
|
|
||||||
|
|
||||||
transport = None
|
|
||||||
if settings.http_proxy:
|
|
||||||
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(
|
|
||||||
timeout=settings.HTTP_TIMEOUT_SECONDS,
|
|
||||||
headers={"User-Agent": settings.HTTP_USER_AGENT},
|
|
||||||
transport=transport,
|
|
||||||
follow_redirects=True,
|
|
||||||
) as client:
|
|
||||||
resp = await client.get(pdf_url)
|
|
||||||
resp.raise_for_status()
|
|
||||||
dest.write_bytes(resp.content)
|
|
||||||
except Exception as exc:
|
|
||||||
raise PdfDownloadError(f"failed to download PDF for {arxiv_id}: {exc}") from exc
|
|
||||||
|
|
||||||
logger.info("Downloaded PDF: %s (%d bytes)", arxiv_id, dest.stat().st_size)
|
|
||||||
return dest
|
|
||||||
|
|
||||||
|
|
||||||
# ── meta.json ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _write_meta_json(paper: Paper) -> Path:
|
|
||||||
"""写入 data/papers/{arxiv_id}/meta.json,返回路径。"""
|
|
||||||
d = _paper_dir(paper.arxiv_id)
|
|
||||||
d.mkdir(parents=True, exist_ok=True)
|
|
||||||
meta_path = d / "meta.json"
|
|
||||||
|
|
||||||
authors = [a.name for a in paper.authors]
|
|
||||||
tags = [t.tag for t in paper.tags]
|
|
||||||
meta = {
|
|
||||||
"arxiv_id": paper.arxiv_id,
|
|
||||||
"title_en": paper.title_en,
|
|
||||||
"abstract": paper.abstract or "",
|
|
||||||
"published_at": paper.published_at.isoformat() if paper.published_at else None,
|
|
||||||
"authors": authors,
|
|
||||||
"tags": tags,
|
|
||||||
"upvotes": paper.upvotes,
|
|
||||||
}
|
|
||||||
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
||||||
return meta_path
|
|
||||||
|
|
||||||
|
|
||||||
# ── pi CLI 调用 ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _call_pi(meta_path: Path, pdf_path: Path) -> str:
|
|
||||||
"""调用 pi CLI 非交互模式,返回 stdout 文本。"""
|
|
||||||
cmd = [
|
|
||||||
settings.PI_BIN,
|
|
||||||
"-p",
|
|
||||||
"--no-tools",
|
|
||||||
"--skill",
|
|
||||||
settings.SUMMARY_SKILL,
|
|
||||||
"请深度解读以下论文,并按指定 JSON schema 输出:",
|
|
||||||
f"@{meta_path}",
|
|
||||||
f"@{pdf_path}",
|
|
||||||
]
|
|
||||||
logger.info("Calling pi: %s %s", paper_id_from_path(meta_path), " ".join(cmd[:4]))
|
|
||||||
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
|
||||||
*cmd,
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.PIPE,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
stdout, stderr = await asyncio.wait_for(
|
|
||||||
proc.communicate(),
|
|
||||||
timeout=settings.SUMMARY_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
proc.kill()
|
|
||||||
await proc.wait()
|
|
||||||
raise PiTimeoutError(
|
|
||||||
f"pi timed out after {settings.SUMMARY_TIMEOUT_SECONDS}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
if proc.returncode != 0:
|
|
||||||
raise PiProcessError(proc.returncode, stderr.decode("utf-8", errors="replace"))
|
|
||||||
|
|
||||||
return stdout.decode("utf-8", errors="replace")
|
|
||||||
|
|
||||||
|
|
||||||
def paper_id_from_path(meta_path: Path) -> str:
|
|
||||||
"""从 meta.json 路径反推 arxiv_id。"""
|
|
||||||
return meta_path.parent.name
|
|
||||||
|
|
||||||
|
|
||||||
# ── JSON 提取 ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_json(raw_output: str) -> dict:
|
|
||||||
"""从 pi 输出中提取 JSON dict。三步策略:直接解析 → 代码块 → 最大花括号块。"""
|
|
||||||
# 策略 1:整体直接解析
|
|
||||||
stripped = raw_output.strip()
|
|
||||||
try:
|
|
||||||
result = json.loads(stripped)
|
|
||||||
if isinstance(result, dict) and "title_zh" in result:
|
|
||||||
return result
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 策略 2:提取 ```json ... ``` 代码块
|
|
||||||
fence_pattern = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL)
|
|
||||||
for match in fence_pattern.finditer(raw_output):
|
|
||||||
try:
|
|
||||||
result = json.loads(match.group(1).strip())
|
|
||||||
if isinstance(result, dict) and "title_zh" in result:
|
|
||||||
return result
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 策略 3:匹配包含 title_zh 的最大 {...} 块
|
|
||||||
brace_pattern = re.compile(r"\{[^{}]*\"title_zh\"[^{}]*\}", re.DOTALL)
|
|
||||||
# 先尝试一层嵌套;如果没命中再用更宽松的策略
|
|
||||||
for match in brace_pattern.finditer(raw_output):
|
|
||||||
try:
|
|
||||||
return json.loads(match.group(0))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 更宽松:找到最大的 { ... } 平衡块
|
|
||||||
best = None
|
|
||||||
best_len = 0
|
|
||||||
for i, ch in enumerate(raw_output):
|
|
||||||
if ch != "{":
|
|
||||||
continue
|
|
||||||
depth = 0
|
|
||||||
for j in range(i, len(raw_output)):
|
|
||||||
if raw_output[j] == "{":
|
|
||||||
depth += 1
|
|
||||||
elif raw_output[j] == "}":
|
|
||||||
depth -= 1
|
|
||||||
if depth == 0:
|
|
||||||
candidate = raw_output[i : j + 1]
|
|
||||||
if len(candidate) > best_len:
|
|
||||||
try:
|
|
||||||
parsed = json.loads(candidate)
|
|
||||||
if isinstance(parsed, dict):
|
|
||||||
best = parsed
|
|
||||||
best_len = len(candidate)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
break
|
|
||||||
|
|
||||||
if best is not None:
|
|
||||||
return best
|
|
||||||
|
|
||||||
raise JsonNotFoundError("no JSON object found in pi output")
|
|
||||||
|
|
||||||
|
|
||||||
# ── 错误分类 ────────────────────────────────────────────────────────────
|
# ── 错误分类 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -284,6 +96,8 @@ def _update_summary_in_db(
|
|||||||
raw_output: str,
|
raw_output: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。"""
|
"""将校验后的总结写入 DB:paper_summaries + papers + paper_tags + FTS5。"""
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
# 1. paper_summaries:upsert
|
# 1. paper_summaries:upsert
|
||||||
@@ -298,9 +112,9 @@ def _update_summary_in_db(
|
|||||||
# 2. papers 表
|
# 2. papers 表
|
||||||
paper.title_zh = schema.title_zh
|
paper.title_zh = schema.title_zh
|
||||||
paper.summary_quality = quality
|
paper.summary_quality = quality
|
||||||
paper_dir = _paper_dir(paper.arxiv_id)
|
p_dir = paper_dir(paper.arxiv_id)
|
||||||
paper.summary_path = str(paper_dir / "summary.json")
|
paper.summary_path = str(p_dir / "summary.json")
|
||||||
paper.raw_output_path = str(paper_dir / "raw_output.txt")
|
paper.raw_output_path = str(p_dir / "raw_output.txt")
|
||||||
|
|
||||||
# 3. AI 标签
|
# 3. AI 标签
|
||||||
existing_tag_names = {t.tag for t in paper.tags}
|
existing_tag_names = {t.tag for t in paper.tags}
|
||||||
@@ -332,7 +146,7 @@ def _update_summary_in_db(
|
|||||||
|
|
||||||
def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
|
def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
|
||||||
"""保存 summary.json 和 raw_output.txt。"""
|
"""保存 summary.json 和 raw_output.txt。"""
|
||||||
d = _paper_dir(arxiv_id)
|
d = paper_dir(arxiv_id)
|
||||||
d.mkdir(parents=True, exist_ok=True)
|
d.mkdir(parents=True, exist_ok=True)
|
||||||
(d / "summary.json").write_text(
|
(d / "summary.json").write_text(
|
||||||
schema.model_dump_json(ensure_ascii=False, indent=2),
|
schema.model_dump_json(ensure_ascii=False, indent=2),
|
||||||
@@ -343,143 +157,11 @@ def _save_files(arxiv_id: str, schema: SummarySchema, raw_output: str) -> None:
|
|||||||
|
|
||||||
def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None:
|
def _save_raw_output_only(arxiv_id: str, raw_output: str) -> None:
|
||||||
"""仅保存 raw_output.txt(失败时)。"""
|
"""仅保存 raw_output.txt(失败时)。"""
|
||||||
d = _paper_dir(arxiv_id)
|
d = paper_dir(arxiv_id)
|
||||||
d.mkdir(parents=True, exist_ok=True)
|
d.mkdir(parents=True, exist_ok=True)
|
||||||
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
|
(d / "raw_output.txt").write_text(raw_output, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_tmp(arxiv_id: str) -> None:
|
|
||||||
"""清理 data/tmp/{arxiv_id}/ 目录。"""
|
|
||||||
tmp = _tmp_dir(arxiv_id)
|
|
||||||
if tmp.exists():
|
|
||||||
try:
|
|
||||||
shutil.rmtree(tmp)
|
|
||||||
logger.debug("Cleaned tmp: %s", arxiv_id)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to clean tmp for %s", arxiv_id, exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
# ── LaTeX 图片提取(Phase 5)───────────────────────────────────────────
|
|
||||||
|
|
||||||
_INCLUDEGRAPHICS_RE = re.compile(
|
|
||||||
r"\\includegraphics\s*(?:\[[^\]]*\])?\s*\{([^}]+)\}", re.MULTILINE
|
|
||||||
)
|
|
||||||
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".svg", ".pdf", ".eps"}
|
|
||||||
|
|
||||||
|
|
||||||
async def _extract_images_from_source(arxiv_id: str, tmp_source: Path | None = None) -> int:
|
|
||||||
"""从 LaTeX 源码中提取图片文件。
|
|
||||||
|
|
||||||
流程:
|
|
||||||
1. 下载源码 zip 到 data/tmp/{arxiv_id}/source/
|
|
||||||
2. 扫描 .tex 文件中的 \\includegraphics
|
|
||||||
3. 复制图片到 data/papers/{arxiv_id}/images/
|
|
||||||
4. 清理源码临时文件
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
提取的图片数量
|
|
||||||
"""
|
|
||||||
tmp_source = _tmp_dir(arxiv_id) / "source"
|
|
||||||
images_dest = _paper_dir(arxiv_id) / "images"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 下载源码 zip(如果还没下载)
|
|
||||||
if not tmp_source.exists():
|
|
||||||
source_url = f"https://arxiv.org/e-print/{arxiv_id}"
|
|
||||||
await _download_source_zip(arxiv_id, source_url, tmp_source)
|
|
||||||
|
|
||||||
if not tmp_source.exists():
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# 扫描 .tex 文件,收集图片路径
|
|
||||||
image_paths: set[str] = set()
|
|
||||||
for tex_file in tmp_source.rglob("*.tex"):
|
|
||||||
try:
|
|
||||||
content = tex_file.read_text(encoding="utf-8", errors="replace")
|
|
||||||
for match in _INCLUDEGRAPHICS_RE.finditer(content):
|
|
||||||
img_path = match.group(1).strip()
|
|
||||||
image_paths.add(img_path)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not image_paths:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# 查找并复制图片
|
|
||||||
images_dest.mkdir(parents=True, exist_ok=True)
|
|
||||||
copied = 0
|
|
||||||
for img_rel in image_paths:
|
|
||||||
# 尝试在源码目录中找到文件
|
|
||||||
for ext in ("", ".png", ".jpg", ".jpeg", ".gif", ".pdf", ".eps"):
|
|
||||||
candidate = tmp_source / (img_rel + ext)
|
|
||||||
if candidate.is_file():
|
|
||||||
dest_name = candidate.name
|
|
||||||
# 避免文件名冲突
|
|
||||||
dest = images_dest / dest_name
|
|
||||||
if dest.exists():
|
|
||||||
stem = dest.stem
|
|
||||||
suffix = dest.suffix
|
|
||||||
dest = images_dest / f"{stem}_{copied}{suffix}"
|
|
||||||
shutil.copy2(candidate, dest)
|
|
||||||
copied += 1
|
|
||||||
break
|
|
||||||
|
|
||||||
if copied > 0:
|
|
||||||
logger.info("Extracted %d images from source for %s", copied, arxiv_id)
|
|
||||||
return copied
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
async def _download_source_zip(
|
|
||||||
arxiv_id: str, source_url: str, dest_dir: Path
|
|
||||||
) -> None:
|
|
||||||
"""下载 arXiv 源码并解压。"""
|
|
||||||
import zipfile
|
|
||||||
|
|
||||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
zip_path = _tmp_dir(arxiv_id) / "source.zip"
|
|
||||||
|
|
||||||
transport = None
|
|
||||||
if settings.http_proxy:
|
|
||||||
transport = httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(
|
|
||||||
timeout=settings.HTTP_TIMEOUT_SECONDS,
|
|
||||||
headers={"User-Agent": settings.HTTP_USER_AGENT},
|
|
||||||
transport=transport,
|
|
||||||
follow_redirects=True,
|
|
||||||
) as client:
|
|
||||||
resp = await client.get(source_url)
|
|
||||||
resp.raise_for_status()
|
|
||||||
zip_path.write_bytes(resp.content)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.debug("Failed to download source for %s: %s", arxiv_id, exc)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
|
||||||
zf.extractall(dest_dir)
|
|
||||||
logger.debug("Extracted source for %s", arxiv_id)
|
|
||||||
except zipfile.BadZipFile:
|
|
||||||
# 可能是 tar.gz
|
|
||||||
import tarfile
|
|
||||||
try:
|
|
||||||
with tarfile.open(zip_path, "r:*") as tf:
|
|
||||||
tf.extractall(dest_dir)
|
|
||||||
logger.debug("Extracted source (tar) for %s", arxiv_id)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Cannot extract source for %s", arxiv_id)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Cannot extract source for %s", arxiv_id, exc_info=True)
|
|
||||||
finally:
|
|
||||||
if zip_path.exists():
|
|
||||||
zip_path.unlink()
|
|
||||||
|
|
||||||
|
|
||||||
# ── 单篇总结 ────────────────────────────────────────────────────────────
|
# ── 单篇总结 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -491,6 +173,8 @@ async def summarize_one(
|
|||||||
force: bool = False,
|
force: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""总结单篇论文的完整流程。"""
|
"""总结单篇论文的完整流程。"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
arxiv_id = paper.arxiv_id
|
arxiv_id = paper.arxiv_id
|
||||||
|
|
||||||
# 获取或创建 summary_status
|
# 获取或创建 summary_status
|
||||||
@@ -520,6 +204,8 @@ async def summarize_one(
|
|||||||
|
|
||||||
async def _do_summarize_one(db: Session, paper: Paper) -> dict:
|
async def _do_summarize_one(db: Session, paper: Paper) -> dict:
|
||||||
"""实际的单篇总结执行(在 semaphore 保护下)。"""
|
"""实际的单篇总结执行(在 semaphore 保护下)。"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
arxiv_id = paper.arxiv_id
|
arxiv_id = paper.arxiv_id
|
||||||
status = paper.summary_status
|
status = paper.summary_status
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
@@ -532,16 +218,16 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
|
|||||||
raw_output = ""
|
raw_output = ""
|
||||||
try:
|
try:
|
||||||
# 写 meta.json
|
# 写 meta.json
|
||||||
meta_path = _write_meta_json(paper)
|
meta_path = write_meta_json(paper)
|
||||||
|
|
||||||
# 下载 PDF
|
# 下载 PDF
|
||||||
await _download_pdf(arxiv_id, paper.pdf_url)
|
await download_pdf(arxiv_id, paper.pdf_url)
|
||||||
|
|
||||||
# 调用 pi
|
# 调用 pi
|
||||||
raw_output = await _call_pi(meta_path, _tmp_dir(arxiv_id) / "paper.pdf")
|
raw_output = await call_pi(meta_path, Path("data/tmp") / arxiv_id / "paper.pdf")
|
||||||
|
|
||||||
# 提取 JSON
|
# 提取 JSON
|
||||||
json_data = _extract_json(raw_output)
|
json_data = extract_json(raw_output)
|
||||||
|
|
||||||
# Pydantic 校验
|
# Pydantic 校验
|
||||||
schema = SummarySchema.model_validate(json_data)
|
schema = SummarySchema.model_validate(json_data)
|
||||||
@@ -564,7 +250,7 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
|
|||||||
|
|
||||||
# Phase 5: LaTeX 图片提取(可选增强,失败不影响总结)
|
# Phase 5: LaTeX 图片提取(可选增强,失败不影响总结)
|
||||||
try:
|
try:
|
||||||
await _extract_images_from_source(arxiv_id)
|
await extract_images_from_source(arxiv_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
logger.warning("Failed to extract images for %s", arxiv_id, exc_info=True)
|
||||||
|
|
||||||
@@ -625,7 +311,7 @@ async def _do_summarize_one(db: Session, paper: Paper) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
_cleanup_tmp(arxiv_id)
|
cleanup_tmp(arxiv_id)
|
||||||
|
|
||||||
|
|
||||||
# ── 单篇入口 ────────────────────────────────────────────────────────────
|
# ── 单篇入口 ────────────────────────────────────────────────────────────
|
||||||
@@ -690,6 +376,8 @@ async def summarize_batch(
|
|||||||
|
|
||||||
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
|
_session_factory: 可选的 session 工厂,测试时注入内存 DB 的 session。
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
# TaskLock 防重入
|
# TaskLock 防重入
|
||||||
@@ -741,7 +429,7 @@ async def summarize_batch(
|
|||||||
log_entry.papers_found = 0
|
log_entry.papers_found = 0
|
||||||
log_entry.papers_new = 0
|
log_entry.papers_new = 0
|
||||||
log_entry.completed_at = datetime.now(timezone.utc)
|
log_entry.completed_at = datetime.now(timezone.utc)
|
||||||
_release_lock(db, lock)
|
release_lock(db, lock)
|
||||||
return {"status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0}
|
return {"status": "success", "done": 0, "failed": 0, "skipped": 0, "total": 0}
|
||||||
|
|
||||||
# 并发控制
|
# 并发控制
|
||||||
@@ -813,15 +501,4 @@ async def summarize_batch(
|
|||||||
return {"status": "failed", "error": str(exc)}
|
return {"status": "failed", "error": str(exc)}
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
_release_lock(db, lock)
|
release_lock(db, lock)
|
||||||
|
|
||||||
|
|
||||||
def _release_lock(db: Session, lock: TaskLock) -> None:
|
|
||||||
"""释放 TaskLock。"""
|
|
||||||
try:
|
|
||||||
lock.status = "finished"
|
|
||||||
lock.released_at = datetime.now(timezone.utc)
|
|
||||||
db.commit()
|
|
||||||
except Exception:
|
|
||||||
db.rollback()
|
|
||||||
logger.warning("Failed to release summarize lock", exc_info=True)
|
|
||||||
|
|||||||
@@ -0,0 +1,81 @@
|
|||||||
|
"""趋势统计服务 — 按日论文数量、热门标签、Upvotes 分布、总结完成率。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, timedelta
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
|
def get_trends_data(db: Session) -> dict:
|
||||||
|
"""从 DB 聚合趋势数据。"""
|
||||||
|
thirty_days_ago = (date.today() - timedelta(days=30)).isoformat()
|
||||||
|
|
||||||
|
# 1. 按日论文数量(近 30 天)
|
||||||
|
daily_rows = db.execute(text("""
|
||||||
|
SELECT paper_date, COUNT(*) as cnt
|
||||||
|
FROM papers
|
||||||
|
WHERE paper_date >= :start_date
|
||||||
|
GROUP BY paper_date
|
||||||
|
ORDER BY paper_date ASC
|
||||||
|
"""), {"start_date": thirty_days_ago}).fetchall()
|
||||||
|
daily_counts = [
|
||||||
|
{"date": str(row[0]), "count": row[1]}
|
||||||
|
for row in daily_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
# 2. 热门标签 Top 20
|
||||||
|
tag_rows = db.execute(text("""
|
||||||
|
SELECT tag, COUNT(*) as cnt
|
||||||
|
FROM paper_tags
|
||||||
|
GROUP BY tag
|
||||||
|
ORDER BY cnt DESC
|
||||||
|
LIMIT 20
|
||||||
|
""")).fetchall()
|
||||||
|
top_tags = [
|
||||||
|
{"tag": row[0], "count": row[1]}
|
||||||
|
for row in tag_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
# 3. Upvotes 分布
|
||||||
|
upvote_rows = db.execute(text("""
|
||||||
|
SELECT
|
||||||
|
CASE
|
||||||
|
WHEN upvotes >= 100 THEN '100+'
|
||||||
|
WHEN upvotes >= 50 THEN '50-99'
|
||||||
|
WHEN upvotes >= 20 THEN '20-49'
|
||||||
|
WHEN upvotes >= 10 THEN '10-19'
|
||||||
|
WHEN upvotes >= 5 THEN '5-9'
|
||||||
|
ELSE '0-4'
|
||||||
|
END as bucket,
|
||||||
|
COUNT(*) as cnt
|
||||||
|
FROM papers
|
||||||
|
GROUP BY bucket
|
||||||
|
ORDER BY MIN(upvotes) DESC
|
||||||
|
""")).fetchall()
|
||||||
|
upvotes_dist = [
|
||||||
|
{"range": row[0], "count": row[1]}
|
||||||
|
for row in upvote_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
# 4. 总结完成率
|
||||||
|
summary_rows = db.execute(text("""
|
||||||
|
SELECT
|
||||||
|
COALESCE(ss.status, 'none') as status,
|
||||||
|
COUNT(*) as cnt
|
||||||
|
FROM papers p
|
||||||
|
LEFT JOIN summary_status ss ON ss.paper_id = p.id
|
||||||
|
GROUP BY status
|
||||||
|
""")).fetchall()
|
||||||
|
summary_completion = [
|
||||||
|
{"status": row[0], "count": row[1]}
|
||||||
|
for row in summary_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"daily_counts": daily_counts,
|
||||||
|
"top_tags": top_tags,
|
||||||
|
"upvotes_dist": upvotes_dist,
|
||||||
|
"summary_completion": summary_completion,
|
||||||
|
}
|
||||||
@@ -1,12 +1,13 @@
|
|||||||
"""用户数据服务 — 收藏、阅读状态、个人笔记。无账号体系,数据写入本地 SQLite。"""
|
"""用户数据服务 — 收藏、阅读状态、个人笔记、阅读列表查询。无账号体系,数据写入本地 SQLite。"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy import or_
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
from app.models import Paper, UserBookmark, UserNote, UserReadingStatus
|
from app.models import Paper, PaperTag, UserBookmark, UserNote, UserReadingStatus
|
||||||
|
|
||||||
# ── 收藏 ──────────────────────────────────────────────────────────────
|
# ── 收藏 ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -113,3 +114,47 @@ def save_note(db: Session, arxiv_id: str, content: str) -> dict:
|
|||||||
"content": content,
|
"content": content,
|
||||||
"updated_at": now.isoformat(),
|
"updated_at": now.isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 阅读列表 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def query_reading_list(
|
||||||
|
db: Session,
|
||||||
|
filter_type: str,
|
||||||
|
tag: str | None,
|
||||||
|
) -> list[Paper]:
|
||||||
|
"""根据筛选条件查询阅读列表。"""
|
||||||
|
# 基础:有任意用户数据的论文
|
||||||
|
base = db.query(Paper).filter(
|
||||||
|
or_(
|
||||||
|
Paper.bookmark.has(),
|
||||||
|
Paper.reading_status.has(),
|
||||||
|
Paper.note.has(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 应用筛选
|
||||||
|
if filter_type == "has_note":
|
||||||
|
base = base.filter(Paper.note.has())
|
||||||
|
elif filter_type in ("unread", "skimmed", "read_summary", "read_full"):
|
||||||
|
base = base.filter(
|
||||||
|
Paper.reading_status.has(UserReadingStatus.status == filter_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 应用标签
|
||||||
|
if tag:
|
||||||
|
base = base.filter(Paper.tags.any(PaperTag.tag == tag))
|
||||||
|
|
||||||
|
return (
|
||||||
|
base.options(
|
||||||
|
joinedload(Paper.authors),
|
||||||
|
joinedload(Paper.tags),
|
||||||
|
joinedload(Paper.summary_status),
|
||||||
|
joinedload(Paper.bookmark),
|
||||||
|
joinedload(Paper.reading_status),
|
||||||
|
joinedload(Paper.note),
|
||||||
|
)
|
||||||
|
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,73 @@
|
|||||||
|
"""公共工具 — 消除各模块间的重复代码。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi.templating import Jinja2Templates
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
# ── 路径常量 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
DATA_DIR = Path("data")
|
||||||
|
PAPERS_DIR = DATA_DIR / "papers"
|
||||||
|
TMP_DIR = DATA_DIR / "tmp"
|
||||||
|
|
||||||
|
# ── 模板单例 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
templates = Jinja2Templates(directory="app/templates")
|
||||||
|
|
||||||
|
|
||||||
|
# ── 时区工具 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def today_str() -> str:
|
||||||
|
"""当前日期字符串(按 APP_TIMEZONE)。"""
|
||||||
|
tz = ZoneInfo(settings.APP_TIMEZONE)
|
||||||
|
return datetime.now(tz).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
|
||||||
|
# ── 锁释放 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def release_lock(db, lock) -> None:
|
||||||
|
"""释放 TaskLock。"""
|
||||||
|
try:
|
||||||
|
lock.status = "finished"
|
||||||
|
lock.released_at = datetime.now(timezone.utc)
|
||||||
|
db.commit()
|
||||||
|
except Exception:
|
||||||
|
db.rollback()
|
||||||
|
|
||||||
|
|
||||||
|
# ── HTTP 客户端工厂 ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def make_http_client(*, sync: bool = False, follow_redirects: bool = False, **kwargs) -> httpx.AsyncClient | httpx.Client:
|
||||||
|
"""创建带 proxy 和默认配置的 httpx 客户端。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sync: True 返回同步 Client,False 返回 AsyncClient
|
||||||
|
follow_redirects: 是否跟随重定向
|
||||||
|
**kwargs: 覆盖默认参数
|
||||||
|
"""
|
||||||
|
defaults: dict = {
|
||||||
|
"timeout": settings.HTTP_TIMEOUT_SECONDS,
|
||||||
|
"headers": {"User-Agent": settings.HTTP_USER_AGENT},
|
||||||
|
"follow_redirects": follow_redirects,
|
||||||
|
}
|
||||||
|
if settings.http_proxy:
|
||||||
|
defaults["transport"] = (
|
||||||
|
httpx.HTTPTransport(proxy=settings.http_proxy)
|
||||||
|
if sync
|
||||||
|
else httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
|
||||||
|
)
|
||||||
|
defaults.update(kwargs)
|
||||||
|
|
||||||
|
if sync:
|
||||||
|
return httpx.Client(**defaults)
|
||||||
|
return httpx.AsyncClient(**defaults)
|
||||||
+1
-1
@@ -15,13 +15,13 @@ from sqlalchemy.pool import StaticPool
|
|||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.main import create_app
|
from app.main import create_app
|
||||||
|
from app.database import init_db
|
||||||
from app.models import (
|
from app.models import (
|
||||||
Paper,
|
Paper,
|
||||||
PaperAuthor,
|
PaperAuthor,
|
||||||
PaperSummary,
|
PaperSummary,
|
||||||
PaperTag,
|
PaperTag,
|
||||||
SummaryStatus,
|
SummaryStatus,
|
||||||
init_db,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ class TestCleanupTmp:
|
|||||||
old_mtime = time.time() - 25 * 3600
|
old_mtime = time.time() - 25 * 3600
|
||||||
os.utime(old_dir, (old_mtime, old_mtime))
|
os.utime(old_dir, (old_mtime, old_mtime))
|
||||||
|
|
||||||
monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir)
|
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
||||||
from app.services.cleaner import cleanup_tmp
|
from app.services.cleaner import cleanup_tmp
|
||||||
result = cleanup_tmp()
|
result = cleanup_tmp()
|
||||||
|
|
||||||
@@ -158,7 +158,7 @@ class TestCleanupTmp:
|
|||||||
recent_dir.mkdir()
|
recent_dir.mkdir()
|
||||||
(recent_dir / "paper.pdf").write_text("fake pdf")
|
(recent_dir / "paper.pdf").write_text("fake pdf")
|
||||||
|
|
||||||
monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir)
|
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
||||||
from app.services.cleaner import cleanup_tmp
|
from app.services.cleaner import cleanup_tmp
|
||||||
result = cleanup_tmp()
|
result = cleanup_tmp()
|
||||||
|
|
||||||
@@ -168,7 +168,7 @@ class TestCleanupTmp:
|
|||||||
|
|
||||||
def test_cleanup_empty_dir(self, tmp_path, monkeypatch):
|
def test_cleanup_empty_dir(self, tmp_path, monkeypatch):
|
||||||
"""data/tmp/ 不存在时安全返回。"""
|
"""data/tmp/ 不存在时安全返回。"""
|
||||||
monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_path / "nonexistent")
|
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_path / "nonexistent")
|
||||||
from app.services.cleaner import cleanup_tmp
|
from app.services.cleaner import cleanup_tmp
|
||||||
result = cleanup_tmp()
|
result = cleanup_tmp()
|
||||||
assert result["scanned"] == 0
|
assert result["scanned"] == 0
|
||||||
@@ -187,7 +187,7 @@ class TestCleanupTmp:
|
|||||||
recent_dir = tmp_dir / "2401.new"
|
recent_dir = tmp_dir / "2401.new"
|
||||||
recent_dir.mkdir()
|
recent_dir.mkdir()
|
||||||
|
|
||||||
monkeypatch.setattr("app.services.cleaner._TMP_DIR", tmp_dir)
|
monkeypatch.setattr("app.services.cleaner.TMP_DIR", tmp_dir)
|
||||||
from app.services.cleaner import cleanup_tmp
|
from app.services.cleaner import cleanup_tmp
|
||||||
result = cleanup_tmp()
|
result = cleanup_tmp()
|
||||||
|
|
||||||
@@ -318,7 +318,7 @@ class TestDeletePapersByDateRange:
|
|||||||
(papers_dir / "2401.10001").mkdir()
|
(papers_dir / "2401.10001").mkdir()
|
||||||
(papers_dir / "2401.10001" / "meta.json").write_text("{}")
|
(papers_dir / "2401.10001" / "meta.json").write_text("{}")
|
||||||
|
|
||||||
monkeypatch.setattr("app.services.cleaner._PAPERS_DIR", papers_dir)
|
monkeypatch.setattr("app.services.cleaner.PAPERS_DIR", papers_dir)
|
||||||
|
|
||||||
result = await delete_papers_by_date_range(
|
result = await delete_papers_by_date_range(
|
||||||
db_session,
|
db_session,
|
||||||
|
|||||||
+42
-39
@@ -125,10 +125,9 @@ class TestEmbedderInit:
|
|||||||
"""CHROMA_ENABLED=false 时不初始化。"""
|
"""CHROMA_ENABLED=false 时不初始化。"""
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
import app.services.embedder as emb
|
import app.services.embedder as emb
|
||||||
emb._client = None
|
emb._chroma.reset()
|
||||||
emb._collection = None
|
|
||||||
emb.init_chroma()
|
emb.init_chroma()
|
||||||
assert emb._client is None
|
assert emb._chroma._client is None
|
||||||
|
|
||||||
def test_chroma_init_success(self, monkeypatch, tmp_path):
|
def test_chroma_init_success(self, monkeypatch, tmp_path):
|
||||||
"""CHROMA_ENABLED=true 时初始化成功。"""
|
"""CHROMA_ENABLED=true 时初始化成功。"""
|
||||||
@@ -136,23 +135,20 @@ class TestEmbedderInit:
|
|||||||
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
monkeypatch.setattr(settings, "CHROMA_DIR", str(tmp_path / "chroma"))
|
||||||
|
|
||||||
import app.services.embedder as emb
|
import app.services.embedder as emb
|
||||||
emb._client = None
|
emb._chroma.reset()
|
||||||
emb._collection = None
|
|
||||||
emb.init_chroma()
|
emb.init_chroma()
|
||||||
|
|
||||||
assert emb._client is not None
|
assert emb._chroma._client is not None
|
||||||
assert emb._collection is not None
|
assert emb._chroma._collection is not None
|
||||||
|
|
||||||
# 清理
|
# 清理
|
||||||
emb._client = None
|
emb._chroma.reset()
|
||||||
emb._collection = None
|
|
||||||
|
|
||||||
def test_get_collection_returns_none_when_disabled(self, monkeypatch):
|
def test_get_collection_returns_none_when_disabled(self, monkeypatch):
|
||||||
"""CHROMA_ENABLED=false 时 get_collection 返回 None。"""
|
"""CHROMA_ENABLED=false 时 get_collection 返回 None。"""
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
import app.services.embedder as emb
|
import app.services.embedder as emb
|
||||||
emb._client = None
|
emb._chroma.reset()
|
||||||
emb._collection = None
|
|
||||||
assert emb.get_collection() is None
|
assert emb.get_collection() is None
|
||||||
|
|
||||||
|
|
||||||
@@ -163,8 +159,7 @@ class TestEmbedderIndexing:
|
|||||||
"""CHROMA_ENABLED=false 时 index_paper 返回 False。"""
|
"""CHROMA_ENABLED=false 时 index_paper 返回 False。"""
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
import app.services.embedder as emb
|
import app.services.embedder as emb
|
||||||
emb._client = None
|
emb._chroma.reset()
|
||||||
emb._collection = None
|
|
||||||
assert emb.index_paper("test-id") is False
|
assert emb.index_paper("test-id") is False
|
||||||
|
|
||||||
def test_index_paper_no_api_config(self, monkeypatch, tmp_path):
|
def test_index_paper_no_api_config(self, monkeypatch, tmp_path):
|
||||||
@@ -175,22 +170,19 @@ class TestEmbedderIndexing:
|
|||||||
monkeypatch.setattr(settings, "EMBED_MODEL", "")
|
monkeypatch.setattr(settings, "EMBED_MODEL", "")
|
||||||
|
|
||||||
import app.services.embedder as emb
|
import app.services.embedder as emb
|
||||||
emb._client = None
|
emb._chroma.reset()
|
||||||
emb._collection = None
|
|
||||||
emb.init_chroma()
|
emb.init_chroma()
|
||||||
|
|
||||||
result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"})
|
result = emb.index_paper("test-id", {"title_zh": "测试", "title_en": "Test"})
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
emb._client = None
|
emb._chroma.reset()
|
||||||
emb._collection = None
|
|
||||||
|
|
||||||
def test_index_batch_disabled(self, monkeypatch):
|
def test_index_batch_disabled(self, monkeypatch):
|
||||||
"""CHROMA_ENABLED=false 时 index_batch 返回全失败。"""
|
"""CHROMA_ENABLED=false 时 index_batch 返回全失败。"""
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
import app.services.embedder as emb
|
import app.services.embedder as emb
|
||||||
emb._client = None
|
emb._chroma.reset()
|
||||||
emb._collection = None
|
|
||||||
result = emb.index_batch(["a", "b"])
|
result = emb.index_batch(["a", "b"])
|
||||||
assert result["success"] == 0
|
assert result["success"] == 0
|
||||||
assert result["failed"] == 2
|
assert result["failed"] == 2
|
||||||
@@ -206,16 +198,14 @@ class TestEmbedderIndexing:
|
|||||||
"""CHROMA_ENABLED=false 时 delete_paper 返回 False。"""
|
"""CHROMA_ENABLED=false 时 delete_paper 返回 False。"""
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
import app.services.embedder as emb
|
import app.services.embedder as emb
|
||||||
emb._client = None
|
emb._chroma.reset()
|
||||||
emb._collection = None
|
|
||||||
assert emb.delete_paper("test-id") is False
|
assert emb.delete_paper("test-id") is False
|
||||||
|
|
||||||
def test_search_similar_disabled(self, monkeypatch):
|
def test_search_similar_disabled(self, monkeypatch):
|
||||||
"""CHROMA_ENABLED=false 时 search_similar 返回空列表。"""
|
"""CHROMA_ENABLED=false 时 search_similar 返回空列表。"""
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
import app.services.embedder as emb
|
import app.services.embedder as emb
|
||||||
emb._client = None
|
emb._chroma.reset()
|
||||||
emb._collection = None
|
|
||||||
assert emb.search_similar("test query") == []
|
assert emb.search_similar("test query") == []
|
||||||
|
|
||||||
|
|
||||||
@@ -427,8 +417,8 @@ class TestTrendsDashboard:
|
|||||||
from unittest.mock import patch as upatch
|
from unittest.mock import patch as upatch
|
||||||
import app.routes.trends as trends_mod
|
import app.routes.trends as trends_mod
|
||||||
|
|
||||||
# monkeypatch _get_trends_data 中的 date.today
|
# monkeypatch get_trends_data 中的 date.today
|
||||||
with upatch("app.routes.trends.date") as mock_date:
|
with upatch("app.services.trends.date") as mock_date:
|
||||||
mock_date.today.return_value = date(2024, 1, 20)
|
mock_date.today.return_value = date(2024, 1, 20)
|
||||||
mock_date.side_effect = lambda *a, **kw: date(*a, **kw)
|
mock_date.side_effect = lambda *a, **kw: date(*a, **kw)
|
||||||
|
|
||||||
@@ -528,15 +518,17 @@ class TestImageExtraction:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path):
|
async def test_extract_images_from_source_no_dir(self, monkeypatch, tmp_path):
|
||||||
"""源码目录不存在时返回 0。"""
|
"""源码目录不存在时返回 0。"""
|
||||||
monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x)
|
monkeypatch.setattr("app.services.pdf_downloader.tmp_dir", lambda x: tmp_path / "tmp" / x)
|
||||||
monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x)
|
monkeypatch.setattr("app.services.pdf_downloader.paper_dir", lambda x: tmp_path / "papers" / x)
|
||||||
from app.services.summarizer import _extract_images_from_source
|
from app.services.image_extractor import extract_images_from_source
|
||||||
result = await _extract_images_from_source("2401.99999")
|
result = await extract_images_from_source("2401.99999")
|
||||||
assert result == 0
|
assert result == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_extract_images_from_tex(self, monkeypatch, tmp_path):
|
async def test_extract_images_from_tex(self, monkeypatch, tmp_path):
|
||||||
"""从 .tex 文件中提取图片。"""
|
"""从 .tex 文件中提取图片。"""
|
||||||
|
from app.services.image_extractor import extract_images_from_source
|
||||||
|
|
||||||
tmp_source = tmp_path / "tmp" / "2401.00001" / "source"
|
tmp_source = tmp_path / "tmp" / "2401.00001" / "source"
|
||||||
tmp_source.mkdir(parents=True)
|
tmp_source.mkdir(parents=True)
|
||||||
|
|
||||||
@@ -559,11 +551,16 @@ class TestImageExtraction:
|
|||||||
(tmp_source / "main.tex").write_text(tex_content)
|
(tmp_source / "main.tex").write_text(tex_content)
|
||||||
|
|
||||||
papers_dir = tmp_path / "papers" / "2401.00001"
|
papers_dir = tmp_path / "papers" / "2401.00001"
|
||||||
monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x)
|
monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
|
||||||
monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x)
|
monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
|
||||||
|
|
||||||
from app.services.summarizer import _extract_images_from_source
|
# Mock download_source_zip to avoid real network call (source dir already exists)
|
||||||
result = await _extract_images_from_source("2401.00001")
|
async def _noop_download(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
|
||||||
|
|
||||||
|
result = await extract_images_from_source("2401.00001")
|
||||||
|
|
||||||
assert result == 2
|
assert result == 2
|
||||||
dest_images = papers_dir / "images"
|
dest_images = papers_dir / "images"
|
||||||
@@ -574,15 +571,22 @@ class TestImageExtraction:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_extract_images_empty_tex(self, monkeypatch, tmp_path):
|
async def test_extract_images_empty_tex(self, monkeypatch, tmp_path):
|
||||||
""".tex 文件无图片时返回 0。"""
|
""".tex 文件无图片时返回 0。"""
|
||||||
|
from app.services.image_extractor import extract_images_from_source
|
||||||
|
|
||||||
tmp_source = tmp_path / "tmp" / "2401.00002" / "source"
|
tmp_source = tmp_path / "tmp" / "2401.00002" / "source"
|
||||||
tmp_source.mkdir(parents=True)
|
tmp_source.mkdir(parents=True)
|
||||||
(tmp_source / "main.tex").write_text(r"\documentclass{article}\begin{document}Hello\end{document}")
|
(tmp_source / "main.tex").write_text(r"\documentclass{article}\begin{document}Hello\end{document}")
|
||||||
|
|
||||||
monkeypatch.setattr("app.services.summarizer._tmp_dir", lambda x: tmp_path / "tmp" / x)
|
monkeypatch.setattr("app.services.image_extractor.tmp_dir", lambda x: tmp_path / "tmp" / x)
|
||||||
monkeypatch.setattr("app.services.summarizer._paper_dir", lambda x: tmp_path / "papers" / x)
|
monkeypatch.setattr("app.services.image_extractor.paper_dir", lambda x: tmp_path / "papers" / x)
|
||||||
|
|
||||||
from app.services.summarizer import _extract_images_from_source
|
# Mock download_source_zip to avoid real network call
|
||||||
result = await _extract_images_from_source("2401.00002")
|
async def _noop_download(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.image_extractor.download_source_zip", _noop_download)
|
||||||
|
|
||||||
|
result = await extract_images_from_source("2401.00002")
|
||||||
assert result == 0
|
assert result == 0
|
||||||
|
|
||||||
|
|
||||||
@@ -644,8 +648,7 @@ class TestGracefulDegradation:
|
|||||||
"""CHROMA 关闭时删除论文正常工作。"""
|
"""CHROMA 关闭时删除论文正常工作。"""
|
||||||
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
monkeypatch.setattr(settings, "CHROMA_ENABLED", False)
|
||||||
import app.services.embedder as emb
|
import app.services.embedder as emb
|
||||||
emb._client = None
|
emb._chroma.reset()
|
||||||
emb._collection = None
|
|
||||||
|
|
||||||
from app.services.cleaner import delete_papers_by_date_range
|
from app.services.cleaner import delete_papers_by_date_range
|
||||||
result = await delete_papers_by_date_range(
|
result = await delete_papers_by_date_range(
|
||||||
|
|||||||
+45
-37
@@ -27,14 +27,7 @@ from app.services.schemas import (
|
|||||||
flatten_for_db,
|
flatten_for_db,
|
||||||
)
|
)
|
||||||
from app.services.summarizer import (
|
from app.services.summarizer import (
|
||||||
JsonNotFoundError,
|
|
||||||
PdfDownloadError,
|
|
||||||
PiProcessError,
|
|
||||||
PiTimeoutError,
|
|
||||||
_call_pi,
|
|
||||||
_classify_error,
|
_classify_error,
|
||||||
_cleanup_tmp,
|
|
||||||
_extract_json,
|
|
||||||
_save_files,
|
_save_files,
|
||||||
_save_raw_output_only,
|
_save_raw_output_only,
|
||||||
_update_summary_in_db,
|
_update_summary_in_db,
|
||||||
@@ -42,6 +35,17 @@ from app.services.summarizer import (
|
|||||||
summarize_one,
|
summarize_one,
|
||||||
summarize_single,
|
summarize_single,
|
||||||
)
|
)
|
||||||
|
from app.services.pi_client import (
|
||||||
|
JsonNotFoundError,
|
||||||
|
PiProcessError,
|
||||||
|
PiTimeoutError,
|
||||||
|
call_pi as _call_pi,
|
||||||
|
extract_json as _extract_json,
|
||||||
|
)
|
||||||
|
from app.services.pdf_downloader import (
|
||||||
|
PdfDownloadError,
|
||||||
|
cleanup_tmp as _cleanup_tmp,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
@@ -287,7 +291,7 @@ class TestFileOperations:
|
|||||||
|
|
||||||
def test_save_files(self, tmp_path, sample_summary_dict):
|
def test_save_files(self, tmp_path, sample_summary_dict):
|
||||||
schema = SummarySchema.model_validate(sample_summary_dict)
|
schema = SummarySchema.model_validate(sample_summary_dict)
|
||||||
with patch("app.services.summarizer._PAPERS_DIR", tmp_path):
|
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
|
||||||
_save_files("2401.12345", schema, "raw output text")
|
_save_files("2401.12345", schema, "raw output text")
|
||||||
|
|
||||||
paper_dir = tmp_path / "2401.12345"
|
paper_dir = tmp_path / "2401.12345"
|
||||||
@@ -297,7 +301,7 @@ class TestFileOperations:
|
|||||||
assert saved["title_zh"] == "测试论文中文标题"
|
assert saved["title_zh"] == "测试论文中文标题"
|
||||||
|
|
||||||
def test_save_raw_output_only(self, tmp_path):
|
def test_save_raw_output_only(self, tmp_path):
|
||||||
with patch("app.services.summarizer._PAPERS_DIR", tmp_path):
|
with patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / aid):
|
||||||
_save_raw_output_only("2401.12345", "raw output")
|
_save_raw_output_only("2401.12345", "raw output")
|
||||||
paper_dir = tmp_path / "2401.12345"
|
paper_dir = tmp_path / "2401.12345"
|
||||||
assert (paper_dir / "raw_output.txt").exists()
|
assert (paper_dir / "raw_output.txt").exists()
|
||||||
@@ -307,13 +311,13 @@ class TestFileOperations:
|
|||||||
tmp_paper = tmp_path / "2401.12345"
|
tmp_paper = tmp_path / "2401.12345"
|
||||||
tmp_paper.mkdir()
|
tmp_paper.mkdir()
|
||||||
(tmp_paper / "paper.pdf").write_bytes(b"%PDF-fake")
|
(tmp_paper / "paper.pdf").write_bytes(b"%PDF-fake")
|
||||||
with patch("app.services.summarizer._TMP_DIR", tmp_path):
|
with patch("app.services.pdf_downloader.TMP_DIR", tmp_path):
|
||||||
_cleanup_tmp("2401.12345")
|
_cleanup_tmp("2401.12345")
|
||||||
assert not tmp_paper.exists()
|
assert not tmp_paper.exists()
|
||||||
|
|
||||||
def test_cleanup_tmp_nonexistent(self, tmp_path):
|
def test_cleanup_tmp_nonexistent(self, tmp_path):
|
||||||
"""清理不存在的目录不报错。"""
|
"""清理不存在的目录不报错。"""
|
||||||
with patch("app.services.summarizer._TMP_DIR", tmp_path):
|
with patch("app.services.pdf_downloader.TMP_DIR", tmp_path):
|
||||||
_cleanup_tmp("nonexistent") # 不抛异常
|
_cleanup_tmp("nonexistent") # 不抛异常
|
||||||
|
|
||||||
|
|
||||||
@@ -329,9 +333,11 @@ class TestSummarizeOneFlow:
|
|||||||
def _patch_paths(self, tmp_path):
|
def _patch_paths(self, tmp_path):
|
||||||
"""将 data 目录重定向到 tmp_path。"""
|
"""将 data 目录重定向到 tmp_path。"""
|
||||||
with (
|
with (
|
||||||
patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"),
|
patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
|
||||||
patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"),
|
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
|
||||||
patch("app.services.summarizer._DATA_DIR", tmp_path),
|
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
|
||||||
|
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
|
||||||
|
patch("app.utils.TMP_DIR", tmp_path / "tmp"),
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -341,8 +347,8 @@ class TestSummarizeOneFlow:
|
|||||||
):
|
):
|
||||||
"""pending → processing → done 全流程。"""
|
"""pending → processing → done 全流程。"""
|
||||||
with (
|
with (
|
||||||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||||||
):
|
):
|
||||||
result = await summarize_one(db_session, sample_paper)
|
result = await summarize_one(db_session, sample_paper)
|
||||||
|
|
||||||
@@ -374,7 +380,7 @@ class TestSummarizeOneFlow:
|
|||||||
"""PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。"""
|
"""PDF 下载失败 → error_type=pdf_download_failed,tmp 被清理。"""
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"app.services.summarizer._download_pdf",
|
"app.services.summarizer.download_pdf",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
side_effect=PdfDownloadError("network error"),
|
side_effect=PdfDownloadError("network error"),
|
||||||
),
|
),
|
||||||
@@ -392,9 +398,9 @@ class TestSummarizeOneFlow:
|
|||||||
async def test_pi_timeout(self, db_session, sample_paper, _patch_paths):
|
async def test_pi_timeout(self, db_session, sample_paper, _patch_paths):
|
||||||
"""pi 超时 → timeout 错误,retry_count 递增。"""
|
"""pi 超时 → timeout 错误,retry_count 递增。"""
|
||||||
with (
|
with (
|
||||||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch(
|
patch(
|
||||||
"app.services.summarizer._call_pi",
|
"app.services.summarizer.call_pi",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
side_effect=PiTimeoutError("timeout after 300s"),
|
side_effect=PiTimeoutError("timeout after 300s"),
|
||||||
),
|
),
|
||||||
@@ -409,9 +415,9 @@ class TestSummarizeOneFlow:
|
|||||||
async def test_json_not_found(self, db_session, sample_paper, _patch_paths):
|
async def test_json_not_found(self, db_session, sample_paper, _patch_paths):
|
||||||
"""pi 输出无 JSON → json_not_found。"""
|
"""pi 输出无 JSON → json_not_found。"""
|
||||||
with (
|
with (
|
||||||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch(
|
patch(
|
||||||
"app.services.summarizer._call_pi",
|
"app.services.summarizer.call_pi",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value="No JSON in this output at all.",
|
return_value="No JSON in this output at all.",
|
||||||
),
|
),
|
||||||
@@ -436,9 +442,9 @@ class TestSummarizeOneFlow:
|
|||||||
bad_output = f"```json\n{bad_json}\n```"
|
bad_output = f"```json\n{bad_json}\n```"
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch(
|
patch(
|
||||||
"app.services.summarizer._call_pi",
|
"app.services.summarizer.call_pi",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=bad_output,
|
return_value=bad_output,
|
||||||
),
|
),
|
||||||
@@ -464,9 +470,9 @@ class TestSummarizeOneFlow:
|
|||||||
):
|
):
|
||||||
"""失败时仍保存 raw_output.txt。"""
|
"""失败时仍保存 raw_output.txt。"""
|
||||||
with (
|
with (
|
||||||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch(
|
patch(
|
||||||
"app.services.summarizer._call_pi",
|
"app.services.summarizer.call_pi",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value="Some output without JSON",
|
return_value="Some output without JSON",
|
||||||
),
|
),
|
||||||
@@ -483,8 +489,8 @@ class TestSummarizeOneFlow:
|
|||||||
):
|
):
|
||||||
"""成功后清理 tmp 目录。"""
|
"""成功后清理 tmp 目录。"""
|
||||||
with (
|
with (
|
||||||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||||||
):
|
):
|
||||||
await summarize_one(db_session, sample_paper)
|
await summarize_one(db_session, sample_paper)
|
||||||
|
|
||||||
@@ -498,7 +504,7 @@ class TestSummarizeOneFlow:
|
|||||||
"""失败后也清理 tmp 目录。"""
|
"""失败后也清理 tmp 目录。"""
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"app.services.summarizer._download_pdf",
|
"app.services.summarizer.download_pdf",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
side_effect=PdfDownloadError("fail"),
|
side_effect=PdfDownloadError("fail"),
|
||||||
),
|
),
|
||||||
@@ -529,9 +535,11 @@ class TestBatchSummarize:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def _patch_paths(self, tmp_path):
|
def _patch_paths(self, tmp_path):
|
||||||
with (
|
with (
|
||||||
patch("app.services.summarizer._PAPERS_DIR", tmp_path / "papers"),
|
patch("app.services.summarizer.paper_dir", lambda aid: tmp_path / "papers" / aid),
|
||||||
patch("app.services.summarizer._TMP_DIR", tmp_path / "tmp"),
|
patch("app.services.pdf_downloader.PAPERS_DIR", tmp_path / "papers"),
|
||||||
patch("app.services.summarizer._DATA_DIR", tmp_path),
|
patch("app.services.pdf_downloader.TMP_DIR", tmp_path / "tmp"),
|
||||||
|
patch("app.utils.PAPERS_DIR", tmp_path / "papers"),
|
||||||
|
patch("app.utils.TMP_DIR", tmp_path / "tmp"),
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -561,8 +569,8 @@ class TestBatchSummarize:
|
|||||||
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||||||
):
|
):
|
||||||
result = await summarize_batch(
|
result = await summarize_batch(
|
||||||
db_session, _session_factory=_TestSession
|
db_session, _session_factory=_TestSession
|
||||||
@@ -612,8 +620,8 @@ class TestBatchSummarize:
|
|||||||
return mock_pi_output
|
return mock_pi_output
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch("app.services.summarizer._call_pi", side_effect=_mock_call_pi),
|
patch("app.services.summarizer.call_pi", side_effect=_mock_call_pi),
|
||||||
):
|
):
|
||||||
result = await summarize_batch(
|
result = await summarize_batch(
|
||||||
db_session, _session_factory=_TestSession
|
db_session, _session_factory=_TestSession
|
||||||
@@ -646,8 +654,8 @@ class TestBatchSummarize:
|
|||||||
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
_TestSession = _sm(bind=db_engine, autoflush=False, autocommit=False)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("app.services.summarizer._download_pdf", new_callable=AsyncMock),
|
patch("app.services.summarizer.download_pdf", new_callable=AsyncMock),
|
||||||
patch("app.services.summarizer._call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
patch("app.services.summarizer.call_pi", new_callable=AsyncMock, return_value=mock_pi_output),
|
||||||
):
|
):
|
||||||
await summarize_batch(
|
await summarize_batch(
|
||||||
db_session, _session_factory=_TestSession
|
db_session, _session_factory=_TestSession
|
||||||
|
|||||||
Reference in New Issue
Block a user