refactor: restructure services and add image/pdf extraction utilities
- Add image_extractor, pdf_downloader, pi_client, trends services - Add shared utils module - Refactor summarizer, embedder, routes for cleaner separation - Update tests to match new service structure
This commit is contained in:
+9
-30
@@ -6,7 +6,6 @@ from datetime import date, datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from pydantic import BaseModel, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -17,10 +16,10 @@ from app.models import CrawlLog, DataDeleteJob, TaskLock
|
||||
from app.services.cleaner import cleanup_tmp, delete_papers_by_date_range
|
||||
from app.services.crawler import crawl_daily
|
||||
from app.services.summarizer import summarize_batch, summarize_single
|
||||
from app.utils import release_lock, templates, today_str
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
security = HTTPBearer()
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
|
||||
async def verify_admin(
|
||||
@@ -32,7 +31,7 @@ async def verify_admin(
|
||||
return credentials.credentials
|
||||
|
||||
|
||||
# ── 请求模型 ──────────────────────────────────────────────────────────────
|
||||
# ── 请求模型 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class DeleteRequest(BaseModel):
|
||||
@@ -49,7 +48,7 @@ class DeleteRequest(BaseModel):
|
||||
return v
|
||||
|
||||
|
||||
# ── 抓取 ──────────────────────────────────────────────────────────────────
|
||||
# ── 抓取 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/crawl")
|
||||
@@ -59,12 +58,7 @@ async def admin_crawl(
|
||||
date: str | None = Query(None, description="YYYY-MM-DD,默认今天"),
|
||||
):
|
||||
"""手动抓取指定日期,默认今天。"""
|
||||
# 计算 target_date
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
tz = ZoneInfo(settings.APP_TIMEZONE)
|
||||
today = datetime.now(tz).strftime("%Y-%m-%d")
|
||||
target_date = date or today
|
||||
target_date = date or today_str()
|
||||
|
||||
# TaskLock 防重入
|
||||
now = datetime.now(timezone.utc)
|
||||
@@ -88,10 +82,10 @@ async def admin_crawl(
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
finally:
|
||||
_release_lock(db, lock)
|
||||
release_lock(db, lock)
|
||||
|
||||
|
||||
# ── 总结 ──────────────────────────────────────────────────────────────────
|
||||
# ── 总结 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/summarize")
|
||||
@@ -119,7 +113,7 @@ async def admin_summarize_single(
|
||||
return result
|
||||
|
||||
|
||||
# ── 清理 ──────────────────────────────────────────────────────────────────
|
||||
# ── 清理 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/cleanup")
|
||||
@@ -155,7 +149,7 @@ async def admin_cleanup(
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
# ── 删除 ──────────────────────────────────────────────────────────────────
|
||||
# ── 删除 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/delete")
|
||||
@@ -177,7 +171,7 @@ async def admin_delete(
|
||||
return result
|
||||
|
||||
|
||||
# ── 日志 ──────────────────────────────────────────────────────────────────
|
||||
# ── 日志 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/logs")
|
||||
@@ -189,7 +183,6 @@ async def admin_logs(
|
||||
per_page: int = Query(20, ge=1, le=100),
|
||||
):
|
||||
"""查看任务日志(CrawlLog + DataDeleteJob)。"""
|
||||
# 查询 crawl_logs
|
||||
crawl_logs = (
|
||||
db.execute(
|
||||
select(CrawlLog)
|
||||
@@ -201,7 +194,6 @@ async def admin_logs(
|
||||
.all()
|
||||
)
|
||||
|
||||
# 查询 delete_jobs
|
||||
delete_jobs = (
|
||||
db.execute(
|
||||
select(DataDeleteJob)
|
||||
@@ -223,16 +215,3 @@ async def admin_logs(
|
||||
"per_page": per_page,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── 工具函数 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _release_lock(db: Session, lock: TaskLock) -> None:
|
||||
"""释放 TaskLock。"""
|
||||
try:
|
||||
lock.status = "finished"
|
||||
lock.released_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
|
||||
@@ -2,19 +2,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Paper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from app.utils import templates
|
||||
|
||||
router = APIRouter()
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
|
||||
@router.get("/compare")
|
||||
|
||||
+7
-21
@@ -3,34 +3,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import date, datetime, timedelta
|
||||
from datetime import date, timedelta
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.models import Paper
|
||||
from app.utils import templates, today_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
|
||||
def _today() -> str:
|
||||
tz = ZoneInfo(settings.APP_TIMEZONE)
|
||||
return datetime.now(tz).strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
@router.get("/")
|
||||
def index(request: Request):
|
||||
"""重定向到 /day/{today}。"""
|
||||
return RedirectResponse(url=f"/day/{_today()}")
|
||||
return RedirectResponse(url=f"/day/{today_str()}")
|
||||
|
||||
|
||||
@router.get("/day/{date_str}")
|
||||
@@ -43,7 +36,7 @@ def day_page(date_str: str, request: Request, db: Session = Depends(get_db)):
|
||||
|
||||
prev_day = (target - timedelta(days=1)).isoformat()
|
||||
next_day = (target + timedelta(days=1)).isoformat()
|
||||
today_str = _today()
|
||||
today = today_str()
|
||||
|
||||
papers = (
|
||||
db.query(Paper)
|
||||
@@ -74,7 +67,7 @@ def day_page(date_str: str, request: Request, db: Session = Depends(get_db)):
|
||||
"current_date": date_str,
|
||||
"prev_day": prev_day,
|
||||
"next_day": next_day,
|
||||
"today": today_str,
|
||||
"today": today,
|
||||
"available_dates": available_dates,
|
||||
"page_title": f"{date_str} 论文列表",
|
||||
},
|
||||
@@ -146,16 +139,9 @@ def _get_similar_papers(db: Session, arxiv_id: str, top_k: int = 6) -> list[dict
|
||||
return []
|
||||
|
||||
try:
|
||||
from app.services.embedder import search_similar
|
||||
|
||||
# 用论文的 arxiv_id 从 ChromaDB 查询
|
||||
col = None
|
||||
try:
|
||||
from app.services.embedder import get_collection
|
||||
col = get_collection()
|
||||
except Exception:
|
||||
return []
|
||||
from app.services.embedder import get_collection
|
||||
|
||||
col = get_collection()
|
||||
if col is None:
|
||||
return []
|
||||
|
||||
|
||||
+7
-61
@@ -2,14 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from zoneinfo import ZoneInfo
|
||||
from datetime import date, timedelta
|
||||
from xml.sax.saxutils import escape
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi.responses import Response
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
@@ -17,9 +14,10 @@ from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.models import Paper, PaperTag, UserReadingStatus
|
||||
from app.services.searcher import get_all_tags, search_papers
|
||||
from app.services.user_data import query_reading_list
|
||||
from app.utils import templates, today_str
|
||||
|
||||
router = APIRouter()
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
|
||||
# ── 搜索页 ────────────────────────────────────────────────────────────
|
||||
@@ -56,7 +54,7 @@ def search_page(
|
||||
"total_pages": result["total_pages"],
|
||||
"all_tags": all_tags,
|
||||
"page_title": f"搜索: {q}" if q else "搜索",
|
||||
"today": _today_str(),
|
||||
"today": today_str(),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -114,7 +112,7 @@ def reading_list_page(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""阅读列表页面。"""
|
||||
papers = _query_reading_list(db, filter, tag or None)
|
||||
papers = query_reading_list(db, filter, tag or None)
|
||||
all_tags = get_all_tags(db)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
@@ -126,54 +124,11 @@ def reading_list_page(
|
||||
"current_tag": tag,
|
||||
"all_tags": all_tags,
|
||||
"page_title": "阅读列表",
|
||||
"today": _today_str(),
|
||||
"today": today_str(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _query_reading_list(
|
||||
db: Session,
|
||||
filter_type: str,
|
||||
tag: str | None,
|
||||
) -> list[Paper]:
|
||||
"""根据筛选条件查询阅读列表。"""
|
||||
from sqlalchemy import or_
|
||||
|
||||
# 基础:有任意用户数据的论文
|
||||
base = db.query(Paper).filter(
|
||||
or_(
|
||||
Paper.bookmark.has(),
|
||||
Paper.reading_status.has(),
|
||||
Paper.note.has(),
|
||||
)
|
||||
)
|
||||
|
||||
# 应用筛选
|
||||
if filter_type == "has_note":
|
||||
base = base.filter(Paper.note.has())
|
||||
elif filter_type in ("unread", "skimmed", "read_summary", "read_full"):
|
||||
base = base.filter(
|
||||
Paper.reading_status.has(UserReadingStatus.status == filter_type)
|
||||
)
|
||||
|
||||
# 应用标签
|
||||
if tag:
|
||||
base = base.filter(Paper.tags.any(PaperTag.tag == tag))
|
||||
|
||||
return (
|
||||
base.options(
|
||||
joinedload(Paper.authors),
|
||||
joinedload(Paper.tags),
|
||||
joinedload(Paper.summary_status),
|
||||
joinedload(Paper.bookmark),
|
||||
joinedload(Paper.reading_status),
|
||||
joinedload(Paper.note),
|
||||
)
|
||||
.order_by(Paper.paper_date.desc(), Paper.upvotes.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
# ── RSS Feed ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -216,7 +171,7 @@ def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> st
|
||||
lines.append(f" <title>{escape(channel_title)}</title>")
|
||||
lines.append(f" <link>{escape(base_url)}</link>")
|
||||
lines.append(" <description>HuggingFace Daily Papers — 中文论文导览站</description>")
|
||||
lines.append(f" <language>zh-CN</language>")
|
||||
lines.append(" <language>zh-CN</language>")
|
||||
|
||||
for paper in papers:
|
||||
title_text = paper.title_zh or paper.title_en
|
||||
@@ -245,12 +200,3 @@ def _generate_rss_xml(papers: list[Paper], base_url: str, tag: str | None) -> st
|
||||
lines.append(" </channel>")
|
||||
lines.append("</rss>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ── 工具函数 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _today_str() -> str:
|
||||
"""当前日期字符串(按 APP_TIMEZONE)。"""
|
||||
tz = ZoneInfo(settings.APP_TIMEZONE)
|
||||
return datetime.now(tz).strftime("%Y-%m-%d")
|
||||
|
||||
+5
-92
@@ -2,34 +2,27 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import date, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy import func, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from app.services.trends import get_trends_data
|
||||
from app.utils import templates, today_str
|
||||
|
||||
router = APIRouter()
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
|
||||
@router.get("/trends")
|
||||
def trends_page(request: Request, db: Session = Depends(get_db)):
|
||||
"""趋势看板页面。"""
|
||||
stats = _get_trends_data(db)
|
||||
stats = get_trends_data(db)
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"trends.html",
|
||||
{
|
||||
"page_title": "趋势看板",
|
||||
"stats": stats,
|
||||
"today": _today_str(),
|
||||
"today": today_str(),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -37,84 +30,4 @@ def trends_page(request: Request, db: Session = Depends(get_db)):
|
||||
@router.get("/api/stats/trends")
|
||||
def trends_api(db: Session = Depends(get_db)):
|
||||
"""趋势数据 JSON API。"""
|
||||
return _get_trends_data(db)
|
||||
|
||||
|
||||
def _get_trends_data(db: Session) -> dict:
|
||||
"""从 DB 聚合趋势数据。"""
|
||||
thirty_days_ago = (date.today() - timedelta(days=30)).isoformat()
|
||||
|
||||
# 1. 按日论文数量(近 30 天)
|
||||
daily_rows = db.execute(text("""
|
||||
SELECT paper_date, COUNT(*) as cnt
|
||||
FROM papers
|
||||
WHERE paper_date >= :start_date
|
||||
GROUP BY paper_date
|
||||
ORDER BY paper_date ASC
|
||||
"""), {"start_date": thirty_days_ago}).fetchall()
|
||||
daily_counts = [
|
||||
{"date": str(row[0]), "count": row[1]}
|
||||
for row in daily_rows
|
||||
]
|
||||
|
||||
# 2. 热门标签 Top 20
|
||||
tag_rows = db.execute(text("""
|
||||
SELECT tag, COUNT(*) as cnt
|
||||
FROM paper_tags
|
||||
GROUP BY tag
|
||||
ORDER BY cnt DESC
|
||||
LIMIT 20
|
||||
""")).fetchall()
|
||||
top_tags = [
|
||||
{"tag": row[0], "count": row[1]}
|
||||
for row in tag_rows
|
||||
]
|
||||
|
||||
# 3. Upvotes 分布
|
||||
upvote_rows = db.execute(text("""
|
||||
SELECT
|
||||
CASE
|
||||
WHEN upvotes >= 100 THEN '100+'
|
||||
WHEN upvotes >= 50 THEN '50-99'
|
||||
WHEN upvotes >= 20 THEN '20-49'
|
||||
WHEN upvotes >= 10 THEN '10-19'
|
||||
WHEN upvotes >= 5 THEN '5-9'
|
||||
ELSE '0-4'
|
||||
END as bucket,
|
||||
COUNT(*) as cnt
|
||||
FROM papers
|
||||
GROUP BY bucket
|
||||
ORDER BY MIN(upvotes) DESC
|
||||
""")).fetchall()
|
||||
upvotes_dist = [
|
||||
{"range": row[0], "count": row[1]}
|
||||
for row in upvote_rows
|
||||
]
|
||||
|
||||
# 4. 总结完成率
|
||||
summary_rows = db.execute(text("""
|
||||
SELECT
|
||||
COALESCE(ss.status, 'none') as status,
|
||||
COUNT(*) as cnt
|
||||
FROM papers p
|
||||
LEFT JOIN summary_status ss ON ss.paper_id = p.id
|
||||
GROUP BY status
|
||||
""")).fetchall()
|
||||
summary_completion = [
|
||||
{"status": row[0], "count": row[1]}
|
||||
for row in summary_rows
|
||||
]
|
||||
|
||||
return {
|
||||
"daily_counts": daily_counts,
|
||||
"top_tags": top_tags,
|
||||
"upvotes_dist": upvotes_dist,
|
||||
"summary_completion": summary_completion,
|
||||
}
|
||||
|
||||
|
||||
def _today_str() -> str:
|
||||
from datetime import datetime
|
||||
from zoneinfo import ZoneInfo
|
||||
tz = ZoneInfo(settings.APP_TIMEZONE)
|
||||
return datetime.now(tz).strftime("%Y-%m-%d")
|
||||
return get_trends_data(db)
|
||||
|
||||
Reference in New Issue
Block a user