refactor: restructure services and add image/pdf extraction utilities

- Add image_extractor, pdf_downloader, pi_client, trends services
- Add shared utils module
- Refactor summarizer, embedder, routes for cleaner separation
- Update tests to match new service structure
This commit is contained in:
2026-06-06 00:00:55 +08:00
parent ba9afa212c
commit 85c4cfb9e8
22 changed files with 843 additions and 780 deletions
+9 -30
View File
@@ -6,7 +6,6 @@ from datetime import date, datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -17,10 +16,10 @@ from app.models import CrawlLog, DataDeleteJob, TaskLock
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
from app.services.crawler import crawl_daily
from app.services.summarizer import summarize_batch, summarize_single
from app.utils import release_lock, templates, today_str
router = APIRouter(prefix="/admin", tags=["admin"])
security = HTTPBearer()
templates = Jinja2Templates(directory="app/templates")
async def verify_admin(
@@ -32,7 +31,7 @@ async def verify_admin(
return credentials.credentials
# ── 请求模型 ──────────────────────────────────────────────────────────────
# ── 请求模型 ──────────────────────────────────────────────────────────
class DeleteRequest(BaseModel):
@@ -49,7 +48,7 @@ class DeleteRequest(BaseModel):
return v
# ── 抓取 ──────────────────────────────────────────────────────────────────
# ── 抓取 ──────────────────────────────────────────────────────────────
@router.post("/crawl")
@@ -59,12 +58,7 @@ async def admin_crawl(
date: str | None = Query(None, description="YYYY-MM-DD,默认今天"),
):
"""手动抓取指定日期,默认今天。"""
# 计算 target_date
from zoneinfo import ZoneInfo
tz = ZoneInfo(settings.APP_TIMEZONE)
today = datetime.now(tz).strftime("%Y-%m-%d")
target_date = date or today
target_date = date or today_str()
# TaskLock 防重入
now = datetime.now(timezone.utc)
@@ -88,10 +82,10 @@ async def admin_crawl(
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
finally:
_release_lock(db, lock)
release_lock(db, lock)
# ── 总结 ──────────────────────────────────────────────────────────────────
# ── 总结 ──────────────────────────────────────────────────────────────
@router.post("/summarize")
@@ -119,7 +113,7 @@ async def admin_summarize_single(
return result
# ── 清理 ──────────────────────────────────────────────────────────────────
# ── 清理 ──────────────────────────────────────────────────────────────
@router.post("/cleanup")
@@ -155,7 +149,7 @@ async def admin_cleanup(
raise HTTPException(status_code=500, detail=str(exc))
# ── 删除 ──────────────────────────────────────────────────────────────────
# ── 删除 ──────────────────────────────────────────────────────────────
@router.post("/delete")
@@ -177,7 +171,7 @@ async def admin_delete(
return result
# ── 日志 ──────────────────────────────────────────────────────────────────
# ── 日志 ──────────────────────────────────────────────────────────────
@router.get("/logs")
@@ -189,7 +183,6 @@ async def admin_logs(
per_page: int = Query(20, ge=1, le=100),
):
"""查看任务日志(CrawlLog + DataDeleteJob)。"""
# 查询 crawl_logs
crawl_logs = (
db.execute(
select(CrawlLog)
@@ -201,7 +194,6 @@ async def admin_logs(
.all()
)
# 查询 delete_jobs
delete_jobs = (
db.execute(
select(DataDeleteJob)
@@ -223,16 +215,3 @@ async def admin_logs(
"per_page": per_page,
},
)
# ── 工具函数 ──────────────────────────────────────────────────────────────
def _release_lock(db: Session, lock: TaskLock) -> None:
"""释放 TaskLock。"""
try:
lock.status = "finished"
lock.released_at = datetime.now(timezone.utc)
db.commit()
except Exception:
db.rollback()
+1 -6
View File
@@ -2,19 +2,14 @@
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session, joinedload
from app.database import get_db
from app.models import Paper
logger = logging.getLogger(__name__)
from app.utils import templates
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
@router.get("/compare")
+7 -21
View File
@@ -3,34 +3,27 @@
from __future__ import annotations
import logging
from datetime import date, datetime, timedelta
from datetime import date, timedelta
from pathlib import Path
from zoneinfo import ZoneInfo
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import RedirectResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session, joinedload
from app.config import settings
from app.database import get_db
from app.models import Paper
from app.utils import templates, today_str
logger = logging.getLogger(__name__)
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
def _today() -> str:
tz = ZoneInfo(settings.APP_TIMEZONE)
return datetime.now(tz).strftime("%Y-%m-%d")
@router.get("/")
def index(request: Request):
"""重定向到 /day/{today}"""
return RedirectResponse(url=f"/day/{_today()}")
return RedirectResponse(url=f"/day/{today_str()}")
@router.get("/day/{date_str}")
@@ -43,7 +36,7 @@ def day_page(date_str: str, request: Request, db: Session = Depends(get_db)):
prev_day = (target - timedelta(days=1)).isoformat()
next_day = (target + timedelta(days=1)).isoformat()
today_str = _today()
today = today_str()
papers = (
db.query(Paper)
@@ -74,7 +67,7 @@ def day_page(date_str: str, request: Request, db: Session = Depends(get_db)):
"current_date": date_str,
"prev_day": prev_day,
"next_day": next_day,
"today": today_str,
"today": today,
"available_dates": available_dates,
"page_title": f"{date_str} 论文列表",
},
@@ -146,16 +139,9 @@ def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict
return []
try:
from app.services.embedder import search_similar
# 用论文的 arxiv_id 从 ChromaDB 查询
col = None
try:
from app.services.embedder import get_collection
col = get_collection()
except Exception:
return []
from app.services.embedder import get_collection
col = get_collection()
if col is None:
return []
+7 -61
View File
@@ -2,14 +2,11 @@
from __future__ import annotations
import math
from datetime import date, datetime, timedelta, timezone
from zoneinfo import ZoneInfo
from datetime import date, timedelta
from xml.sax.saxutils import escape
from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import Response
from fastapi.templating import Jinja2Templates
from sqlalchemy import text
from sqlalchemy.orm import Session, joinedload
@@ -17,9 +14,10 @@ from app.config import settings
from app.database import get_db
from app.models import Paper, PaperTag, UserReadingStatus
from app.services.searcher import get_all_tags, search_papers
from app.services.user_data import query_reading_list
from app.utils import templates, today_str
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
# ── 搜索页 ────────────────────────────────────────────────────────────
@@ -56,7 +54,7 @@ def search_page(
"total_pages": result["total_pages"],
"all_tags": all_tags,
"page_title": f"搜索: {q}" if q else "搜索",
"today": _today_str(),
"today": today_str(),
},
)
@@ -114,7 +112,7 @@ def reading_list_page(
db: Session = Depends(get_db),
):
"""阅读列表页面。"""
papers = _query_reading_list(db, filter, tag or None)
papers = query_reading_list(db, filter, tag or None)
all_tags = get_all_tags(db)
return templates.TemplateResponse(
@@ -126,54 +124,11 @@ def reading_list_page(
"current_tag": tag,
"all_tags": all_tags,
"page_title": "阅读列表",
"today": _today_str(),
"today": today_str(),
},
)
def _query_reading_list(
db: Session,
filter_type: str,
tag: str | None,
) -> list[Paper]:
"""根据筛选条件查询阅读列表。"""
from sqlalchemy import or_
# 基础:有任意用户数据的论文
base = db.query(Paper).filter(
or_(
Paper.bookmark.has(),
Paper.reading_status.has(),
Paper.note.has(),
)
)
# 应用筛选
if filter_type == "has_note":
base = base.filter(Paper.note.has())
elif filter_type in ("unread", "skimmed", "read_summary", "read_full"):
base = base.filter(
Paper.reading_status.has(UserReadingStatus.status == filter_type)
)
# 应用标签
if tag:
base = base.filter(Paper.tags.any(PaperTag.tag == tag))
return (
base.options(
joinedload(Paper.authors),
joinedload(Paper.tags),
joinedload(Paper.summary_status),
joinedload(Paper.bookmark),
joinedload(Paper.reading_status),
joinedload(Paper.note),
)
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
.all()
)
# ── RSS Feed ──────────────────────────────────────────────────────────
@@ -216,7 +171,7 @@ def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> st
lines.append(f" <title>{escape(channel_title)}</title>")
lines.append(f" <link>{escape(base_url)}</link>")
lines.append(" <description>HuggingFace Daily Papers — 中文论文导览站</description>")
lines.append(f" <language>zh-CN</language>")
lines.append(" <language>zh-CN</language>")
for paper in papers:
title_text = paper.title_zh or paper.title_en
@@ -245,12 +200,3 @@ def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> st
lines.append(" </channel>")
lines.append("</rss>")
return "\n".join(lines)
# ── 工具函数 ──────────────────────────────────────────────────────────
def _today_str() -> str:
"""当前日期字符串(按 APP_TIMEZONE)。"""
tz = ZoneInfo(settings.APP_TIMEZONE)
return datetime.now(tz).strftime("%Y-%m-%d")
+5 -92
View File
@@ -2,34 +2,27 @@
from __future__ import annotations
import logging
from datetime import date, timedelta
from fastapi import APIRouter, Depends, Request
from fastapi.templating import Jinja2Templates
from sqlalchemy import func, text
from sqlalchemy.orm import Session
from app.config import settings
from app.database import get_db
logger = logging.getLogger(__name__)
from app.services.trends import get_trends_data
from app.utils import templates, today_str
router = APIRouter()
templates = Jinja2Templates(directory="app/templates")
@router.get("/trends")
def trends_page(request: Request, db: Session = Depends(get_db)):
"""趋势看板页面。"""
stats = _get_trends_data(db)
stats = get_trends_data(db)
return templates.TemplateResponse(
request,
"trends.html",
{
"page_title": "趋势看板",
"stats": stats,
"today": _today_str(),
"today": today_str(),
},
)
@@ -37,84 +30,4 @@ def trends_page(request: Request, db: Session = Depends(get_db)):
@router.get("/api/stats/trends")
def trends_api(db: Session = Depends(get_db)):
"""趋势数据 JSON API。"""
return _get_trends_data(db)
def _get_trends_data(db: Session) -> dict:
"""从 DB 聚合趋势数据。"""
thirty_days_ago = (date.today() - timedelta(days=30)).isoformat()
# 1. 按日论文数量(近 30 天)
daily_rows = db.execute(text("""
SELECT paper_date, COUNT(*) as cnt
FROM papers
WHERE paper_date >= :start_date
GROUP BY paper_date
ORDER BY paper_date ASC
"""), {"start_date": thirty_days_ago}).fetchall()
daily_counts = [
{"date": str(row[0]), "count": row[1]}
for row in daily_rows
]
# 2. 热门标签 Top 20
tag_rows = db.execute(text("""
SELECT tag, COUNT(*) as cnt
FROM paper_tags
GROUP BY tag
ORDER BY cnt DESC
LIMIT 20
""")).fetchall()
top_tags = [
{"tag": row[0], "count": row[1]}
for row in tag_rows
]
# 3. Upvotes 分布
upvote_rows = db.execute(text("""
SELECT
CASE
WHEN upvotes >= 100 THEN '100+'
WHEN upvotes >= 50 THEN '50-99'
WHEN upvotes >= 20 THEN '20-49'
WHEN upvotes >= 10 THEN '10-19'
WHEN upvotes >= 5 THEN '5-9'
ELSE '0-4'
END as bucket,
COUNT(*) as cnt
FROM papers
GROUP BY bucket
ORDER BY MIN(upvotes) DESC
""")).fetchall()
upvotes_dist = [
{"range": row[0], "count": row[1]}
for row in upvote_rows
]
# 4. 总结完成率
summary_rows = db.execute(text("""
SELECT
COALESCE(ss.status, 'none') as status,
COUNT(*) as cnt
FROM papers p
LEFT JOIN summary_status ss ON ss.paper_id = p.id
GROUP BY status
""")).fetchall()
summary_completion = [
{"status": row[0], "count": row[1]}
for row in summary_rows
]
return {
"daily_counts": daily_counts,
"top_tags": top_tags,
"upvotes_dist": upvotes_dist,
"summary_completion": summary_completion,
}
def _today_str() -> str:
from datetime import datetime
from zoneinfo import ZoneInfo
tz = ZoneInfo(settings.APP_TIMEZONE)
return datetime.now(tz).strftime("%Y-%m-%d")
return get_trends_data(db)