21f16e6756
- Split summarizer into summary_generator and summary_persister modules - Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection - Add layout_detector service for PicoDet-S_layout_3cls integration - Add exceptions module with ConflictError and NotFoundError - Improve admin dashboard with better statistics and task management - Add design review document with system optimization suggestions - Add new tests for crawler, pdf_downloader, pipeline, and summary_utils - Update dependencies and configuration - Clean up dead code and improve error handling
206 lines
6.2 KiB
Python
206 lines
6.2 KiB
Python
"""公共工具 — 消除各模块间的重复代码。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from datetime import date, datetime, timedelta, timezone
|
||
from pathlib import Path
|
||
from typing import Any
|
||
from zoneinfo import ZoneInfo
|
||
|
||
import bleach
|
||
|
||
import httpx
|
||
from fastapi.templating import Jinja2Templates
|
||
|
||
from app.config import settings
|
||
|
||
# ── 路径常量 ──────────────────────────────────────────────────────────
|
||
|
||
DATA_DIR = Path("data")
|
||
PAPERS_DIR = DATA_DIR / "papers"
|
||
TMP_DIR = DATA_DIR / "tmp"
|
||
|
||
# ── 模板单例 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
class _Templates(Jinja2Templates):
|
||
"""自动注入 is_admin 到模板上下文的 Jinja2Templates 子类。"""
|
||
|
||
def TemplateResponse(self, request, name, context=None, **kwargs):
|
||
context = context or {}
|
||
context.setdefault("is_admin", request.session.get("is_admin", False))
|
||
return super().TemplateResponse(request, name, context, **kwargs)
|
||
|
||
|
||
templates = _Templates(directory="app/templates")
|
||
|
||
|
||
# ── 时区工具 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
def utc_now() -> datetime:
|
||
"""当前 UTC 时间(替代 datetime.now(timezone.utc) 的简写)。"""
|
||
return datetime.now(timezone.utc)
|
||
|
||
|
||
def today_str() -> str:
|
||
"""当前日期字符串(按 APP_TIMEZONE)。"""
|
||
tz = ZoneInfo(settings.APP_TIMEZONE)
|
||
return datetime.now(tz).strftime("%Y-%m-%d")
|
||
|
||
|
||
def yesterday_str() -> str:
|
||
"""昨天日期字符串(按 APP_TIMEZONE)。"""
|
||
tz = ZoneInfo(settings.APP_TIMEZONE)
|
||
yesterday = datetime.now(tz).date() - timedelta(days=1)
|
||
return yesterday.isoformat()
|
||
|
||
|
||
def recent_date_strs(n: int) -> list[str]:
|
||
"""最近 N 天的日期字符串列表(含今天,按 APP_TIMEZONE)。"""
|
||
tz = ZoneInfo(settings.APP_TIMEZONE)
|
||
today = datetime.now(tz).date()
|
||
return [(today - timedelta(days=i)).isoformat() for i in range(n)]
|
||
|
||
|
||
def latest_paper_date(db) -> str:
|
||
"""查询数据库中最新的 paper_date,无数据时回退到 today_str()。"""
|
||
from sqlalchemy import func, select
|
||
|
||
from app.models import Paper
|
||
|
||
result = db.scalar(select(func.max(Paper.paper_date)))
|
||
if result is not None:
|
||
return result.isoformat() if isinstance(result, date) else str(result)
|
||
return today_str()
|
||
|
||
|
||
# ── 锁释放 ────────────────────────────────────────────────────────────
|
||
|
||
|
||
def release_lock(db, lock) -> None:
|
||
"""释放 TaskLock。"""
|
||
try:
|
||
lock.status = "finished"
|
||
lock.released_at = utc_now()
|
||
db.commit()
|
||
except Exception:
|
||
db.rollback()
|
||
|
||
|
||
# ── HTTP 客户端工厂 ───────────────────────────────────────────────────
|
||
|
||
|
||
def make_http_client(
|
||
*, sync: bool = False, follow_redirects: bool = False, **kwargs
|
||
) -> httpx.AsyncClient | httpx.Client:
|
||
"""创建带 proxy 和默认配置的 httpx 客户端。
|
||
|
||
Args:
|
||
sync: True 返回同步 Client,False 返回 AsyncClient
|
||
follow_redirects: 是否跟随重定向
|
||
**kwargs: 覆盖默认参数
|
||
"""
|
||
defaults: dict = {
|
||
"timeout": settings.HTTP_TIMEOUT_SECONDS,
|
||
"headers": {"User-Agent": settings.HTTP_USER_AGENT},
|
||
"follow_redirects": follow_redirects,
|
||
}
|
||
if settings.http_proxy:
|
||
defaults["transport"] = (
|
||
httpx.HTTPTransport(proxy=settings.http_proxy)
|
||
if sync
|
||
else httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
|
||
)
|
||
defaults.update(kwargs)
|
||
|
||
if sync:
|
||
return httpx.Client(**defaults)
|
||
return httpx.AsyncClient(**defaults)
|
||
|
||
|
||
# ── JSON 安全解析 ──────────────────────────────────────────────────────
|
||
|
||
|
||
def safe_json_loads(text: str | None, default: Any = None) -> Any:
|
||
"""安全解析 JSON 字符串,解析失败返回 default 值(不会抛异常)。"""
|
||
if not text:
|
||
return default
|
||
try:
|
||
return json.loads(text)
|
||
except (json.JSONDecodeError, TypeError, ValueError):
|
||
return default
|
||
|
||
|
||
# ── HTML 清洗 ──────────────────────────────────────────────────────────
|
||
|
||
# AI 生成内容中允许的 HTML 标签和属性
|
||
_ALLOWED_TAGS = {
|
||
"p",
|
||
"br",
|
||
"strong",
|
||
"b",
|
||
"em",
|
||
"i",
|
||
"u",
|
||
"s",
|
||
"del",
|
||
"h3",
|
||
"h4",
|
||
"h5",
|
||
"h6",
|
||
"ul",
|
||
"ol",
|
||
"li",
|
||
"a",
|
||
"code",
|
||
"pre",
|
||
"blockquote",
|
||
"table",
|
||
"thead",
|
||
"tbody",
|
||
"tr",
|
||
"th",
|
||
"td",
|
||
"sup",
|
||
"sub",
|
||
"span",
|
||
}
|
||
_ALLOWED_ATTRS = {
|
||
"a": {"href", "title"},
|
||
"th": {"colspan", "rowspan"},
|
||
"td": {"colspan", "rowspan"},
|
||
"span": {"class"},
|
||
}
|
||
|
||
|
||
def sanitize_html(text: str | None) -> str:
|
||
"""清洗 AI 生成的 HTML,移除危险标签但保留安全的富文本。
|
||
|
||
- 移除: <script>, <iframe>, on* 事件属性, javascript: 链接
|
||
- 保留: 段落、加粗、列表、表格、链接等排印元素
|
||
"""
|
||
if not text:
|
||
return ""
|
||
cleaned = bleach.clean(
|
||
text,
|
||
tags=_ALLOWED_TAGS,
|
||
attributes=_ALLOWED_ATTRS,
|
||
strip=True,
|
||
)
|
||
return cleaned
|
||
|
||
|
||
# ── 错误消息截断 ────────────────────────────────────────────────────────
|
||
|
||
_ERROR_TRUNCATE_LIMIT = 500
|
||
|
||
|
||
def truncate_error(exc: Exception | str, limit: int = _ERROR_TRUNCATE_LIMIT) -> str:
|
||
"""将异常或字符串截断到指定长度,保持统一的错误消息格式。"""
|
||
text = str(exc)
|
||
if len(text) <= limit:
|
||
return text
|
||
return text[:limit] + f"... ({len(text)} chars total)"
|